mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-05-25 12:04:32 +00:00
agent: option to add SQLite functions to DB connection (#1674)
* agent: option to add SQLite functions to DB connection * add module
This commit is contained in:
@@ -42,6 +42,9 @@ module Simplex.Messaging.Agent.Store.SQLite
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Concurrent.MVar
|
||||
import Control.Concurrent.STM
|
||||
import Control.Exception (bracketOnError, onException, throwIO)
|
||||
import Control.Monad
|
||||
import Data.ByteArray (ScrubbedBytes)
|
||||
import qualified Data.ByteArray as BA
|
||||
@@ -58,21 +61,19 @@ import Simplex.Messaging.Agent.Store.Migrations (DBMigrate (..), sharedMigrateSc
|
||||
import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Common
|
||||
import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Util
|
||||
import Simplex.Messaging.Agent.Store.Shared (Migration (..), MigrationConfig (..), MigrationError (..))
|
||||
import Simplex.Messaging.Util (ifM, safeDecodeUtf8)
|
||||
import System.Directory (copyFile, createDirectoryIfMissing, doesFileExist)
|
||||
import System.FilePath (takeDirectory, takeFileName, (</>))
|
||||
import UnliftIO.Exception (bracketOnError, onException)
|
||||
import UnliftIO.MVar
|
||||
import UnliftIO.STM
|
||||
|
||||
-- * SQLite Store implementation
|
||||
|
||||
createDBStore :: DBOpts -> [Migration] -> MigrationConfig -> IO (Either MigrationError DBStore)
|
||||
createDBStore opts@DBOpts {dbFilePath, dbKey, keepKey, track} migrations migrationConfig = do
|
||||
createDBStore opts@DBOpts {dbFilePath} migrations migrationConfig = do
|
||||
let dbDir = takeDirectory dbFilePath
|
||||
createDirectoryIfMissing True dbDir
|
||||
st <- connectSQLiteStore dbFilePath dbKey keepKey track
|
||||
st <- connectSQLiteStore opts
|
||||
r <- migrateDBSchema st opts Nothing migrations migrationConfig `onException` closeDBStore st
|
||||
case r of
|
||||
Right () -> pure $ Right st
|
||||
@@ -91,27 +92,27 @@ migrateDBSchema st DBOpts {dbFilePath, vacuum} migrationsTable migrations Migrat
|
||||
dbm = DBMigrate {initialize, getCurrent, run, backup}
|
||||
in sharedMigrateSchema dbm (dbNew st) migrations confirm
|
||||
|
||||
connectSQLiteStore :: FilePath -> ScrubbedBytes -> Bool -> DB.TrackQueries -> IO DBStore
|
||||
connectSQLiteStore dbFilePath key keepKey track = do
|
||||
connectSQLiteStore :: DBOpts -> IO DBStore
|
||||
connectSQLiteStore DBOpts {dbFilePath, dbFunctions, dbKey = key, keepKey, track} = do
|
||||
dbNew <- not <$> doesFileExist dbFilePath
|
||||
dbConn <- dbBusyLoop (connectDB dbFilePath key track)
|
||||
dbConn <- dbBusyLoop $ connectDB dbFilePath dbFunctions key track
|
||||
dbConnection <- newMVar dbConn
|
||||
dbKey <- newTVarIO $! storeKey key keepKey
|
||||
dbClosed <- newTVarIO False
|
||||
dbSem <- newTVarIO 0
|
||||
pure DBStore {dbFilePath, dbKey, dbSem, dbConnection, dbNew, dbClosed}
|
||||
pure DBStore {dbFilePath, dbFunctions, dbKey, dbSem, dbConnection, dbNew, dbClosed}
|
||||
|
||||
connectDB :: FilePath -> ScrubbedBytes -> DB.TrackQueries -> IO DB.Connection
|
||||
connectDB path key track = do
|
||||
connectDB :: FilePath -> [SQLiteFuncDef] -> ScrubbedBytes -> DB.TrackQueries -> IO DB.Connection
|
||||
connectDB path functions key track = do
|
||||
db <- DB.open path track
|
||||
prepare db `onException` DB.close db
|
||||
-- _printPragmas db path
|
||||
pure db
|
||||
where
|
||||
prepare db = do
|
||||
let exec = SQLite3.exec $ SQL.connectionHandle $ DB.conn db
|
||||
unless (BA.null key) . exec $ "PRAGMA key = " <> keyString key <> ";"
|
||||
exec . fromQuery $
|
||||
let db' = SQL.connectionHandle $ DB.conn db
|
||||
unless (BA.null key) . SQLite3.exec db' $ "PRAGMA key = " <> keyString key <> ";"
|
||||
SQLite3.exec db' . fromQuery $
|
||||
[sql|
|
||||
PRAGMA busy_timeout = 100;
|
||||
PRAGMA foreign_keys = ON;
|
||||
@@ -119,6 +120,9 @@ connectDB path key track = do
|
||||
PRAGMA secure_delete = ON;
|
||||
PRAGMA auto_vacuum = FULL;
|
||||
|]
|
||||
forM_ functions $ \SQLiteFuncDef {funcName, argCount, deterministic, funcPtr} ->
|
||||
createStaticFunction db' funcName argCount deterministic funcPtr
|
||||
>>= either (throwIO . userError . show) pure
|
||||
|
||||
closeDBStore :: DBStore -> IO ()
|
||||
closeDBStore st@DBStore {dbClosed} =
|
||||
@@ -132,12 +136,12 @@ openSQLiteStore st@DBStore {dbClosed} key keepKey =
|
||||
ifM (readTVarIO dbClosed) (openSQLiteStore_ st key keepKey) (putStrLn "openSQLiteStore: already opened")
|
||||
|
||||
openSQLiteStore_ :: DBStore -> ScrubbedBytes -> Bool -> IO ()
|
||||
openSQLiteStore_ DBStore {dbConnection, dbFilePath, dbKey, dbClosed} key keepKey =
|
||||
openSQLiteStore_ DBStore {dbConnection, dbFilePath, dbFunctions, dbKey, dbClosed} key keepKey =
|
||||
bracketOnError
|
||||
(takeMVar dbConnection)
|
||||
(tryPutMVar dbConnection)
|
||||
$ \DB.Connection {slow, track} -> do
|
||||
DB.Connection {conn} <- connectDB dbFilePath key track
|
||||
DB.Connection {conn} <- connectDB dbFilePath dbFunctions key track
|
||||
atomically $ do
|
||||
writeTVar dbClosed False
|
||||
writeTVar dbKey $! storeKey key keepKey
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
module Simplex.Messaging.Agent.Store.SQLite.Common
|
||||
( DBStore (..),
|
||||
DBOpts (..),
|
||||
SQLiteFuncDef (..),
|
||||
withConnection,
|
||||
withConnection',
|
||||
withTransaction,
|
||||
@@ -20,9 +21,13 @@ import Control.Concurrent (threadDelay)
|
||||
import Control.Concurrent.STM (retry)
|
||||
import Data.ByteArray (ScrubbedBytes)
|
||||
import qualified Data.ByteArray as BA
|
||||
import Data.ByteString (ByteString)
|
||||
import Database.SQLite.Simple (SQLError)
|
||||
import qualified Database.SQLite.Simple as SQL
|
||||
import Database.SQLite3.Bindings
|
||||
import Foreign.Ptr
|
||||
import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Util
|
||||
import Simplex.Messaging.Util (ifM, unlessM)
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.MVar
|
||||
@@ -33,6 +38,7 @@ storeKey key keepKey = if keepKey || BA.null key then Just key else Nothing
|
||||
|
||||
data DBStore = DBStore
|
||||
{ dbFilePath :: FilePath,
|
||||
dbFunctions :: [SQLiteFuncDef],
|
||||
dbKey :: TVar (Maybe ScrubbedBytes),
|
||||
dbSem :: TVar Int,
|
||||
dbConnection :: MVar DB.Connection,
|
||||
@@ -42,12 +48,21 @@ data DBStore = DBStore
|
||||
|
||||
data DBOpts = DBOpts
|
||||
{ dbFilePath :: FilePath,
|
||||
dbFunctions :: [SQLiteFuncDef],
|
||||
dbKey :: ScrubbedBytes,
|
||||
keepKey :: Bool,
|
||||
vacuum :: Bool,
|
||||
track :: DB.TrackQueries
|
||||
}
|
||||
|
||||
-- e.g. `SQLiteFuncDef "name" 2 True f`
|
||||
data SQLiteFuncDef = SQLiteFuncDef
|
||||
{ funcName :: ByteString,
|
||||
argCount :: CArgCount,
|
||||
deterministic :: Bool,
|
||||
funcPtr :: FunPtr SQLiteFunc
|
||||
}
|
||||
|
||||
withConnectionPriority :: DBStore -> Bool -> (DB.Connection -> IO a) -> IO a
|
||||
withConnectionPriority DBStore {dbSem, dbConnection} priority action
|
||||
| priority = E.bracket_ signal release $ withMVar dbConnection action
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
module Simplex.Messaging.Agent.Store.SQLite.Util where
|
||||
|
||||
import Control.Exception (SomeException, catch, mask_)
|
||||
import Data.ByteString (ByteString)
|
||||
import qualified Data.ByteString as B
|
||||
import Database.SQLite3.Direct (Database (..), FuncArgs (..), FuncContext (..))
|
||||
import Database.SQLite3.Bindings
|
||||
import Foreign.C.String
|
||||
import Foreign.Ptr
|
||||
import Foreign.StablePtr
|
||||
|
||||
data CFuncPtrs = CFuncPtrs (FunPtr CFunc) (FunPtr CFunc) (FunPtr CFuncFinal)
|
||||
|
||||
type SQLiteFunc = Ptr CContext -> CArgCount -> Ptr (Ptr CValue) -> IO ()
|
||||
|
||||
mkSQLiteFunc :: (FuncContext -> FuncArgs -> IO ()) -> SQLiteFunc
|
||||
mkSQLiteFunc f cxt nArgs cvals = catchAsResultError cxt $ f (FuncContext cxt) (FuncArgs nArgs cvals)
|
||||
{-# INLINE mkSQLiteFunc #-}
|
||||
|
||||
-- Based on createFunction from Database.SQLite3.Direct, but uses static function pointer to avoid dynamic wrapper that triggers DCL.
|
||||
createStaticFunction :: Database -> ByteString -> CArgCount -> Bool -> FunPtr SQLiteFunc -> IO (Either Error ())
|
||||
createStaticFunction (Database db) name nArgs isDet funPtr = mask_ $ do
|
||||
u <- newStablePtr $ CFuncPtrs funPtr nullFunPtr nullFunPtr
|
||||
let flags = if isDet then c_SQLITE_DETERMINISTIC else 0
|
||||
B.useAsCString name $ \namePtr ->
|
||||
toResult () <$> c_sqlite3_create_function_v2 db namePtr nArgs flags (castStablePtrToPtr u) funPtr nullFunPtr nullFunPtr nullFunPtr
|
||||
|
||||
-- Convert a 'CError' to a 'Either Error', in the common case where
|
||||
-- SQLITE_OK signals success and anything else signals an error.
|
||||
--
|
||||
-- Note that SQLITE_OK == 0.
|
||||
toResult :: a -> CError -> Either Error a
|
||||
toResult a (CError 0) = Right a
|
||||
toResult _ code = Left $ decodeError code
|
||||
|
||||
-- call c_sqlite3_result_error in the event of an error
|
||||
catchAsResultError :: Ptr CContext -> IO () -> IO ()
|
||||
catchAsResultError ctx action = catch action $ \exn -> do
|
||||
let msg = show (exn :: SomeException)
|
||||
withCAStringLen msg $ \(ptr, len) ->
|
||||
c_sqlite3_result_error ctx ptr (fromIntegral len)
|
||||
Reference in New Issue
Block a user