diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index d80432cd3..45a2389ea 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -31,6 +31,7 @@ module Simplex.Messaging.Agent.Store.SQLite createSQLiteStore, connectSQLiteStore, closeSQLiteStore, + openSQLiteStore, sqlString, execSQL, upMigration, -- used in tests @@ -269,13 +270,13 @@ import Simplex.Messaging.Parsers (blobFieldParser, dropPrefix, fromTextField_, s import Simplex.Messaging.Protocol import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Transport.Client (TransportHost) -import Simplex.Messaging.Util (bshow, eitherToMaybe, groupOn, ($>>=), (<$$>)) +import Simplex.Messaging.Util (bshow, eitherToMaybe, groupOn, ifM, ($>>=), (<$$>)) import Simplex.Messaging.Version import System.Directory (copyFile, createDirectoryIfMissing, doesFileExist) import System.Exit (exitFailure) import System.FilePath (takeDirectory) import System.IO (hFlush, stdout) -import UnliftIO.Exception (onException) +import UnliftIO.Exception (onException, bracketOnError) import qualified UnliftIO.Exception as E import UnliftIO.STM @@ -379,10 +380,12 @@ confirmOrExit s = do connectSQLiteStore :: FilePath -> String -> IO SQLiteStore connectSQLiteStore dbFilePath dbKey = do dbNew <- not <$> doesFileExist dbFilePath - dbConn <- dbBusyLoop $ connectDB dbFilePath dbKey - dbConnVar <- newTMVarIO dbConn - dbEncrypted <- newTVarIO . not $ null dbKey - pure SQLiteStore {dbFilePath, dbEncrypted, dbConnection = dbConnVar, dbNew} + dbConn <- dbBusyLoop (connectDB dbFilePath dbKey) + atomically $ do + dbConnection <- newTMVar dbConn + dbEncrypted <- newTVar . not $ null dbKey + dbClosed <- newTVar False + pure SQLiteStore {dbFilePath, dbEncrypted, dbConnection, dbNew, dbClosed} connectDB :: FilePath -> String -> IO DB.Connection connectDB path key = do @@ -404,7 +407,25 @@ connectDB path key = do |] closeSQLiteStore :: SQLiteStore -> IO () -closeSQLiteStore st = atomically (takeTMVar $ dbConnection st) >>= DB.close +closeSQLiteStore st@SQLiteStore {dbClosed} = + ifM (readTVarIO dbClosed) (putStrLn "closeSQLiteStore: already closed") $ + withConnection st $ \conn -> do + DB.close conn + atomically $ writeTVar dbClosed True + +openSQLiteStore :: SQLiteStore -> String -> IO () +openSQLiteStore SQLiteStore {dbConnection, dbFilePath, dbClosed} key = + ifM (readTVarIO dbClosed) open (putStrLn "closeSQLiteStore: already opened") + where + open = + bracketOnError + (atomically $ takeTMVar dbConnection) + (atomically . tryPutTMVar dbConnection) + $ \DB.Connection {slow} -> do + DB.Connection {conn} <- connectDB dbFilePath key + atomically $ do + putTMVar dbConnection DB.Connection {conn, slow} + writeTVar dbClosed False sqlString :: String -> Text sqlString s = quote <> T.replace quote "''" (T.pack s) <> quote diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs index ef2c688aa..0948afd08 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs @@ -27,6 +27,7 @@ data SQLiteStore = SQLiteStore { dbFilePath :: FilePath, dbEncrypted :: TVar Bool, dbConnection :: TMVar DB.Connection, + dbClosed :: TVar Bool, dbNew :: Bool } diff --git a/src/Simplex/Messaging/Util.hs b/src/Simplex/Messaging/Util.hs index 4e000eab5..f235a3341 100644 --- a/src/Simplex/Messaging/Util.hs +++ b/src/Simplex/Messaging/Util.hs @@ -3,13 +3,11 @@ module Simplex.Messaging.Util where -import Control.Concurrent (threadDelay) import qualified Control.Exception as E import Control.Monad import Control.Monad.Except import Control.Monad.IO.Unlift import Data.Bifunctor (first) -import qualified Data.ByteString as BW import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Int (Int64) @@ -21,7 +19,6 @@ import qualified Data.Text as T import Data.Text.Encoding (decodeUtf8With) import Data.Time (NominalDiffTime) import GHC.Conc -import Numeric (showHex) import UnliftIO.Async import qualified UnliftIO.Exception as UE diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index 9a266699d..a2c8e3929 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -6,6 +6,7 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# OPTIONS_GHC -fno-warn-ambiguous-fields #-} @@ -13,9 +14,11 @@ module AgentTests.SQLiteTests (storeTests) where import Control.Concurrent.Async (concurrently_) import Control.Concurrent.STM +import Control.Exception (SomeException) import Control.Monad (replicateM_) import Crypto.Random (drgNew) import Data.ByteString.Char8 (ByteString) +import Data.List (isInfixOf) import qualified Data.Text as T import Data.Text.Encoding (encodeUtf8) import Data.Time @@ -51,11 +54,14 @@ withStore2 = before connect2 . after (removeStore . fst) pure (s1, s2) createStore :: IO SQLiteStore -createStore = do +createStore = createEncryptedStore "" + +createEncryptedStore :: String -> IO SQLiteStore +createEncryptedStore key = do -- Randomize DB file name to avoid SQLite IO errors supposedly caused by asynchronous -- IO operations on multiple similarly named files; error seems to be environment specific r <- randomIO :: IO Word32 - Right st <- createSQLiteStore (testDB <> show r) "" Migrations.app MCError + Right st <- createSQLiteStore (testDB <> show r) key Migrations.app MCError pure st removeStore :: SQLiteStore -> IO () @@ -104,6 +110,9 @@ storeTests = do testCreateRcvMsg testCreateSndMsg testCreateRcvAndSndMsgs + describe "open/close store" $ do + it "should close and re-open" testCloseReopenStore + it "should close and re-open encrypted store" testCloseReopenEncryptedStore testConcurrentWrites :: SpecWith (SQLiteStore, SQLiteStore) testConcurrentWrites = @@ -504,3 +513,43 @@ testCreateRcvAndSndMsgs = testCreateRcvMsg_ db 2 "rcv_hash_2" connId rcvQueue1 $ mkRcvMsgData (InternalId 4) (InternalRcvId 3) 3 "3" "rcv_hash_3" testCreateSndMsg_ db "snd_hash_1" connId $ mkSndMsgData (InternalId 5) (InternalSndId 2) "snd_hash_2" testCreateSndMsg_ db "snd_hash_2" connId $ mkSndMsgData (InternalId 6) (InternalSndId 3) "snd_hash_3" + +testCloseReopenStore :: IO () +testCloseReopenStore = do + st <- createStore + hasMigrations st + closeSQLiteStore st + closeSQLiteStore st + errorGettingMigrations st + openSQLiteStore st "" + openSQLiteStore st "" + hasMigrations st + closeSQLiteStore st + errorGettingMigrations st + openSQLiteStore st "" + hasMigrations st + +testCloseReopenEncryptedStore :: IO () +testCloseReopenEncryptedStore = do + let key = "test_key" + st <- createEncryptedStore key + hasMigrations st + closeSQLiteStore st + closeSQLiteStore st + errorGettingMigrations st + openSQLiteStore st key + openSQLiteStore st key + hasMigrations st + closeSQLiteStore st + errorGettingMigrations st + openSQLiteStore st key + hasMigrations st + +getMigrations :: SQLiteStore -> IO Bool +getMigrations st = not . null <$> withTransaction st (Migrations.getCurrent . DB.conn) + +hasMigrations :: SQLiteStore -> Expectation +hasMigrations st = getMigrations st `shouldReturn` True + +errorGettingMigrations :: SQLiteStore -> Expectation +errorGettingMigrations st = getMigrations st `shouldThrow` \(e :: SomeException) -> "ErrorMisuse" `isInfixOf` show e