Files
simplexmq/src/Simplex/Messaging/Agent/Store/SQLite/Util.hs
T
2026-03-03 21:16:46 +00:00

96 lines
4.0 KiB
Haskell

module Simplex.Messaging.Agent.Store.SQLite.Util
( SQLiteFunc,
SQLiteFuncFinal,
createStaticFunction,
createStaticAggregate,
mkSQLiteFunc,
) where
import Control.Exception (SomeException, catch, mask_)
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Data.IORef
import Database.SQLite3.Direct (Database (..), FuncArgs (..), FuncContext (..))
import Database.SQLite3.Bindings
import Foreign.C.String
import Foreign.Ptr
import Foreign.StablePtr
import Foreign.Storable
data CFuncPtrs = CFuncPtrs (FunPtr CFunc) (FunPtr CFunc) (FunPtr CFuncFinal)
type SQLiteFunc = Ptr CContext -> CArgCount -> Ptr (Ptr CValue) -> IO ()
type SQLiteFuncFinal = Ptr CContext -> IO ()
mkSQLiteFunc :: (FuncContext -> FuncArgs -> IO ()) -> SQLiteFunc
mkSQLiteFunc f cxt nArgs cvals = catchAsResultError cxt $ f (FuncContext cxt) (FuncArgs nArgs cvals)
{-# INLINE mkSQLiteFunc #-}
-- Based on createFunction from Database.SQLite3.Direct, but uses static function pointer to avoid dynamic wrapper that triggers DCL.
createStaticFunction :: Database -> ByteString -> CArgCount -> Bool -> FunPtr SQLiteFunc -> IO (Either Error ())
createStaticFunction (Database db) name nArgs isDet funPtr = mask_ $ do
u <- newStablePtr $ CFuncPtrs funPtr nullFunPtr nullFunPtr
let flags = if isDet then c_SQLITE_DETERMINISTIC else 0
B.useAsCString name $ \namePtr ->
toResult () <$> c_sqlite3_create_function_v2 db namePtr nArgs flags (castStablePtrToPtr u) funPtr nullFunPtr nullFunPtr nullFunPtr
mkSQLiteAggStep :: a -> (FuncContext -> FuncArgs -> a -> IO a) -> SQLiteFunc
mkSQLiteAggStep initSt xStep cxt nArgs cvals = catchAsResultError cxt $ do
-- we store the aggregate state in the buffer returned by
-- c_sqlite3_aggregate_context as a StablePtr pointing to an IORef that
-- contains the actual aggregate state
aggCtx <- getAggregateContext cxt
aggStPtr <- peek aggCtx
aggStRef <-
if castStablePtrToPtr aggStPtr /= nullPtr
then deRefStablePtr aggStPtr
else do
aggStRef <- newIORef initSt
aggStPtr' <- newStablePtr aggStRef
poke aggCtx aggStPtr'
return aggStRef
aggSt <- readIORef aggStRef
aggSt' <- xStep (FuncContext cxt) (FuncArgs nArgs cvals) aggSt
writeIORef aggStRef aggSt'
mkSQLiteAggFinal :: a -> (FuncContext -> a -> IO ()) -> SQLiteFuncFinal
mkSQLiteAggFinal initSt xFinal cxt = do
aggCtx <- getAggregateContext cxt
aggStPtr <- peek aggCtx
if castStablePtrToPtr aggStPtr == nullPtr
then catchAsResultError cxt $ xFinal (FuncContext cxt) initSt
else do
catchAsResultError cxt $ do
aggStRef <- deRefStablePtr aggStPtr
aggSt <- readIORef aggStRef
xFinal (FuncContext cxt) aggSt
freeStablePtr aggStPtr
getAggregateContext :: Ptr CContext -> IO (Ptr a)
getAggregateContext cxt = c_sqlite3_aggregate_context cxt stPtrSize
where
stPtrSize = fromIntegral $ sizeOf (undefined :: StablePtr ())
-- Based on createAggregate from Database.SQLite3.Direct, but uses static function pointers to avoid dynamic wrappers that trigger DCL.
createStaticAggregate :: Database -> ByteString -> CArgCount -> FunPtr SQLiteFunc -> FunPtr SQLiteFuncFinal -> IO (Either Error ())
createStaticAggregate (Database db) name nArgs stepPtr finalPtr = mask_ $ do
u <- newStablePtr $ CFuncPtrs nullFunPtr stepPtr finalPtr
B.useAsCString name $ \namePtr ->
toResult () <$> c_sqlite3_create_function_v2 db namePtr nArgs 0 (castStablePtrToPtr u) nullFunPtr stepPtr finalPtr nullFunPtr
-- Convert a 'CError' to a 'Either Error', in the common case where
-- SQLITE_OK signals success and anything else signals an error.
--
-- Note that SQLITE_OK == 0.
toResult :: a -> CError -> Either Error a
toResult a (CError 0) = Right a
toResult _ code = Left $ decodeError code
-- call c_sqlite3_result_error in the event of an error
catchAsResultError :: Ptr CContext -> IO () -> IO ()
catchAsResultError ctx action = catch action $ \exn -> do
let msg = show (exn :: SomeException)
withCAStringLen msg $ \(ptr, len) ->
c_sqlite3_result_error ctx ptr (fromIntegral len)