mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-05-11 03:56:59 +00:00
96 lines
4.0 KiB
Haskell
96 lines
4.0 KiB
Haskell
module Simplex.Messaging.Agent.Store.SQLite.Util
|
|
( SQLiteFunc,
|
|
SQLiteFuncFinal,
|
|
createStaticFunction,
|
|
createStaticAggregate,
|
|
mkSQLiteFunc,
|
|
) where
|
|
|
|
import Control.Exception (SomeException, catch, mask_)
|
|
import Data.ByteString (ByteString)
|
|
import qualified Data.ByteString as B
|
|
import Data.IORef
|
|
import Database.SQLite3.Direct (Database (..), FuncArgs (..), FuncContext (..))
|
|
import Database.SQLite3.Bindings
|
|
import Foreign.C.String
|
|
import Foreign.Ptr
|
|
import Foreign.StablePtr
|
|
import Foreign.Storable
|
|
|
|
data CFuncPtrs = CFuncPtrs (FunPtr CFunc) (FunPtr CFunc) (FunPtr CFuncFinal)
|
|
|
|
type SQLiteFunc = Ptr CContext -> CArgCount -> Ptr (Ptr CValue) -> IO ()
|
|
|
|
type SQLiteFuncFinal = Ptr CContext -> IO ()
|
|
|
|
mkSQLiteFunc :: (FuncContext -> FuncArgs -> IO ()) -> SQLiteFunc
|
|
mkSQLiteFunc f cxt nArgs cvals = catchAsResultError cxt $ f (FuncContext cxt) (FuncArgs nArgs cvals)
|
|
{-# INLINE mkSQLiteFunc #-}
|
|
|
|
-- Based on createFunction from Database.SQLite3.Direct, but uses static function pointer to avoid dynamic wrapper that triggers DCL.
|
|
createStaticFunction :: Database -> ByteString -> CArgCount -> Bool -> FunPtr SQLiteFunc -> IO (Either Error ())
|
|
createStaticFunction (Database db) name nArgs isDet funPtr = mask_ $ do
|
|
u <- newStablePtr $ CFuncPtrs funPtr nullFunPtr nullFunPtr
|
|
let flags = if isDet then c_SQLITE_DETERMINISTIC else 0
|
|
B.useAsCString name $ \namePtr ->
|
|
toResult () <$> c_sqlite3_create_function_v2 db namePtr nArgs flags (castStablePtrToPtr u) funPtr nullFunPtr nullFunPtr nullFunPtr
|
|
|
|
mkSQLiteAggStep :: a -> (FuncContext -> FuncArgs -> a -> IO a) -> SQLiteFunc
|
|
mkSQLiteAggStep initSt xStep cxt nArgs cvals = catchAsResultError cxt $ do
|
|
-- we store the aggregate state in the buffer returned by
|
|
-- c_sqlite3_aggregate_context as a StablePtr pointing to an IORef that
|
|
-- contains the actual aggregate state
|
|
aggCtx <- getAggregateContext cxt
|
|
aggStPtr <- peek aggCtx
|
|
aggStRef <-
|
|
if castStablePtrToPtr aggStPtr /= nullPtr
|
|
then deRefStablePtr aggStPtr
|
|
else do
|
|
aggStRef <- newIORef initSt
|
|
aggStPtr' <- newStablePtr aggStRef
|
|
poke aggCtx aggStPtr'
|
|
return aggStRef
|
|
aggSt <- readIORef aggStRef
|
|
aggSt' <- xStep (FuncContext cxt) (FuncArgs nArgs cvals) aggSt
|
|
writeIORef aggStRef aggSt'
|
|
|
|
mkSQLiteAggFinal :: a -> (FuncContext -> a -> IO ()) -> SQLiteFuncFinal
|
|
mkSQLiteAggFinal initSt xFinal cxt = do
|
|
aggCtx <- getAggregateContext cxt
|
|
aggStPtr <- peek aggCtx
|
|
if castStablePtrToPtr aggStPtr == nullPtr
|
|
then catchAsResultError cxt $ xFinal (FuncContext cxt) initSt
|
|
else do
|
|
catchAsResultError cxt $ do
|
|
aggStRef <- deRefStablePtr aggStPtr
|
|
aggSt <- readIORef aggStRef
|
|
xFinal (FuncContext cxt) aggSt
|
|
freeStablePtr aggStPtr
|
|
|
|
getAggregateContext :: Ptr CContext -> IO (Ptr a)
|
|
getAggregateContext cxt = c_sqlite3_aggregate_context cxt stPtrSize
|
|
where
|
|
stPtrSize = fromIntegral $ sizeOf (undefined :: StablePtr ())
|
|
|
|
-- Based on createAggregate from Database.SQLite3.Direct, but uses static function pointers to avoid dynamic wrappers that trigger DCL.
|
|
createStaticAggregate :: Database -> ByteString -> CArgCount -> FunPtr SQLiteFunc -> FunPtr SQLiteFuncFinal -> IO (Either Error ())
|
|
createStaticAggregate (Database db) name nArgs stepPtr finalPtr = mask_ $ do
|
|
u <- newStablePtr $ CFuncPtrs nullFunPtr stepPtr finalPtr
|
|
B.useAsCString name $ \namePtr ->
|
|
toResult () <$> c_sqlite3_create_function_v2 db namePtr nArgs 0 (castStablePtrToPtr u) nullFunPtr stepPtr finalPtr nullFunPtr
|
|
|
|
-- Convert a 'CError' to a 'Either Error', in the common case where
|
|
-- SQLITE_OK signals success and anything else signals an error.
|
|
--
|
|
-- Note that SQLITE_OK == 0.
|
|
toResult :: a -> CError -> Either Error a
|
|
toResult a (CError 0) = Right a
|
|
toResult _ code = Left $ decodeError code
|
|
|
|
-- call c_sqlite3_result_error in the event of an error
|
|
catchAsResultError :: Ptr CContext -> IO () -> IO ()
|
|
catchAsResultError ctx action = catch action $ \exn -> do
|
|
let msg = show (exn :: SomeException)
|
|
withCAStringLen msg $ \(ptr, len) ->
|
|
c_sqlite3_result_error ctx ptr (fromIntegral len)
|