diff --git a/apps/smp-agent/Main.hs b/apps/smp-agent/Main.hs index 58c4cdd40..e3b269712 100644 --- a/apps/smp-agent/Main.hs +++ b/apps/smp-agent/Main.hs @@ -5,6 +5,7 @@ module Main where import Control.Logger.Simple +import Data.ByteArray (ScrubbedBytes) import qualified Data.List.NonEmpty as L import qualified Data.Map.Strict as M import Simplex.Messaging.Agent.Env.SQLite @@ -19,7 +20,7 @@ cfg = defaultAgentConfig agentDbFile :: String agentDbFile = "smp-agent.db" -agentDbKey :: String +agentDbKey :: ScrubbedBytes agentDbKey = "" servers :: InitialAgentServers @@ -38,5 +39,5 @@ main :: IO () main = do putStrLn $ "SMP agent listening on port " ++ tcpPort (cfg :: AgentConfig) setLogLevel LogInfo -- LogError - Right st <- createAgentStore agentDbFile agentDbKey MCConsole + Right st <- createAgentStore agentDbFile agentDbKey False MCConsole withGlobalLogging logCfg $ runSMPAgent (transport @TLS) cfg servers st diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index e8769b749..336d82fbf 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -849,23 +849,21 @@ getNotificationMessage' c nonce encNtfInfo = do (ntfConnId, rcvNtfDhSecret) <- withStore c (`getNtfRcvQueue` smpQueue) ntfMsgMeta <- (eitherToMaybe . smpDecode <$> agentCbDecrypt rcvNtfDhSecret nmsgNonce encNMsgMeta) `catchAgentError` \_ -> pure Nothing maxMsgs <- asks $ ntfMaxMessages . config - (NotificationInfo {ntfConnId, ntfTs, ntfMsgMeta},) <$> getNtfMessages ntfConnId maxMsgs ntfMsgMeta [] + (NotificationInfo {ntfConnId, ntfTs, ntfMsgMeta},) <$> getNtfMessages ntfConnId ntfMsgMeta maxMsgs _ -> throwError $ CMD PROHIBITED where - getNtfMessages ntfConnId maxMs nMeta ms - | length ms < maxMs = - getConnectionMessage' c ntfConnId >>= \case - Just m@SMP.SMPMsgMeta {msgId, msgTs, msgFlags} -> case nMeta of - Just SMP.NMsgMeta {msgId = msgId', msgTs = msgTs'} - | msgId == msgId' || msgTs > msgTs' -> pure $ reverse (m : ms) - | otherwise -> getMsg (m : ms) - _ - | SMP.notification msgFlags -> pure $ reverse (m : ms) - | otherwise -> getMsg (m : ms) - _ -> pure $ reverse ms - | otherwise = pure $ reverse ms + getNtfMessages ntfConnId nMeta = getMsg where - getMsg = getNtfMessages ntfConnId maxMs nMeta + getMsg 0 = pure [] + getMsg n = + getConnectionMessage' c ntfConnId >>= \case + Just m + | lastMsg m -> pure [m] + | otherwise -> (m :) <$> getMsg (n - 1) + Nothing -> pure [] + lastMsg SMP.SMPMsgMeta {msgId, msgTs, msgFlags} = case nMeta of + Just SMP.NMsgMeta {msgId = msgId', msgTs = msgTs'} -> msgId == msgId' || msgTs > msgTs' + Nothing -> SMP.notification msgFlags -- | Send message to the connection (SEND command) in Reader monad sendMessage' :: forall m. AgentMonad m => AgentClient -> ConnId -> MsgFlags -> MsgBody -> m AgentMsgId @@ -1887,11 +1885,14 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, s conn cData@ConnData {userId, connId, duplexHandshake, connAgentVersion, ratchetSyncState = rss} = withConnLock c connId "processSMP" $ case cmd of - SMP.MSG msg@SMP.RcvMessage {msgId = srvMsgId} -> - handleNotifyAck $ - decryptSMPMessage v rq msg >>= \case + SMP.MSG msg@SMP.RcvMessage {msgId = srvMsgId} -> + handleNotifyAck $ do + msg' <- decryptSMPMessage v rq msg + handleNotifyAck $ case msg' of SMP.ClientRcvMsgBody {msgTs = srvTs, msgFlags, msgBody} -> processClientMsg srvTs msgFlags msgBody SMP.ClientRcvMsgQuota {} -> queueDrained >> ack + whenM (atomically $ hasGetLock c rq) $ + notify (MSGNTF $ SMP.rcvMessageMeta srvMsgId msg') where queueDrained = case conn of DuplexConnection _ _ sqs -> void $ enqueueMessages c cData sqs SMP.noMsgFlags $ QCONT (sndAddress rq) diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 253be2811..18eb3d642 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -72,6 +72,7 @@ module Simplex.Messaging.Agent.Client logSecret, removeSubscription, hasActiveSubscription, + hasGetLock, agentClientStore, agentDRG, getAgentSubscriptions, @@ -897,8 +898,8 @@ subscribeQueues c qs = do -- only "checked" queues are subscribed (errs <>) <$> sendTSessionBatches "SUB" 90 id (subscribeQueues_ u) c qs' where - checkQueue rq@RcvQueue {rcvId, server} = do - prohibited <- atomically . TM.member (server, rcvId) $ getMsgLocks c + checkQueue rq = do + prohibited <- atomically $ hasGetLock c rq pure $ if prohibited then Left (rq, Left $ CMD PROHIBITED) else Right rq subscribeQueues_ :: UnliftIO m -> SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses SMPClientError ()) subscribeQueues_ u smp qs' = do @@ -1049,6 +1050,10 @@ sendAck c rq@RcvQueue {rcvId, rcvPrivateKey} msgId = do ackSMPMessage smp rcvPrivateKey rcvId msgId atomically $ releaseGetLock c rq +hasGetLock :: AgentClient -> RcvQueue -> STM Bool +hasGetLock c RcvQueue {server, rcvId} = + TM.member (server, rcvId) $ getMsgLocks c + releaseGetLock :: AgentClient -> RcvQueue -> STM () releaseGetLock c RcvQueue {server, rcvId} = TM.lookup (server, rcvId) (getMsgLocks c) >>= mapM_ (`tryPutTMVar` ()) diff --git a/src/Simplex/Messaging/Agent/Env/SQLite.hs b/src/Simplex/Messaging/Agent/Env/SQLite.hs index 6f2347391..e47de1c46 100644 --- a/src/Simplex/Messaging/Agent/Env/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Env/SQLite.hs @@ -34,6 +34,7 @@ import Control.Monad.Except import Control.Monad.IO.Unlift import Control.Monad.Reader import Crypto.Random +import Data.ByteArray (ScrubbedBytes) import Data.Int (Int64) import Data.List.NonEmpty (NonEmpty) import Data.Map (Map) @@ -163,7 +164,7 @@ defaultAgentConfig = ntfWorkerDelay = 100000, -- microseconds ntfSMPWorkerDelay = 500000, -- microseconds ntfSubCheckInterval = nominalDay, - ntfMaxMessages = 4, + ntfMaxMessages = 3, -- CA certificate private key is not needed for initialization -- ! we do not generate these caCertificateFile = "/etc/opt/simplex-agent/ca.crt", @@ -196,8 +197,8 @@ newSMPAgentEnv config@AgentConfig {initialClientId} store = do multicastSubscribers <- newTMVarIO 0 pure Env {config, store, random, clientCounter, randomServer, ntfSupervisor, xftpAgent, multicastSubscribers} -createAgentStore :: FilePath -> String -> MigrationConfirmation -> IO (Either MigrationError SQLiteStore) -createAgentStore dbFilePath dbKey = createSQLiteStore dbFilePath dbKey Migrations.app +createAgentStore :: FilePath -> ScrubbedBytes -> Bool -> MigrationConfirmation -> IO (Either MigrationError SQLiteStore) +createAgentStore dbFilePath dbKey keepKey = createSQLiteStore dbFilePath dbKey keepKey Migrations.app data NtfSupervisor = NtfSupervisor { ntfTkn :: TVar (Maybe NtfToken), diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index efe75e2f5..c79df9a70 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -193,6 +193,7 @@ import Simplex.Messaging.Protocol MsgId, NMsgMeta, ProtocolServer (..), + SMPMsgMeta, SMPServer, SMPServerWithAuth, SndPublicVerifyKey, @@ -337,6 +338,7 @@ data ACommand (p :: AParty) (e :: AEntity) where SENT :: AgentMsgId -> ACommand Agent AEConn MERR :: AgentMsgId -> AgentErrorType -> ACommand Agent AEConn MSG :: MsgMeta -> MsgFlags -> MsgBody -> ACommand Agent AEConn + MSGNTF :: SMPMsgMeta -> ACommand Agent AEConn ACK :: AgentMsgId -> Maybe MsgReceiptInfo -> ACommand Client AEConn RCVD :: MsgMeta -> NonEmpty MsgReceipt -> ACommand Agent AEConn SWCH :: ACommand Client AEConn @@ -397,6 +399,7 @@ data ACommandTag (p :: AParty) (e :: AEntity) where SENT_ :: ACommandTag Agent AEConn MERR_ :: ACommandTag Agent AEConn MSG_ :: ACommandTag Agent AEConn + MSGNTF_ :: ACommandTag Agent AEConn ACK_ :: ACommandTag Client AEConn RCVD_ :: ACommandTag Agent AEConn SWCH_ :: ACommandTag Client AEConn @@ -450,6 +453,7 @@ aCommandTag = \case SENT _ -> SENT_ MERR {} -> MERR_ MSG {} -> MSG_ + MSGNTF {} -> MSGNTF_ ACK {} -> ACK_ RCVD {} -> RCVD_ SWCH -> SWCH_ @@ -1604,6 +1608,7 @@ instance StrEncoding ACmdTag where "SENT" -> ct SENT_ "MERR" -> ct MERR_ "MSG" -> ct MSG_ + "MSGNTF" -> ct MSGNTF_ "ACK" -> t ACK_ "RCVD" -> ct RCVD_ "SWCH" -> t SWCH_ @@ -1659,6 +1664,7 @@ instance (APartyI p, AEntityI e) => StrEncoding (ACommandTag p e) where SENT_ -> "SENT" MERR_ -> "MERR" MSG_ -> "MSG" + MSGNTF_ -> "MSGNTF" ACK_ -> "ACK" RCVD_ -> "RCVD" SWCH_ -> "SWCH" @@ -1727,6 +1733,7 @@ commandP binaryP = SENT_ -> s (SENT <$> A.decimal) MERR_ -> s (MERR <$> A.decimal <* A.space <*> strP) MSG_ -> s (MSG <$> strP <* A.space <*> smpP <* A.space <*> binaryP) + MSGNTF_ -> s (MSGNTF <$> strP) RCVD_ -> s (RCVD <$> strP <* A.space <*> strP) DEL_RCVQ_ -> s (DEL_RCVQ <$> strP_ <*> strP_ <*> strP) DEL_CONN_ -> pure DEL_CONN @@ -1781,6 +1788,7 @@ serializeCommand = \case SENT mId -> s (SENT_, Str $ bshow mId) MERR mId e -> s (MERR_, Str $ bshow mId, e) MSG msgMeta msgFlags msgBody -> B.unwords [s MSG_, s msgMeta, smpEncode msgFlags, serializeBinary msgBody] + MSGNTF smpMsgMeta -> s (MSGNTF_, smpMsgMeta) ACK mId rcptInfo_ -> s (ACK_, Str $ bshow mId) <> maybe "" (B.cons ' ' . serializeBinary) rcptInfo_ RCVD msgMeta rcpts -> s (RCVD_, msgMeta, rcpts) SWCH -> s SWCH_ diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index e37cd2167..f1388489f 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -31,7 +31,10 @@ module Simplex.Messaging.Agent.Store.SQLite connectSQLiteStore, closeSQLiteStore, openSQLiteStore, + reopenSQLiteStore, sqlString, + keyString, + storeKey, execSQL, upMigration, -- used in tests @@ -221,6 +224,8 @@ import Crypto.Random (ChaChaDRG) import qualified Data.Aeson.TH as J import qualified Data.Attoparsec.ByteString.Char8 as A import Data.Bifunctor (second) +import Data.ByteArray (ScrubbedBytes) +import qualified Data.ByteArray as BA import Data.ByteString (ByteString) import qualified Data.ByteString.Base64.URL as U import Data.Char (toLower) @@ -267,7 +272,7 @@ import Simplex.Messaging.Parsers (blobFieldParser, defaultJSON, dropPrefix, from import Simplex.Messaging.Protocol import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Transport.Client (TransportHost) -import Simplex.Messaging.Util (bshow, eitherToMaybe, groupOn, ifM, ($>>=), (<$$>)) +import Simplex.Messaging.Util (bshow, eitherToMaybe, groupOn, ifM, safeDecodeUtf8, ($>>=), (<$$>)) import Simplex.Messaging.Version import System.Directory (copyFile, createDirectoryIfMissing, doesFileExist) import System.Exit (exitFailure) @@ -316,11 +321,11 @@ instance StrEncoding MigrationConfirmation where "error" -> pure MCError _ -> fail "invalid MigrationConfirmation" -createSQLiteStore :: FilePath -> String -> [Migration] -> MigrationConfirmation -> IO (Either MigrationError SQLiteStore) -createSQLiteStore dbFilePath dbKey migrations confirmMigrations = do +createSQLiteStore :: FilePath -> ScrubbedBytes -> Bool -> [Migration] -> MigrationConfirmation -> IO (Either MigrationError SQLiteStore) +createSQLiteStore dbFilePath dbKey keepKey migrations confirmMigrations = do let dbDir = takeDirectory dbFilePath createDirectoryIfMissing True dbDir - st <- connectSQLiteStore dbFilePath dbKey + st <- connectSQLiteStore dbFilePath dbKey keepKey r <- migrateSchema st migrations confirmMigrations `onException` closeSQLiteStore st case r of Right () -> pure $ Right st @@ -366,17 +371,17 @@ confirmOrExit s = do ok <- getLine when (map toLower ok /= "y") exitFailure -connectSQLiteStore :: FilePath -> String -> IO SQLiteStore -connectSQLiteStore dbFilePath dbKey = do +connectSQLiteStore :: FilePath -> ScrubbedBytes -> Bool -> IO SQLiteStore +connectSQLiteStore dbFilePath key keepKey = do dbNew <- not <$> doesFileExist dbFilePath - dbConn <- dbBusyLoop (connectDB dbFilePath dbKey) + dbConn <- dbBusyLoop (connectDB dbFilePath key) atomically $ do dbConnection <- newTMVar dbConn - dbEncrypted <- newTVar . not $ null dbKey + dbKey <- newTVar $! storeKey key keepKey dbClosed <- newTVar False - pure SQLiteStore {dbFilePath, dbEncrypted, dbConnection, dbNew, dbClosed} + pure SQLiteStore {dbFilePath, dbKey, dbConnection, dbNew, dbClosed} -connectDB :: FilePath -> String -> IO DB.Connection +connectDB :: FilePath -> ScrubbedBytes -> IO DB.Connection connectDB path key = do db <- DB.open path prepare db `onException` DB.close db @@ -385,7 +390,7 @@ connectDB path key = do where prepare db = do let exec = SQLite3.exec $ SQL.connectionHandle $ DB.conn db - unless (null key) . exec $ "PRAGMA key = " <> sqlString key <> ";" + unless (BA.null key) . exec $ "PRAGMA key = " <> keyString key <> ";" exec . fromQuery $ [sql| PRAGMA busy_timeout = 100; @@ -402,22 +407,36 @@ closeSQLiteStore st@SQLiteStore {dbClosed} = DB.close conn atomically $ writeTVar dbClosed True -openSQLiteStore :: SQLiteStore -> String -> IO () -openSQLiteStore SQLiteStore {dbConnection, dbFilePath, dbClosed} key = - ifM (readTVarIO dbClosed) open (putStrLn "closeSQLiteStore: already opened") +openSQLiteStore :: SQLiteStore -> ScrubbedBytes -> Bool -> IO () +openSQLiteStore st@SQLiteStore {dbClosed} key keepKey = + ifM (readTVarIO dbClosed) (openSQLiteStore_ st key keepKey) (putStrLn "openSQLiteStore: already opened") + +openSQLiteStore_ :: SQLiteStore -> ScrubbedBytes -> Bool -> IO () +openSQLiteStore_ SQLiteStore {dbConnection, dbFilePath, dbKey, dbClosed} key keepKey = + bracketOnError + (atomically $ takeTMVar dbConnection) + (atomically . tryPutTMVar dbConnection) + $ \DB.Connection {slow} -> do + DB.Connection {conn} <- connectDB dbFilePath key + atomically $ do + putTMVar dbConnection DB.Connection {conn, slow} + writeTVar dbClosed False + writeTVar dbKey $! storeKey key keepKey + +reopenSQLiteStore :: SQLiteStore -> IO () +reopenSQLiteStore st@SQLiteStore {dbKey, dbClosed} = + ifM (readTVarIO dbClosed) open (putStrLn "reopenSQLiteStore: already opened") where open = - bracketOnError - (atomically $ takeTMVar dbConnection) - (atomically . tryPutTMVar dbConnection) - $ \DB.Connection {slow} -> do - DB.Connection {conn} <- connectDB dbFilePath key - atomically $ do - putTMVar dbConnection DB.Connection {conn, slow} - writeTVar dbClosed False + readTVarIO dbKey >>= \case + Just key -> openSQLiteStore_ st key True + Nothing -> fail "reopenSQLiteStore: no key" -sqlString :: String -> Text -sqlString s = quote <> T.replace quote "''" (T.pack s) <> quote +keyString :: ScrubbedBytes -> Text +keyString = sqlString . safeDecodeUtf8 . BA.convert + +sqlString :: Text -> Text +sqlString s = quote <> T.replace quote "''" s <> quote where quote = "'" diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs index 0948afd08..9ba6cd08f 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs @@ -10,10 +10,13 @@ module Simplex.Messaging.Agent.Store.SQLite.Common withTransaction', withTransactionCtx, dbBusyLoop, + storeKey, ) where import Control.Concurrent (threadDelay) +import Data.ByteArray (ScrubbedBytes) +import qualified Data.ByteArray as BA import Data.Time.Clock (diffUTCTime, getCurrentTime) import Database.SQLite.Simple (SQLError) import qualified Database.SQLite.Simple as SQL @@ -23,9 +26,12 @@ import UnliftIO.Exception (bracket) import qualified UnliftIO.Exception as E import UnliftIO.STM +storeKey :: ScrubbedBytes -> Bool -> Maybe ScrubbedBytes +storeKey key keepKey = if keepKey || BA.null key then Just key else Nothing + data SQLiteStore = SQLiteStore { dbFilePath :: FilePath, - dbEncrypted :: TVar Bool, + dbKey :: TVar (Maybe ScrubbedBytes), dbConnection :: TMVar DB.Connection, dbClosed :: TVar Bool, dbNew :: Bool diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 0563cfeb5..e65385cba 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -460,7 +460,14 @@ data SMPMsgMeta = SMPMsgMeta msgTs :: SystemTime, msgFlags :: MsgFlags } - deriving (Show) + deriving (Eq, Show) + +instance StrEncoding SMPMsgMeta where + strEncode SMPMsgMeta {msgId, msgTs, msgFlags} = + strEncode (msgId, msgTs, msgFlags) + strP = do + (msgId, msgTs, msgFlags) <- strP + pure SMPMsgMeta {msgId, msgTs, msgFlags} rcvMessageMeta :: MsgId -> ClientRcvMsgBody -> SMPMsgMeta rcvMessageMeta msgId = \case diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 8c37579fc..7dec405ba 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -1010,34 +1010,35 @@ testOnlyCreatePull :: IO () testOnlyCreatePull = withAgentClients2 $ \alice bob -> runRight_ $ do (bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing SMOnlyCreate aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" SMOnlyCreate - getMsg alice bobId - Just ("", _, CONF confId _ "bob's connInfo") <- timeout 5_000000 $ get alice + Just ("", _, CONF confId _ "bob's connInfo") <- getMsg alice bobId $ timeout 5_000000 $ get alice allowConnection alice bobId confId "alice's connInfo" liftIO $ threadDelay 1_000000 - getMsg bob aliceId - get bob ##> ("", aliceId, INFO "alice's connInfo") + getMsg bob aliceId $ + get bob ##> ("", aliceId, INFO "alice's connInfo") liftIO $ threadDelay 1_000000 - getMsg alice bobId + getMsg alice bobId $ pure () get alice ##> ("", bobId, CON) - getMsg bob aliceId - get bob ##> ("", aliceId, CON) + getMsg bob aliceId $ + get bob ##> ("", aliceId, CON) -- exchange messages 4 <- sendMessage alice bobId SMP.noMsgFlags "hello" get alice ##> ("", bobId, SENT 4) - getMsg bob aliceId - get bob =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False + getMsg bob aliceId $ + get bob =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False ackMessage bob aliceId 4 Nothing 5 <- sendMessage bob aliceId SMP.noMsgFlags "hello too" get bob ##> ("", aliceId, SENT 5) - getMsg alice bobId - get alice =##> \case ("", c, Msg "hello too") -> c == bobId; _ -> False + getMsg alice bobId $ + get alice =##> \case ("", c, Msg "hello too") -> c == bobId; _ -> False ackMessage alice bobId 5 Nothing where - getMsg :: AgentClient -> ConnId -> ExceptT AgentErrorType IO () - getMsg c cId = do + getMsg :: AgentClient -> ConnId -> ExceptT AgentErrorType IO a -> ExceptT AgentErrorType IO a + getMsg c cId action = do liftIO $ noMessages c "nothing should be delivered before GET" Just _ <- getConnectionMessage c cId - pure () + r <- action + get c =##> \case ("", cId', MSGNTF _) -> cId == cId'; _ -> False + pure r makeConnection :: AgentClient -> AgentClient -> ExceptT AgentErrorType IO (ConnId, ConnId) makeConnection alice bob = makeConnectionForUsers alice 1 bob 1 @@ -2049,7 +2050,7 @@ testTwoUsers = withAgentClients2 $ \a b -> do getSMPAgentClient' :: AgentConfig -> InitialAgentServers -> FilePath -> IO AgentClient getSMPAgentClient' cfg' initServers dbPath = do - Right st <- liftIO $ createAgentStore dbPath "" MCError + Right st <- liftIO $ createAgentStore dbPath "" False MCError getSMPAgentClient cfg' initServers st testServerMultipleIdentities :: HasCallStack => IO () diff --git a/tests/AgentTests/MigrationTests.hs b/tests/AgentTests/MigrationTests.hs index e3a147644..406bdef60 100644 --- a/tests/AgentTests/MigrationTests.hs +++ b/tests/AgentTests/MigrationTests.hs @@ -178,16 +178,16 @@ testMigration :: testMigration (initMs, initTables) (finalMs, confirmModes, tablesOrError) = forM_ confirmModes $ \confirmMode -> do r <- randomIO :: IO Word32 let dpPath = testDB <> show r - Right st <- createSQLiteStore dpPath "" initMs MCError + Right st <- createSQLiteStore dpPath "" False initMs MCError st `shouldHaveTables` initTables closeSQLiteStore st case tablesOrError of Right tables -> do - Right st' <- createSQLiteStore dpPath "" finalMs confirmMode + Right st' <- createSQLiteStore dpPath "" False finalMs confirmMode st' `shouldHaveTables` tables closeSQLiteStore st' Left e -> do - Left e' <- createSQLiteStore dpPath "" finalMs confirmMode + Left e' <- createSQLiteStore dpPath "" False finalMs confirmMode e `shouldBe` e' removeFile dpPath where diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index cf6e8373b..9c2aa2fdf 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -16,6 +16,7 @@ import Control.Concurrent.STM import Control.Exception (SomeException) import Control.Monad (replicateM_) import Crypto.Random (drgNew) +import Data.ByteArray (ScrubbedBytes) import Data.ByteString.Char8 (ByteString) import Data.List (isInfixOf) import qualified Data.Text as T @@ -49,18 +50,18 @@ withStore2 = before connect2 . after (removeStore . fst) connect2 :: IO (SQLiteStore, SQLiteStore) connect2 = do s1 <- createStore - s2 <- connectSQLiteStore (dbFilePath s1) "" + s2 <- connectSQLiteStore (dbFilePath s1) "" False pure (s1, s2) createStore :: IO SQLiteStore -createStore = createEncryptedStore "" +createStore = createEncryptedStore "" False -createEncryptedStore :: String -> IO SQLiteStore -createEncryptedStore key = do +createEncryptedStore :: ScrubbedBytes -> Bool -> IO SQLiteStore +createEncryptedStore key keepKey = do -- Randomize DB file name to avoid SQLite IO errors supposedly caused by asynchronous -- IO operations on multiple similarly named files; error seems to be environment specific r <- randomIO :: IO Word32 - Right st <- createSQLiteStore (testDB <> show r) key Migrations.app MCError + Right st <- createSQLiteStore (testDB <> show r) key keepKey Migrations.app MCError pure st removeStore :: SQLiteStore -> IO () @@ -112,6 +113,7 @@ storeTests = do describe "open/close store" $ do it "should close and re-open" testCloseReopenStore it "should close and re-open encrypted store" testCloseReopenEncryptedStore + it "should close and re-open encrypted store (keep key)" testReopenEncryptedStoreKeepKey testConcurrentWrites :: SpecWith (SQLiteStore, SQLiteStore) testConcurrentWrites = @@ -520,28 +522,39 @@ testCloseReopenStore = do closeSQLiteStore st closeSQLiteStore st errorGettingMigrations st - openSQLiteStore st "" - openSQLiteStore st "" + openSQLiteStore st "" False + openSQLiteStore st "" False hasMigrations st closeSQLiteStore st errorGettingMigrations st - openSQLiteStore st "" + reopenSQLiteStore st hasMigrations st testCloseReopenEncryptedStore :: IO () testCloseReopenEncryptedStore = do let key = "test_key" - st <- createEncryptedStore key + st <- createEncryptedStore key False hasMigrations st closeSQLiteStore st closeSQLiteStore st errorGettingMigrations st - openSQLiteStore st key - openSQLiteStore st key + reopenSQLiteStore st `shouldThrow` \(e :: SomeException) -> "reopenSQLiteStore: no key" `isInfixOf` show e + openSQLiteStore st key True + openSQLiteStore st key True hasMigrations st closeSQLiteStore st errorGettingMigrations st - openSQLiteStore st key + reopenSQLiteStore st + hasMigrations st + +testReopenEncryptedStoreKeepKey :: IO () +testReopenEncryptedStoreKeepKey = do + let key = "test_key" + st <- createEncryptedStore key True + hasMigrations st + closeSQLiteStore st + errorGettingMigrations st + reopenSQLiteStore st hasMigrations st getMigrations :: SQLiteStore -> IO Bool diff --git a/tests/AgentTests/SchemaDump.hs b/tests/AgentTests/SchemaDump.hs index 3bea2188b..3c8246366 100644 --- a/tests/AgentTests/SchemaDump.hs +++ b/tests/AgentTests/SchemaDump.hs @@ -33,14 +33,14 @@ testVerifySchemaDump :: IO () testVerifySchemaDump = do savedSchema <- ifM (doesFileExist appSchema) (readFile appSchema) (pure "") savedSchema `deepseq` pure () - void $ createSQLiteStore testDB "" Migrations.app MCConsole + void $ createSQLiteStore testDB "" False Migrations.app MCConsole getSchema testDB appSchema `shouldReturn` savedSchema removeFile testDB testSchemaMigrations :: IO () testSchemaMigrations = do let noDownMigrations = dropWhileEnd (\Migration {down} -> isJust down) Migrations.app - Right st <- createSQLiteStore testDB "" noDownMigrations MCError + Right st <- createSQLiteStore testDB "" False noDownMigrations MCError mapM_ (testDownMigration st) $ drop (length noDownMigrations) Migrations.app closeSQLiteStore st removeFile testDB diff --git a/tests/SMPAgentClient.hs b/tests/SMPAgentClient.hs index bc8ad5ecc..a767218c0 100644 --- a/tests/SMPAgentClient.hs +++ b/tests/SMPAgentClient.hs @@ -213,7 +213,7 @@ withSmpAgentThreadOn_ t (port', smpPort', db') afterProcess = initServers' = initAgentServers {smp = userServers [ProtoServerWithAuth (SMPServer "localhost" smpPort' testKeyHash) Nothing]} in serverBracket ( \started -> do - Right st <- liftIO $ createAgentStore db' "" MCError + Right st <- liftIO $ createAgentStore db' "" False MCError runSMPAgentBlocking t cfg' initServers' st started ) afterProcess