mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-03-30 18:35:59 +00:00
3017 lines
132 KiB
Haskell
3017 lines
132 KiB
Haskell
{-# LANGUAGE ConstraintKinds #-}
|
|
{-# LANGUAGE DataKinds #-}
|
|
{-# LANGUAGE DeriveAnyClass #-}
|
|
{-# LANGUAGE DuplicateRecordFields #-}
|
|
{-# LANGUAGE FlexibleContexts #-}
|
|
{-# LANGUAGE FlexibleInstances #-}
|
|
{-# LANGUAGE GADTs #-}
|
|
{-# LANGUAGE InstanceSigs #-}
|
|
{-# LANGUAGE LambdaCase #-}
|
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
|
{-# LANGUAGE NamedFieldPuns #-}
|
|
{-# LANGUAGE NumericUnderscores #-}
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
{-# LANGUAGE PatternSynonyms #-}
|
|
{-# LANGUAGE QuasiQuotes #-}
|
|
{-# LANGUAGE RecordWildCards #-}
|
|
{-# LANGUAGE ScopedTypeVariables #-}
|
|
{-# LANGUAGE TemplateHaskell #-}
|
|
{-# LANGUAGE TupleSections #-}
|
|
{-# LANGUAGE TypeOperators #-}
|
|
{-# LANGUAGE UndecidableInstances #-}
|
|
{-# OPTIONS_GHC -fno-warn-ambiguous-fields #-}
|
|
{-# OPTIONS_GHC -fno-warn-orphans #-}
|
|
|
|
module Simplex.Messaging.Agent.Store.SQLite
|
|
( SQLiteStore (..),
|
|
MigrationConfirmation (..),
|
|
MigrationError (..),
|
|
UpMigration (..),
|
|
createSQLiteStore,
|
|
connectSQLiteStore,
|
|
closeSQLiteStore,
|
|
openSQLiteStore,
|
|
reopenSQLiteStore,
|
|
sqlString,
|
|
keyString,
|
|
storeKey,
|
|
execSQL,
|
|
upMigration, -- used in tests
|
|
|
|
-- * Users
|
|
createUserRecord,
|
|
deleteUserRecord,
|
|
setUserDeleted,
|
|
deleteUserWithoutConns,
|
|
deleteUsersWithoutConns,
|
|
checkUser,
|
|
|
|
-- * Queues and connections
|
|
createNewConn,
|
|
updateNewConnRcv,
|
|
updateNewConnSnd,
|
|
createSndConn,
|
|
getConn,
|
|
getDeletedConn,
|
|
getConns,
|
|
getDeletedConns,
|
|
getConnData,
|
|
setConnDeleted,
|
|
setConnAgentVersion,
|
|
getDeletedConnIds,
|
|
getDeletedWaitingDeliveryConnIds,
|
|
setConnRatchetSync,
|
|
addProcessedRatchetKeyHash,
|
|
checkRatchetKeyHashExists,
|
|
deleteRatchetKeyHashesExpired,
|
|
getRcvConn,
|
|
getRcvQueueById,
|
|
getSndQueueById,
|
|
deleteConn,
|
|
upgradeRcvConnToDuplex,
|
|
upgradeSndConnToDuplex,
|
|
addConnRcvQueue,
|
|
addConnSndQueue,
|
|
setRcvQueueStatus,
|
|
setRcvSwitchStatus,
|
|
setRcvQueueDeleted,
|
|
setRcvQueueConfirmedE2E,
|
|
setSndQueueStatus,
|
|
setSndSwitchStatus,
|
|
setRcvQueuePrimary,
|
|
setSndQueuePrimary,
|
|
deleteConnRcvQueue,
|
|
incRcvDeleteErrors,
|
|
deleteConnSndQueue,
|
|
getPrimaryRcvQueue,
|
|
getRcvQueue,
|
|
getDeletedRcvQueue,
|
|
setRcvQueueNtfCreds,
|
|
-- Confirmations
|
|
createConfirmation,
|
|
acceptConfirmation,
|
|
getAcceptedConfirmation,
|
|
removeConfirmations,
|
|
-- Invitations - sent via Contact connections
|
|
setConnectionVersion,
|
|
createInvitation,
|
|
getInvitation,
|
|
acceptInvitation,
|
|
unacceptInvitation,
|
|
deleteInvitation,
|
|
-- Messages
|
|
updateRcvIds,
|
|
createRcvMsg,
|
|
updateSndIds,
|
|
createSndMsg,
|
|
createSndMsgDelivery,
|
|
getSndMsgViaRcpt,
|
|
updateSndMsgRcpt,
|
|
getPendingQueueMsg,
|
|
updatePendingMsgRIState,
|
|
deletePendingMsgs,
|
|
getExpiredSndMessages,
|
|
setMsgUserAck,
|
|
getRcvMsg,
|
|
getLastMsg,
|
|
checkRcvMsgHashExists,
|
|
deleteMsg,
|
|
deleteDeliveredSndMsg,
|
|
deleteSndMsgDelivery,
|
|
deleteRcvMsgHashesExpired,
|
|
deleteSndMsgsExpired,
|
|
-- Double ratchet persistence
|
|
createRatchetX3dhKeys,
|
|
getRatchetX3dhKeys,
|
|
setRatchetX3dhKeys,
|
|
createRatchet,
|
|
deleteRatchet,
|
|
getRatchet,
|
|
getSkippedMsgKeys,
|
|
updateRatchet,
|
|
-- Async commands
|
|
createCommand,
|
|
getPendingCommandServers,
|
|
getPendingServerCommand,
|
|
deleteCommand,
|
|
-- Notification device token persistence
|
|
createNtfToken,
|
|
getSavedNtfToken,
|
|
updateNtfTokenRegistration,
|
|
updateDeviceToken,
|
|
updateNtfMode,
|
|
updateNtfToken,
|
|
removeNtfToken,
|
|
-- Notification subscription persistence
|
|
getNtfSubscription,
|
|
createNtfSubscription,
|
|
supervisorUpdateNtfSub,
|
|
supervisorUpdateNtfAction,
|
|
updateNtfSubscription,
|
|
setNullNtfSubscriptionAction,
|
|
deleteNtfSubscription,
|
|
getNextNtfSubNTFAction,
|
|
markNtfSubActionNtfFailed_, -- exported for tests
|
|
getNextNtfSubSMPAction,
|
|
markNtfSubActionSMPFailed_, -- exported for tests
|
|
getActiveNtfToken,
|
|
getNtfRcvQueue,
|
|
setConnectionNtfs,
|
|
|
|
-- * File transfer
|
|
|
|
-- Rcv files
|
|
createRcvFile,
|
|
createRcvFileRedirect,
|
|
getRcvFile,
|
|
getRcvFileByEntityId,
|
|
getRcvFileRedirects,
|
|
updateRcvChunkReplicaDelay,
|
|
updateRcvFileChunkReceived,
|
|
updateRcvFileStatus,
|
|
updateRcvFileError,
|
|
updateRcvFileComplete,
|
|
updateRcvFileRedirect,
|
|
updateRcvFileNoTmpPath,
|
|
updateRcvFileDeleted,
|
|
deleteRcvFile',
|
|
getNextRcvChunkToDownload,
|
|
getNextRcvFileToDecrypt,
|
|
getPendingRcvFilesServers,
|
|
getCleanupRcvFilesTmpPaths,
|
|
getCleanupRcvFilesDeleted,
|
|
getRcvFilesExpired,
|
|
-- Snd files
|
|
createSndFile,
|
|
getSndFile,
|
|
getSndFileByEntityId,
|
|
getNextSndFileToPrepare,
|
|
updateSndFileError,
|
|
updateSndFileStatus,
|
|
updateSndFileEncrypted,
|
|
updateSndFileComplete,
|
|
updateSndFileNoPrefixPath,
|
|
updateSndFileDeleted,
|
|
deleteSndFile',
|
|
getSndFileDeleted,
|
|
createSndFileReplica,
|
|
createSndFileReplica_, -- exported for tests
|
|
getNextSndChunkToUpload,
|
|
updateSndChunkReplicaDelay,
|
|
addSndChunkReplicaRecipients,
|
|
updateSndChunkReplicaStatus,
|
|
getPendingSndFilesServers,
|
|
getCleanupSndFilesPrefixPaths,
|
|
getCleanupSndFilesDeleted,
|
|
getSndFilesExpired,
|
|
createDeletedSndChunkReplica,
|
|
getNextDeletedSndChunkReplica,
|
|
updateDeletedSndChunkReplicaDelay,
|
|
deleteDeletedSndChunkReplica,
|
|
getPendingDelFilesServers,
|
|
deleteDeletedSndChunkReplicasExpired,
|
|
|
|
-- * utilities
|
|
withConnection,
|
|
withTransaction,
|
|
withTransactionCtx,
|
|
firstRow,
|
|
firstRow',
|
|
maybeFirstRow,
|
|
)
|
|
where
|
|
|
|
import Control.Monad
|
|
import Control.Monad.Except
|
|
import Control.Monad.IO.Class
|
|
import Crypto.Random (ChaChaDRG)
|
|
import qualified Data.Aeson.TH as J
|
|
import qualified Data.Attoparsec.ByteString.Char8 as A
|
|
import Data.Bifunctor (first, second)
|
|
import Data.ByteArray (ScrubbedBytes)
|
|
import qualified Data.ByteArray as BA
|
|
import Data.ByteString (ByteString)
|
|
import qualified Data.ByteString.Base64.URL as U
|
|
import qualified Data.ByteString.Char8 as B
|
|
import Data.Char (toLower)
|
|
import Data.Functor (($>))
|
|
import Data.IORef
|
|
import Data.Int (Int64)
|
|
import Data.List (foldl', intercalate, sortBy)
|
|
import Data.List.NonEmpty (NonEmpty (..))
|
|
import qualified Data.List.NonEmpty as L
|
|
import qualified Data.Map.Strict as M
|
|
import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing, listToMaybe)
|
|
import Data.Ord (Down (..))
|
|
import Data.Text (Text)
|
|
import qualified Data.Text as T
|
|
import Data.Text.Encoding (decodeLatin1, encodeUtf8)
|
|
import Data.Time.Clock (NominalDiffTime, UTCTime, addUTCTime, getCurrentTime)
|
|
import Data.Word (Word32)
|
|
import Database.SQLite.Simple (FromRow (..), NamedParam (..), Only (..), Query (..), SQLError, ToRow (..), field, (:.) (..))
|
|
import qualified Database.SQLite.Simple as SQL
|
|
import Database.SQLite.Simple.FromField
|
|
import Database.SQLite.Simple.QQ (sql)
|
|
import Database.SQLite.Simple.ToField (ToField (..))
|
|
import qualified Database.SQLite3 as SQLite3
|
|
import Network.Socket (ServiceName)
|
|
import Simplex.FileTransfer.Client (XFTPChunkSpec (..))
|
|
import Simplex.FileTransfer.Description
|
|
import Simplex.FileTransfer.Protocol (FileParty (..), SFileParty (..))
|
|
import Simplex.FileTransfer.Types
|
|
import Simplex.Messaging.Agent.Protocol
|
|
import Simplex.Messaging.Agent.RetryInterval (RI2State (..))
|
|
import Simplex.Messaging.Agent.Store
|
|
import Simplex.Messaging.Agent.Store.SQLite.Common
|
|
import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB
|
|
import Simplex.Messaging.Agent.Store.SQLite.Migrations (DownMigration (..), MTRError, Migration (..), MigrationsToRun (..), mtrErrorDescription)
|
|
import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations
|
|
import qualified Simplex.Messaging.Crypto as C
|
|
import Simplex.Messaging.Crypto.File (CryptoFile (..), CryptoFileArgs (..))
|
|
import Simplex.Messaging.Crypto.Ratchet (RatchetX448, SkippedMsgDiff (..), SkippedMsgKeys)
|
|
import qualified Simplex.Messaging.Crypto.Ratchet as CR
|
|
import Simplex.Messaging.Encoding
|
|
import Simplex.Messaging.Encoding.String
|
|
import Simplex.Messaging.Notifications.Protocol (DeviceToken (..), NtfSubscriptionId, NtfTknStatus (..), NtfTokenId, SMPQueueNtf (..))
|
|
import Simplex.Messaging.Notifications.Types
|
|
import Simplex.Messaging.Parsers (blobFieldParser, defaultJSON, dropPrefix, fromTextField_, sumTypeJSON)
|
|
import Simplex.Messaging.Protocol
|
|
import qualified Simplex.Messaging.Protocol as SMP
|
|
import Simplex.Messaging.Transport.Client (TransportHost)
|
|
import Simplex.Messaging.Util (bshow, catchAllErrors, eitherToMaybe, ifM, safeDecodeUtf8, ($>>=), (<$$>))
|
|
import Simplex.Messaging.Version.Internal
|
|
import System.Directory (copyFile, createDirectoryIfMissing, doesFileExist)
|
|
import System.Exit (exitFailure)
|
|
import System.FilePath (takeDirectory)
|
|
import System.IO (hFlush, stdout)
|
|
import UnliftIO.Exception (bracketOnError, onException)
|
|
import qualified UnliftIO.Exception as E
|
|
import UnliftIO.STM
|
|
|
|
-- * SQLite Store implementation
|
|
|
|
data MigrationError
|
|
= MEUpgrade {upMigrations :: [UpMigration]}
|
|
| MEDowngrade {downMigrations :: [String]}
|
|
| MigrationError {mtrError :: MTRError}
|
|
deriving (Eq, Show)
|
|
|
|
migrationErrorDescription :: MigrationError -> String
|
|
migrationErrorDescription = \case
|
|
MEUpgrade ums ->
|
|
"The app has a newer version than the database.\nConfirm to back up and upgrade using these migrations: " <> intercalate ", " (map upName ums)
|
|
MEDowngrade dms ->
|
|
"Database version is newer than the app.\nConfirm to back up and downgrade using these migrations: " <> intercalate ", " dms
|
|
MigrationError err -> mtrErrorDescription err
|
|
|
|
data UpMigration = UpMigration {upName :: String, withDown :: Bool}
|
|
deriving (Eq, Show)
|
|
|
|
upMigration :: Migration -> UpMigration
|
|
upMigration Migration {name, down} = UpMigration name $ isJust down
|
|
|
|
data MigrationConfirmation = MCYesUp | MCYesUpDown | MCConsole | MCError
|
|
deriving (Eq, Show)
|
|
|
|
instance StrEncoding MigrationConfirmation where
|
|
strEncode = \case
|
|
MCYesUp -> "yesUp"
|
|
MCYesUpDown -> "yesUpDown"
|
|
MCConsole -> "console"
|
|
MCError -> "error"
|
|
strP =
|
|
A.takeByteString >>= \case
|
|
"yesUp" -> pure MCYesUp
|
|
"yesUpDown" -> pure MCYesUpDown
|
|
"console" -> pure MCConsole
|
|
"error" -> pure MCError
|
|
_ -> fail "invalid MigrationConfirmation"
|
|
|
|
createSQLiteStore :: FilePath -> ScrubbedBytes -> Bool -> [Migration] -> MigrationConfirmation -> IO (Either MigrationError SQLiteStore)
|
|
createSQLiteStore dbFilePath dbKey keepKey migrations confirmMigrations = do
|
|
let dbDir = takeDirectory dbFilePath
|
|
createDirectoryIfMissing True dbDir
|
|
st <- connectSQLiteStore dbFilePath dbKey keepKey
|
|
r <- migrateSchema st migrations confirmMigrations `onException` closeSQLiteStore st
|
|
case r of
|
|
Right () -> pure $ Right st
|
|
Left e -> closeSQLiteStore st $> Left e
|
|
|
|
migrateSchema :: SQLiteStore -> [Migration] -> MigrationConfirmation -> IO (Either MigrationError ())
|
|
migrateSchema st migrations confirmMigrations = do
|
|
Migrations.initialize st
|
|
Migrations.get st migrations >>= \case
|
|
Left e -> do
|
|
when (confirmMigrations == MCConsole) $ confirmOrExit ("Database state error: " <> mtrErrorDescription e)
|
|
pure . Left $ MigrationError e
|
|
Right MTRNone -> pure $ Right ()
|
|
Right ms@(MTRUp ums)
|
|
| dbNew st -> Migrations.run st ms $> Right ()
|
|
| otherwise -> case confirmMigrations of
|
|
MCYesUp -> run ms
|
|
MCYesUpDown -> run ms
|
|
MCConsole -> confirm err >> run ms
|
|
MCError -> pure $ Left err
|
|
where
|
|
err = MEUpgrade $ map upMigration ums -- "The app has a newer version than the database.\nConfirm to back up and upgrade using these migrations: " <> intercalate ", " (map name ums)
|
|
Right ms@(MTRDown dms) -> case confirmMigrations of
|
|
MCYesUpDown -> run ms
|
|
MCConsole -> confirm err >> run ms
|
|
MCYesUp -> pure $ Left err
|
|
MCError -> pure $ Left err
|
|
where
|
|
err = MEDowngrade $ map downName dms
|
|
where
|
|
confirm err = confirmOrExit $ migrationErrorDescription err
|
|
run ms = do
|
|
let f = dbFilePath st
|
|
copyFile f (f <> ".bak")
|
|
Migrations.run st ms
|
|
pure $ Right ()
|
|
|
|
confirmOrExit :: String -> IO ()
|
|
confirmOrExit s = do
|
|
putStrLn s
|
|
putStr "Continue (y/N): "
|
|
hFlush stdout
|
|
ok <- getLine
|
|
when (map toLower ok /= "y") exitFailure
|
|
|
|
connectSQLiteStore :: FilePath -> ScrubbedBytes -> Bool -> IO SQLiteStore
|
|
connectSQLiteStore dbFilePath key keepKey = do
|
|
dbNew <- not <$> doesFileExist dbFilePath
|
|
dbConn <- dbBusyLoop (connectDB dbFilePath key)
|
|
atomically $ do
|
|
dbConnection <- newTMVar dbConn
|
|
dbKey <- newTVar $! storeKey key keepKey
|
|
dbClosed <- newTVar False
|
|
pure SQLiteStore {dbFilePath, dbKey, dbConnection, dbNew, dbClosed}
|
|
|
|
connectDB :: FilePath -> ScrubbedBytes -> IO DB.Connection
|
|
connectDB path key = do
|
|
db <- DB.open path
|
|
prepare db `onException` DB.close db
|
|
-- _printPragmas db path
|
|
pure db
|
|
where
|
|
prepare db = do
|
|
let exec = SQLite3.exec $ SQL.connectionHandle $ DB.conn db
|
|
unless (BA.null key) . exec $ "PRAGMA key = " <> keyString key <> ";"
|
|
exec . fromQuery $
|
|
[sql|
|
|
PRAGMA busy_timeout = 100;
|
|
PRAGMA foreign_keys = ON;
|
|
-- PRAGMA trusted_schema = OFF;
|
|
PRAGMA secure_delete = ON;
|
|
PRAGMA auto_vacuum = FULL;
|
|
|]
|
|
|
|
closeSQLiteStore :: SQLiteStore -> IO ()
|
|
closeSQLiteStore st@SQLiteStore {dbClosed} =
|
|
ifM (readTVarIO dbClosed) (putStrLn "closeSQLiteStore: already closed") $
|
|
withConnection st $ \conn -> do
|
|
DB.close conn
|
|
atomically $ writeTVar dbClosed True
|
|
|
|
openSQLiteStore :: SQLiteStore -> ScrubbedBytes -> Bool -> IO ()
|
|
openSQLiteStore st@SQLiteStore {dbClosed} key keepKey =
|
|
ifM (readTVarIO dbClosed) (openSQLiteStore_ st key keepKey) (putStrLn "openSQLiteStore: already opened")
|
|
|
|
openSQLiteStore_ :: SQLiteStore -> ScrubbedBytes -> Bool -> IO ()
|
|
openSQLiteStore_ SQLiteStore {dbConnection, dbFilePath, dbKey, dbClosed} key keepKey =
|
|
bracketOnError
|
|
(atomically $ takeTMVar dbConnection)
|
|
(atomically . tryPutTMVar dbConnection)
|
|
$ \DB.Connection {slow} -> do
|
|
DB.Connection {conn} <- connectDB dbFilePath key
|
|
atomically $ do
|
|
putTMVar dbConnection DB.Connection {conn, slow}
|
|
writeTVar dbClosed False
|
|
writeTVar dbKey $! storeKey key keepKey
|
|
|
|
reopenSQLiteStore :: SQLiteStore -> IO ()
|
|
reopenSQLiteStore st@SQLiteStore {dbKey, dbClosed} =
|
|
ifM (readTVarIO dbClosed) open (putStrLn "reopenSQLiteStore: already opened")
|
|
where
|
|
open =
|
|
readTVarIO dbKey >>= \case
|
|
Just key -> openSQLiteStore_ st key True
|
|
Nothing -> fail "reopenSQLiteStore: no key"
|
|
|
|
keyString :: ScrubbedBytes -> Text
|
|
keyString = sqlString . safeDecodeUtf8 . BA.convert
|
|
|
|
sqlString :: Text -> Text
|
|
sqlString s = quote <> T.replace quote "''" s <> quote
|
|
where
|
|
quote = "'"
|
|
|
|
-- _printPragmas :: DB.Connection -> FilePath -> IO ()
|
|
-- _printPragmas db path = do
|
|
-- foreign_keys <- DB.query_ db "PRAGMA foreign_keys;" :: IO [[Int]]
|
|
-- print $ path <> " foreign_keys: " <> show foreign_keys
|
|
-- -- when run via sqlite-simple query for trusted_schema seems to return empty list
|
|
-- trusted_schema <- DB.query_ db "PRAGMA trusted_schema;" :: IO [[Int]]
|
|
-- print $ path <> " trusted_schema: " <> show trusted_schema
|
|
-- secure_delete <- DB.query_ db "PRAGMA secure_delete;" :: IO [[Int]]
|
|
-- print $ path <> " secure_delete: " <> show secure_delete
|
|
-- auto_vacuum <- DB.query_ db "PRAGMA auto_vacuum;" :: IO [[Int]]
|
|
-- print $ path <> " auto_vacuum: " <> show auto_vacuum
|
|
|
|
execSQL :: DB.Connection -> Text -> IO [Text]
|
|
execSQL db query = do
|
|
rs <- newIORef []
|
|
SQLite3.execWithCallback (SQL.connectionHandle $ DB.conn db) query (addSQLResultRow rs)
|
|
reverse <$> readIORef rs
|
|
|
|
addSQLResultRow :: IORef [Text] -> SQLite3.ColumnIndex -> [Text] -> [Maybe Text] -> IO ()
|
|
addSQLResultRow rs _count names values = modifyIORef' rs $ \case
|
|
[] -> [showValues values, T.intercalate "|" names]
|
|
rs' -> showValues values : rs'
|
|
where
|
|
showValues = T.intercalate "|" . map (fromMaybe "")
|
|
|
|
checkConstraint :: StoreError -> IO (Either StoreError a) -> IO (Either StoreError a)
|
|
checkConstraint err action = action `E.catch` (pure . Left . handleSQLError err)
|
|
|
|
handleSQLError :: StoreError -> SQLError -> StoreError
|
|
handleSQLError err e
|
|
| SQL.sqlError e == SQL.ErrorConstraint = err
|
|
| otherwise = SEInternal $ bshow e
|
|
|
|
createUserRecord :: DB.Connection -> IO UserId
|
|
createUserRecord db = do
|
|
DB.execute_ db "INSERT INTO users DEFAULT VALUES"
|
|
insertedRowId db
|
|
|
|
checkUser :: DB.Connection -> UserId -> IO (Either StoreError ())
|
|
checkUser db userId =
|
|
firstRow (\(_ :: Only Int64) -> ()) SEUserNotFound $
|
|
DB.query db "SELECT user_id FROM users WHERE user_id = ? AND deleted = ?" (userId, False)
|
|
|
|
deleteUserRecord :: DB.Connection -> UserId -> IO (Either StoreError ())
|
|
deleteUserRecord db userId = runExceptT $ do
|
|
ExceptT $ checkUser db userId
|
|
liftIO $ DB.execute db "DELETE FROM users WHERE user_id = ?" (Only userId)
|
|
|
|
setUserDeleted :: DB.Connection -> UserId -> IO (Either StoreError [ConnId])
|
|
setUserDeleted db userId = runExceptT $ do
|
|
ExceptT $ checkUser db userId
|
|
liftIO $ do
|
|
DB.execute db "UPDATE users SET deleted = ? WHERE user_id = ?" (True, userId)
|
|
map fromOnly <$> DB.query db "SELECT conn_id FROM connections WHERE user_id = ?" (Only userId)
|
|
|
|
deleteUserWithoutConns :: DB.Connection -> UserId -> IO Bool
|
|
deleteUserWithoutConns db userId = do
|
|
userId_ :: Maybe Int64 <-
|
|
maybeFirstRow fromOnly $
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT user_id FROM users u
|
|
WHERE u.user_id = ?
|
|
AND u.deleted = ?
|
|
AND NOT EXISTS (SELECT c.conn_id FROM connections c WHERE c.user_id = u.user_id)
|
|
|]
|
|
(userId, True)
|
|
case userId_ of
|
|
Just _ -> DB.execute db "DELETE FROM users WHERE user_id = ?" (Only userId) $> True
|
|
_ -> pure False
|
|
|
|
deleteUsersWithoutConns :: DB.Connection -> IO [Int64]
|
|
deleteUsersWithoutConns db = do
|
|
userIds <-
|
|
map fromOnly
|
|
<$> DB.query
|
|
db
|
|
[sql|
|
|
SELECT user_id FROM users u
|
|
WHERE u.deleted = ?
|
|
AND NOT EXISTS (SELECT c.conn_id FROM connections c WHERE c.user_id = u.user_id)
|
|
|]
|
|
(Only True)
|
|
forM_ userIds $ DB.execute db "DELETE FROM users WHERE user_id = ?" . Only
|
|
pure userIds
|
|
|
|
createConn_ ::
|
|
TVar ChaChaDRG ->
|
|
ConnData ->
|
|
(ConnId -> IO a) ->
|
|
IO (Either StoreError (ConnId, a))
|
|
createConn_ gVar cData create = checkConstraint SEConnDuplicate $ case cData of
|
|
ConnData {connId = ""} -> createWithRandomId' gVar create
|
|
ConnData {connId} -> Right . (connId,) <$> create connId
|
|
|
|
createNewConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SConnectionMode c -> IO (Either StoreError ConnId)
|
|
createNewConn db gVar cData cMode = do
|
|
fst <$$> createConn_ gVar cData (\connId -> createConnRecord db connId cData cMode)
|
|
|
|
updateNewConnRcv :: DB.Connection -> ConnId -> NewRcvQueue -> IO (Either StoreError RcvQueue)
|
|
updateNewConnRcv db connId rq =
|
|
getConn db connId $>>= \case
|
|
(SomeConn _ NewConnection {}) -> updateConn
|
|
(SomeConn _ RcvConnection {}) -> updateConn -- to allow retries
|
|
(SomeConn c _) -> pure . Left . SEBadConnType $ connType c
|
|
where
|
|
updateConn :: IO (Either StoreError RcvQueue)
|
|
updateConn = Right <$> addConnRcvQueue_ db connId rq
|
|
|
|
updateNewConnSnd :: DB.Connection -> ConnId -> NewSndQueue -> IO (Either StoreError SndQueue)
|
|
updateNewConnSnd db connId sq =
|
|
getConn db connId $>>= \case
|
|
(SomeConn _ NewConnection {}) -> updateConn
|
|
(SomeConn _ SndConnection {}) -> updateConn -- to allow retries
|
|
(SomeConn c _) -> pure . Left . SEBadConnType $ connType c
|
|
where
|
|
updateConn :: IO (Either StoreError SndQueue)
|
|
updateConn = Right <$> addConnSndQueue_ db connId sq
|
|
|
|
createSndConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> NewSndQueue -> IO (Either StoreError (ConnId, SndQueue))
|
|
createSndConn db gVar cData q@SndQueue {server} =
|
|
-- check confirmed snd queue doesn't already exist, to prevent it being deleted by REPLACE in insertSndQueue_
|
|
ifM (liftIO $ checkConfirmedSndQueueExists_ db q) (pure $ Left SESndQueueExists) $
|
|
createConn_ gVar cData $ \connId -> do
|
|
serverKeyHash_ <- createServer_ db server
|
|
createConnRecord db connId cData SCMInvitation
|
|
insertSndQueue_ db connId q serverKeyHash_
|
|
|
|
createConnRecord :: DB.Connection -> ConnId -> ConnData -> SConnectionMode c -> IO ()
|
|
createConnRecord db connId ConnData {userId, connAgentVersion, enableNtfs, pqEncryption} cMode =
|
|
DB.execute
|
|
db
|
|
[sql|
|
|
INSERT INTO connections
|
|
(user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, pq_encryption, duplex_handshake) VALUES (?,?,?,?,?,?,?)
|
|
|]
|
|
(userId, connId, cMode, connAgentVersion, enableNtfs, pqEncryption, True)
|
|
|
|
checkConfirmedSndQueueExists_ :: DB.Connection -> NewSndQueue -> IO Bool
|
|
checkConfirmedSndQueueExists_ db SndQueue {server, sndId} = do
|
|
fromMaybe False
|
|
<$> maybeFirstRow
|
|
fromOnly
|
|
( DB.query
|
|
db
|
|
"SELECT 1 FROM snd_queues WHERE host = ? AND port = ? AND snd_id = ? AND status != ? LIMIT 1"
|
|
(host server, port server, sndId, New)
|
|
)
|
|
|
|
getRcvConn :: DB.Connection -> SMPServer -> SMP.RecipientId -> IO (Either StoreError (RcvQueue, SomeConn))
|
|
getRcvConn db ProtocolServer {host, port} rcvId = runExceptT $ do
|
|
rq@RcvQueue {connId} <-
|
|
ExceptT . firstRow toRcvQueue SEConnNotFound $
|
|
DB.query db (rcvQueueQuery <> " WHERE q.host = ? AND q.port = ? AND q.rcv_id = ? AND q.deleted = 0") (host, port, rcvId)
|
|
(rq,) <$> ExceptT (getConn db connId)
|
|
|
|
-- | Deletes connection, optionally checking for pending snd message deliveries; returns connection id if it was deleted
|
|
deleteConn :: DB.Connection -> Maybe NominalDiffTime -> ConnId -> IO (Maybe ConnId)
|
|
deleteConn db waitDeliveryTimeout_ connId = case waitDeliveryTimeout_ of
|
|
Nothing -> delete
|
|
Just timeout ->
|
|
ifM
|
|
checkNoPendingDeliveries_
|
|
delete
|
|
( ifM
|
|
(checkWaitDeliveryTimeout_ timeout)
|
|
delete
|
|
(pure Nothing)
|
|
)
|
|
where
|
|
delete = DB.execute db "DELETE FROM connections WHERE conn_id = ?" (Only connId) $> Just connId
|
|
checkNoPendingDeliveries_ = do
|
|
r :: (Maybe Int64) <-
|
|
maybeFirstRow fromOnly $
|
|
DB.query db "SELECT 1 FROM snd_message_deliveries WHERE conn_id = ? AND failed = 0 LIMIT 1" (Only connId)
|
|
pure $ isNothing r
|
|
checkWaitDeliveryTimeout_ timeout = do
|
|
cutoffTs <- addUTCTime (-timeout) <$> getCurrentTime
|
|
r :: (Maybe Int64) <-
|
|
maybeFirstRow fromOnly $
|
|
DB.query db "SELECT 1 FROM connections WHERE conn_id = ? AND deleted_at_wait_delivery < ? LIMIT 1" (connId, cutoffTs)
|
|
pure $ isJust r
|
|
|
|
upgradeRcvConnToDuplex :: DB.Connection -> ConnId -> NewSndQueue -> IO (Either StoreError SndQueue)
|
|
upgradeRcvConnToDuplex db connId sq =
|
|
getConn db connId $>>= \case
|
|
(SomeConn _ RcvConnection {}) -> Right <$> addConnSndQueue_ db connId sq
|
|
(SomeConn c _) -> pure . Left . SEBadConnType $ connType c
|
|
|
|
upgradeSndConnToDuplex :: DB.Connection -> ConnId -> NewRcvQueue -> IO (Either StoreError RcvQueue)
|
|
upgradeSndConnToDuplex db connId rq =
|
|
getConn db connId >>= \case
|
|
Right (SomeConn _ SndConnection {}) -> Right <$> addConnRcvQueue_ db connId rq
|
|
Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c
|
|
_ -> pure $ Left SEConnNotFound
|
|
|
|
addConnRcvQueue :: DB.Connection -> ConnId -> NewRcvQueue -> IO (Either StoreError RcvQueue)
|
|
addConnRcvQueue db connId rq =
|
|
getConn db connId >>= \case
|
|
Right (SomeConn _ DuplexConnection {}) -> Right <$> addConnRcvQueue_ db connId rq
|
|
Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c
|
|
_ -> pure $ Left SEConnNotFound
|
|
|
|
addConnRcvQueue_ :: DB.Connection -> ConnId -> NewRcvQueue -> IO RcvQueue
|
|
addConnRcvQueue_ db connId rq@RcvQueue {server} = do
|
|
serverKeyHash_ <- createServer_ db server
|
|
insertRcvQueue_ db connId rq serverKeyHash_
|
|
|
|
addConnSndQueue :: DB.Connection -> ConnId -> NewSndQueue -> IO (Either StoreError SndQueue)
|
|
addConnSndQueue db connId sq =
|
|
getConn db connId >>= \case
|
|
Right (SomeConn _ DuplexConnection {}) -> Right <$> addConnSndQueue_ db connId sq
|
|
Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c
|
|
_ -> pure $ Left SEConnNotFound
|
|
|
|
addConnSndQueue_ :: DB.Connection -> ConnId -> NewSndQueue -> IO SndQueue
|
|
addConnSndQueue_ db connId sq@SndQueue {server} = do
|
|
serverKeyHash_ <- createServer_ db server
|
|
insertSndQueue_ db connId sq serverKeyHash_
|
|
|
|
setRcvQueueStatus :: DB.Connection -> RcvQueue -> QueueStatus -> IO ()
|
|
setRcvQueueStatus db RcvQueue {rcvId, server = ProtocolServer {host, port}} status =
|
|
-- ? return error if queue does not exist?
|
|
DB.executeNamed
|
|
db
|
|
[sql|
|
|
UPDATE rcv_queues
|
|
SET status = :status
|
|
WHERE host = :host AND port = :port AND rcv_id = :rcv_id;
|
|
|]
|
|
[":status" := status, ":host" := host, ":port" := port, ":rcv_id" := rcvId]
|
|
|
|
setRcvSwitchStatus :: DB.Connection -> RcvQueue -> Maybe RcvSwitchStatus -> IO RcvQueue
|
|
setRcvSwitchStatus db rq@RcvQueue {rcvId, server = ProtocolServer {host, port}} rcvSwchStatus = do
|
|
DB.execute
|
|
db
|
|
[sql|
|
|
UPDATE rcv_queues
|
|
SET switch_status = ?
|
|
WHERE host = ? AND port = ? AND rcv_id = ?
|
|
|]
|
|
(rcvSwchStatus, host, port, rcvId)
|
|
pure rq {rcvSwchStatus}
|
|
|
|
setRcvQueueDeleted :: DB.Connection -> RcvQueue -> IO ()
|
|
setRcvQueueDeleted db RcvQueue {rcvId, server = ProtocolServer {host, port}} = do
|
|
DB.execute
|
|
db
|
|
[sql|
|
|
UPDATE rcv_queues
|
|
SET deleted = 1
|
|
WHERE host = ? AND port = ? AND rcv_id = ?
|
|
|]
|
|
(host, port, rcvId)
|
|
|
|
setRcvQueueConfirmedE2E :: DB.Connection -> RcvQueue -> C.DhSecretX25519 -> VersionSMPC -> IO ()
|
|
setRcvQueueConfirmedE2E db RcvQueue {rcvId, server = ProtocolServer {host, port}} e2eDhSecret smpClientVersion =
|
|
DB.executeNamed
|
|
db
|
|
[sql|
|
|
UPDATE rcv_queues
|
|
SET e2e_dh_secret = :e2e_dh_secret,
|
|
status = :status,
|
|
smp_client_version = :smp_client_version
|
|
WHERE host = :host AND port = :port AND rcv_id = :rcv_id
|
|
|]
|
|
[ ":status" := Confirmed,
|
|
":e2e_dh_secret" := e2eDhSecret,
|
|
":smp_client_version" := smpClientVersion,
|
|
":host" := host,
|
|
":port" := port,
|
|
":rcv_id" := rcvId
|
|
]
|
|
|
|
setSndQueueStatus :: DB.Connection -> SndQueue -> QueueStatus -> IO ()
|
|
setSndQueueStatus db SndQueue {sndId, server = ProtocolServer {host, port}} status =
|
|
-- ? return error if queue does not exist?
|
|
DB.executeNamed
|
|
db
|
|
[sql|
|
|
UPDATE snd_queues
|
|
SET status = :status
|
|
WHERE host = :host AND port = :port AND snd_id = :snd_id;
|
|
|]
|
|
[":status" := status, ":host" := host, ":port" := port, ":snd_id" := sndId]
|
|
|
|
setSndSwitchStatus :: DB.Connection -> SndQueue -> Maybe SndSwitchStatus -> IO SndQueue
|
|
setSndSwitchStatus db sq@SndQueue {sndId, server = ProtocolServer {host, port}} sndSwchStatus = do
|
|
DB.execute
|
|
db
|
|
[sql|
|
|
UPDATE snd_queues
|
|
SET switch_status = ?
|
|
WHERE host = ? AND port = ? AND snd_id = ?
|
|
|]
|
|
(sndSwchStatus, host, port, sndId)
|
|
pure sq {sndSwchStatus}
|
|
|
|
setRcvQueuePrimary :: DB.Connection -> ConnId -> RcvQueue -> IO ()
|
|
setRcvQueuePrimary db connId RcvQueue {dbQueueId} = do
|
|
DB.execute db "UPDATE rcv_queues SET rcv_primary = ? WHERE conn_id = ?" (False, connId)
|
|
DB.execute
|
|
db
|
|
"UPDATE rcv_queues SET rcv_primary = ?, replace_rcv_queue_id = ? WHERE conn_id = ? AND rcv_queue_id = ?"
|
|
(True, Nothing :: Maybe Int64, connId, dbQueueId)
|
|
|
|
setSndQueuePrimary :: DB.Connection -> ConnId -> SndQueue -> IO ()
|
|
setSndQueuePrimary db connId SndQueue {dbQueueId} = do
|
|
DB.execute db "UPDATE snd_queues SET snd_primary = ? WHERE conn_id = ?" (False, connId)
|
|
DB.execute
|
|
db
|
|
"UPDATE snd_queues SET snd_primary = ?, replace_snd_queue_id = ? WHERE conn_id = ? AND snd_queue_id = ?"
|
|
(True, Nothing :: Maybe Int64, connId, dbQueueId)
|
|
|
|
incRcvDeleteErrors :: DB.Connection -> RcvQueue -> IO ()
|
|
incRcvDeleteErrors db RcvQueue {connId, dbQueueId} =
|
|
DB.execute db "UPDATE rcv_queues SET delete_errors = delete_errors + 1 WHERE conn_id = ? AND rcv_queue_id = ?" (connId, dbQueueId)
|
|
|
|
deleteConnRcvQueue :: DB.Connection -> RcvQueue -> IO ()
|
|
deleteConnRcvQueue db RcvQueue {connId, dbQueueId} =
|
|
DB.execute db "DELETE FROM rcv_queues WHERE conn_id = ? AND rcv_queue_id = ?" (connId, dbQueueId)
|
|
|
|
deleteConnSndQueue :: DB.Connection -> ConnId -> SndQueue -> IO ()
|
|
deleteConnSndQueue db connId SndQueue {dbQueueId} = do
|
|
DB.execute db "DELETE FROM snd_queues WHERE conn_id = ? AND snd_queue_id = ?" (connId, dbQueueId)
|
|
DB.execute db "DELETE FROM snd_message_deliveries WHERE conn_id = ? AND snd_queue_id = ?" (connId, dbQueueId)
|
|
|
|
getPrimaryRcvQueue :: DB.Connection -> ConnId -> IO (Either StoreError RcvQueue)
|
|
getPrimaryRcvQueue db connId =
|
|
maybe (Left SEConnNotFound) (Right . L.head) <$> getRcvQueuesByConnId_ db connId
|
|
|
|
getRcvQueue :: DB.Connection -> ConnId -> SMPServer -> SMP.RecipientId -> IO (Either StoreError RcvQueue)
|
|
getRcvQueue db connId (SMPServer host port _) rcvId =
|
|
firstRow toRcvQueue SEConnNotFound $
|
|
DB.query db (rcvQueueQuery <> "WHERE q.conn_id = ? AND q.host = ? AND q.port = ? AND q.rcv_id = ? AND q.deleted = 0") (connId, host, port, rcvId)
|
|
|
|
getDeletedRcvQueue :: DB.Connection -> ConnId -> SMPServer -> SMP.RecipientId -> IO (Either StoreError RcvQueue)
|
|
getDeletedRcvQueue db connId (SMPServer host port _) rcvId =
|
|
firstRow toRcvQueue SEConnNotFound $
|
|
DB.query db (rcvQueueQuery <> "WHERE q.conn_id = ? AND q.host = ? AND q.port = ? AND q.rcv_id = ? AND q.deleted = 1") (connId, host, port, rcvId)
|
|
|
|
setRcvQueueNtfCreds :: DB.Connection -> ConnId -> Maybe ClientNtfCreds -> IO ()
|
|
setRcvQueueNtfCreds db connId clientNtfCreds =
|
|
DB.execute
|
|
db
|
|
[sql|
|
|
UPDATE rcv_queues
|
|
SET ntf_public_key = ?, ntf_private_key = ?, ntf_id = ?, rcv_ntf_dh_secret = ?
|
|
WHERE conn_id = ?
|
|
|]
|
|
(ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_, connId)
|
|
where
|
|
(ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) = case clientNtfCreds of
|
|
Just ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret} -> (Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret)
|
|
Nothing -> (Nothing, Nothing, Nothing, Nothing)
|
|
|
|
type SMPConfirmationRow = (SndPublicAuthKey, C.PublicKeyX25519, ConnInfo, Maybe [SMPQueueInfo], Maybe VersionSMPC)
|
|
|
|
smpConfirmation :: SMPConfirmationRow -> SMPConfirmation
|
|
smpConfirmation (senderKey, e2ePubKey, connInfo, smpReplyQueues_, smpClientVersion_) =
|
|
SMPConfirmation
|
|
{ senderKey,
|
|
e2ePubKey,
|
|
connInfo,
|
|
smpReplyQueues = fromMaybe [] smpReplyQueues_,
|
|
smpClientVersion = fromMaybe initialSMPClientVersion smpClientVersion_
|
|
}
|
|
|
|
createConfirmation :: DB.Connection -> TVar ChaChaDRG -> NewConfirmation -> IO (Either StoreError ConfirmationId)
|
|
createConfirmation db gVar NewConfirmation {connId, senderConf = SMPConfirmation {senderKey, e2ePubKey, connInfo, smpReplyQueues, smpClientVersion}, ratchetState} =
|
|
createWithRandomId gVar $ \confirmationId ->
|
|
DB.execute
|
|
db
|
|
[sql|
|
|
INSERT INTO conn_confirmations
|
|
(confirmation_id, conn_id, sender_key, e2e_snd_pub_key, ratchet_state, sender_conn_info, smp_reply_queues, smp_client_version, accepted) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 0);
|
|
|]
|
|
(confirmationId, connId, senderKey, e2ePubKey, ratchetState, connInfo, smpReplyQueues, smpClientVersion)
|
|
|
|
acceptConfirmation :: DB.Connection -> ConfirmationId -> ConnInfo -> IO (Either StoreError AcceptedConfirmation)
|
|
acceptConfirmation db confirmationId ownConnInfo = do
|
|
DB.executeNamed
|
|
db
|
|
[sql|
|
|
UPDATE conn_confirmations
|
|
SET accepted = 1,
|
|
own_conn_info = :own_conn_info
|
|
WHERE confirmation_id = :confirmation_id;
|
|
|]
|
|
[ ":own_conn_info" := ownConnInfo,
|
|
":confirmation_id" := confirmationId
|
|
]
|
|
firstRow confirmation SEConfirmationNotFound $
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT conn_id, ratchet_state, sender_key, e2e_snd_pub_key, sender_conn_info, smp_reply_queues, smp_client_version
|
|
FROM conn_confirmations
|
|
WHERE confirmation_id = ?;
|
|
|]
|
|
(Only confirmationId)
|
|
where
|
|
confirmation ((connId, ratchetState) :. confRow) =
|
|
AcceptedConfirmation
|
|
{ confirmationId,
|
|
connId,
|
|
senderConf = smpConfirmation confRow,
|
|
ratchetState,
|
|
ownConnInfo
|
|
}
|
|
|
|
getAcceptedConfirmation :: DB.Connection -> ConnId -> IO (Either StoreError AcceptedConfirmation)
|
|
getAcceptedConfirmation db connId =
|
|
firstRow confirmation SEConfirmationNotFound $
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT confirmation_id, ratchet_state, own_conn_info, sender_key, e2e_snd_pub_key, sender_conn_info, smp_reply_queues, smp_client_version
|
|
FROM conn_confirmations
|
|
WHERE conn_id = ? AND accepted = 1;
|
|
|]
|
|
(Only connId)
|
|
where
|
|
confirmation ((confirmationId, ratchetState, ownConnInfo) :. confRow) =
|
|
AcceptedConfirmation
|
|
{ confirmationId,
|
|
connId,
|
|
senderConf = smpConfirmation confRow,
|
|
ratchetState,
|
|
ownConnInfo
|
|
}
|
|
|
|
removeConfirmations :: DB.Connection -> ConnId -> IO ()
|
|
removeConfirmations db connId =
|
|
DB.executeNamed
|
|
db
|
|
[sql|
|
|
DELETE FROM conn_confirmations
|
|
WHERE conn_id = :conn_id;
|
|
|]
|
|
[":conn_id" := connId]
|
|
|
|
setConnectionVersion :: DB.Connection -> ConnId -> VersionSMPA -> IO ()
|
|
setConnectionVersion db connId aVersion =
|
|
DB.execute db "UPDATE connections SET smp_agent_version = ? WHERE conn_id = ?" (aVersion, connId)
|
|
|
|
createInvitation :: DB.Connection -> TVar ChaChaDRG -> NewInvitation -> IO (Either StoreError InvitationId)
|
|
createInvitation db gVar NewInvitation {contactConnId, connReq, recipientConnInfo} =
|
|
createWithRandomId gVar $ \invitationId ->
|
|
DB.execute
|
|
db
|
|
[sql|
|
|
INSERT INTO conn_invitations
|
|
(invitation_id, contact_conn_id, cr_invitation, recipient_conn_info, accepted) VALUES (?, ?, ?, ?, 0);
|
|
|]
|
|
(invitationId, contactConnId, connReq, recipientConnInfo)
|
|
|
|
getInvitation :: DB.Connection -> InvitationId -> IO (Either StoreError Invitation)
|
|
getInvitation db invitationId =
|
|
firstRow invitation SEInvitationNotFound $
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT contact_conn_id, cr_invitation, recipient_conn_info, own_conn_info, accepted
|
|
FROM conn_invitations
|
|
WHERE invitation_id = ?
|
|
AND accepted = 0
|
|
|]
|
|
(Only invitationId)
|
|
where
|
|
invitation (contactConnId, connReq, recipientConnInfo, ownConnInfo, accepted) =
|
|
Invitation {invitationId, contactConnId, connReq, recipientConnInfo, ownConnInfo, accepted}
|
|
|
|
acceptInvitation :: DB.Connection -> InvitationId -> ConnInfo -> IO ()
|
|
acceptInvitation db invitationId ownConnInfo =
|
|
DB.executeNamed
|
|
db
|
|
[sql|
|
|
UPDATE conn_invitations
|
|
SET accepted = 1,
|
|
own_conn_info = :own_conn_info
|
|
WHERE invitation_id = :invitation_id
|
|
|]
|
|
[ ":own_conn_info" := ownConnInfo,
|
|
":invitation_id" := invitationId
|
|
]
|
|
|
|
unacceptInvitation :: DB.Connection -> InvitationId -> IO ()
|
|
unacceptInvitation db invitationId =
|
|
DB.execute db "UPDATE conn_invitations SET accepted = 0, own_conn_info = NULL WHERE invitation_id = ?" (Only invitationId)
|
|
|
|
deleteInvitation :: DB.Connection -> ConnId -> InvitationId -> IO (Either StoreError ())
|
|
deleteInvitation db contactConnId invId =
|
|
getConn db contactConnId $>>= \case
|
|
SomeConn SCContact _ ->
|
|
Right <$> DB.execute db "DELETE FROM conn_invitations WHERE contact_conn_id = ? AND invitation_id = ?" (contactConnId, invId)
|
|
_ -> pure $ Left SEConnNotFound
|
|
|
|
updateRcvIds :: DB.Connection -> ConnId -> IO (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
|
|
updateRcvIds db connId = do
|
|
(lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash) <- retrieveLastIdsAndHashRcv_ db connId
|
|
let internalId = InternalId $ unId lastInternalId + 1
|
|
internalRcvId = InternalRcvId $ unRcvId lastInternalRcvId + 1
|
|
updateLastIdsRcv_ db connId internalId internalRcvId
|
|
pure (internalId, internalRcvId, lastExternalSndId, lastRcvHash)
|
|
|
|
createRcvMsg :: DB.Connection -> ConnId -> RcvQueue -> RcvMsgData -> IO ()
|
|
createRcvMsg db connId rq rcvMsgData = do
|
|
insertRcvMsgBase_ db connId rcvMsgData
|
|
insertRcvMsgDetails_ db connId rq rcvMsgData
|
|
updateHashRcv_ db connId rcvMsgData
|
|
|
|
updateSndIds :: DB.Connection -> ConnId -> IO (InternalId, InternalSndId, PrevSndMsgHash)
|
|
updateSndIds db connId = do
|
|
(lastInternalId, lastInternalSndId, prevSndHash) <- retrieveLastIdsAndHashSnd_ db connId
|
|
let internalId = InternalId $ unId lastInternalId + 1
|
|
internalSndId = InternalSndId $ unSndId lastInternalSndId + 1
|
|
updateLastIdsSnd_ db connId internalId internalSndId
|
|
pure (internalId, internalSndId, prevSndHash)
|
|
|
|
createSndMsg :: DB.Connection -> ConnId -> SndMsgData -> IO ()
|
|
createSndMsg db connId sndMsgData = do
|
|
insertSndMsgBase_ db connId sndMsgData
|
|
insertSndMsgDetails_ db connId sndMsgData
|
|
updateHashSnd_ db connId sndMsgData
|
|
|
|
createSndMsgDelivery :: DB.Connection -> ConnId -> SndQueue -> InternalId -> IO ()
|
|
createSndMsgDelivery db connId SndQueue {dbQueueId} msgId =
|
|
DB.execute db "INSERT INTO snd_message_deliveries (conn_id, snd_queue_id, internal_id) VALUES (?, ?, ?)" (connId, dbQueueId, msgId)
|
|
|
|
getSndMsgViaRcpt :: DB.Connection -> ConnId -> InternalSndId -> IO (Either StoreError SndMsg)
|
|
getSndMsgViaRcpt db connId sndMsgId =
|
|
firstRow toSndMsg SEMsgNotFound $
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT s.internal_id, m.msg_type, s.internal_hash, s.rcpt_internal_id, s.rcpt_status
|
|
FROM snd_messages s
|
|
JOIN messages m ON s.conn_id = m.conn_id AND s.internal_id = m.internal_id
|
|
WHERE s.conn_id = ? AND s.internal_snd_id = ?
|
|
|]
|
|
(connId, sndMsgId)
|
|
where
|
|
toSndMsg :: (InternalId, AgentMessageType, MsgHash, Maybe AgentMsgId, Maybe MsgReceiptStatus) -> SndMsg
|
|
toSndMsg (internalId, msgType, internalHash, rcptInternalId_, rcptStatus_) =
|
|
let msgReceipt = MsgReceipt <$> rcptInternalId_ <*> rcptStatus_
|
|
in SndMsg {internalId, internalSndId = sndMsgId, msgType, internalHash, msgReceipt}
|
|
|
|
updateSndMsgRcpt :: DB.Connection -> ConnId -> InternalSndId -> MsgReceipt -> IO ()
|
|
updateSndMsgRcpt db connId sndMsgId MsgReceipt {agentMsgId, msgRcptStatus} =
|
|
DB.execute
|
|
db
|
|
"UPDATE snd_messages SET rcpt_internal_id = ?, rcpt_status = ? WHERE conn_id = ? AND internal_snd_id = ?"
|
|
(agentMsgId, msgRcptStatus, connId, sndMsgId)
|
|
|
|
getPendingQueueMsg :: DB.Connection -> ConnId -> SndQueue -> IO (Either StoreError (Maybe (Maybe RcvQueue, PendingMsgData)))
|
|
getPendingQueueMsg db connId SndQueue {dbQueueId} =
|
|
getWorkItem "message" getMsgId getMsgData markMsgFailed
|
|
where
|
|
getMsgId :: IO (Maybe InternalId)
|
|
getMsgId =
|
|
maybeFirstRow fromOnly $
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT internal_id
|
|
FROM snd_message_deliveries d
|
|
WHERE conn_id = ? AND snd_queue_id = ? AND failed = 0
|
|
ORDER BY internal_id ASC
|
|
LIMIT 1
|
|
|]
|
|
(connId, dbQueueId)
|
|
getMsgData :: InternalId -> IO (Either StoreError (Maybe RcvQueue, PendingMsgData))
|
|
getMsgData msgId = runExceptT $ do
|
|
msg <- ExceptT $ firstRow pendingMsgData err getMsgData_
|
|
rq_ <- liftIO $ L.head <$$> getRcvQueuesByConnId_ db connId
|
|
pure (rq_, msg)
|
|
where
|
|
getMsgData_ =
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT m.msg_type, m.msg_flags, m.msg_body, m.pq_encryption, m.internal_ts, s.retry_int_slow, s.retry_int_fast
|
|
FROM messages m
|
|
JOIN snd_messages s ON s.conn_id = m.conn_id AND s.internal_id = m.internal_id
|
|
WHERE m.conn_id = ? AND m.internal_id = ?
|
|
|]
|
|
(connId, msgId)
|
|
err = SEInternal $ "msg delivery " <> bshow msgId <> " returned []"
|
|
pendingMsgData :: (AgentMessageType, Maybe MsgFlags, MsgBody, CR.PQEncryption, InternalTs, Maybe Int64, Maybe Int64) -> PendingMsgData
|
|
pendingMsgData (msgType, msgFlags_, msgBody, pqEncryption, internalTs, riSlow_, riFast_) =
|
|
let msgFlags = fromMaybe SMP.noMsgFlags msgFlags_
|
|
msgRetryState = RI2State <$> riSlow_ <*> riFast_
|
|
in PendingMsgData {msgId, msgType, msgFlags, msgBody, pqEncryption, msgRetryState, internalTs}
|
|
markMsgFailed msgId = DB.execute db "UPDATE snd_message_deliveries SET failed = 1 WHERE conn_id = ? AND internal_id = ?" (connId, msgId)
|
|
|
|
getWorkItem :: Show i => ByteString -> IO (Maybe i) -> (i -> IO (Either StoreError a)) -> (i -> IO ()) -> IO (Either StoreError (Maybe a))
|
|
getWorkItem itemName getId getItem markFailed =
|
|
runExceptT $ handleErr "getId" getId >>= mapM tryGetItem
|
|
where
|
|
tryGetItem itemId = ExceptT (getItem itemId) `catchStoreErrors` \e -> mark itemId >> throwError e
|
|
mark itemId = handleErr ("markFailed ID " <> bshow itemId) $ markFailed itemId
|
|
catchStoreErrors = catchAllErrors (SEInternal . bshow)
|
|
-- Errors caught by this function will suspend worker as if there is no more work,
|
|
handleErr :: ByteString -> IO a -> ExceptT StoreError IO a
|
|
handleErr opName action = ExceptT $ first mkError <$> E.try action
|
|
where
|
|
mkError :: E.SomeException -> StoreError
|
|
mkError e = SEWorkItemError $ itemName <> " " <> opName <> " error: " <> bshow e
|
|
|
|
updatePendingMsgRIState :: DB.Connection -> ConnId -> InternalId -> RI2State -> IO ()
|
|
updatePendingMsgRIState db connId msgId RI2State {slowInterval, fastInterval} =
|
|
DB.execute db "UPDATE snd_messages SET retry_int_slow = ?, retry_int_fast = ? WHERE conn_id = ? AND internal_id = ?" (slowInterval, fastInterval, connId, msgId)
|
|
|
|
deletePendingMsgs :: DB.Connection -> ConnId -> SndQueue -> IO ()
|
|
deletePendingMsgs db connId SndQueue {dbQueueId} =
|
|
DB.execute db "DELETE FROM snd_message_deliveries WHERE conn_id = ? AND snd_queue_id = ?" (connId, dbQueueId)
|
|
|
|
getExpiredSndMessages :: DB.Connection -> ConnId -> SndQueue -> UTCTime -> IO [InternalId]
|
|
getExpiredSndMessages db connId SndQueue {dbQueueId} expireTs = do
|
|
-- type is Maybe InternalId because MAX always returns one row, possibly with NULL value
|
|
maxId :: [Maybe InternalId] <-
|
|
map fromOnly
|
|
<$> DB.query
|
|
db
|
|
[sql|
|
|
SELECT MAX(internal_id)
|
|
FROM messages
|
|
WHERE conn_id = ? AND internal_snd_id IS NOT NULL AND internal_ts < ?
|
|
|]
|
|
(connId, expireTs)
|
|
case maxId of
|
|
Just msgId : _ ->
|
|
map fromOnly
|
|
<$> DB.query
|
|
db
|
|
[sql|
|
|
SELECT internal_id
|
|
FROM snd_message_deliveries
|
|
WHERE conn_id = ? AND snd_queue_id = ? AND failed = 0 AND internal_id <= ?
|
|
ORDER BY internal_id ASC
|
|
|]
|
|
(connId, dbQueueId, msgId)
|
|
_ -> pure []
|
|
|
|
setMsgUserAck :: DB.Connection -> ConnId -> InternalId -> IO (Either StoreError (RcvQueue, SMP.MsgId))
|
|
setMsgUserAck db connId agentMsgId = runExceptT $ do
|
|
(dbRcvId, srvMsgId) <-
|
|
ExceptT . firstRow id SEMsgNotFound $
|
|
DB.query db "SELECT rcv_queue_id, broker_id FROM rcv_messages WHERE conn_id = ? AND internal_id = ?" (connId, agentMsgId)
|
|
rq <- ExceptT $ getRcvQueueById db connId dbRcvId
|
|
liftIO $ DB.execute db "UPDATE rcv_messages SET user_ack = ? WHERE conn_id = ? AND internal_id = ?" (True, connId, agentMsgId)
|
|
pure (rq, srvMsgId)
|
|
|
|
getRcvMsg :: DB.Connection -> ConnId -> InternalId -> IO (Either StoreError RcvMsg)
|
|
getRcvMsg db connId agentMsgId =
|
|
firstRow toRcvMsg SEMsgNotFound $
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT
|
|
r.internal_id, m.internal_ts, r.broker_id, r.broker_ts, r.external_snd_id, r.integrity, r.internal_hash,
|
|
m.msg_type, m.msg_body, m.pq_encryption, s.internal_id, s.rcpt_status, r.user_ack
|
|
FROM rcv_messages r
|
|
JOIN messages m ON r.conn_id = m.conn_id AND r.internal_id = m.internal_id
|
|
LEFT JOIN snd_messages s ON s.conn_id = r.conn_id AND s.rcpt_internal_id = r.internal_id
|
|
WHERE r.conn_id = ? AND r.internal_id = ?
|
|
|]
|
|
(connId, agentMsgId)
|
|
|
|
getLastMsg :: DB.Connection -> ConnId -> SMP.MsgId -> IO (Maybe RcvMsg)
|
|
getLastMsg db connId msgId =
|
|
maybeFirstRow toRcvMsg $
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT
|
|
r.internal_id, m.internal_ts, r.broker_id, r.broker_ts, r.external_snd_id, r.integrity, r.internal_hash,
|
|
m.msg_type, m.msg_body, m.pq_encryption, s.internal_id, s.rcpt_status, r.user_ack
|
|
FROM rcv_messages r
|
|
JOIN messages m ON r.conn_id = m.conn_id AND r.internal_id = m.internal_id
|
|
JOIN connections c ON r.conn_id = c.conn_id AND c.last_internal_msg_id = r.internal_id
|
|
LEFT JOIN snd_messages s ON s.conn_id = r.conn_id AND s.rcpt_internal_id = r.internal_id
|
|
WHERE r.conn_id = ? AND r.broker_id = ?
|
|
|]
|
|
(connId, msgId)
|
|
|
|
toRcvMsg :: (Int64, InternalTs, BrokerId, BrokerTs) :. (AgentMsgId, MsgIntegrity, MsgHash, AgentMessageType, MsgBody, CR.PQEncryption, Maybe AgentMsgId, Maybe MsgReceiptStatus, Bool) -> RcvMsg
|
|
toRcvMsg ((agentMsgId, internalTs, brokerId, brokerTs) :. (sndMsgId, integrity, internalHash, msgType, msgBody, pqEncryption, rcptInternalId_, rcptStatus_, userAck)) =
|
|
let msgMeta = MsgMeta {recipient = (agentMsgId, internalTs), broker = (brokerId, brokerTs), sndMsgId, integrity, pqEncryption}
|
|
msgReceipt = MsgReceipt <$> rcptInternalId_ <*> rcptStatus_
|
|
in RcvMsg {internalId = InternalId agentMsgId, msgMeta, msgType, msgBody, internalHash, msgReceipt, userAck}
|
|
|
|
checkRcvMsgHashExists :: DB.Connection -> ConnId -> ByteString -> IO Bool
|
|
checkRcvMsgHashExists db connId hash = do
|
|
fromMaybe False
|
|
<$> maybeFirstRow
|
|
fromOnly
|
|
( DB.query
|
|
db
|
|
"SELECT 1 FROM encrypted_rcv_message_hashes WHERE conn_id = ? AND hash = ? LIMIT 1"
|
|
(connId, hash)
|
|
)
|
|
|
|
deleteMsg :: DB.Connection -> ConnId -> InternalId -> IO ()
|
|
deleteMsg db connId msgId =
|
|
DB.execute db "DELETE FROM messages WHERE conn_id = ? AND internal_id = ?;" (connId, msgId)
|
|
|
|
deleteMsgContent :: DB.Connection -> ConnId -> InternalId -> IO ()
|
|
deleteMsgContent db connId msgId =
|
|
DB.execute db "UPDATE messages SET msg_body = x'' WHERE conn_id = ? AND internal_id = ?;" (connId, msgId)
|
|
|
|
deleteDeliveredSndMsg :: DB.Connection -> ConnId -> InternalId -> IO ()
|
|
deleteDeliveredSndMsg db connId msgId = do
|
|
cnt <- countPendingSndDeliveries_ db connId msgId
|
|
when (cnt == 0) $ deleteMsg db connId msgId
|
|
|
|
deleteSndMsgDelivery :: DB.Connection -> ConnId -> SndQueue -> InternalId -> Bool -> IO ()
|
|
deleteSndMsgDelivery db connId SndQueue {dbQueueId} msgId keepForReceipt = do
|
|
DB.execute
|
|
db
|
|
"DELETE FROM snd_message_deliveries WHERE conn_id = ? AND snd_queue_id = ? AND internal_id = ?"
|
|
(connId, dbQueueId, msgId)
|
|
cnt <- countPendingSndDeliveries_ db connId msgId
|
|
when (cnt == 0) $ do
|
|
del <-
|
|
maybeFirstRow id (DB.query db "SELECT rcpt_internal_id, rcpt_status FROM snd_messages WHERE conn_id = ? AND internal_id = ?" (connId, msgId)) >>= \case
|
|
Just (Just (_ :: Int64), Just MROk) -> pure deleteMsg
|
|
_ -> pure $ if keepForReceipt then deleteMsgContent else deleteMsg
|
|
del db connId msgId
|
|
|
|
countPendingSndDeliveries_ :: DB.Connection -> ConnId -> InternalId -> IO Int
|
|
countPendingSndDeliveries_ db connId msgId = do
|
|
(Only cnt : _) <- DB.query db "SELECT count(*) FROM snd_message_deliveries WHERE conn_id = ? AND internal_id = ? AND failed = 0" (connId, msgId)
|
|
pure cnt
|
|
|
|
deleteRcvMsgHashesExpired :: DB.Connection -> NominalDiffTime -> IO ()
|
|
deleteRcvMsgHashesExpired db ttl = do
|
|
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
|
|
DB.execute db "DELETE FROM encrypted_rcv_message_hashes WHERE created_at < ?" (Only cutoffTs)
|
|
|
|
deleteSndMsgsExpired :: DB.Connection -> NominalDiffTime -> IO ()
|
|
deleteSndMsgsExpired db ttl = do
|
|
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
|
|
DB.execute
|
|
db
|
|
"DELETE FROM messages WHERE internal_ts < ? AND internal_snd_id IS NOT NULL"
|
|
(Only cutoffTs)
|
|
|
|
createRatchetX3dhKeys :: DB.Connection -> ConnId -> C.PrivateKeyX448 -> C.PrivateKeyX448 -> Maybe CR.RcvPrivRKEMParams -> IO ()
|
|
createRatchetX3dhKeys db connId x3dhPrivKey1 x3dhPrivKey2 pqPrivKem =
|
|
DB.execute db "INSERT INTO ratchets (conn_id, x3dh_priv_key_1, x3dh_priv_key_2, pq_priv_kem) VALUES (?, ?, ?, ?)" (connId, x3dhPrivKey1, x3dhPrivKey2, pqPrivKem)
|
|
|
|
getRatchetX3dhKeys :: DB.Connection -> ConnId -> IO (Either StoreError (C.PrivateKeyX448, C.PrivateKeyX448, Maybe CR.RcvPrivRKEMParams))
|
|
getRatchetX3dhKeys db connId =
|
|
firstRow' keys SEX3dhKeysNotFound $
|
|
DB.query db "SELECT x3dh_priv_key_1, x3dh_priv_key_2, pq_priv_kem FROM ratchets WHERE conn_id = ?" (Only connId)
|
|
where
|
|
keys = \case
|
|
(Just k1, Just k2, pKem) -> Right (k1, k2, pKem)
|
|
_ -> Left SEX3dhKeysNotFound
|
|
|
|
-- used to remember new keys when starting ratchet re-synchronization
|
|
-- TODO remove the columns for public keys in v5.7.
|
|
-- Currently, the keys are not used but still stored to support app downgrade to the previous version.
|
|
setRatchetX3dhKeys :: DB.Connection -> ConnId -> C.PrivateKeyX448 -> C.PrivateKeyX448 -> Maybe CR.RcvPrivRKEMParams -> IO ()
|
|
setRatchetX3dhKeys db connId x3dhPrivKey1 x3dhPrivKey2 pqPrivKem =
|
|
DB.execute
|
|
db
|
|
[sql|
|
|
UPDATE ratchets
|
|
SET x3dh_priv_key_1 = ?, x3dh_priv_key_2 = ?, x3dh_pub_key_1 = ?, x3dh_pub_key_2 = ?, pq_priv_kem = ?
|
|
WHERE conn_id = ?
|
|
|]
|
|
(x3dhPrivKey1, x3dhPrivKey2, C.publicKey x3dhPrivKey1, C.publicKey x3dhPrivKey2, pqPrivKem, connId)
|
|
|
|
-- TODO remove the columns for public keys in v5.7.
|
|
createRatchet :: DB.Connection -> ConnId -> RatchetX448 -> IO ()
|
|
createRatchet db connId rc =
|
|
DB.executeNamed
|
|
db
|
|
[sql|
|
|
INSERT INTO ratchets (conn_id, ratchet_state)
|
|
VALUES (:conn_id, :ratchet_state)
|
|
ON CONFLICT (conn_id) DO UPDATE SET
|
|
ratchet_state = :ratchet_state,
|
|
x3dh_priv_key_1 = NULL,
|
|
x3dh_priv_key_2 = NULL,
|
|
x3dh_pub_key_1 = NULL,
|
|
x3dh_pub_key_2 = NULL,
|
|
pq_priv_kem = NULL
|
|
|]
|
|
[":conn_id" := connId, ":ratchet_state" := rc]
|
|
|
|
deleteRatchet :: DB.Connection -> ConnId -> IO ()
|
|
deleteRatchet db connId =
|
|
DB.execute db "DELETE FROM ratchets WHERE conn_id = ?" (Only connId)
|
|
|
|
getRatchet :: DB.Connection -> ConnId -> IO (Either StoreError RatchetX448)
|
|
getRatchet db connId =
|
|
firstRow' ratchet SERatchetNotFound $ DB.query db "SELECT ratchet_state FROM ratchets WHERE conn_id = ?" (Only connId)
|
|
where
|
|
ratchet = maybe (Left SERatchetNotFound) Right . fromOnly
|
|
|
|
getSkippedMsgKeys :: DB.Connection -> ConnId -> IO SkippedMsgKeys
|
|
getSkippedMsgKeys db connId =
|
|
skipped <$> DB.query db "SELECT header_key, msg_n, msg_key FROM skipped_messages WHERE conn_id = ?" (Only connId)
|
|
where
|
|
skipped = foldl' addSkippedKey M.empty
|
|
addSkippedKey smks (hk, msgN, mk) = M.alter (Just . addMsgKey) hk smks
|
|
where
|
|
addMsgKey = maybe (M.singleton msgN mk) (M.insert msgN mk)
|
|
|
|
updateRatchet :: DB.Connection -> ConnId -> RatchetX448 -> SkippedMsgDiff -> IO ()
|
|
updateRatchet db connId rc skipped = do
|
|
DB.execute db "UPDATE ratchets SET ratchet_state = ? WHERE conn_id = ?" (rc, connId)
|
|
case skipped of
|
|
SMDNoChange -> pure ()
|
|
SMDRemove hk msgN ->
|
|
DB.execute db "DELETE FROM skipped_messages WHERE conn_id = ? AND header_key = ? AND msg_n = ?" (connId, hk, msgN)
|
|
SMDAdd smks ->
|
|
forM_ (M.assocs smks) $ \(hk, mks) ->
|
|
forM_ (M.assocs mks) $ \(msgN, mk) ->
|
|
DB.execute db "INSERT INTO skipped_messages (conn_id, header_key, msg_n, msg_key) VALUES (?, ?, ?, ?)" (connId, hk, msgN, mk)
|
|
|
|
createCommand :: DB.Connection -> ACorrId -> ConnId -> Maybe SMPServer -> AgentCommand -> IO (Either StoreError ())
|
|
createCommand db corrId connId srv_ cmd = runExceptT $ do
|
|
(host_, port_, serverKeyHash_) <- serverFields
|
|
createdAt <- liftIO getCurrentTime
|
|
liftIO $
|
|
DB.execute
|
|
db
|
|
"INSERT INTO commands (host, port, corr_id, conn_id, command_tag, command, server_key_hash, created_at) VALUES (?,?,?,?,?,?,?,?)"
|
|
(host_, port_, corrId, connId, agentCommandTag cmd, cmd, serverKeyHash_, createdAt)
|
|
where
|
|
serverFields :: ExceptT StoreError IO (Maybe (NonEmpty TransportHost), Maybe ServiceName, Maybe C.KeyHash)
|
|
serverFields = case srv_ of
|
|
Just srv@(SMPServer host port _) ->
|
|
(Just host,Just port,) <$> ExceptT (getServerKeyHash_ db srv)
|
|
Nothing -> pure (Nothing, Nothing, Nothing)
|
|
|
|
insertedRowId :: DB.Connection -> IO Int64
|
|
insertedRowId db = fromOnly . head <$> DB.query_ db "SELECT last_insert_rowid()"
|
|
|
|
getPendingCommandServers :: DB.Connection -> ConnId -> IO [Maybe SMPServer]
|
|
getPendingCommandServers db connId = do
|
|
-- TODO review whether this can break if, e.g., the server has another key hash.
|
|
map smpServer
|
|
<$> DB.query
|
|
db
|
|
[sql|
|
|
SELECT DISTINCT c.host, c.port, COALESCE(c.server_key_hash, s.key_hash)
|
|
FROM commands c
|
|
LEFT JOIN servers s ON s.host = c.host AND s.port = c.port
|
|
WHERE conn_id = ?
|
|
|]
|
|
(Only connId)
|
|
where
|
|
smpServer (host, port, keyHash) = SMPServer <$> host <*> port <*> keyHash
|
|
|
|
getPendingServerCommand :: DB.Connection -> Maybe SMPServer -> IO (Either StoreError (Maybe PendingCommand))
|
|
getPendingServerCommand db srv_ = getWorkItem "command" getCmdId getCommand markCommandFailed
|
|
where
|
|
getCmdId :: IO (Maybe Int64)
|
|
getCmdId =
|
|
maybeFirstRow fromOnly $ case srv_ of
|
|
Nothing ->
|
|
DB.query_
|
|
db
|
|
[sql|
|
|
SELECT command_id FROM commands
|
|
WHERE host IS NULL AND port IS NULL AND failed = 0
|
|
ORDER BY created_at ASC, command_id ASC
|
|
LIMIT 1
|
|
|]
|
|
Just (SMPServer host port _) ->
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT command_id FROM commands
|
|
WHERE host = ? AND port = ? AND failed = 0
|
|
ORDER BY created_at ASC, command_id ASC
|
|
LIMIT 1
|
|
|]
|
|
(host, port)
|
|
getCommand :: Int64 -> IO (Either StoreError PendingCommand)
|
|
getCommand cmdId =
|
|
firstRow pendingCommand err $
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT c.corr_id, cs.user_id, c.conn_id, c.command
|
|
FROM commands c
|
|
JOIN connections cs USING (conn_id)
|
|
WHERE c.command_id = ?
|
|
|]
|
|
(Only cmdId)
|
|
where
|
|
err = SEInternal $ "command " <> bshow cmdId <> " returned []"
|
|
pendingCommand (corrId, userId, connId, command) = PendingCommand {cmdId, corrId, userId, connId, command}
|
|
markCommandFailed cmdId = DB.execute db "UPDATE commands SET failed = 1 WHERE command_id = ?" (Only cmdId)
|
|
|
|
deleteCommand :: DB.Connection -> AsyncCmdId -> IO ()
|
|
deleteCommand db cmdId =
|
|
DB.execute db "DELETE FROM commands WHERE command_id = ?" (Only cmdId)
|
|
|
|
createNtfToken :: DB.Connection -> NtfToken -> IO ()
|
|
createNtfToken db NtfToken {deviceToken = DeviceToken provider token, ntfServer = srv@ProtocolServer {host, port}, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhKeys = (ntfDhPubKey, ntfDhPrivKey), ntfDhSecret, ntfTknStatus, ntfTknAction, ntfMode} = do
|
|
upsertNtfServer_ db srv
|
|
DB.execute
|
|
db
|
|
[sql|
|
|
INSERT INTO ntf_tokens
|
|
(provider, device_token, ntf_host, ntf_port, tkn_id, tkn_pub_key, tkn_priv_key, tkn_pub_dh_key, tkn_priv_dh_key, tkn_dh_secret, tkn_status, tkn_action, ntf_mode) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
|]
|
|
((provider, token, host, port, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhPubKey, ntfDhPrivKey, ntfDhSecret) :. (ntfTknStatus, ntfTknAction, ntfMode))
|
|
|
|
getSavedNtfToken :: DB.Connection -> IO (Maybe NtfToken)
|
|
getSavedNtfToken db = do
|
|
maybeFirstRow ntfToken $
|
|
DB.query_
|
|
db
|
|
[sql|
|
|
SELECT s.ntf_host, s.ntf_port, s.ntf_key_hash,
|
|
t.provider, t.device_token, t.tkn_id, t.tkn_pub_key, t.tkn_priv_key, t.tkn_pub_dh_key, t.tkn_priv_dh_key, t.tkn_dh_secret,
|
|
t.tkn_status, t.tkn_action, t.ntf_mode
|
|
FROM ntf_tokens t
|
|
JOIN ntf_servers s USING (ntf_host, ntf_port)
|
|
|]
|
|
where
|
|
ntfToken ((host, port, keyHash) :. (provider, dt, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhPubKey, ntfDhPrivKey, ntfDhSecret) :. (ntfTknStatus, ntfTknAction, ntfMode_)) =
|
|
let ntfServer = NtfServer host port keyHash
|
|
ntfDhKeys = (ntfDhPubKey, ntfDhPrivKey)
|
|
ntfMode = fromMaybe NMPeriodic ntfMode_
|
|
in NtfToken {deviceToken = DeviceToken provider dt, ntfServer, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhKeys, ntfDhSecret, ntfTknStatus, ntfTknAction, ntfMode}
|
|
|
|
updateNtfTokenRegistration :: DB.Connection -> NtfToken -> NtfTokenId -> C.DhSecretX25519 -> IO ()
|
|
updateNtfTokenRegistration db NtfToken {deviceToken = DeviceToken provider token, ntfServer = ProtocolServer {host, port}} tknId ntfDhSecret = do
|
|
updatedAt <- getCurrentTime
|
|
DB.execute
|
|
db
|
|
[sql|
|
|
UPDATE ntf_tokens
|
|
SET tkn_id = ?, tkn_dh_secret = ?, tkn_status = ?, tkn_action = ?, updated_at = ?
|
|
WHERE provider = ? AND device_token = ? AND ntf_host = ? AND ntf_port = ?
|
|
|]
|
|
(tknId, ntfDhSecret, NTRegistered, Nothing :: Maybe NtfTknAction, updatedAt, provider, token, host, port)
|
|
|
|
updateDeviceToken :: DB.Connection -> NtfToken -> DeviceToken -> IO ()
|
|
updateDeviceToken db NtfToken {deviceToken = DeviceToken provider token, ntfServer = ProtocolServer {host, port}} (DeviceToken toProvider toToken) = do
|
|
updatedAt <- getCurrentTime
|
|
DB.execute
|
|
db
|
|
[sql|
|
|
UPDATE ntf_tokens
|
|
SET provider = ?, device_token = ?, tkn_status = ?, tkn_action = ?, updated_at = ?
|
|
WHERE provider = ? AND device_token = ? AND ntf_host = ? AND ntf_port = ?
|
|
|]
|
|
(toProvider, toToken, NTRegistered, Nothing :: Maybe NtfTknAction, updatedAt, provider, token, host, port)
|
|
|
|
updateNtfMode :: DB.Connection -> NtfToken -> NotificationsMode -> IO ()
|
|
updateNtfMode db NtfToken {deviceToken = DeviceToken provider token, ntfServer = ProtocolServer {host, port}} ntfMode = do
|
|
updatedAt <- getCurrentTime
|
|
DB.execute
|
|
db
|
|
[sql|
|
|
UPDATE ntf_tokens
|
|
SET ntf_mode = ?, updated_at = ?
|
|
WHERE provider = ? AND device_token = ? AND ntf_host = ? AND ntf_port = ?
|
|
|]
|
|
(ntfMode, updatedAt, provider, token, host, port)
|
|
|
|
updateNtfToken :: DB.Connection -> NtfToken -> NtfTknStatus -> Maybe NtfTknAction -> IO ()
|
|
updateNtfToken db NtfToken {deviceToken = DeviceToken provider token, ntfServer = ProtocolServer {host, port}} tknStatus tknAction = do
|
|
updatedAt <- getCurrentTime
|
|
DB.execute
|
|
db
|
|
[sql|
|
|
UPDATE ntf_tokens
|
|
SET tkn_status = ?, tkn_action = ?, updated_at = ?
|
|
WHERE provider = ? AND device_token = ? AND ntf_host = ? AND ntf_port = ?
|
|
|]
|
|
(tknStatus, tknAction, updatedAt, provider, token, host, port)
|
|
|
|
removeNtfToken :: DB.Connection -> NtfToken -> IO ()
|
|
removeNtfToken db NtfToken {deviceToken = DeviceToken provider token, ntfServer = ProtocolServer {host, port}} =
|
|
DB.execute
|
|
db
|
|
[sql|
|
|
DELETE FROM ntf_tokens
|
|
WHERE provider = ? AND device_token = ? AND ntf_host = ? AND ntf_port = ?
|
|
|]
|
|
(provider, token, host, port)
|
|
|
|
getNtfSubscription :: DB.Connection -> ConnId -> IO (Maybe (NtfSubscription, Maybe (NtfSubAction, NtfActionTs)))
|
|
getNtfSubscription db connId =
|
|
maybeFirstRow ntfSubscription $
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT s.host, s.port, COALESCE(nsb.smp_server_key_hash, s.key_hash), ns.ntf_host, ns.ntf_port, ns.ntf_key_hash,
|
|
nsb.smp_ntf_id, nsb.ntf_sub_id, nsb.ntf_sub_status, nsb.ntf_sub_action, nsb.ntf_sub_smp_action, nsb.ntf_sub_action_ts
|
|
FROM ntf_subscriptions nsb
|
|
JOIN servers s ON s.host = nsb.smp_host AND s.port = nsb.smp_port
|
|
JOIN ntf_servers ns USING (ntf_host, ntf_port)
|
|
WHERE nsb.conn_id = ?
|
|
|]
|
|
(Only connId)
|
|
where
|
|
ntfSubscription (smpHost, smpPort, smpKeyHash, ntfHost, ntfPort, ntfKeyHash, ntfQueueId, ntfSubId, ntfSubStatus, ntfAction_, smpAction_, actionTs_) =
|
|
let smpServer = SMPServer smpHost smpPort smpKeyHash
|
|
ntfServer = NtfServer ntfHost ntfPort ntfKeyHash
|
|
action = case (ntfAction_, smpAction_, actionTs_) of
|
|
(Just ntfAction, Nothing, Just actionTs) -> Just (NtfSubNTFAction ntfAction, actionTs)
|
|
(Nothing, Just smpAction, Just actionTs) -> Just (NtfSubSMPAction smpAction, actionTs)
|
|
_ -> Nothing
|
|
in (NtfSubscription {connId, smpServer, ntfQueueId, ntfServer, ntfSubId, ntfSubStatus}, action)
|
|
|
|
createNtfSubscription :: DB.Connection -> NtfSubscription -> NtfSubAction -> IO (Either StoreError ())
|
|
createNtfSubscription db ntfSubscription action = runExceptT $ do
|
|
let NtfSubscription {connId, smpServer = smpServer@(SMPServer host port _), ntfQueueId, ntfServer = (NtfServer ntfHost ntfPort _), ntfSubId, ntfSubStatus} = ntfSubscription
|
|
smpServerKeyHash_ <- ExceptT $ getServerKeyHash_ db smpServer
|
|
actionTs <- liftIO getCurrentTime
|
|
liftIO $
|
|
DB.execute
|
|
db
|
|
[sql|
|
|
INSERT INTO ntf_subscriptions
|
|
(conn_id, smp_host, smp_port, smp_ntf_id, ntf_host, ntf_port, ntf_sub_id,
|
|
ntf_sub_status, ntf_sub_action, ntf_sub_smp_action, ntf_sub_action_ts, smp_server_key_hash)
|
|
VALUES (?,?,?,?,?,?,?,?,?,?,?,?)
|
|
|]
|
|
( (connId, host, port, ntfQueueId, ntfHost, ntfPort, ntfSubId)
|
|
:. (ntfSubStatus, ntfSubAction, ntfSubSMPAction, actionTs, smpServerKeyHash_)
|
|
)
|
|
where
|
|
(ntfSubAction, ntfSubSMPAction) = ntfSubAndSMPAction action
|
|
|
|
supervisorUpdateNtfSub :: DB.Connection -> NtfSubscription -> NtfSubAction -> IO ()
|
|
supervisorUpdateNtfSub db NtfSubscription {connId, ntfQueueId, ntfServer = (NtfServer ntfHost ntfPort _), ntfSubId, ntfSubStatus} action = do
|
|
ts <- getCurrentTime
|
|
DB.execute
|
|
db
|
|
[sql|
|
|
UPDATE ntf_subscriptions
|
|
SET smp_ntf_id = ?, ntf_host = ?, ntf_port = ?, ntf_sub_id = ?, ntf_sub_status = ?, ntf_sub_action = ?, ntf_sub_smp_action = ?, ntf_sub_action_ts = ?, updated_by_supervisor = ?, updated_at = ?
|
|
WHERE conn_id = ?
|
|
|]
|
|
((ntfQueueId, ntfHost, ntfPort, ntfSubId) :. (ntfSubStatus, ntfSubAction, ntfSubSMPAction, ts, True, ts, connId))
|
|
where
|
|
(ntfSubAction, ntfSubSMPAction) = ntfSubAndSMPAction action
|
|
|
|
supervisorUpdateNtfAction :: DB.Connection -> ConnId -> NtfSubAction -> IO ()
|
|
supervisorUpdateNtfAction db connId action = do
|
|
ts <- getCurrentTime
|
|
DB.execute
|
|
db
|
|
[sql|
|
|
UPDATE ntf_subscriptions
|
|
SET ntf_sub_action = ?, ntf_sub_smp_action = ?, ntf_sub_action_ts = ?, updated_by_supervisor = ?, updated_at = ?
|
|
WHERE conn_id = ?
|
|
|]
|
|
(ntfSubAction, ntfSubSMPAction, ts, True, ts, connId)
|
|
where
|
|
(ntfSubAction, ntfSubSMPAction) = ntfSubAndSMPAction action
|
|
|
|
updateNtfSubscription :: DB.Connection -> NtfSubscription -> NtfSubAction -> NtfActionTs -> IO ()
|
|
updateNtfSubscription db NtfSubscription {connId, ntfQueueId, ntfServer = (NtfServer ntfHost ntfPort _), ntfSubId, ntfSubStatus} action actionTs = do
|
|
r <- maybeFirstRow fromOnly $ DB.query db "SELECT updated_by_supervisor FROM ntf_subscriptions WHERE conn_id = ?" (Only connId)
|
|
forM_ r $ \updatedBySupervisor -> do
|
|
updatedAt <- getCurrentTime
|
|
if updatedBySupervisor
|
|
then
|
|
DB.execute
|
|
db
|
|
[sql|
|
|
UPDATE ntf_subscriptions
|
|
SET smp_ntf_id = ?, ntf_sub_id = ?, ntf_sub_status = ?, updated_by_supervisor = ?, updated_at = ?
|
|
WHERE conn_id = ?
|
|
|]
|
|
(ntfQueueId, ntfSubId, ntfSubStatus, False, updatedAt, connId)
|
|
else
|
|
DB.execute
|
|
db
|
|
[sql|
|
|
UPDATE ntf_subscriptions
|
|
SET smp_ntf_id = ?, ntf_host = ?, ntf_port = ?, ntf_sub_id = ?, ntf_sub_status = ?, ntf_sub_action = ?, ntf_sub_smp_action = ?, ntf_sub_action_ts = ?, updated_by_supervisor = ?, updated_at = ?
|
|
WHERE conn_id = ?
|
|
|]
|
|
((ntfQueueId, ntfHost, ntfPort, ntfSubId) :. (ntfSubStatus, ntfSubAction, ntfSubSMPAction, actionTs, False, updatedAt, connId))
|
|
where
|
|
(ntfSubAction, ntfSubSMPAction) = ntfSubAndSMPAction action
|
|
|
|
setNullNtfSubscriptionAction :: DB.Connection -> ConnId -> IO ()
|
|
setNullNtfSubscriptionAction db connId = do
|
|
r <- maybeFirstRow fromOnly $ DB.query db "SELECT updated_by_supervisor FROM ntf_subscriptions WHERE conn_id = ?" (Only connId)
|
|
forM_ r $ \updatedBySupervisor ->
|
|
unless updatedBySupervisor $ do
|
|
updatedAt <- getCurrentTime
|
|
DB.execute
|
|
db
|
|
[sql|
|
|
UPDATE ntf_subscriptions
|
|
SET ntf_sub_action = ?, ntf_sub_smp_action = ?, ntf_sub_action_ts = ?, updated_by_supervisor = ?, updated_at = ?
|
|
WHERE conn_id = ?
|
|
|]
|
|
(Nothing :: Maybe NtfSubNTFAction, Nothing :: Maybe NtfSubSMPAction, Nothing :: Maybe UTCTime, False, updatedAt, connId)
|
|
|
|
deleteNtfSubscription :: DB.Connection -> ConnId -> IO ()
|
|
deleteNtfSubscription db connId = do
|
|
r <- maybeFirstRow fromOnly $ DB.query db "SELECT updated_by_supervisor FROM ntf_subscriptions WHERE conn_id = ?" (Only connId)
|
|
forM_ r $ \updatedBySupervisor -> do
|
|
updatedAt <- getCurrentTime
|
|
if updatedBySupervisor
|
|
then
|
|
DB.execute
|
|
db
|
|
[sql|
|
|
UPDATE ntf_subscriptions
|
|
SET smp_ntf_id = ?, ntf_sub_id = ?, ntf_sub_status = ?, updated_by_supervisor = ?, updated_at = ?
|
|
WHERE conn_id = ?
|
|
|]
|
|
(Nothing :: Maybe SMP.NotifierId, Nothing :: Maybe NtfSubscriptionId, NASDeleted, False, updatedAt, connId)
|
|
else DB.execute db "DELETE FROM ntf_subscriptions WHERE conn_id = ?" (Only connId)
|
|
|
|
getNextNtfSubNTFAction :: DB.Connection -> NtfServer -> IO (Either StoreError (Maybe (NtfSubscription, NtfSubNTFAction, NtfActionTs)))
|
|
getNextNtfSubNTFAction db ntfServer@(NtfServer ntfHost ntfPort _) =
|
|
getWorkItem "ntf NTF" getNtfConnId getNtfSubAction (markNtfSubActionNtfFailed_ db)
|
|
where
|
|
getNtfConnId :: IO (Maybe ConnId)
|
|
getNtfConnId =
|
|
maybeFirstRow fromOnly $
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT conn_id
|
|
FROM ntf_subscriptions
|
|
WHERE ntf_host = ? AND ntf_port = ? AND ntf_sub_action IS NOT NULL
|
|
AND (ntf_failed = 0 OR updated_by_supervisor = 1)
|
|
ORDER BY ntf_sub_action_ts ASC
|
|
LIMIT 1
|
|
|]
|
|
(ntfHost, ntfPort)
|
|
getNtfSubAction :: ConnId -> IO (Either StoreError (NtfSubscription, NtfSubNTFAction, NtfActionTs))
|
|
getNtfSubAction connId = do
|
|
markUpdatedByWorker db connId
|
|
firstRow ntfSubAction err $
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT s.host, s.port, COALESCE(ns.smp_server_key_hash, s.key_hash),
|
|
ns.smp_ntf_id, ns.ntf_sub_id, ns.ntf_sub_status, ns.ntf_sub_action_ts, ns.ntf_sub_action
|
|
FROM ntf_subscriptions ns
|
|
JOIN servers s ON s.host = ns.smp_host AND s.port = ns.smp_port
|
|
WHERE ns.conn_id = ?
|
|
|]
|
|
(Only connId)
|
|
where
|
|
err = SEInternal $ "ntf subscription " <> bshow connId <> " returned []"
|
|
ntfSubAction (smpHost, smpPort, smpKeyHash, ntfQueueId, ntfSubId, ntfSubStatus, actionTs, action) =
|
|
let smpServer = SMPServer smpHost smpPort smpKeyHash
|
|
ntfSubscription = NtfSubscription {connId, smpServer, ntfQueueId, ntfServer, ntfSubId, ntfSubStatus}
|
|
in (ntfSubscription, action, actionTs)
|
|
|
|
markNtfSubActionNtfFailed_ :: DB.Connection -> ConnId -> IO ()
|
|
markNtfSubActionNtfFailed_ db connId =
|
|
DB.execute db "UPDATE ntf_subscriptions SET ntf_failed = 1 where conn_id = ?" (Only connId)
|
|
|
|
getNextNtfSubSMPAction :: DB.Connection -> SMPServer -> IO (Either StoreError (Maybe (NtfSubscription, NtfSubSMPAction, NtfActionTs)))
|
|
getNextNtfSubSMPAction db smpServer@(SMPServer smpHost smpPort _) =
|
|
getWorkItem "ntf SMP" getNtfConnId getNtfSubAction (markNtfSubActionSMPFailed_ db)
|
|
where
|
|
getNtfConnId :: IO (Maybe ConnId)
|
|
getNtfConnId =
|
|
maybeFirstRow fromOnly $
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT conn_id
|
|
FROM ntf_subscriptions ns
|
|
WHERE smp_host = ? AND smp_port = ? AND ntf_sub_smp_action IS NOT NULL AND ntf_sub_action_ts IS NOT NULL
|
|
AND (smp_failed = 0 OR updated_by_supervisor = 1)
|
|
ORDER BY ntf_sub_action_ts ASC
|
|
LIMIT 1
|
|
|]
|
|
(smpHost, smpPort)
|
|
getNtfSubAction :: ConnId -> IO (Either StoreError (NtfSubscription, NtfSubSMPAction, NtfActionTs))
|
|
getNtfSubAction connId = do
|
|
markUpdatedByWorker db connId
|
|
firstRow ntfSubAction err $
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT s.ntf_host, s.ntf_port, s.ntf_key_hash,
|
|
ns.smp_ntf_id, ns.ntf_sub_id, ns.ntf_sub_status, ns.ntf_sub_action_ts, ns.ntf_sub_smp_action
|
|
FROM ntf_subscriptions ns
|
|
JOIN ntf_servers s USING (ntf_host, ntf_port)
|
|
WHERE ns.conn_id = ?
|
|
|]
|
|
(Only connId)
|
|
where
|
|
err = SEInternal $ "ntf subscription " <> bshow connId <> " returned []"
|
|
ntfSubAction (ntfHost, ntfPort, ntfKeyHash, ntfQueueId, ntfSubId, ntfSubStatus, actionTs, action) =
|
|
let ntfServer = NtfServer ntfHost ntfPort ntfKeyHash
|
|
ntfSubscription = NtfSubscription {connId, smpServer, ntfQueueId, ntfServer, ntfSubId, ntfSubStatus}
|
|
in (ntfSubscription, action, actionTs)
|
|
|
|
markNtfSubActionSMPFailed_ :: DB.Connection -> ConnId -> IO ()
|
|
markNtfSubActionSMPFailed_ db connId =
|
|
DB.execute db "UPDATE ntf_subscriptions SET smp_failed = 1 where conn_id = ?" (Only connId)
|
|
|
|
markUpdatedByWorker :: DB.Connection -> ConnId -> IO ()
|
|
markUpdatedByWorker db connId =
|
|
DB.execute db "UPDATE ntf_subscriptions SET updated_by_supervisor = 0 WHERE conn_id = ?" (Only connId)
|
|
|
|
getActiveNtfToken :: DB.Connection -> IO (Maybe NtfToken)
|
|
getActiveNtfToken db =
|
|
maybeFirstRow ntfToken $
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT s.ntf_host, s.ntf_port, s.ntf_key_hash,
|
|
t.provider, t.device_token, t.tkn_id, t.tkn_pub_key, t.tkn_priv_key, t.tkn_pub_dh_key, t.tkn_priv_dh_key, t.tkn_dh_secret,
|
|
t.tkn_status, t.tkn_action, t.ntf_mode
|
|
FROM ntf_tokens t
|
|
JOIN ntf_servers s USING (ntf_host, ntf_port)
|
|
WHERE t.tkn_status = ?
|
|
|]
|
|
(Only NTActive)
|
|
where
|
|
ntfToken ((host, port, keyHash) :. (provider, dt, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhPubKey, ntfDhPrivKey, ntfDhSecret) :. (ntfTknStatus, ntfTknAction, ntfMode_)) =
|
|
let ntfServer = NtfServer host port keyHash
|
|
ntfDhKeys = (ntfDhPubKey, ntfDhPrivKey)
|
|
ntfMode = fromMaybe NMPeriodic ntfMode_
|
|
in NtfToken {deviceToken = DeviceToken provider dt, ntfServer, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhKeys, ntfDhSecret, ntfTknStatus, ntfTknAction, ntfMode}
|
|
|
|
getNtfRcvQueue :: DB.Connection -> SMPQueueNtf -> IO (Either StoreError (ConnId, RcvNtfDhSecret))
|
|
getNtfRcvQueue db SMPQueueNtf {smpServer = (SMPServer host port _), notifierId} =
|
|
firstRow' res SEConnNotFound $
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT conn_id, rcv_ntf_dh_secret
|
|
FROM rcv_queues
|
|
WHERE host = ? AND port = ? AND ntf_id = ? AND deleted = 0
|
|
|]
|
|
(host, port, notifierId)
|
|
where
|
|
res (connId, Just rcvNtfDhSecret) = Right (connId, rcvNtfDhSecret)
|
|
res _ = Left SEConnNotFound
|
|
|
|
setConnectionNtfs :: DB.Connection -> ConnId -> Bool -> IO ()
|
|
setConnectionNtfs db connId enableNtfs =
|
|
DB.execute db "UPDATE connections SET enable_ntfs = ? WHERE conn_id = ?" (enableNtfs, connId)
|
|
|
|
-- * Auxiliary helpers
|
|
|
|
instance ToField QueueStatus where toField = toField . serializeQueueStatus
|
|
|
|
instance FromField QueueStatus where fromField = fromTextField_ queueStatusT
|
|
|
|
instance ToField (DBQueueId 'QSStored) where toField (DBQueueId qId) = toField qId
|
|
|
|
instance FromField (DBQueueId 'QSStored) where fromField x = DBQueueId <$> fromField x
|
|
|
|
instance ToField InternalRcvId where toField (InternalRcvId x) = toField x
|
|
|
|
instance FromField InternalRcvId where fromField x = InternalRcvId <$> fromField x
|
|
|
|
instance ToField InternalSndId where toField (InternalSndId x) = toField x
|
|
|
|
instance FromField InternalSndId where fromField x = InternalSndId <$> fromField x
|
|
|
|
instance ToField InternalId where toField (InternalId x) = toField x
|
|
|
|
instance FromField InternalId where fromField x = InternalId <$> fromField x
|
|
|
|
instance ToField AgentMessageType where toField = toField . smpEncode
|
|
|
|
instance FromField AgentMessageType where fromField = blobFieldParser smpP
|
|
|
|
instance ToField MsgIntegrity where toField = toField . strEncode
|
|
|
|
instance FromField MsgIntegrity where fromField = blobFieldParser strP
|
|
|
|
instance ToField SMPQueueUri where toField = toField . strEncode
|
|
|
|
instance FromField SMPQueueUri where fromField = blobFieldParser strP
|
|
|
|
instance ToField AConnectionRequestUri where toField = toField . strEncode
|
|
|
|
instance FromField AConnectionRequestUri where fromField = blobFieldParser strP
|
|
|
|
instance ConnectionModeI c => ToField (ConnectionRequestUri c) where toField = toField . strEncode
|
|
|
|
instance (E.Typeable c, ConnectionModeI c) => FromField (ConnectionRequestUri c) where fromField = blobFieldParser strP
|
|
|
|
instance ToField ConnectionMode where toField = toField . decodeLatin1 . strEncode
|
|
|
|
instance FromField ConnectionMode where fromField = fromTextField_ connModeT
|
|
|
|
instance ToField (SConnectionMode c) where toField = toField . connMode
|
|
|
|
instance FromField AConnectionMode where fromField = fromTextField_ $ fmap connMode' . connModeT
|
|
|
|
instance ToField MsgFlags where toField = toField . decodeLatin1 . smpEncode
|
|
|
|
instance FromField MsgFlags where fromField = fromTextField_ $ eitherToMaybe . smpDecode . encodeUtf8
|
|
|
|
instance ToField [SMPQueueInfo] where toField = toField . smpEncodeList
|
|
|
|
instance FromField [SMPQueueInfo] where fromField = blobFieldParser smpListP
|
|
|
|
instance ToField (NonEmpty TransportHost) where toField = toField . decodeLatin1 . strEncode
|
|
|
|
instance FromField (NonEmpty TransportHost) where fromField = fromTextField_ $ eitherToMaybe . strDecode . encodeUtf8
|
|
|
|
instance ToField AgentCommand where toField = toField . strEncode
|
|
|
|
instance FromField AgentCommand where fromField = blobFieldParser strP
|
|
|
|
instance ToField AgentCommandTag where toField = toField . strEncode
|
|
|
|
instance FromField AgentCommandTag where fromField = blobFieldParser strP
|
|
|
|
instance ToField MsgReceiptStatus where toField = toField . decodeLatin1 . strEncode
|
|
|
|
instance FromField MsgReceiptStatus where fromField = fromTextField_ $ eitherToMaybe . strDecode . encodeUtf8
|
|
|
|
instance ToField (Version v) where toField (Version v) = toField v
|
|
|
|
instance FromField (Version v) where fromField f = Version <$> fromField f
|
|
|
|
instance ToField CR.PQEncryption where toField (CR.PQEncryption pqEnc) = toField pqEnc
|
|
|
|
instance FromField CR.PQEncryption where fromField f = CR.PQEncryption <$> fromField f
|
|
|
|
listToEither :: e -> [a] -> Either e a
|
|
listToEither _ (x : _) = Right x
|
|
listToEither e _ = Left e
|
|
|
|
firstRow :: (a -> b) -> e -> IO [a] -> IO (Either e b)
|
|
firstRow f e a = second f . listToEither e <$> a
|
|
|
|
maybeFirstRow :: Functor f => (a -> b) -> f [a] -> f (Maybe b)
|
|
maybeFirstRow f q = fmap f . listToMaybe <$> q
|
|
|
|
firstRow' :: (a -> Either e b) -> e -> IO [a] -> IO (Either e b)
|
|
firstRow' f e a = (f <=< listToEither e) <$> a
|
|
|
|
{- ORMOLU_DISABLE -}
|
|
-- SQLite.Simple only has these up to 10 fields, which is insufficient for some of our queries
|
|
instance (FromField a, FromField b, FromField c, FromField d, FromField e,
|
|
FromField f, FromField g, FromField h, FromField i, FromField j,
|
|
FromField k) =>
|
|
FromRow (a,b,c,d,e,f,g,h,i,j,k) where
|
|
fromRow = (,,,,,,,,,,) <$> field <*> field <*> field <*> field <*> field
|
|
<*> field <*> field <*> field <*> field <*> field
|
|
<*> field
|
|
|
|
instance (FromField a, FromField b, FromField c, FromField d, FromField e,
|
|
FromField f, FromField g, FromField h, FromField i, FromField j,
|
|
FromField k, FromField l) =>
|
|
FromRow (a,b,c,d,e,f,g,h,i,j,k,l) where
|
|
fromRow = (,,,,,,,,,,,) <$> field <*> field <*> field <*> field <*> field
|
|
<*> field <*> field <*> field <*> field <*> field
|
|
<*> field <*> field
|
|
|
|
instance (ToField a, ToField b, ToField c, ToField d, ToField e, ToField f,
|
|
ToField g, ToField h, ToField i, ToField j, ToField k, ToField l) =>
|
|
ToRow (a,b,c,d,e,f,g,h,i,j,k,l) where
|
|
toRow (a,b,c,d,e,f,g,h,i,j,k,l) =
|
|
[ toField a, toField b, toField c, toField d, toField e, toField f,
|
|
toField g, toField h, toField i, toField j, toField k, toField l
|
|
]
|
|
|
|
{- ORMOLU_ENABLE -}
|
|
|
|
-- * Server helper
|
|
|
|
-- | Creates a new server, if it doesn't exist, and returns the passed key hash if it is different from stored.
|
|
createServer_ :: DB.Connection -> SMPServer -> IO (Maybe C.KeyHash)
|
|
createServer_ db newSrv@ProtocolServer {host, port, keyHash} =
|
|
getServerKeyHash_ db newSrv >>= \case
|
|
Right keyHash_ -> pure keyHash_
|
|
Left _ -> insertNewServer_ $> Nothing
|
|
where
|
|
insertNewServer_ =
|
|
DB.execute db "INSERT INTO servers (host, port, key_hash) VALUES (?,?,?)" (host, port, keyHash)
|
|
|
|
-- | Returns the passed server key hash if it is different from the stored one, or the error if the server does not exist.
|
|
getServerKeyHash_ :: DB.Connection -> SMPServer -> IO (Either StoreError (Maybe C.KeyHash))
|
|
getServerKeyHash_ db ProtocolServer {host, port, keyHash} = do
|
|
firstRow useKeyHash SEServerNotFound $
|
|
DB.query db "SELECT key_hash FROM servers WHERE host = ? AND port = ?" (host, port)
|
|
where
|
|
useKeyHash (Only keyHash') = if keyHash /= keyHash' then Just keyHash else Nothing
|
|
|
|
upsertNtfServer_ :: DB.Connection -> NtfServer -> IO ()
|
|
upsertNtfServer_ db ProtocolServer {host, port, keyHash} = do
|
|
DB.executeNamed
|
|
db
|
|
[sql|
|
|
INSERT INTO ntf_servers (ntf_host, ntf_port, ntf_key_hash) VALUES (:host,:port,:key_hash)
|
|
ON CONFLICT (ntf_host, ntf_port) DO UPDATE SET
|
|
ntf_host=excluded.ntf_host,
|
|
ntf_port=excluded.ntf_port,
|
|
ntf_key_hash=excluded.ntf_key_hash;
|
|
|]
|
|
[":host" := host, ":port" := port, ":key_hash" := keyHash]
|
|
|
|
-- * createRcvConn helpers
|
|
|
|
insertRcvQueue_ :: DB.Connection -> ConnId -> NewRcvQueue -> Maybe C.KeyHash -> IO RcvQueue
|
|
insertRcvQueue_ db connId' rq@RcvQueue {..} serverKeyHash_ = do
|
|
qId <- newQueueId_ <$> DB.query db "SELECT rcv_queue_id FROM rcv_queues WHERE conn_id = ? ORDER BY rcv_queue_id DESC LIMIT 1" (Only connId')
|
|
DB.execute
|
|
db
|
|
[sql|
|
|
INSERT INTO rcv_queues
|
|
(host, port, rcv_id, conn_id, rcv_private_key, rcv_dh_secret, e2e_priv_key, e2e_dh_secret, snd_id, status, rcv_queue_id, rcv_primary, replace_rcv_queue_id, smp_client_version, server_key_hash) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?);
|
|
|]
|
|
((host server, port server, rcvId, connId', rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret) :. (sndId, status, qId, primary, dbReplaceQueueId, smpClientVersion, serverKeyHash_))
|
|
pure (rq :: NewRcvQueue) {connId = connId', dbQueueId = qId}
|
|
|
|
-- * createSndConn helpers
|
|
|
|
insertSndQueue_ :: DB.Connection -> ConnId -> NewSndQueue -> Maybe C.KeyHash -> IO SndQueue
|
|
insertSndQueue_ db connId' sq@SndQueue {..} serverKeyHash_ = do
|
|
qId <- newQueueId_ <$> DB.query db "SELECT snd_queue_id FROM snd_queues WHERE conn_id = ? ORDER BY snd_queue_id DESC LIMIT 1" (Only connId')
|
|
DB.execute
|
|
db
|
|
[sql|
|
|
INSERT OR REPLACE INTO snd_queues
|
|
(host, port, snd_id, conn_id, snd_public_key, snd_private_key, e2e_pub_key, e2e_dh_secret, status, snd_queue_id, snd_primary, replace_snd_queue_id, smp_client_version, server_key_hash) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?);
|
|
|]
|
|
((host server, port server, sndId, connId', sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret) :. (status, qId, primary, dbReplaceQueueId, smpClientVersion, serverKeyHash_))
|
|
pure (sq :: NewSndQueue) {connId = connId', dbQueueId = qId}
|
|
|
|
newQueueId_ :: [Only Int64] -> DBQueueId 'QSStored
|
|
newQueueId_ [] = DBQueueId 1
|
|
newQueueId_ (Only maxId : _) = DBQueueId (maxId + 1)
|
|
|
|
-- * getConn helpers
|
|
|
|
getConn :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn)
|
|
getConn = getAnyConn False
|
|
|
|
getDeletedConn :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn)
|
|
getDeletedConn = getAnyConn True
|
|
|
|
getAnyConn :: Bool -> DB.Connection -> ConnId -> IO (Either StoreError SomeConn)
|
|
getAnyConn deleted' dbConn connId =
|
|
getConnData dbConn connId >>= \case
|
|
Nothing -> pure $ Left SEConnNotFound
|
|
Just (cData@ConnData {deleted}, cMode)
|
|
| deleted /= deleted' -> pure $ Left SEConnNotFound
|
|
| otherwise -> do
|
|
rQ <- getRcvQueuesByConnId_ dbConn connId
|
|
sQ <- getSndQueuesByConnId_ dbConn connId
|
|
pure $ case (rQ, sQ, cMode) of
|
|
(Just rqs, Just sqs, CMInvitation) -> Right $ SomeConn SCDuplex (DuplexConnection cData rqs sqs)
|
|
(Just (rq :| _), Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection cData rq)
|
|
(Nothing, Just (sq :| _), CMInvitation) -> Right $ SomeConn SCSnd (SndConnection cData sq)
|
|
(Just (rq :| _), Nothing, CMContact) -> Right $ SomeConn SCContact (ContactConnection cData rq)
|
|
(Nothing, Nothing, _) -> Right $ SomeConn SCNew (NewConnection cData)
|
|
_ -> Left SEConnNotFound
|
|
|
|
getConns :: DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn]
|
|
getConns = getAnyConns_ False
|
|
|
|
getDeletedConns :: DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn]
|
|
getDeletedConns = getAnyConns_ True
|
|
|
|
getAnyConns_ :: Bool -> DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn]
|
|
getAnyConns_ deleted' db connIds = forM connIds $ E.handle handleDBError . getAnyConn deleted' db
|
|
where
|
|
handleDBError :: E.SomeException -> IO (Either StoreError SomeConn)
|
|
handleDBError = pure . Left . SEInternal . bshow
|
|
|
|
getConnData :: DB.Connection -> ConnId -> IO (Maybe (ConnData, ConnectionMode))
|
|
getConnData db connId' =
|
|
maybeFirstRow cData $
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT
|
|
user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs,
|
|
last_external_snd_msg_id, deleted, ratchet_sync_state, pq_encryption
|
|
FROM connections
|
|
WHERE conn_id = ?
|
|
|]
|
|
(Only connId')
|
|
where
|
|
cData (userId, connId, cMode, connAgentVersion, enableNtfs_, lastExternalSndId, deleted, ratchetSyncState, pqEncryption) =
|
|
(ConnData {userId, connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, lastExternalSndId, deleted, ratchetSyncState, pqEncryption}, cMode)
|
|
|
|
setConnDeleted :: DB.Connection -> Bool -> ConnId -> IO ()
|
|
setConnDeleted db waitDelivery connId
|
|
| waitDelivery = do
|
|
currentTs <- getCurrentTime
|
|
DB.execute db "UPDATE connections SET deleted_at_wait_delivery = ? WHERE conn_id = ?" (currentTs, connId)
|
|
| otherwise =
|
|
DB.execute db "UPDATE connections SET deleted = ? WHERE conn_id = ?" (True, connId)
|
|
|
|
setConnAgentVersion :: DB.Connection -> ConnId -> VersionSMPA -> IO ()
|
|
setConnAgentVersion db connId aVersion =
|
|
DB.execute db "UPDATE connections SET smp_agent_version = ? WHERE conn_id = ?" (aVersion, connId)
|
|
|
|
getDeletedConnIds :: DB.Connection -> IO [ConnId]
|
|
getDeletedConnIds db = map fromOnly <$> DB.query db "SELECT conn_id FROM connections WHERE deleted = ?" (Only True)
|
|
|
|
getDeletedWaitingDeliveryConnIds :: DB.Connection -> IO [ConnId]
|
|
getDeletedWaitingDeliveryConnIds db =
|
|
map fromOnly <$> DB.query_ db "SELECT conn_id FROM connections WHERE deleted_at_wait_delivery IS NOT NULL"
|
|
|
|
setConnRatchetSync :: DB.Connection -> ConnId -> RatchetSyncState -> IO ()
|
|
setConnRatchetSync db connId ratchetSyncState =
|
|
DB.execute db "UPDATE connections SET ratchet_sync_state = ? WHERE conn_id = ?" (ratchetSyncState, connId)
|
|
|
|
addProcessedRatchetKeyHash :: DB.Connection -> ConnId -> ByteString -> IO ()
|
|
addProcessedRatchetKeyHash db connId hash =
|
|
DB.execute db "INSERT INTO processed_ratchet_key_hashes (conn_id, hash) VALUES (?,?)" (connId, hash)
|
|
|
|
checkRatchetKeyHashExists :: DB.Connection -> ConnId -> ByteString -> IO Bool
|
|
checkRatchetKeyHashExists db connId hash = do
|
|
fromMaybe False
|
|
<$> maybeFirstRow
|
|
fromOnly
|
|
( DB.query
|
|
db
|
|
"SELECT 1 FROM processed_ratchet_key_hashes WHERE conn_id = ? AND hash = ? LIMIT 1"
|
|
(connId, hash)
|
|
)
|
|
|
|
deleteRatchetKeyHashesExpired :: DB.Connection -> NominalDiffTime -> IO ()
|
|
deleteRatchetKeyHashesExpired db ttl = do
|
|
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
|
|
DB.execute db "DELETE FROM processed_ratchet_key_hashes WHERE created_at < ?" (Only cutoffTs)
|
|
|
|
-- | returns all connection queues, the first queue is the primary one
|
|
getRcvQueuesByConnId_ :: DB.Connection -> ConnId -> IO (Maybe (NonEmpty RcvQueue))
|
|
getRcvQueuesByConnId_ db connId =
|
|
L.nonEmpty . sortBy primaryFirst . map toRcvQueue
|
|
<$> DB.query db (rcvQueueQuery <> "WHERE q.conn_id = ? AND q.deleted = 0") (Only connId)
|
|
where
|
|
primaryFirst RcvQueue {primary = p, dbReplaceQueueId = i} RcvQueue {primary = p', dbReplaceQueueId = i'} =
|
|
-- the current primary queue is ordered first, the next primary - second
|
|
compare (Down p) (Down p') <> compare i i'
|
|
|
|
rcvQueueQuery :: Query
|
|
rcvQueueQuery =
|
|
[sql|
|
|
SELECT c.user_id, COALESCE(q.server_key_hash, s.key_hash), q.conn_id, q.host, q.port, q.rcv_id, q.rcv_private_key, q.rcv_dh_secret,
|
|
q.e2e_priv_key, q.e2e_dh_secret, q.snd_id, q.status,
|
|
q.rcv_queue_id, q.rcv_primary, q.replace_rcv_queue_id, q.switch_status, q.smp_client_version, q.delete_errors,
|
|
q.ntf_public_key, q.ntf_private_key, q.ntf_id, q.rcv_ntf_dh_secret
|
|
FROM rcv_queues q
|
|
JOIN servers s ON q.host = s.host AND q.port = s.port
|
|
JOIN connections c ON q.conn_id = c.conn_id
|
|
|]
|
|
|
|
toRcvQueue ::
|
|
(UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SMP.RecipientId, SMP.RcvPrivateAuthKey, SMP.RcvDhSecret, C.PrivateKeyX25519, Maybe C.DhSecretX25519, SMP.SenderId, QueueStatus)
|
|
:. (DBQueueId 'QSStored, Bool, Maybe Int64, Maybe RcvSwitchStatus, Maybe VersionSMPC, Int)
|
|
:. (Maybe SMP.NtfPublicAuthKey, Maybe SMP.NtfPrivateAuthKey, Maybe SMP.NotifierId, Maybe RcvNtfDhSecret) ->
|
|
RcvQueue
|
|
toRcvQueue ((userId, keyHash, connId, host, port, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status) :. (dbQueueId, primary, dbReplaceQueueId, rcvSwchStatus, smpClientVersion_, deleteErrors) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_)) =
|
|
let server = SMPServer host port keyHash
|
|
smpClientVersion = fromMaybe initialSMPClientVersion smpClientVersion_
|
|
clientNtfCreds = case (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) of
|
|
(Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret) -> Just $ ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret}
|
|
_ -> Nothing
|
|
in RcvQueue {userId, connId, server, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status, dbQueueId, primary, dbReplaceQueueId, rcvSwchStatus, smpClientVersion, clientNtfCreds, deleteErrors}
|
|
|
|
getRcvQueueById :: DB.Connection -> ConnId -> Int64 -> IO (Either StoreError RcvQueue)
|
|
getRcvQueueById db connId dbRcvId =
|
|
firstRow toRcvQueue SEConnNotFound $
|
|
DB.query db (rcvQueueQuery <> " WHERE q.conn_id = ? AND q.rcv_queue_id = ? AND q.deleted = 0") (connId, dbRcvId)
|
|
|
|
-- | returns all connection queues, the first queue is the primary one
|
|
getSndQueuesByConnId_ :: DB.Connection -> ConnId -> IO (Maybe (NonEmpty SndQueue))
|
|
getSndQueuesByConnId_ dbConn connId =
|
|
L.nonEmpty . sortBy primaryFirst . map toSndQueue
|
|
<$> DB.query dbConn (sndQueueQuery <> "WHERE q.conn_id = ?") (Only connId)
|
|
where
|
|
primaryFirst SndQueue {primary = p, dbReplaceQueueId = i} SndQueue {primary = p', dbReplaceQueueId = i'} =
|
|
-- the current primary queue is ordered first, the next primary - second
|
|
compare (Down p) (Down p') <> compare i i'
|
|
|
|
sndQueueQuery :: Query
|
|
sndQueueQuery =
|
|
[sql|
|
|
SELECT
|
|
c.user_id, COALESCE(q.server_key_hash, s.key_hash), q.conn_id, q.host, q.port, q.snd_id,
|
|
q.snd_public_key, q.snd_private_key, q.e2e_pub_key, q.e2e_dh_secret, q.status,
|
|
q.snd_queue_id, q.snd_primary, q.replace_snd_queue_id, q.switch_status, q.smp_client_version
|
|
FROM snd_queues q
|
|
JOIN servers s ON q.host = s.host AND q.port = s.port
|
|
JOIN connections c ON q.conn_id = c.conn_id
|
|
|]
|
|
|
|
toSndQueue ::
|
|
(UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SenderId)
|
|
:. (Maybe SndPublicAuthKey, SndPrivateAuthKey, Maybe C.PublicKeyX25519, C.DhSecretX25519, QueueStatus)
|
|
:. (DBQueueId 'QSStored, Bool, Maybe Int64, Maybe SndSwitchStatus, VersionSMPC) ->
|
|
SndQueue
|
|
toSndQueue
|
|
( (userId, keyHash, connId, host, port, sndId)
|
|
:. (sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status)
|
|
:. (dbQueueId, primary, dbReplaceQueueId, sndSwchStatus, smpClientVersion)
|
|
) =
|
|
let server = SMPServer host port keyHash
|
|
in SndQueue {userId, connId, server, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, dbQueueId, primary, dbReplaceQueueId, sndSwchStatus, smpClientVersion}
|
|
|
|
getSndQueueById :: DB.Connection -> ConnId -> Int64 -> IO (Either StoreError SndQueue)
|
|
getSndQueueById db connId dbSndId =
|
|
firstRow toSndQueue SEConnNotFound $
|
|
DB.query db (sndQueueQuery <> " WHERE q.conn_id = ? AND q.snd_queue_id = ?") (connId, dbSndId)
|
|
|
|
-- * updateRcvIds helpers
|
|
|
|
retrieveLastIdsAndHashRcv_ :: DB.Connection -> ConnId -> IO (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
|
|
retrieveLastIdsAndHashRcv_ dbConn connId = do
|
|
[(lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash)] <-
|
|
DB.queryNamed
|
|
dbConn
|
|
[sql|
|
|
SELECT last_internal_msg_id, last_internal_rcv_msg_id, last_external_snd_msg_id, last_rcv_msg_hash
|
|
FROM connections
|
|
WHERE conn_id = :conn_id;
|
|
|]
|
|
[":conn_id" := connId]
|
|
return (lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash)
|
|
|
|
updateLastIdsRcv_ :: DB.Connection -> ConnId -> InternalId -> InternalRcvId -> IO ()
|
|
updateLastIdsRcv_ dbConn connId newInternalId newInternalRcvId =
|
|
DB.executeNamed
|
|
dbConn
|
|
[sql|
|
|
UPDATE connections
|
|
SET last_internal_msg_id = :last_internal_msg_id,
|
|
last_internal_rcv_msg_id = :last_internal_rcv_msg_id
|
|
WHERE conn_id = :conn_id;
|
|
|]
|
|
[ ":last_internal_msg_id" := newInternalId,
|
|
":last_internal_rcv_msg_id" := newInternalRcvId,
|
|
":conn_id" := connId
|
|
]
|
|
|
|
-- * createRcvMsg helpers
|
|
|
|
insertRcvMsgBase_ :: DB.Connection -> ConnId -> RcvMsgData -> IO ()
|
|
insertRcvMsgBase_ dbConn connId RcvMsgData {msgMeta, msgType, msgFlags, msgBody, internalRcvId} = do
|
|
let MsgMeta {recipient = (internalId, internalTs), pqEncryption} = msgMeta
|
|
DB.execute
|
|
dbConn
|
|
[sql|
|
|
INSERT INTO messages
|
|
(conn_id, internal_id, internal_ts, internal_rcv_id, internal_snd_id, msg_type, msg_flags, msg_body, pq_encryption)
|
|
VALUES (?,?,?,?,?,?,?,?,?);
|
|
|]
|
|
(connId, internalId, internalTs, internalRcvId, Nothing :: Maybe Int64, msgType, msgFlags, msgBody, pqEncryption)
|
|
|
|
insertRcvMsgDetails_ :: DB.Connection -> ConnId -> RcvQueue -> RcvMsgData -> IO ()
|
|
insertRcvMsgDetails_ db connId RcvQueue {dbQueueId} RcvMsgData {msgMeta, internalRcvId, internalHash, externalPrevSndHash, encryptedMsgHash} = do
|
|
let MsgMeta {integrity, recipient, broker, sndMsgId} = msgMeta
|
|
DB.executeNamed
|
|
db
|
|
[sql|
|
|
INSERT INTO rcv_messages
|
|
( conn_id, rcv_queue_id, internal_rcv_id, internal_id, external_snd_id,
|
|
broker_id, broker_ts,
|
|
internal_hash, external_prev_snd_hash, integrity)
|
|
VALUES
|
|
(:conn_id,:rcv_queue_id,:internal_rcv_id,:internal_id,:external_snd_id,
|
|
:broker_id,:broker_ts,
|
|
:internal_hash,:external_prev_snd_hash,:integrity);
|
|
|]
|
|
[ ":conn_id" := connId,
|
|
":rcv_queue_id" := dbQueueId,
|
|
":internal_rcv_id" := internalRcvId,
|
|
":internal_id" := fst recipient,
|
|
":external_snd_id" := sndMsgId,
|
|
":broker_id" := fst broker,
|
|
":broker_ts" := snd broker,
|
|
":internal_hash" := internalHash,
|
|
":external_prev_snd_hash" := externalPrevSndHash,
|
|
":integrity" := integrity
|
|
]
|
|
DB.execute db "INSERT INTO encrypted_rcv_message_hashes (conn_id, hash) VALUES (?,?)" (connId, encryptedMsgHash)
|
|
|
|
updateHashRcv_ :: DB.Connection -> ConnId -> RcvMsgData -> IO ()
|
|
updateHashRcv_ dbConn connId RcvMsgData {msgMeta = MsgMeta {sndMsgId}, internalHash, internalRcvId} =
|
|
DB.executeNamed
|
|
dbConn
|
|
-- last_internal_rcv_msg_id equality check prevents race condition in case next id was reserved
|
|
[sql|
|
|
UPDATE connections
|
|
SET last_external_snd_msg_id = :last_external_snd_msg_id,
|
|
last_rcv_msg_hash = :last_rcv_msg_hash
|
|
WHERE conn_id = :conn_id
|
|
AND last_internal_rcv_msg_id = :last_internal_rcv_msg_id;
|
|
|]
|
|
[ ":last_external_snd_msg_id" := sndMsgId,
|
|
":last_rcv_msg_hash" := internalHash,
|
|
":conn_id" := connId,
|
|
":last_internal_rcv_msg_id" := internalRcvId
|
|
]
|
|
|
|
-- * updateSndIds helpers
|
|
|
|
retrieveLastIdsAndHashSnd_ :: DB.Connection -> ConnId -> IO (InternalId, InternalSndId, PrevSndMsgHash)
|
|
retrieveLastIdsAndHashSnd_ dbConn connId = do
|
|
[(lastInternalId, lastInternalSndId, lastSndHash)] <-
|
|
DB.queryNamed
|
|
dbConn
|
|
[sql|
|
|
SELECT last_internal_msg_id, last_internal_snd_msg_id, last_snd_msg_hash
|
|
FROM connections
|
|
WHERE conn_id = :conn_id;
|
|
|]
|
|
[":conn_id" := connId]
|
|
return (lastInternalId, lastInternalSndId, lastSndHash)
|
|
|
|
updateLastIdsSnd_ :: DB.Connection -> ConnId -> InternalId -> InternalSndId -> IO ()
|
|
updateLastIdsSnd_ dbConn connId newInternalId newInternalSndId =
|
|
DB.executeNamed
|
|
dbConn
|
|
[sql|
|
|
UPDATE connections
|
|
SET last_internal_msg_id = :last_internal_msg_id,
|
|
last_internal_snd_msg_id = :last_internal_snd_msg_id
|
|
WHERE conn_id = :conn_id;
|
|
|]
|
|
[ ":last_internal_msg_id" := newInternalId,
|
|
":last_internal_snd_msg_id" := newInternalSndId,
|
|
":conn_id" := connId
|
|
]
|
|
|
|
-- * createSndMsg helpers
|
|
|
|
insertSndMsgBase_ :: DB.Connection -> ConnId -> SndMsgData -> IO ()
|
|
insertSndMsgBase_ db connId SndMsgData {internalId, internalTs, internalSndId, msgType, msgFlags, msgBody, pqEncryption} = do
|
|
DB.execute
|
|
db
|
|
[sql|
|
|
INSERT INTO messages
|
|
(conn_id, internal_id, internal_ts, internal_rcv_id, internal_snd_id, msg_type, msg_flags, msg_body, pq_encryption)
|
|
VALUES
|
|
(?,?,?,?,?,?,?,?,?);
|
|
|]
|
|
(connId, internalId, internalTs, Nothing :: Maybe Int64, internalSndId, msgType, msgFlags, msgBody, pqEncryption)
|
|
|
|
insertSndMsgDetails_ :: DB.Connection -> ConnId -> SndMsgData -> IO ()
|
|
insertSndMsgDetails_ dbConn connId SndMsgData {..} =
|
|
DB.executeNamed
|
|
dbConn
|
|
[sql|
|
|
INSERT INTO snd_messages
|
|
( conn_id, internal_snd_id, internal_id, internal_hash, previous_msg_hash)
|
|
VALUES
|
|
(:conn_id,:internal_snd_id,:internal_id,:internal_hash,:previous_msg_hash);
|
|
|]
|
|
[ ":conn_id" := connId,
|
|
":internal_snd_id" := internalSndId,
|
|
":internal_id" := internalId,
|
|
":internal_hash" := internalHash,
|
|
":previous_msg_hash" := prevMsgHash
|
|
]
|
|
|
|
updateHashSnd_ :: DB.Connection -> ConnId -> SndMsgData -> IO ()
|
|
updateHashSnd_ dbConn connId SndMsgData {..} =
|
|
DB.executeNamed
|
|
dbConn
|
|
-- last_internal_snd_msg_id equality check prevents race condition in case next id was reserved
|
|
[sql|
|
|
UPDATE connections
|
|
SET last_snd_msg_hash = :last_snd_msg_hash
|
|
WHERE conn_id = :conn_id
|
|
AND last_internal_snd_msg_id = :last_internal_snd_msg_id;
|
|
|]
|
|
[ ":last_snd_msg_hash" := internalHash,
|
|
":conn_id" := connId,
|
|
":last_internal_snd_msg_id" := internalSndId
|
|
]
|
|
|
|
-- create record with a random ID
|
|
createWithRandomId :: TVar ChaChaDRG -> (ByteString -> IO ()) -> IO (Either StoreError ByteString)
|
|
createWithRandomId gVar create = fst <$$> createWithRandomId' gVar create
|
|
|
|
createWithRandomId' :: forall a. TVar ChaChaDRG -> (ByteString -> IO a) -> IO (Either StoreError (ByteString, a))
|
|
createWithRandomId' gVar create = tryCreate 3
|
|
where
|
|
tryCreate :: Int -> IO (Either StoreError (ByteString, a))
|
|
tryCreate 0 = pure $ Left SEUniqueID
|
|
tryCreate n = do
|
|
id' <- randomId gVar 12
|
|
E.try (create id') >>= \case
|
|
Right r -> pure $ Right (id', r)
|
|
Left e
|
|
| SQL.sqlError e == SQL.ErrorConstraint -> tryCreate (n - 1)
|
|
| otherwise -> pure . Left . SEInternal $ bshow e
|
|
|
|
randomId :: TVar ChaChaDRG -> Int -> IO ByteString
|
|
randomId gVar n = atomically $ U.encode <$> C.randomBytes n gVar
|
|
|
|
ntfSubAndSMPAction :: NtfSubAction -> (Maybe NtfSubNTFAction, Maybe NtfSubSMPAction)
|
|
ntfSubAndSMPAction (NtfSubNTFAction action) = (Just action, Nothing)
|
|
ntfSubAndSMPAction (NtfSubSMPAction action) = (Nothing, Just action)
|
|
|
|
createXFTPServer_ :: DB.Connection -> XFTPServer -> IO Int64
|
|
createXFTPServer_ db newSrv@ProtocolServer {host, port, keyHash} =
|
|
getXFTPServerId_ db newSrv >>= \case
|
|
Right srvId -> pure srvId
|
|
Left _ -> insertNewServer_
|
|
where
|
|
insertNewServer_ = do
|
|
DB.execute db "INSERT INTO xftp_servers (xftp_host, xftp_port, xftp_key_hash) VALUES (?,?,?)" (host, port, keyHash)
|
|
insertedRowId db
|
|
|
|
getXFTPServerId_ :: DB.Connection -> XFTPServer -> IO (Either StoreError Int64)
|
|
getXFTPServerId_ db ProtocolServer {host, port, keyHash} = do
|
|
firstRow fromOnly SEXFTPServerNotFound $
|
|
DB.query db "SELECT xftp_server_id FROM xftp_servers WHERE xftp_host = ? AND xftp_port = ? AND xftp_key_hash = ?" (host, port, keyHash)
|
|
|
|
createRcvFile :: DB.Connection -> TVar ChaChaDRG -> UserId -> FileDescription 'FRecipient -> FilePath -> FilePath -> CryptoFile -> IO (Either StoreError RcvFileId)
|
|
createRcvFile db gVar userId fd@FileDescription {chunks} prefixPath tmpPath file = runExceptT $ do
|
|
(rcvFileEntityId, rcvFileId) <- ExceptT $ insertRcvFile db gVar userId fd prefixPath tmpPath file Nothing Nothing
|
|
liftIO $
|
|
forM_ chunks $ \fc@FileChunk {replicas} -> do
|
|
chunkId <- insertRcvFileChunk db fc rcvFileId
|
|
forM_ (zip [1 ..] replicas) $ \(rno, replica) -> insertRcvFileChunkReplica db rno replica chunkId
|
|
pure rcvFileEntityId
|
|
|
|
createRcvFileRedirect :: DB.Connection -> TVar ChaChaDRG -> UserId -> FileDescription FRecipient -> FilePath -> FilePath -> CryptoFile -> FilePath -> CryptoFile -> IO (Either StoreError RcvFileId)
|
|
createRcvFileRedirect _ _ _ FileDescription {redirect = Nothing} _ _ _ _ _ = pure $ Left $ SEInternal "createRcvFileRedirect called without redirect"
|
|
createRcvFileRedirect db gVar userId redirectFd@FileDescription {chunks = redirectChunks, redirect = Just RedirectFileInfo {size, digest}} prefixPath redirectPath redirectFile dstPath dstFile = runExceptT $ do
|
|
(dstEntityId, dstId) <- ExceptT $ insertRcvFile db gVar userId dummyDst prefixPath dstPath dstFile Nothing Nothing
|
|
(_, redirectId) <- ExceptT $ insertRcvFile db gVar userId redirectFd prefixPath redirectPath redirectFile (Just dstId) (Just dstEntityId)
|
|
liftIO $
|
|
forM_ redirectChunks $ \fc@FileChunk {replicas} -> do
|
|
chunkId <- insertRcvFileChunk db fc redirectId
|
|
forM_ (zip [1 ..] replicas) $ \(rno, replica) -> insertRcvFileChunkReplica db rno replica chunkId
|
|
pure dstEntityId
|
|
where
|
|
dummyDst =
|
|
FileDescription
|
|
{ party = SFRecipient,
|
|
size,
|
|
digest,
|
|
redirect = Nothing,
|
|
-- updated later with updateRcvFileRedirect
|
|
key = C.unsafeSbKey $ B.replicate 32 '#',
|
|
nonce = C.cbNonce "",
|
|
chunkSize = FileSize 0,
|
|
chunks = []
|
|
}
|
|
|
|
insertRcvFile :: DB.Connection -> TVar ChaChaDRG -> UserId -> FileDescription 'FRecipient -> FilePath -> FilePath -> CryptoFile -> Maybe DBRcvFileId -> Maybe RcvFileId -> IO (Either StoreError (RcvFileId, DBRcvFileId))
|
|
insertRcvFile db gVar userId FileDescription {size, digest, key, nonce, chunkSize, redirect} prefixPath tmpPath (CryptoFile savePath cfArgs) redirectId_ redirectEntityId_ = runExceptT $ do
|
|
let (redirectDigest_, redirectSize_) = case redirect of
|
|
Just RedirectFileInfo {digest = d, size = s} -> (Just d, Just s)
|
|
Nothing -> (Nothing, Nothing)
|
|
rcvFileEntityId <- ExceptT $
|
|
createWithRandomId gVar $ \rcvFileEntityId ->
|
|
DB.execute
|
|
db
|
|
"INSERT INTO rcv_files (rcv_file_entity_id, user_id, size, digest, key, nonce, chunk_size, prefix_path, tmp_path, save_path, save_file_key, save_file_nonce, status, redirect_id, redirect_entity_id, redirect_digest, redirect_size) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)"
|
|
((rcvFileEntityId, userId, size, digest, key, nonce, chunkSize, prefixPath, tmpPath) :. (savePath, fileKey <$> cfArgs, fileNonce <$> cfArgs, RFSReceiving, redirectId_, redirectEntityId_, redirectDigest_, redirectSize_))
|
|
rcvFileId <- liftIO $ insertedRowId db
|
|
pure (rcvFileEntityId, rcvFileId)
|
|
|
|
insertRcvFileChunk :: DB.Connection -> FileChunk -> DBRcvFileId -> IO Int64
|
|
insertRcvFileChunk db FileChunk {chunkNo, chunkSize, digest} rcvFileId = do
|
|
DB.execute
|
|
db
|
|
"INSERT INTO rcv_file_chunks (rcv_file_id, chunk_no, chunk_size, digest) VALUES (?,?,?,?)"
|
|
(rcvFileId, chunkNo, chunkSize, digest)
|
|
insertedRowId db
|
|
|
|
insertRcvFileChunkReplica :: DB.Connection -> Int -> FileChunkReplica -> Int64 -> IO ()
|
|
insertRcvFileChunkReplica db replicaNo FileChunkReplica {server, replicaId, replicaKey} chunkId = do
|
|
srvId <- createXFTPServer_ db server
|
|
DB.execute
|
|
db
|
|
"INSERT INTO rcv_file_chunk_replicas (replica_number, rcv_file_chunk_id, xftp_server_id, replica_id, replica_key) VALUES (?,?,?,?,?)"
|
|
(replicaNo, chunkId, srvId, replicaId, replicaKey)
|
|
|
|
getRcvFileByEntityId :: DB.Connection -> RcvFileId -> IO (Either StoreError RcvFile)
|
|
getRcvFileByEntityId db rcvFileEntityId = runExceptT $ do
|
|
rcvFileId <- ExceptT $ getRcvFileIdByEntityId_ db rcvFileEntityId
|
|
ExceptT $ getRcvFile db rcvFileId
|
|
|
|
getRcvFileIdByEntityId_ :: DB.Connection -> RcvFileId -> IO (Either StoreError DBRcvFileId)
|
|
getRcvFileIdByEntityId_ db rcvFileEntityId =
|
|
firstRow fromOnly SEFileNotFound $
|
|
DB.query db "SELECT rcv_file_id FROM rcv_files WHERE rcv_file_entity_id = ?" (Only rcvFileEntityId)
|
|
|
|
getRcvFileRedirects :: DB.Connection -> DBRcvFileId -> IO [RcvFile]
|
|
getRcvFileRedirects db rcvFileId = do
|
|
redirects <- fromOnly <$$> DB.query db "SELECT rcv_file_id FROM rcv_files WHERE redirect_id = ?" (Only rcvFileId)
|
|
fmap catMaybes . forM redirects $ getRcvFile db >=> either (const $ pure Nothing) (pure . Just)
|
|
|
|
getRcvFile :: DB.Connection -> DBRcvFileId -> IO (Either StoreError RcvFile)
|
|
getRcvFile db rcvFileId = runExceptT $ do
|
|
f@RcvFile {rcvFileEntityId, userId, tmpPath} <- ExceptT getFile
|
|
chunks <- maybe (pure []) (liftIO . getChunks rcvFileEntityId userId) tmpPath
|
|
pure (f {chunks} :: RcvFile)
|
|
where
|
|
getFile :: IO (Either StoreError RcvFile)
|
|
getFile = do
|
|
firstRow toFile SEFileNotFound $
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT rcv_file_entity_id, user_id, size, digest, key, nonce, chunk_size, prefix_path, tmp_path, save_path, save_file_key, save_file_nonce, status, deleted, redirect_id, redirect_entity_id, redirect_size, redirect_digest
|
|
FROM rcv_files
|
|
WHERE rcv_file_id = ?
|
|
|]
|
|
(Only rcvFileId)
|
|
where
|
|
toFile :: (RcvFileId, UserId, FileSize Int64, FileDigest, C.SbKey, C.CbNonce, FileSize Word32, FilePath, Maybe FilePath) :. (FilePath, Maybe C.SbKey, Maybe C.CbNonce, RcvFileStatus, Bool, Maybe DBRcvFileId, Maybe RcvFileId, Maybe (FileSize Int64), Maybe FileDigest) -> RcvFile
|
|
toFile ((rcvFileEntityId, userId, size, digest, key, nonce, chunkSize, prefixPath, tmpPath) :. (savePath, saveKey_, saveNonce_, status, deleted, redirectDbId, redirectEntityId, redirectSize_, redirectDigest_)) =
|
|
let cfArgs = CFArgs <$> saveKey_ <*> saveNonce_
|
|
saveFile = CryptoFile savePath cfArgs
|
|
redirect =
|
|
RcvFileRedirect
|
|
<$> redirectDbId
|
|
<*> redirectEntityId
|
|
<*> (RedirectFileInfo <$> redirectSize_ <*> redirectDigest_)
|
|
in RcvFile {rcvFileId, rcvFileEntityId, userId, size, digest, key, nonce, chunkSize, redirect, prefixPath, tmpPath, saveFile, status, deleted, chunks = []}
|
|
getChunks :: RcvFileId -> UserId -> FilePath -> IO [RcvFileChunk]
|
|
getChunks rcvFileEntityId userId fileTmpPath = do
|
|
chunks <-
|
|
map toChunk
|
|
<$> DB.query
|
|
db
|
|
[sql|
|
|
SELECT rcv_file_chunk_id, chunk_no, chunk_size, digest, tmp_path
|
|
FROM rcv_file_chunks
|
|
WHERE rcv_file_id = ?
|
|
|]
|
|
(Only rcvFileId)
|
|
forM chunks $ \chunk@RcvFileChunk {rcvChunkId} -> do
|
|
replicas' <- getChunkReplicas rcvChunkId
|
|
pure (chunk {replicas = replicas'} :: RcvFileChunk)
|
|
where
|
|
toChunk :: (Int64, Int, FileSize Word32, FileDigest, Maybe FilePath) -> RcvFileChunk
|
|
toChunk (rcvChunkId, chunkNo, chunkSize, digest, chunkTmpPath) =
|
|
RcvFileChunk {rcvFileId, rcvFileEntityId, userId, rcvChunkId, chunkNo, chunkSize, digest, fileTmpPath, chunkTmpPath, replicas = []}
|
|
getChunkReplicas :: Int64 -> IO [RcvFileChunkReplica]
|
|
getChunkReplicas chunkId = do
|
|
map toReplica
|
|
<$> DB.query
|
|
db
|
|
[sql|
|
|
SELECT
|
|
r.rcv_file_chunk_replica_id, r.replica_id, r.replica_key, r.received, r.delay, r.retries,
|
|
s.xftp_host, s.xftp_port, s.xftp_key_hash
|
|
FROM rcv_file_chunk_replicas r
|
|
JOIN xftp_servers s ON s.xftp_server_id = r.xftp_server_id
|
|
WHERE r.rcv_file_chunk_id = ?
|
|
|]
|
|
(Only chunkId)
|
|
where
|
|
toReplica :: (Int64, ChunkReplicaId, C.APrivateAuthKey, Bool, Maybe Int64, Int, NonEmpty TransportHost, ServiceName, C.KeyHash) -> RcvFileChunkReplica
|
|
toReplica (rcvChunkReplicaId, replicaId, replicaKey, received, delay, retries, host, port, keyHash) =
|
|
let server = XFTPServer host port keyHash
|
|
in RcvFileChunkReplica {rcvChunkReplicaId, server, replicaId, replicaKey, received, delay, retries}
|
|
|
|
updateRcvChunkReplicaDelay :: DB.Connection -> Int64 -> Int64 -> IO ()
|
|
updateRcvChunkReplicaDelay db replicaId delay = do
|
|
updatedAt <- getCurrentTime
|
|
DB.execute db "UPDATE rcv_file_chunk_replicas SET delay = ?, retries = retries + 1, updated_at = ? WHERE rcv_file_chunk_replica_id = ?" (delay, updatedAt, replicaId)
|
|
|
|
updateRcvFileChunkReceived :: DB.Connection -> Int64 -> Int64 -> FilePath -> IO ()
|
|
updateRcvFileChunkReceived db replicaId chunkId chunkTmpPath = do
|
|
updatedAt <- getCurrentTime
|
|
DB.execute db "UPDATE rcv_file_chunk_replicas SET received = 1, updated_at = ? WHERE rcv_file_chunk_replica_id = ?" (updatedAt, replicaId)
|
|
DB.execute db "UPDATE rcv_file_chunks SET tmp_path = ?, updated_at = ? WHERE rcv_file_chunk_id = ?" (chunkTmpPath, updatedAt, chunkId)
|
|
|
|
updateRcvFileStatus :: DB.Connection -> DBRcvFileId -> RcvFileStatus -> IO ()
|
|
updateRcvFileStatus db rcvFileId status = do
|
|
updatedAt <- getCurrentTime
|
|
DB.execute db "UPDATE rcv_files SET status = ?, updated_at = ? WHERE rcv_file_id = ?" (status, updatedAt, rcvFileId)
|
|
|
|
updateRcvFileError :: DB.Connection -> DBRcvFileId -> String -> IO ()
|
|
updateRcvFileError db rcvFileId errStr = do
|
|
updatedAt <- getCurrentTime
|
|
DB.execute db "UPDATE rcv_files SET tmp_path = NULL, error = ?, status = ?, updated_at = ? WHERE rcv_file_id = ?" (errStr, RFSError, updatedAt, rcvFileId)
|
|
|
|
updateRcvFileComplete :: DB.Connection -> DBRcvFileId -> IO ()
|
|
updateRcvFileComplete db rcvFileId = do
|
|
updatedAt <- getCurrentTime
|
|
DB.execute db "UPDATE rcv_files SET tmp_path = NULL, status = ?, updated_at = ? WHERE rcv_file_id = ?" (RFSComplete, updatedAt, rcvFileId)
|
|
|
|
updateRcvFileRedirect :: DB.Connection -> DBRcvFileId -> FileDescription 'FRecipient -> IO (Either StoreError ())
|
|
updateRcvFileRedirect db rcvFileId FileDescription {key, nonce, chunkSize, chunks} = runExceptT $ do
|
|
updatedAt <- liftIO getCurrentTime
|
|
liftIO $ DB.execute db "UPDATE rcv_files SET key = ?, nonce = ?, chunk_size = ?, updated_at = ? WHERE rcv_file_id = ?" (key, nonce, chunkSize, updatedAt, rcvFileId)
|
|
liftIO $ forM_ chunks $ \fc@FileChunk {replicas} -> do
|
|
chunkId <- insertRcvFileChunk db fc rcvFileId
|
|
forM_ (zip [1 ..] replicas) $ \(rno, replica) -> insertRcvFileChunkReplica db rno replica chunkId
|
|
|
|
updateRcvFileNoTmpPath :: DB.Connection -> DBRcvFileId -> IO ()
|
|
updateRcvFileNoTmpPath db rcvFileId = do
|
|
updatedAt <- getCurrentTime
|
|
DB.execute db "UPDATE rcv_files SET tmp_path = NULL, updated_at = ? WHERE rcv_file_id = ?" (updatedAt, rcvFileId)
|
|
|
|
updateRcvFileDeleted :: DB.Connection -> DBRcvFileId -> IO ()
|
|
updateRcvFileDeleted db rcvFileId = do
|
|
updatedAt <- getCurrentTime
|
|
DB.execute db "UPDATE rcv_files SET deleted = 1, updated_at = ? WHERE rcv_file_id = ?" (updatedAt, rcvFileId)
|
|
|
|
deleteRcvFile' :: DB.Connection -> DBRcvFileId -> IO ()
|
|
deleteRcvFile' db rcvFileId =
|
|
DB.execute db "DELETE FROM rcv_files WHERE rcv_file_id = ?" (Only rcvFileId)
|
|
|
|
getNextRcvChunkToDownload :: DB.Connection -> XFTPServer -> NominalDiffTime -> IO (Either StoreError (Maybe RcvFileChunk))
|
|
getNextRcvChunkToDownload db server@ProtocolServer {host, port, keyHash} ttl = do
|
|
getWorkItem "rcv_file_download" getReplicaId getChunkData (markRcvFileFailed db . snd)
|
|
where
|
|
getReplicaId :: IO (Maybe (Int64, DBRcvFileId))
|
|
getReplicaId = do
|
|
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
|
|
maybeFirstRow id $
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT r.rcv_file_chunk_replica_id, f.rcv_file_id
|
|
FROM rcv_file_chunk_replicas r
|
|
JOIN xftp_servers s ON s.xftp_server_id = r.xftp_server_id
|
|
JOIN rcv_file_chunks c ON c.rcv_file_chunk_id = r.rcv_file_chunk_id
|
|
JOIN rcv_files f ON f.rcv_file_id = c.rcv_file_id
|
|
WHERE s.xftp_host = ? AND s.xftp_port = ? AND s.xftp_key_hash = ?
|
|
AND r.received = 0 AND r.replica_number = 1
|
|
AND f.status = ? AND f.deleted = 0 AND f.created_at >= ?
|
|
AND f.failed = 0
|
|
ORDER BY r.retries ASC, r.created_at ASC
|
|
LIMIT 1
|
|
|]
|
|
(host, port, keyHash, RFSReceiving, cutoffTs)
|
|
getChunkData :: (Int64, DBRcvFileId) -> IO (Either StoreError RcvFileChunk)
|
|
getChunkData (rcvFileChunkReplicaId, _fileId) =
|
|
firstRow toChunk SEFileNotFound $
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT
|
|
f.rcv_file_id, f.rcv_file_entity_id, f.user_id, c.rcv_file_chunk_id, c.chunk_no, c.chunk_size, c.digest, f.tmp_path, c.tmp_path,
|
|
r.rcv_file_chunk_replica_id, r.replica_id, r.replica_key, r.received, r.delay, r.retries
|
|
FROM rcv_file_chunk_replicas r
|
|
JOIN xftp_servers s ON s.xftp_server_id = r.xftp_server_id
|
|
JOIN rcv_file_chunks c ON c.rcv_file_chunk_id = r.rcv_file_chunk_id
|
|
JOIN rcv_files f ON f.rcv_file_id = c.rcv_file_id
|
|
WHERE r.rcv_file_chunk_replica_id = ?
|
|
|]
|
|
(Only rcvFileChunkReplicaId)
|
|
where
|
|
toChunk :: ((DBRcvFileId, RcvFileId, UserId, Int64, Int, FileSize Word32, FileDigest, FilePath, Maybe FilePath) :. (Int64, ChunkReplicaId, C.APrivateAuthKey, Bool, Maybe Int64, Int)) -> RcvFileChunk
|
|
toChunk ((rcvFileId, rcvFileEntityId, userId, rcvChunkId, chunkNo, chunkSize, digest, fileTmpPath, chunkTmpPath) :. (rcvChunkReplicaId, replicaId, replicaKey, received, delay, retries)) =
|
|
RcvFileChunk
|
|
{ rcvFileId,
|
|
rcvFileEntityId,
|
|
userId,
|
|
rcvChunkId,
|
|
chunkNo,
|
|
chunkSize,
|
|
digest,
|
|
fileTmpPath,
|
|
chunkTmpPath,
|
|
replicas = [RcvFileChunkReplica {rcvChunkReplicaId, server, replicaId, replicaKey, received, delay, retries}]
|
|
}
|
|
|
|
getNextRcvFileToDecrypt :: DB.Connection -> NominalDiffTime -> IO (Either StoreError (Maybe RcvFile))
|
|
getNextRcvFileToDecrypt db ttl =
|
|
getWorkItem "rcv_file_decrypt" getFileId (getRcvFile db) (markRcvFileFailed db)
|
|
where
|
|
getFileId :: IO (Maybe DBRcvFileId)
|
|
getFileId = do
|
|
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
|
|
maybeFirstRow fromOnly $
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT rcv_file_id
|
|
FROM rcv_files
|
|
WHERE status IN (?,?) AND deleted = 0 AND created_at >= ?
|
|
AND failed = 0
|
|
ORDER BY created_at ASC LIMIT 1
|
|
|]
|
|
(RFSReceived, RFSDecrypting, cutoffTs)
|
|
|
|
markRcvFileFailed :: DB.Connection -> DBRcvFileId -> IO ()
|
|
markRcvFileFailed db fileId = do
|
|
DB.execute db "UPDATE rcv_files SET failed = 1 WHERE rcv_file_id = ?" (Only fileId)
|
|
|
|
getPendingRcvFilesServers :: DB.Connection -> NominalDiffTime -> IO [XFTPServer]
|
|
getPendingRcvFilesServers db ttl = do
|
|
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
|
|
map toXFTPServer
|
|
<$> DB.query
|
|
db
|
|
[sql|
|
|
SELECT DISTINCT
|
|
s.xftp_host, s.xftp_port, s.xftp_key_hash
|
|
FROM rcv_file_chunk_replicas r
|
|
JOIN xftp_servers s ON s.xftp_server_id = r.xftp_server_id
|
|
JOIN rcv_file_chunks c ON c.rcv_file_chunk_id = r.rcv_file_chunk_id
|
|
JOIN rcv_files f ON f.rcv_file_id = c.rcv_file_id
|
|
WHERE r.received = 0 AND r.replica_number = 1
|
|
AND f.status = ? AND f.deleted = 0 AND f.created_at >= ?
|
|
|]
|
|
(RFSReceiving, cutoffTs)
|
|
|
|
toXFTPServer :: (NonEmpty TransportHost, ServiceName, C.KeyHash) -> XFTPServer
|
|
toXFTPServer (host, port, keyHash) = XFTPServer host port keyHash
|
|
|
|
getCleanupRcvFilesTmpPaths :: DB.Connection -> IO [(DBRcvFileId, RcvFileId, FilePath)]
|
|
getCleanupRcvFilesTmpPaths db =
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT rcv_file_id, rcv_file_entity_id, tmp_path
|
|
FROM rcv_files
|
|
WHERE status IN (?,?) AND tmp_path IS NOT NULL
|
|
|]
|
|
(RFSComplete, RFSError)
|
|
|
|
getCleanupRcvFilesDeleted :: DB.Connection -> IO [(DBRcvFileId, RcvFileId, FilePath)]
|
|
getCleanupRcvFilesDeleted db =
|
|
DB.query_
|
|
db
|
|
[sql|
|
|
SELECT rcv_file_id, rcv_file_entity_id, prefix_path
|
|
FROM rcv_files
|
|
WHERE deleted = 1
|
|
|]
|
|
|
|
getRcvFilesExpired :: DB.Connection -> NominalDiffTime -> IO [(DBRcvFileId, RcvFileId, FilePath)]
|
|
getRcvFilesExpired db ttl = do
|
|
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT rcv_file_id, rcv_file_entity_id, prefix_path
|
|
FROM rcv_files
|
|
WHERE created_at < ?
|
|
|]
|
|
(Only cutoffTs)
|
|
|
|
createSndFile :: DB.Connection -> TVar ChaChaDRG -> UserId -> CryptoFile -> Int -> FilePath -> C.SbKey -> C.CbNonce -> Maybe RedirectFileInfo -> IO (Either StoreError SndFileId)
|
|
createSndFile db gVar userId (CryptoFile path cfArgs) numRecipients prefixPath key nonce redirect_ =
|
|
createWithRandomId gVar $ \sndFileEntityId ->
|
|
DB.execute
|
|
db
|
|
"INSERT INTO snd_files (snd_file_entity_id, user_id, path, src_file_key, src_file_nonce, num_recipients, prefix_path, key, nonce, status, redirect_size, redirect_digest) VALUES (?,?,?,?,?,?,?,?,?,?,?,?)"
|
|
((sndFileEntityId, userId, path, fileKey <$> cfArgs, fileNonce <$> cfArgs, numRecipients) :. (prefixPath, key, nonce, SFSNew, redirectSize_, redirectDigest_))
|
|
where
|
|
(redirectSize_, redirectDigest_) =
|
|
case redirect_ of
|
|
Nothing -> (Nothing, Nothing)
|
|
Just RedirectFileInfo {size, digest} -> (Just size, Just digest)
|
|
|
|
getSndFileByEntityId :: DB.Connection -> SndFileId -> IO (Either StoreError SndFile)
|
|
getSndFileByEntityId db sndFileEntityId = runExceptT $ do
|
|
sndFileId <- ExceptT $ getSndFileIdByEntityId_ db sndFileEntityId
|
|
ExceptT $ getSndFile db sndFileId
|
|
|
|
getSndFileIdByEntityId_ :: DB.Connection -> SndFileId -> IO (Either StoreError DBSndFileId)
|
|
getSndFileIdByEntityId_ db sndFileEntityId =
|
|
firstRow fromOnly SEFileNotFound $
|
|
DB.query db "SELECT snd_file_id FROM snd_files WHERE snd_file_entity_id = ?" (Only sndFileEntityId)
|
|
|
|
getSndFile :: DB.Connection -> DBSndFileId -> IO (Either StoreError SndFile)
|
|
getSndFile db sndFileId = runExceptT $ do
|
|
f@SndFile {sndFileEntityId, userId, numRecipients, prefixPath} <- ExceptT getFile
|
|
chunks <- maybe (pure []) (liftIO . getChunks sndFileEntityId userId numRecipients) prefixPath
|
|
pure (f {chunks} :: SndFile)
|
|
where
|
|
getFile :: IO (Either StoreError SndFile)
|
|
getFile = do
|
|
firstRow toFile SEFileNotFound $
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT snd_file_entity_id, user_id, path, src_file_key, src_file_nonce, num_recipients, digest, prefix_path, key, nonce, status, deleted, redirect_size, redirect_digest
|
|
FROM snd_files
|
|
WHERE snd_file_id = ?
|
|
|]
|
|
(Only sndFileId)
|
|
where
|
|
toFile :: (SndFileId, UserId, FilePath, Maybe C.SbKey, Maybe C.CbNonce, Int, Maybe FileDigest, Maybe FilePath, C.SbKey, C.CbNonce) :. (SndFileStatus, Bool, Maybe (FileSize Int64), Maybe FileDigest) -> SndFile
|
|
toFile ((sndFileEntityId, userId, srcPath, srcKey_, srcNonce_, numRecipients, digest, prefixPath, key, nonce) :. (status, deleted, redirectSize_, redirectDigest_)) =
|
|
let cfArgs = CFArgs <$> srcKey_ <*> srcNonce_
|
|
srcFile = CryptoFile srcPath cfArgs
|
|
redirect = RedirectFileInfo <$> redirectSize_ <*> redirectDigest_
|
|
in SndFile {sndFileId, sndFileEntityId, userId, srcFile, numRecipients, digest, prefixPath, key, nonce, status, deleted, redirect, chunks = []}
|
|
getChunks :: SndFileId -> UserId -> Int -> FilePath -> IO [SndFileChunk]
|
|
getChunks sndFileEntityId userId numRecipients filePrefixPath = do
|
|
chunks <-
|
|
map toChunk
|
|
<$> DB.query
|
|
db
|
|
[sql|
|
|
SELECT snd_file_chunk_id, chunk_no, chunk_offset, chunk_size, digest
|
|
FROM snd_file_chunks
|
|
WHERE snd_file_id = ?
|
|
|]
|
|
(Only sndFileId)
|
|
forM chunks $ \chunk@SndFileChunk {sndChunkId} -> do
|
|
replicas' <- getChunkReplicas sndChunkId
|
|
pure (chunk {replicas = replicas'} :: SndFileChunk)
|
|
where
|
|
toChunk :: (Int64, Int, Int64, Word32, FileDigest) -> SndFileChunk
|
|
toChunk (sndChunkId, chunkNo, chunkOffset, chunkSize, digest) =
|
|
let chunkSpec = XFTPChunkSpec {filePath = sndFileEncPath filePrefixPath, chunkOffset, chunkSize}
|
|
in SndFileChunk {sndFileId, sndFileEntityId, userId, numRecipients, sndChunkId, chunkNo, chunkSpec, filePrefixPath, digest, replicas = []}
|
|
getChunkReplicas :: Int64 -> IO [SndFileChunkReplica]
|
|
getChunkReplicas chunkId = do
|
|
replicas <-
|
|
map toReplica
|
|
<$> DB.query
|
|
db
|
|
[sql|
|
|
SELECT
|
|
r.snd_file_chunk_replica_id, r.replica_id, r.replica_key, r.replica_status, r.delay, r.retries,
|
|
s.xftp_host, s.xftp_port, s.xftp_key_hash
|
|
FROM snd_file_chunk_replicas r
|
|
JOIN xftp_servers s ON s.xftp_server_id = r.xftp_server_id
|
|
WHERE r.snd_file_chunk_id = ?
|
|
|]
|
|
(Only chunkId)
|
|
forM replicas $ \replica@SndFileChunkReplica {sndChunkReplicaId} -> do
|
|
rcvIdsKeys <- getChunkReplicaRecipients_ db sndChunkReplicaId
|
|
pure (replica :: SndFileChunkReplica) {rcvIdsKeys}
|
|
where
|
|
toReplica :: (Int64, ChunkReplicaId, C.APrivateAuthKey, SndFileReplicaStatus, Maybe Int64, Int, NonEmpty TransportHost, ServiceName, C.KeyHash) -> SndFileChunkReplica
|
|
toReplica (sndChunkReplicaId, replicaId, replicaKey, replicaStatus, delay, retries, host, port, keyHash) =
|
|
let server = XFTPServer host port keyHash
|
|
in SndFileChunkReplica {sndChunkReplicaId, server, replicaId, replicaKey, replicaStatus, delay, retries, rcvIdsKeys = []}
|
|
|
|
getChunkReplicaRecipients_ :: DB.Connection -> Int64 -> IO [(ChunkReplicaId, C.APrivateAuthKey)]
|
|
getChunkReplicaRecipients_ db replicaId =
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT rcv_replica_id, rcv_replica_key
|
|
FROM snd_file_chunk_replica_recipients
|
|
WHERE snd_file_chunk_replica_id = ?
|
|
|]
|
|
(Only replicaId)
|
|
|
|
getNextSndFileToPrepare :: DB.Connection -> NominalDiffTime -> IO (Either StoreError (Maybe SndFile))
|
|
getNextSndFileToPrepare db ttl =
|
|
getWorkItem "snd_file_prepare" getFileId (getSndFile db) (markSndFileFailed db)
|
|
where
|
|
getFileId :: IO (Maybe DBSndFileId)
|
|
getFileId = do
|
|
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
|
|
maybeFirstRow fromOnly $
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT snd_file_id
|
|
FROM snd_files
|
|
WHERE status IN (?,?,?) AND deleted = 0 AND created_at >= ?
|
|
AND failed = 0
|
|
ORDER BY created_at ASC LIMIT 1
|
|
|]
|
|
(SFSNew, SFSEncrypting, SFSEncrypted, cutoffTs)
|
|
|
|
markSndFileFailed :: DB.Connection -> DBSndFileId -> IO ()
|
|
markSndFileFailed db fileId =
|
|
DB.execute db "UPDATE snd_files SET failed = 1 WHERE snd_file_id = ?" (Only fileId)
|
|
|
|
updateSndFileError :: DB.Connection -> DBSndFileId -> String -> IO ()
|
|
updateSndFileError db sndFileId errStr = do
|
|
updatedAt <- getCurrentTime
|
|
DB.execute db "UPDATE snd_files SET prefix_path = NULL, error = ?, status = ?, updated_at = ? WHERE snd_file_id = ?" (errStr, SFSError, updatedAt, sndFileId)
|
|
|
|
updateSndFileStatus :: DB.Connection -> DBSndFileId -> SndFileStatus -> IO ()
|
|
updateSndFileStatus db sndFileId status = do
|
|
updatedAt <- getCurrentTime
|
|
DB.execute db "UPDATE snd_files SET status = ?, updated_at = ? WHERE snd_file_id = ?" (status, updatedAt, sndFileId)
|
|
|
|
updateSndFileEncrypted :: DB.Connection -> DBSndFileId -> FileDigest -> [(XFTPChunkSpec, FileDigest)] -> IO ()
|
|
updateSndFileEncrypted db sndFileId digest chunkSpecsDigests = do
|
|
updatedAt <- getCurrentTime
|
|
DB.execute db "UPDATE snd_files SET status = ?, digest = ?, updated_at = ? WHERE snd_file_id = ?" (SFSEncrypted, digest, updatedAt, sndFileId)
|
|
forM_ (zip [1 ..] chunkSpecsDigests) $ \(chunkNo :: Int, (XFTPChunkSpec {chunkOffset, chunkSize}, chunkDigest)) ->
|
|
DB.execute db "INSERT INTO snd_file_chunks (snd_file_id, chunk_no, chunk_offset, chunk_size, digest) VALUES (?,?,?,?,?)" (sndFileId, chunkNo, chunkOffset, chunkSize, chunkDigest)
|
|
|
|
updateSndFileComplete :: DB.Connection -> DBSndFileId -> IO ()
|
|
updateSndFileComplete db sndFileId = do
|
|
updatedAt <- getCurrentTime
|
|
DB.execute db "UPDATE snd_files SET prefix_path = NULL, status = ?, updated_at = ? WHERE snd_file_id = ?" (SFSComplete, updatedAt, sndFileId)
|
|
|
|
updateSndFileNoPrefixPath :: DB.Connection -> DBSndFileId -> IO ()
|
|
updateSndFileNoPrefixPath db sndFileId = do
|
|
updatedAt <- getCurrentTime
|
|
DB.execute db "UPDATE snd_files SET prefix_path = NULL, updated_at = ? WHERE snd_file_id = ?" (updatedAt, sndFileId)
|
|
|
|
updateSndFileDeleted :: DB.Connection -> DBSndFileId -> IO ()
|
|
updateSndFileDeleted db sndFileId = do
|
|
updatedAt <- getCurrentTime
|
|
DB.execute db "UPDATE snd_files SET deleted = 1, updated_at = ? WHERE snd_file_id = ?" (updatedAt, sndFileId)
|
|
|
|
deleteSndFile' :: DB.Connection -> DBSndFileId -> IO ()
|
|
deleteSndFile' db sndFileId =
|
|
DB.execute db "DELETE FROM snd_files WHERE snd_file_id = ?" (Only sndFileId)
|
|
|
|
getSndFileDeleted :: DB.Connection -> DBSndFileId -> IO Bool
|
|
getSndFileDeleted db sndFileId =
|
|
fromMaybe True
|
|
<$> maybeFirstRow fromOnly (DB.query db "SELECT deleted FROM snd_files WHERE snd_file_id = ?" (Only sndFileId))
|
|
|
|
createSndFileReplica :: DB.Connection -> SndFileChunk -> NewSndChunkReplica -> IO ()
|
|
createSndFileReplica db SndFileChunk {sndChunkId} = createSndFileReplica_ db sndChunkId
|
|
|
|
createSndFileReplica_ :: DB.Connection -> Int64 -> NewSndChunkReplica -> IO ()
|
|
createSndFileReplica_ db sndChunkId NewSndChunkReplica {server, replicaId, replicaKey, rcvIdsKeys} = do
|
|
srvId <- createXFTPServer_ db server
|
|
DB.execute
|
|
db
|
|
[sql|
|
|
INSERT INTO snd_file_chunk_replicas
|
|
(snd_file_chunk_id, replica_number, xftp_server_id, replica_id, replica_key, replica_status)
|
|
VALUES (?,?,?,?,?,?)
|
|
|]
|
|
(sndChunkId, 1 :: Int, srvId, replicaId, replicaKey, SFRSCreated)
|
|
rId <- insertedRowId db
|
|
forM_ rcvIdsKeys $ \(rcvId, rcvKey) -> do
|
|
DB.execute
|
|
db
|
|
[sql|
|
|
INSERT INTO snd_file_chunk_replica_recipients
|
|
(snd_file_chunk_replica_id, rcv_replica_id, rcv_replica_key)
|
|
VALUES (?,?,?)
|
|
|]
|
|
(rId, rcvId, rcvKey)
|
|
|
|
getNextSndChunkToUpload :: DB.Connection -> XFTPServer -> NominalDiffTime -> IO (Either StoreError (Maybe SndFileChunk))
|
|
getNextSndChunkToUpload db server@ProtocolServer {host, port, keyHash} ttl = do
|
|
getWorkItem "snd_file_upload" getReplicaId getChunkData (markSndFileFailed db . snd)
|
|
where
|
|
getReplicaId :: IO (Maybe (Int64, DBSndFileId))
|
|
getReplicaId = do
|
|
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
|
|
maybeFirstRow id $
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT r.snd_file_chunk_replica_id, f.snd_file_id
|
|
FROM snd_file_chunk_replicas r
|
|
JOIN xftp_servers s ON s.xftp_server_id = r.xftp_server_id
|
|
JOIN snd_file_chunks c ON c.snd_file_chunk_id = r.snd_file_chunk_id
|
|
JOIN snd_files f ON f.snd_file_id = c.snd_file_id
|
|
WHERE s.xftp_host = ? AND s.xftp_port = ? AND s.xftp_key_hash = ?
|
|
AND r.replica_status = ? AND r.replica_number = 1
|
|
AND (f.status = ? OR f.status = ?) AND f.deleted = 0 AND f.created_at >= ?
|
|
AND f.failed = 0
|
|
ORDER BY r.retries ASC, r.created_at ASC
|
|
LIMIT 1
|
|
|]
|
|
(host, port, keyHash, SFRSCreated, SFSEncrypted, SFSUploading, cutoffTs)
|
|
getChunkData :: (Int64, DBSndFileId) -> IO (Either StoreError SndFileChunk)
|
|
getChunkData (sndFileChunkReplicaId, _fileId) = do
|
|
chunk_ <-
|
|
firstRow toChunk SEFileNotFound $
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT
|
|
f.snd_file_id, f.snd_file_entity_id, f.user_id, f.num_recipients, f.prefix_path,
|
|
c.snd_file_chunk_id, c.chunk_no, c.chunk_offset, c.chunk_size, c.digest,
|
|
r.snd_file_chunk_replica_id, r.replica_id, r.replica_key, r.replica_status, r.delay, r.retries
|
|
FROM snd_file_chunk_replicas r
|
|
JOIN xftp_servers s ON s.xftp_server_id = r.xftp_server_id
|
|
JOIN snd_file_chunks c ON c.snd_file_chunk_id = r.snd_file_chunk_id
|
|
JOIN snd_files f ON f.snd_file_id = c.snd_file_id
|
|
WHERE r.snd_file_chunk_replica_id = ?
|
|
|]
|
|
(Only sndFileChunkReplicaId)
|
|
forM chunk_ $ \chunk@SndFileChunk {replicas} -> do
|
|
replicas' <- forM replicas $ \replica@SndFileChunkReplica {sndChunkReplicaId} -> do
|
|
rcvIdsKeys <- getChunkReplicaRecipients_ db sndChunkReplicaId
|
|
pure (replica :: SndFileChunkReplica) {rcvIdsKeys}
|
|
pure (chunk {replicas = replicas'} :: SndFileChunk)
|
|
where
|
|
toChunk :: ((DBSndFileId, SndFileId, UserId, Int, FilePath) :. (Int64, Int, Int64, Word32, FileDigest) :. (Int64, ChunkReplicaId, C.APrivateAuthKey, SndFileReplicaStatus, Maybe Int64, Int)) -> SndFileChunk
|
|
toChunk ((sndFileId, sndFileEntityId, userId, numRecipients, filePrefixPath) :. (sndChunkId, chunkNo, chunkOffset, chunkSize, digest) :. (sndChunkReplicaId, replicaId, replicaKey, replicaStatus, delay, retries)) =
|
|
let chunkSpec = XFTPChunkSpec {filePath = sndFileEncPath filePrefixPath, chunkOffset, chunkSize}
|
|
in SndFileChunk
|
|
{ sndFileId,
|
|
sndFileEntityId,
|
|
userId,
|
|
numRecipients,
|
|
sndChunkId,
|
|
chunkNo,
|
|
chunkSpec,
|
|
digest,
|
|
filePrefixPath,
|
|
replicas = [SndFileChunkReplica {sndChunkReplicaId, server, replicaId, replicaKey, replicaStatus, delay, retries, rcvIdsKeys = []}]
|
|
}
|
|
|
|
updateSndChunkReplicaDelay :: DB.Connection -> Int64 -> Int64 -> IO ()
|
|
updateSndChunkReplicaDelay db replicaId delay = do
|
|
updatedAt <- getCurrentTime
|
|
DB.execute db "UPDATE snd_file_chunk_replicas SET delay = ?, retries = retries + 1, updated_at = ? WHERE snd_file_chunk_replica_id = ?" (delay, updatedAt, replicaId)
|
|
|
|
addSndChunkReplicaRecipients :: DB.Connection -> SndFileChunkReplica -> [(ChunkReplicaId, C.APrivateAuthKey)] -> IO SndFileChunkReplica
|
|
addSndChunkReplicaRecipients db r@SndFileChunkReplica {sndChunkReplicaId} rcvIdsKeys = do
|
|
forM_ rcvIdsKeys $ \(rcvId, rcvKey) -> do
|
|
DB.execute
|
|
db
|
|
[sql|
|
|
INSERT INTO snd_file_chunk_replica_recipients
|
|
(snd_file_chunk_replica_id, rcv_replica_id, rcv_replica_key)
|
|
VALUES (?,?,?)
|
|
|]
|
|
(sndChunkReplicaId, rcvId, rcvKey)
|
|
rcvIdsKeys' <- getChunkReplicaRecipients_ db sndChunkReplicaId
|
|
pure (r :: SndFileChunkReplica) {rcvIdsKeys = rcvIdsKeys'}
|
|
|
|
updateSndChunkReplicaStatus :: DB.Connection -> Int64 -> SndFileReplicaStatus -> IO ()
|
|
updateSndChunkReplicaStatus db replicaId status = do
|
|
updatedAt <- getCurrentTime
|
|
DB.execute db "UPDATE snd_file_chunk_replicas SET replica_status = ?, updated_at = ? WHERE snd_file_chunk_replica_id = ?" (status, updatedAt, replicaId)
|
|
|
|
getPendingSndFilesServers :: DB.Connection -> NominalDiffTime -> IO [XFTPServer]
|
|
getPendingSndFilesServers db ttl = do
|
|
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
|
|
map toXFTPServer
|
|
<$> DB.query
|
|
db
|
|
[sql|
|
|
SELECT DISTINCT
|
|
s.xftp_host, s.xftp_port, s.xftp_key_hash
|
|
FROM snd_file_chunk_replicas r
|
|
JOIN xftp_servers s ON s.xftp_server_id = r.xftp_server_id
|
|
JOIN snd_file_chunks c ON c.snd_file_chunk_id = r.snd_file_chunk_id
|
|
JOIN snd_files f ON f.snd_file_id = c.snd_file_id
|
|
WHERE r.replica_status = ? AND r.replica_number = 1
|
|
AND (f.status = ? OR f.status = ?) AND f.deleted = 0 AND f.created_at >= ?
|
|
|]
|
|
(SFRSCreated, SFSEncrypted, SFSUploading, cutoffTs)
|
|
|
|
getCleanupSndFilesPrefixPaths :: DB.Connection -> IO [(DBSndFileId, SndFileId, FilePath)]
|
|
getCleanupSndFilesPrefixPaths db =
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT snd_file_id, snd_file_entity_id, prefix_path
|
|
FROM snd_files
|
|
WHERE status IN (?,?) AND prefix_path IS NOT NULL
|
|
|]
|
|
(SFSComplete, SFSError)
|
|
|
|
getCleanupSndFilesDeleted :: DB.Connection -> IO [(DBSndFileId, SndFileId, Maybe FilePath)]
|
|
getCleanupSndFilesDeleted db =
|
|
DB.query_
|
|
db
|
|
[sql|
|
|
SELECT snd_file_id, snd_file_entity_id, prefix_path
|
|
FROM snd_files
|
|
WHERE deleted = 1
|
|
|]
|
|
|
|
getSndFilesExpired :: DB.Connection -> NominalDiffTime -> IO [(DBSndFileId, SndFileId, Maybe FilePath)]
|
|
getSndFilesExpired db ttl = do
|
|
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT snd_file_id, snd_file_entity_id, prefix_path
|
|
FROM snd_files
|
|
WHERE created_at < ?
|
|
|]
|
|
(Only cutoffTs)
|
|
|
|
createDeletedSndChunkReplica :: DB.Connection -> UserId -> FileChunkReplica -> FileDigest -> IO ()
|
|
createDeletedSndChunkReplica db userId FileChunkReplica {server, replicaId, replicaKey} chunkDigest = do
|
|
srvId <- createXFTPServer_ db server
|
|
DB.execute
|
|
db
|
|
"INSERT INTO deleted_snd_chunk_replicas (user_id, xftp_server_id, replica_id, replica_key, chunk_digest) VALUES (?,?,?,?,?)"
|
|
(userId, srvId, replicaId, replicaKey, chunkDigest)
|
|
|
|
getDeletedSndChunkReplica :: DB.Connection -> DBSndFileId -> IO (Either StoreError DeletedSndChunkReplica)
|
|
getDeletedSndChunkReplica db deletedSndChunkReplicaId =
|
|
firstRow toReplica SEDeletedSndChunkReplicaNotFound $
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT
|
|
r.user_id, r.replica_id, r.replica_key, r.chunk_digest, r.delay, r.retries,
|
|
s.xftp_host, s.xftp_port, s.xftp_key_hash
|
|
FROM deleted_snd_chunk_replicas r
|
|
JOIN xftp_servers s ON s.xftp_server_id = r.xftp_server_id
|
|
WHERE r.deleted_snd_chunk_replica_id = ?
|
|
|]
|
|
(Only deletedSndChunkReplicaId)
|
|
where
|
|
toReplica :: (UserId, ChunkReplicaId, C.APrivateAuthKey, FileDigest, Maybe Int64, Int, NonEmpty TransportHost, ServiceName, C.KeyHash) -> DeletedSndChunkReplica
|
|
toReplica (userId, replicaId, replicaKey, chunkDigest, delay, retries, host, port, keyHash) =
|
|
let server = XFTPServer host port keyHash
|
|
in DeletedSndChunkReplica {deletedSndChunkReplicaId, userId, server, replicaId, replicaKey, chunkDigest, delay, retries}
|
|
|
|
getNextDeletedSndChunkReplica :: DB.Connection -> XFTPServer -> NominalDiffTime -> IO (Either StoreError (Maybe DeletedSndChunkReplica))
|
|
getNextDeletedSndChunkReplica db ProtocolServer {host, port, keyHash} ttl =
|
|
getWorkItem "deleted replica" getReplicaId (getDeletedSndChunkReplica db) markReplicaFailed
|
|
where
|
|
getReplicaId :: IO (Maybe Int64)
|
|
getReplicaId = do
|
|
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
|
|
maybeFirstRow fromOnly $
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT r.deleted_snd_chunk_replica_id
|
|
FROM deleted_snd_chunk_replicas r
|
|
JOIN xftp_servers s ON s.xftp_server_id = r.xftp_server_id
|
|
WHERE s.xftp_host = ? AND s.xftp_port = ? AND s.xftp_key_hash = ?
|
|
AND r.created_at >= ?
|
|
AND failed = 0
|
|
ORDER BY r.retries ASC, r.created_at ASC
|
|
LIMIT 1
|
|
|]
|
|
(host, port, keyHash, cutoffTs)
|
|
markReplicaFailed :: Int64 -> IO ()
|
|
markReplicaFailed replicaId = do
|
|
DB.execute db "UPDATE deleted_snd_chunk_replicas SET failed = 1 WHERE deleted_snd_chunk_replica_id = ?" (Only replicaId)
|
|
|
|
updateDeletedSndChunkReplicaDelay :: DB.Connection -> Int64 -> Int64 -> IO ()
|
|
updateDeletedSndChunkReplicaDelay db deletedSndChunkReplicaId delay = do
|
|
updatedAt <- getCurrentTime
|
|
DB.execute db "UPDATE deleted_snd_chunk_replicas SET delay = ?, retries = retries + 1, updated_at = ? WHERE deleted_snd_chunk_replica_id = ?" (delay, updatedAt, deletedSndChunkReplicaId)
|
|
|
|
deleteDeletedSndChunkReplica :: DB.Connection -> Int64 -> IO ()
|
|
deleteDeletedSndChunkReplica db deletedSndChunkReplicaId =
|
|
DB.execute db "DELETE FROM deleted_snd_chunk_replicas WHERE deleted_snd_chunk_replica_id = ?" (Only deletedSndChunkReplicaId)
|
|
|
|
getPendingDelFilesServers :: DB.Connection -> NominalDiffTime -> IO [XFTPServer]
|
|
getPendingDelFilesServers db ttl = do
|
|
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
|
|
map toXFTPServer
|
|
<$> DB.query
|
|
db
|
|
[sql|
|
|
SELECT DISTINCT
|
|
s.xftp_host, s.xftp_port, s.xftp_key_hash
|
|
FROM deleted_snd_chunk_replicas r
|
|
JOIN xftp_servers s ON s.xftp_server_id = r.xftp_server_id
|
|
WHERE r.created_at >= ?
|
|
|]
|
|
(Only cutoffTs)
|
|
|
|
deleteDeletedSndChunkReplicasExpired :: DB.Connection -> NominalDiffTime -> IO ()
|
|
deleteDeletedSndChunkReplicasExpired db ttl = do
|
|
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
|
|
DB.execute db "DELETE FROM deleted_snd_chunk_replicas WHERE created_at < ?" (Only cutoffTs)
|
|
|
|
$(J.deriveJSON defaultJSON ''UpMigration)
|
|
|
|
$(J.deriveToJSON (sumTypeJSON $ dropPrefix "ME") ''MigrationError)
|