diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 01b06da82..9fe1ba7ad 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -30,5 +30,10 @@ jobs: path: ~/.stack key: ${{ runner.os }}-${{ hashFiles('stack.yaml') }} + - name: Log SQLite default threading mode + run: | + sqlite3 test.db "pragma COMPILE_OPTIONS;" | grep THREADSAFE + rm test.db + - name: Build and run tests run: stack build --test diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index d5592d70e..a10e19838 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -18,9 +18,12 @@ module Simplex.Messaging.Agent.Store.SQLite ) where +import Control.Monad (when) import Control.Monad.Except (MonadError (throwError), MonadIO (liftIO)) import Control.Monad.IO.Unlift (MonadUnliftIO) +import Data.List (find) import Data.Maybe (fromMaybe) +import Data.Text (isPrefixOf) import qualified Data.Text as T import Data.Text.Encoding (decodeUtf8) import Database.SQLite.Simple as DB @@ -36,6 +39,7 @@ import Simplex.Messaging.Agent.Transmission import Simplex.Messaging.Protocol (MsgBody) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Util (liftIOEither) +import System.Exit (ExitCode (ExitFailure), exitWith) import Text.Read (readMaybe) import qualified UnliftIO.Exception as E @@ -49,6 +53,15 @@ data SQLiteStore = SQLiteStore createSQLiteStore :: MonadUnliftIO m => String -> m SQLiteStore createSQLiteStore dbFilename = do store <- connectSQLiteStore dbFilename + compileOptions <- liftIO (DB.query_ (dbConn store) "pragma COMPILE_OPTIONS;" :: IO [[T.Text]]) + let threadsafeOption = find (isPrefixOf "THREADSAFE=") (concat compileOptions) + liftIO $ case threadsafeOption of + Just "THREADSAFE=0" -> do + putStrLn "SQLite compiled with not threadsafe code, continue (y/n):" + s <- getLine + when (s /= "y") (exitWith $ ExitFailure 2) + Nothing -> putStrLn "Warning: SQLite THREADSAFE compile option not found" + _ -> return () liftIO . createSchema $ dbConn store return store diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index 1f168191f..0840da3fd 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -7,6 +7,7 @@ module AgentTests.SQLiteTests (storeTests) where import Control.Monad.Except (ExceptT, runExceptT) import qualified Crypto.PubKey.RSA as R +import qualified Data.Text as T import Data.Text.Encoding (encodeUtf8) import Data.Time import Data.Word (Word32) @@ -47,6 +48,7 @@ action `throwsError` e = runExceptT action `shouldReturn` Left e -- TODO add null port tests storeTests :: Spec storeTests = withStore do + describe "compiled as threadsafe" testCompiledThreadsafe describe "foreign keys enabled" testForeignKeysEnabled describe "store methods" do describe "createRcvConn" testCreateRcvConn @@ -71,6 +73,12 @@ storeTests = withStore do describe "SndQueue exists" testCreateSndMsg describe "SndQueue doesn't exist" testCreateSndMsgNoQueue +testCompiledThreadsafe :: SpecWith SQLiteStore +testCompiledThreadsafe = do + it "should throw error if compiled sqlite library is not threadsafe" $ \store -> do + compileOptions <- DB.query_ (dbConn store) "pragma COMPILE_OPTIONS;" :: IO [[T.Text]] + compileOptions `shouldNotContain` [["THREADSAFE=0"]] + testForeignKeysEnabled :: SpecWith SQLiteStore testForeignKeysEnabled = do it "should throw error if foreign keys are enabled" $ \store -> do diff --git a/tests/SMPAgentClient.hs b/tests/SMPAgentClient.hs index ac55ead12..23b65d25b 100644 --- a/tests/SMPAgentClient.hs +++ b/tests/SMPAgentClient.hs @@ -135,7 +135,7 @@ withSmpAgent = withSmpAgentOn (agentTestPort, testDB) testSMPAgentClientOn :: MonadUnliftIO m => ServiceName -> (Handle -> m a) -> m a testSMPAgentClientOn port' client = do - threadDelay 250_000 -- TODO hack: thread delay for SMP agent to start + threadDelay 500_000 -- TODO hack: thread delay for SMP agent to start runTCPClient agentTestHost port' $ \h -> do line <- liftIO $ getLn h if line == "Welcome to SMP v0.2.0 agent" diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index 99f6004f2..0cde0e740 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -25,7 +25,7 @@ testPort = "5000" testSMPClient :: MonadUnliftIO m => (Handle -> m a) -> m a testSMPClient client = do - threadDelay 50_000 -- TODO hack: thread delay for SMP server to start + threadDelay 250_000 -- TODO hack: thread delay for SMP server to start runTCPClient testHost testPort $ \h -> do line <- liftIO $ getLn h if line == "Welcome to SMP v0.2.0"