diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index beac334fb..e47d2a15c 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -287,6 +287,7 @@ import System.FilePath (takeDirectory) import System.IO (hFlush, stdout) import UnliftIO.Exception (bracketOnError, onException) import qualified UnliftIO.Exception as E +import UnliftIO.MVar import UnliftIO.STM -- * SQLite Store implementation @@ -382,8 +383,8 @@ connectSQLiteStore :: FilePath -> ScrubbedBytes -> Bool -> IO SQLiteStore connectSQLiteStore dbFilePath key keepKey = do dbNew <- not <$> doesFileExist dbFilePath dbConn <- dbBusyLoop (connectDB dbFilePath key) + dbConnection <- newMVar dbConn atomically $ do - dbConnection <- newTMVar dbConn dbKey <- newTVar $! storeKey key keepKey dbClosed <- newTVar False pure SQLiteStore {dbFilePath, dbKey, dbConnection, dbNew, dbClosed} @@ -421,14 +422,14 @@ openSQLiteStore st@SQLiteStore {dbClosed} key keepKey = openSQLiteStore_ :: SQLiteStore -> ScrubbedBytes -> Bool -> IO () openSQLiteStore_ SQLiteStore {dbConnection, dbFilePath, dbKey, dbClosed} key keepKey = bracketOnError - (atomically $ takeTMVar dbConnection) - (atomically . tryPutTMVar dbConnection) + (takeMVar dbConnection) + (tryPutMVar dbConnection) $ \DB.Connection {slow} -> do DB.Connection {conn} <- connectDB dbFilePath key atomically $ do - putTMVar dbConnection DB.Connection {conn, slow} writeTVar dbClosed False writeTVar dbKey $! storeKey key keepKey + putMVar dbConnection DB.Connection {conn, slow} reopenSQLiteStore :: SQLiteStore -> IO () reopenSQLiteStore st@SQLiteStore {dbKey, dbClosed} = diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs index 18c16cc8b..b9a9bd501 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs @@ -22,8 +22,8 @@ import Database.SQLite.Simple (SQLError) import qualified Database.SQLite.Simple as SQL import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB import Simplex.Messaging.Util (diffToMilliseconds) -import UnliftIO.Exception (bracket) import qualified UnliftIO.Exception as E +import UnliftIO.MVar import UnliftIO.STM storeKey :: ScrubbedBytes -> Bool -> Maybe ScrubbedBytes @@ -32,16 +32,13 @@ storeKey key keepKey = if keepKey || BA.null key then Just key else Nothing data SQLiteStore = SQLiteStore { dbFilePath :: FilePath, dbKey :: TVar (Maybe ScrubbedBytes), - dbConnection :: TMVar DB.Connection, + dbConnection :: MVar DB.Connection, dbClosed :: TVar Bool, dbNew :: Bool } withConnection :: SQLiteStore -> (DB.Connection -> IO a) -> IO a -withConnection SQLiteStore {dbConnection} = - bracket - (atomically $ takeTMVar dbConnection) - (atomically . putTMVar dbConnection) +withConnection SQLiteStore {dbConnection} = withMVar dbConnection withConnection' :: SQLiteStore -> (SQL.Connection -> IO a) -> IO a withConnection' st action = withConnection st $ action . DB.conn @@ -71,9 +68,9 @@ dbBusyLoop action = loop 500 3000000 loop :: Int -> Int -> IO a loop t tLim = action `E.catch` \(e :: SQLError) -> - let se = SQL.sqlError e in - if tLim > t && (se == SQL.ErrorBusy || se == SQL.ErrorLocked) - then do - threadDelay t - loop (t * 9 `div` 8) (tLim - t) - else E.throwIO e + let se = SQL.sqlError e + in if tLim > t && (se == SQL.ErrorBusy || se == SQL.ErrorLocked) + then do + threadDelay t + loop (t * 9 `div` 8) (tLim - t) + else E.throwIO e diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index 4bac4fb83..436dd0eca 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -5,8 +5,8 @@ {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} @@ -17,6 +17,7 @@ module AgentTests.SQLiteTests (storeTests) where import AgentTests.EqInstances () import Control.Concurrent.Async (concurrently_) +import Control.Concurrent.MVar import Control.Concurrent.STM import Control.Exception (SomeException) import Control.Monad (replicateM_) @@ -45,9 +46,9 @@ import Simplex.Messaging.Agent.Store.SQLite.Common (withTransaction') import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations import qualified Simplex.Messaging.Crypto as C +import Simplex.Messaging.Crypto.File (CryptoFile (..)) import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), pattern PQSupportOn) import qualified Simplex.Messaging.Crypto.Ratchet as CR -import Simplex.Messaging.Crypto.File (CryptoFile (..)) import Simplex.Messaging.Encoding.String (StrEncoding (..)) import Simplex.Messaging.Protocol (SubscriptionMode (..), pattern VersionSMPC) import qualified Simplex.Messaging.Protocol as SMP @@ -88,7 +89,7 @@ removeStore db = do removeFile $ dbFilePath db where close :: SQLiteStore -> IO () - close st = mapM_ DB.close =<< atomically (tryTakeTMVar $ dbConnection st) + close st = mapM_ DB.close =<< tryTakeMVar (dbConnection st) storeTests :: Spec storeTests = do