From 294d7ec8dde9898b66188a346f6d9d17119763da Mon Sep 17 00:00:00 2001 From: spaced4ndy <8711996+spaced4ndy@users.noreply.github.com> Date: Thu, 29 Feb 2024 22:08:58 +0400 Subject: [PATCH 01/30] agent: delay connection deletion to finish delivery of pending messages (#1015) * agent: delay connection deletion to finish delivery of pending messages (wip) * fixes, test * notify, test * add tests * comment * add test * timeout * test timeout * up * more tests * rename --------- Co-authored-by: Evgeny Poberezkin --- simplexmq.cabal | 1 + src/Simplex/Messaging/Agent.hs | 66 +++-- src/Simplex/Messaging/Agent/Env/SQLite.hs | 2 + src/Simplex/Messaging/Agent/Store/SQLite.hs | 80 ++++-- .../Agent/Store/SQLite/Migrations.hs | 4 +- .../M20240223_connections_wait_delivery.hs | 18 ++ .../Store/SQLite/Migrations/agent_schema.sql | 3 +- tests/AgentTests/FunctionalAPITests.hs | 250 +++++++++++++++++- tests/AgentTests/SQLiteTests.hs | 12 +- 9 files changed, 371 insertions(+), 65 deletions(-) create mode 100644 src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20240223_connections_wait_delivery.hs diff --git a/simplexmq.cabal b/simplexmq.cabal index 5a8d91390..f1d7c1bec 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -103,6 +103,7 @@ library Simplex.Messaging.Agent.Store.SQLite.Migrations.M20231225_failed_work_items Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240121_message_delivery_indexes Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240124_file_redirect + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240223_connections_wait_delivery Simplex.Messaging.Agent.TRcvQueues Simplex.Messaging.Client Simplex.Messaging.Client.Agent diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index aa9872e9c..4c2db7322 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -141,7 +141,7 @@ import qualified Data.Text as T import Data.Time.Clock import Data.Time.Clock.System (systemToUTCTime) import Data.Word (Word16) -import Simplex.FileTransfer.Agent (closeXFTPAgent, deleteSndFileInternal, deleteSndFilesInternal, deleteSndFileRemote, deleteSndFilesRemote, startXFTPWorkers, toFSFilePath, xftpDeleteRcvFile', xftpDeleteRcvFiles', xftpReceiveFile', xftpSendDescription', xftpSendFile') +import Simplex.FileTransfer.Agent (closeXFTPAgent, deleteSndFileInternal, deleteSndFileRemote, deleteSndFilesInternal, deleteSndFilesRemote, startXFTPWorkers, toFSFilePath, xftpDeleteRcvFile', xftpDeleteRcvFiles', xftpReceiveFile', xftpSendDescription', xftpSendFile') import Simplex.FileTransfer.Description (ValidFileDescription) import Simplex.FileTransfer.Protocol (FileParty (..)) import Simplex.FileTransfer.Util (removePath) @@ -242,12 +242,12 @@ switchConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> ConnId - switchConnectionAsync c = withAgentEnv c .: switchConnectionAsync' c -- | Delete SMP agent connection (DEL command) asynchronously, no synchronous response -deleteConnectionAsync :: AgentErrorMonad m => AgentClient -> ConnId -> m () -deleteConnectionAsync c = withAgentEnv c . deleteConnectionAsync' c +deleteConnectionAsync :: AgentErrorMonad m => AgentClient -> Bool -> ConnId -> m () +deleteConnectionAsync c waitDelivery = withAgentEnv c . deleteConnectionAsync' c waitDelivery -- | Delete SMP agent connections using batch commands asynchronously, no synchronous response -deleteConnectionsAsync :: AgentErrorMonad m => AgentClient -> [ConnId] -> m () -deleteConnectionsAsync c = withAgentEnv c . deleteConnectionsAsync' c +deleteConnectionsAsync :: AgentErrorMonad m => AgentClient -> Bool -> [ConnId] -> m () +deleteConnectionsAsync c waitDelivery = withAgentEnv c . deleteConnectionsAsync' c waitDelivery -- | Create SMP agent connection (NEW command) createConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SubscriptionMode -> m (ConnId, ConnectionRequestUri c) @@ -541,7 +541,7 @@ createUser' c smp xftp = do deleteUser' :: AgentMonad m => AgentClient -> UserId -> Bool -> m () deleteUser' c userId delSMPQueues = do if delSMPQueues - then withStore c (`setUserDeleted` userId) >>= deleteConnectionsAsync_ delUser c + then withStore c (`setUserDeleted` userId) >>= deleteConnectionsAsync_ delUser c False else withStore c (`deleteUserRecord` userId) atomically $ TM.delete userId $ smpServers c where @@ -613,21 +613,21 @@ ackMessageAsync' c corrId connId msgId rcptInfo_ = do (RcvQueue {server}, _) <- withStoreCtx "ackMessageAsync': setMsgUserAck" c $ \db -> setMsgUserAck db connId mId enqueueCommand c corrId connId (Just server) . AClientCommand $ APC SAEConn $ ACK msgId rcptInfo_ -deleteConnectionAsync' :: forall m. AgentMonad m => AgentClient -> ConnId -> m () -deleteConnectionAsync' c connId = deleteConnectionsAsync' c [connId] +deleteConnectionAsync' :: forall m. AgentMonad m => AgentClient -> Bool -> ConnId -> m () +deleteConnectionAsync' c waitDelivery connId = deleteConnectionsAsync' c waitDelivery [connId] -deleteConnectionsAsync' :: AgentMonad m => AgentClient -> [ConnId] -> m () +deleteConnectionsAsync' :: AgentMonad m => AgentClient -> Bool -> [ConnId] -> m () deleteConnectionsAsync' = deleteConnectionsAsync_ $ pure () -deleteConnectionsAsync_ :: forall m. AgentMonad m => m () -> AgentClient -> [ConnId] -> m () -deleteConnectionsAsync_ onSuccess c connIds = case connIds of +deleteConnectionsAsync_ :: forall m. AgentMonad m => m () -> AgentClient -> Bool -> [ConnId] -> m () +deleteConnectionsAsync_ onSuccess c waitDelivery connIds = case connIds of [] -> onSuccess _ -> do - (_, rqs, connIds') <- prepareDeleteConnections_ getConns c connIds - withStore' c $ forM_ connIds' . setConnDeleted + (_, rqs, connIds') <- prepareDeleteConnections_ getConns c waitDelivery connIds + withStore' c $ \db -> forM_ connIds' $ setConnDeleted db waitDelivery void . forkIO $ withLock (deleteLock c) "deleteConnectionsAsync" $ - deleteConnQueues c True rqs >> onSuccess + deleteConnQueues c waitDelivery True rqs >> onSuccess -- | Add connection to the new receive queue switchConnectionAsync' :: AgentMonad m => AgentClient -> ACorrId -> ConnId -> m ConnectionStats @@ -712,7 +712,7 @@ joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo subMode srv Right _ -> pure connId' Left e -> do -- possible improvement: recovery for failure on network timeout, see rfcs/2022-04-20-smp-conf-timeout-recovery.md - withStore' c (`deleteConn` connId') + void $ withStore' c $ \db -> deleteConn db Nothing connId' throwError e joinConnSrv c userId connId enableNtfs (CRContactUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)}) cInfo subMode srv = do aVRange <- asks $ smpAgentVRange . config @@ -1452,19 +1452,23 @@ disableConn c connId = do -- Unlike deleteConnectionsAsync, this function does not mark connections as deleted in case of deletion failure. deleteConnections' :: forall m. AgentMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ())) -deleteConnections' = deleteConnections_ getConns False +deleteConnections' = deleteConnections_ getConns False False deleteDeletedConns :: forall m. AgentMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ())) -deleteDeletedConns = deleteConnections_ getDeletedConns True +deleteDeletedConns = deleteConnections_ getDeletedConns True False + +deleteDeletedWaitingDeliveryConns :: forall m. AgentMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ())) +deleteDeletedWaitingDeliveryConns = deleteConnections_ getConns True True prepareDeleteConnections_ :: forall m. AgentMonad m => (DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn]) -> AgentClient -> + Bool -> [ConnId] -> m (Map ConnId (Either AgentErrorType ()), [RcvQueue], [ConnId]) -prepareDeleteConnections_ getConnections c connIds = do +prepareDeleteConnections_ getConnections c waitDelivery connIds = do conns :: Map ConnId (Either StoreError SomeConn) <- M.fromList . zip connIds <$> withStore' c (`getConnections` connIds) let (errs, cs) = M.mapEither id conns errs' = M.map (Left . storeError) errs @@ -1472,19 +1476,27 @@ prepareDeleteConnections_ getConnections c connIds = do rqs = concat $ M.elems rcvQs connIds' = M.keys rcvQs forM_ connIds' $ disableConn c - withStore' c $ forM_ (M.keys delRs) . deleteConn + -- ! delRs is not used to notify about the result in any of the calling functions, + -- ! it is only used to check results count in deleteConnections_; + -- ! if it was used to notify about the result, it might be necessary to differentiate + -- ! between completed deletions of connections, and deletions delayed due to wait for delivery (see deleteConn) + deliveryTimeout <- if waitDelivery then asks (Just . connDeleteDeliveryTimeout . config) else pure Nothing + rs' <- catMaybes . rights <$> withStoreBatch' c (\db -> map (deleteConn db deliveryTimeout) (M.keys delRs)) + forM_ rs' $ \cId -> notify ("", cId, APC SAEConn DEL_CONN) pure (errs' <> delRs, rqs, connIds') where rcvQueues :: SomeConn -> Either (Either AgentErrorType ()) [RcvQueue] rcvQueues (SomeConn _ conn) = case connRcvQueues conn of [] -> Left $ Right () rqs -> Right rqs + notify = atomically . writeTBQueue (subQ c) -deleteConnQueues :: forall m. AgentMonad m => AgentClient -> Bool -> [RcvQueue] -> m (Map ConnId (Either AgentErrorType ())) -deleteConnQueues c ntf rqs = do +deleteConnQueues :: forall m. AgentMonad m => AgentClient -> Bool -> Bool -> [RcvQueue] -> m (Map ConnId (Either AgentErrorType ())) +deleteConnQueues c waitDelivery ntf rqs = do rs <- connResults <$> (deleteQueueRecs =<< deleteQueues c rqs) let connIds = M.keys $ M.filter isRight rs - rs' <- rights <$> withStoreBatch' c (\db -> map (\cId -> deleteConn db cId $> cId) connIds) + deliveryTimeout <- if waitDelivery then asks (Just . connDeleteDeliveryTimeout . config) else pure Nothing + rs' <- catMaybes . rights <$> withStoreBatch' c (\db -> map (deleteConn db deliveryTimeout) connIds) forM_ rs' $ \cId -> notify ("", cId, APC SAEConn DEL_CONN) pure rs where @@ -1527,13 +1539,14 @@ deleteConnections_ :: AgentMonad m => (DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn]) -> Bool -> + Bool -> AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ())) -deleteConnections_ _ _ _ [] = pure M.empty -deleteConnections_ getConnections ntf c connIds = do - (rs, rqs, _) <- prepareDeleteConnections_ getConnections c connIds - rcvRs <- deleteConnQueues c ntf rqs +deleteConnections_ _ _ _ _ [] = pure M.empty +deleteConnections_ getConnections ntf waitDelivery c connIds = do + (rs, rqs, _) <- prepareDeleteConnections_ getConnections c waitDelivery connIds + rcvRs <- deleteConnQueues c waitDelivery ntf rqs let rs' = M.union rs rcvRs notifyResultError rs' pure rs' @@ -1862,6 +1875,7 @@ cleanupManager c@AgentClient {subQ} = do deleteConns = withLock (deleteLock c) "cleanupManager" $ do void $ withStore' c getDeletedConnIds >>= deleteDeletedConns c + void $ withStore' c getDeletedWaitingDeliveryConnIds >>= deleteDeletedWaitingDeliveryConns c withStore' c deleteUsersWithoutConns >>= mapM_ (notify "" . DEL_USER) deleteRcvFilesExpired = do rcvFilesTTL <- asks $ rcvFilesTTL . config diff --git a/src/Simplex/Messaging/Agent/Env/SQLite.hs b/src/Simplex/Messaging/Agent/Env/SQLite.hs index 71e710473..e603e50b8 100644 --- a/src/Simplex/Messaging/Agent/Env/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Env/SQLite.hs @@ -93,6 +93,7 @@ data AgentConfig = AgentConfig reconnectInterval :: RetryInterval, messageRetryInterval :: RetryInterval2, messageTimeout :: NominalDiffTime, + connDeleteDeliveryTimeout :: NominalDiffTime, helloTimeout :: NominalDiffTime, quotaExceededTimeout :: NominalDiffTime, initialCleanupDelay :: Int64, @@ -161,6 +162,7 @@ defaultAgentConfig = reconnectInterval = defaultReconnectInterval, messageRetryInterval = defaultMessageRetryInterval, messageTimeout = 2 * nominalDay, + connDeleteDeliveryTimeout = 2 * nominalDay, helloTimeout = 2 * nominalDay, quotaExceededTimeout = 7 * nominalDay, initialCleanupDelay = 30 * 1000000, -- 30 seconds diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index c82e91d9f..ab41d096e 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -60,6 +60,7 @@ module Simplex.Messaging.Agent.Store.SQLite setConnDeleted, setConnAgentVersion, getDeletedConnIds, + getDeletedWaitingDeliveryConnIds, setConnRatchetSync, addProcessedRatchetKeyHash, checkRatchetKeyHashExists, @@ -241,7 +242,7 @@ import Data.List (foldl', intercalate, sortBy) import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L import qualified Data.Map.Strict as M -import Data.Maybe (fromMaybe, isJust, listToMaybe, catMaybes) +import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing, listToMaybe) import Data.Ord (Down (..)) import Data.Text (Text) import qualified Data.Text as T @@ -602,12 +603,32 @@ getRcvConn db ProtocolServer {host, port} rcvId = runExceptT $ do DB.query db (rcvQueueQuery <> " WHERE q.host = ? AND q.port = ? AND q.rcv_id = ? AND q.deleted = 0") (host, port, rcvId) (rq,) <$> ExceptT (getConn db connId) -deleteConn :: DB.Connection -> ConnId -> IO () -deleteConn db connId = - DB.executeNamed - db - "DELETE FROM connections WHERE conn_id = :conn_id;" - [":conn_id" := connId] +-- | Deletes connection, optionally checking for pending snd message deliveries; returns connection id if it was deleted +deleteConn :: DB.Connection -> Maybe NominalDiffTime -> ConnId -> IO (Maybe ConnId) +deleteConn db waitDeliveryTimeout_ connId = case waitDeliveryTimeout_ of + Nothing -> delete + Just timeout -> + ifM + checkNoPendingDeliveries_ + delete + ( ifM + (checkWaitDeliveryTimeout_ timeout) + delete + (pure Nothing) + ) + where + delete = DB.execute db "DELETE FROM connections WHERE conn_id = ?" (Only connId) $> Just connId + checkNoPendingDeliveries_ = do + r :: (Maybe Int64) <- + maybeFirstRow fromOnly $ + DB.query db "SELECT 1 FROM snd_message_deliveries WHERE conn_id = ? AND failed = 0 LIMIT 1" (Only connId) + pure $ isNothing r + checkWaitDeliveryTimeout_ timeout = do + cutoffTs <- addUTCTime (-timeout) <$> getCurrentTime + r :: (Maybe Int64) <- + maybeFirstRow fromOnly $ + DB.query db "SELECT 1 FROM connections WHERE conn_id = ? AND deleted_at_wait_delivery < ? LIMIT 1" (connId, cutoffTs) + pure $ isJust r upgradeRcvConnToDuplex :: DB.Connection -> ConnId -> NewSndQueue -> IO (Either StoreError SndQueue) upgradeRcvConnToDuplex db connId sq = @@ -1912,8 +1933,13 @@ getConnData db connId' = cData (userId, connId, cMode, connAgentVersion, enableNtfs_, lastExternalSndId, deleted, ratchetSyncState) = (ConnData {userId, connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, lastExternalSndId, deleted, ratchetSyncState}, cMode) -setConnDeleted :: DB.Connection -> ConnId -> IO () -setConnDeleted db connId = DB.execute db "UPDATE connections SET deleted = ? WHERE conn_id = ?" (True, connId) +setConnDeleted :: DB.Connection -> Bool -> ConnId -> IO () +setConnDeleted db waitDelivery connId + | waitDelivery = do + currentTs <- getCurrentTime + DB.execute db "UPDATE connections SET deleted_at_wait_delivery = ? WHERE conn_id = ?" (currentTs, connId) + | otherwise = + DB.execute db "UPDATE connections SET deleted = ? WHERE conn_id = ?" (True, connId) setConnAgentVersion :: DB.Connection -> ConnId -> Version -> IO () setConnAgentVersion db connId aVersion = @@ -1922,6 +1948,10 @@ setConnAgentVersion db connId aVersion = getDeletedConnIds :: DB.Connection -> IO [ConnId] getDeletedConnIds db = map fromOnly <$> DB.query db "SELECT conn_id FROM connections WHERE deleted = ?" (Only True) +getDeletedWaitingDeliveryConnIds :: DB.Connection -> IO [ConnId] +getDeletedWaitingDeliveryConnIds db = + map fromOnly <$> DB.query_ db "SELECT conn_id FROM connections WHERE deleted_at_wait_delivery IS NOT NULL" + setConnRatchetSync :: DB.Connection -> ConnId -> RatchetSyncState -> IO () setConnRatchetSync db connId ratchetSyncState = DB.execute db "UPDATE connections SET ratchet_sync_state = ? WHERE conn_id = ?" (ratchetSyncState, connId) @@ -2267,17 +2297,18 @@ createRcvFileRedirect db gVar userId redirectFd@FileDescription {chunks = redire forM_ (zip [1 ..] replicas) $ \(rno, replica) -> insertRcvFileChunkReplica db rno replica chunkId pure dstEntityId where - dummyDst = FileDescription - { party = SFRecipient, - size, - digest, - redirect = Nothing, - -- updated later with updateRcvFileRedirect - key = C.unsafeSbKey $ B.replicate 32 '#', - nonce = C.cbNonce "", - chunkSize = FileSize 0, - chunks = [] - } + dummyDst = + FileDescription + { party = SFRecipient, + size, + digest, + redirect = Nothing, + -- updated later with updateRcvFileRedirect + key = C.unsafeSbKey $ B.replicate 32 '#', + nonce = C.cbNonce "", + chunkSize = FileSize 0, + chunks = [] + } insertRcvFile :: DB.Connection -> TVar ChaChaDRG -> UserId -> FileDescription 'FRecipient -> FilePath -> FilePath -> CryptoFile -> Maybe DBRcvFileId -> Maybe RcvFileId -> IO (Either StoreError (RcvFileId, DBRcvFileId)) insertRcvFile db gVar userId FileDescription {size, digest, key, nonce, chunkSize, redirect} prefixPath tmpPath (CryptoFile savePath cfArgs) redirectId_ redirectEntityId_ = runExceptT $ do @@ -2346,10 +2377,11 @@ getRcvFile db rcvFileId = runExceptT $ do toFile ((rcvFileEntityId, userId, size, digest, key, nonce, chunkSize, prefixPath, tmpPath) :. (savePath, saveKey_, saveNonce_, status, deleted, redirectDbId, redirectEntityId, redirectSize_, redirectDigest_)) = let cfArgs = CFArgs <$> saveKey_ <*> saveNonce_ saveFile = CryptoFile savePath cfArgs - redirect = RcvFileRedirect - <$> redirectDbId - <*> redirectEntityId - <*> (RedirectFileInfo <$> redirectSize_ <*> redirectDigest_) + redirect = + RcvFileRedirect + <$> redirectDbId + <*> redirectEntityId + <*> (RedirectFileInfo <$> redirectSize_ <*> redirectDigest_) in RcvFile {rcvFileId, rcvFileEntityId, userId, size, digest, key, nonce, chunkSize, redirect, prefixPath, tmpPath, saveFile, status, deleted, chunks = []} getChunks :: RcvFileId -> UserId -> FilePath -> IO [RcvFileChunk] getChunks rcvFileEntityId userId fileTmpPath = do diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs index 83c900f72..2ed79afa3 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs @@ -69,6 +69,7 @@ import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20231222_command_created import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20231225_failed_work_items import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240121_message_delivery_indexes import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240124_file_redirect +import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240223_connections_wait_delivery import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (dropPrefix, sumTypeJSON) import Simplex.Messaging.Transport.Client (TransportHost) @@ -106,7 +107,8 @@ schemaMigrations = ("m20231222_command_created_at", m20231222_command_created_at, Just down_m20231222_command_created_at), ("m20231225_failed_work_items", m20231225_failed_work_items, Just down_m20231225_failed_work_items), ("m20240121_message_delivery_indexes", m20240121_message_delivery_indexes, Just down_m20240121_message_delivery_indexes), - ("m20240124_file_redirect", m20240124_file_redirect, Just down_m20240124_file_redirect) + ("m20240124_file_redirect", m20240124_file_redirect, Just down_m20240124_file_redirect), + ("m20240223_connections_wait_delivery", m20240223_connections_wait_delivery, Just down_m20240223_connections_wait_delivery) ] -- | The list of migrations in ascending order by date diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20240223_connections_wait_delivery.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20240223_connections_wait_delivery.hs new file mode 100644 index 000000000..e61179768 --- /dev/null +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20240223_connections_wait_delivery.hs @@ -0,0 +1,18 @@ +{-# LANGUAGE QuasiQuotes #-} + +module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240223_connections_wait_delivery where + +import Database.SQLite.Simple (Query) +import Database.SQLite.Simple.QQ (sql) + +m20240223_connections_wait_delivery :: Query +m20240223_connections_wait_delivery = + [sql| +ALTER TABLE connections ADD COLUMN deleted_at_wait_delivery TEXT; +|] + +down_m20240223_connections_wait_delivery :: Query +down_m20240223_connections_wait_delivery = + [sql| +ALTER TABLE connections DROP COLUMN deleted_at_wait_delivery; +|] diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql index b9efaa05c..35459042d 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql @@ -26,7 +26,8 @@ CREATE TABLE connections( deleted INTEGER DEFAULT 0 CHECK(deleted NOT NULL), user_id INTEGER CHECK(user_id NOT NULL) REFERENCES users ON DELETE CASCADE, - ratchet_sync_state TEXT NOT NULL DEFAULT 'ok' + ratchet_sync_state TEXT NOT NULL DEFAULT 'ok', + deleted_at_wait_delivery TEXT ) WITHOUT ROWID; CREATE TABLE rcv_queues( host TEXT NOT NULL, diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index a5e994c68..d4be8304d 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -52,7 +52,7 @@ import Data.Time.Clock.System (SystemTime (..), getSystemTime) import Data.Type.Equality import qualified Database.SQLite.Simple as SQL import SMPAgentClient -import SMPClient (cfg, testPort, testPort2, testStoreLogFile2, withSmpServer, withSmpServerV7, withSmpServerConfigOn, withSmpServerOn, withSmpServerStoreLogOn, withSmpServerStoreMsgLogOn) +import SMPClient (cfg, testPort, testPort2, testStoreLogFile2, withSmpServer, withSmpServerConfigOn, withSmpServerOn, withSmpServerStoreLogOn, withSmpServerStoreMsgLogOn, withSmpServerV7) import Simplex.Messaging.Agent import Simplex.Messaging.Agent.Client (ProtocolTestFailure (..), ProtocolTestStep (..)) import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..), InitialAgentServers (..), createAgentStore) @@ -67,7 +67,7 @@ import Simplex.Messaging.Protocol (BasicAuth, ErrorType (..), MsgBody, ProtocolS import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Server.Env.STM (ServerConfig (..)) import Simplex.Messaging.Server.Expiration -import Simplex.Messaging.Transport (ATransport (..), basicAuthSMPVersion, authCmdsSMPVersion, currentServerSMPRelayVersion) +import Simplex.Messaging.Transport (ATransport (..), authCmdsSMPVersion, basicAuthSMPVersion, currentServerSMPRelayVersion) import Simplex.Messaging.Version import System.Directory (copyFile, renameFile) import Test.Hspec @@ -147,7 +147,7 @@ agentCfgVPrev = } agentCfgV7 :: AgentConfig -agentCfgV7 = +agentCfgV7 = agentCfg { sndAuthAlg = C.AuthAlg C.SX25519, smpCfg = smpCfgV7, @@ -271,6 +271,16 @@ functionalAPITests t = do withSmpServer t testAcceptContactAsync it "should delete connections using async command when server connection fails" $ testDeleteConnectionAsync t + it "delete waiting for delivery - should delete connection immediately if there are no pending messages" $ + testDeleteConnectionAsyncWaitDeliveryNoPending t + it "delete waiting for delivery - should delete connection after waiting for delivery to complete" $ + testDeleteConnectionAsyncWaitDelivery t + it "delete waiting for delivery - should delete connection if message can't be delivered due to AUTH error" $ + testDeleteConnectionAsyncWaitDeliveryAUTHErr t + it "delete waiting for delivery - should delete connection by timeout even if message wasn't delivered" $ + testDeleteConnectionAsyncWaitDeliveryTimeout t + it "delete waiting for delivery - should delete connection by timeout, message in progress can be delivered" $ + testDeleteConnectionAsyncWaitDeliveryTimeout2 t it "join connection when reply queue creation fails" $ testJoinConnectionAsyncReplyError t describe "Users" $ do @@ -381,7 +391,7 @@ testRatchetMatrix2 t runTest = do pendingV "ratchets prev to current" $ withSmpServer t $ runTestCfg2 agentCfgRatchetVPrev agentCfg 3 runTest pendingV "ratchets current to prev" $ withSmpServer t $ runTestCfg2 agentCfg agentCfgRatchetVPrev 3 runTest where - pendingV = + pendingV = let vr = e2eEncryptVRange agentCfg in if minVersion vr == maxVersion vr then xit else it @@ -1428,7 +1438,7 @@ testAsyncCommands = ] ackMessageAsync alice "7" bobId (baseId + 4) Nothing get alice =##> \case ("7", _, OK) -> True; _ -> False - deleteConnectionAsync alice bobId + deleteConnectionAsync alice False bobId get alice =##> \case ("", c, DEL_RCVQ _ _ Nothing) -> c == bobId; _ -> False get alice =##> \case ("", c, DEL_CONN) -> c == bobId; _ -> False liftIO $ noMessages alice "nothing else should be delivered to alice" @@ -1498,7 +1508,7 @@ testDeleteConnectionAsync t = do (bId3, _inv) <- createConnection a 1 True SCMInvitation Nothing SMSubscribe pure ([bId1, bId2, bId3] :: [ConnId]) runRight_ $ do - deleteConnectionsAsync a connIds + deleteConnectionsAsync a False connIds get a =##> \case ("", c, DEL_RCVQ _ _ (Just (BROKER _ e))) -> c `elem` connIds && (e == TIMEOUT || e == NETWORK); _ -> False get a =##> \case ("", c, DEL_RCVQ _ _ (Just (BROKER _ e))) -> c `elem` connIds && (e == TIMEOUT || e == NETWORK); _ -> False get a =##> \case ("", c, DEL_RCVQ _ _ (Just (BROKER _ e))) -> c `elem` connIds && (e == TIMEOUT || e == NETWORK); _ -> False @@ -1508,6 +1518,232 @@ testDeleteConnectionAsync t = do liftIO $ noMessages a "nothing else should be delivered to alice" disconnectAgentClient a +testDeleteConnectionAsyncWaitDeliveryNoPending :: ATransport -> IO () +testDeleteConnectionAsyncWaitDeliveryNoPending t = do + alice <- getSMPAgentClient' 1 agentCfg initAgentServers testDB + bob <- getSMPAgentClient' 2 agentCfg initAgentServers testDB2 + withSmpServerStoreLogOn t testPort $ \_ -> runRight_ $ do + (aliceId, bobId) <- makeConnection alice bob + + 1 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "hello" + get alice ##> ("", bobId, SENT $ baseId + 1) + get bob =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False + ackMessage bob aliceId (baseId + 1) Nothing + + 2 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "hello too" + get bob ##> ("", aliceId, SENT $ baseId + 2) + get alice =##> \case ("", c, Msg "hello too") -> c == bobId; _ -> False + ackMessage alice bobId (baseId + 2) Nothing + + deleteConnectionsAsync alice True [bobId] + get alice =##> \case ("", cId, DEL_RCVQ _ _ Nothing) -> cId == bobId; _ -> False + get alice =##> \case ("", cId, DEL_CONN) -> cId == bobId; _ -> False + + 3 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "message 2" + get bob ##> ("", aliceId, MERR (baseId + 3) (SMP AUTH)) + + liftIO $ noMessages alice "nothing else should be delivered to alice" + liftIO $ noMessages bob "nothing else should be delivered to bob" + + disconnectAgentClient alice + disconnectAgentClient bob + where + baseId = 3 + msgId = subtract baseId + +testDeleteConnectionAsyncWaitDelivery :: ATransport -> IO () +testDeleteConnectionAsyncWaitDelivery t = do + alice <- getSMPAgentClient' 1 agentCfg {initialCleanupDelay = 10000, cleanupInterval = 10000, deleteErrorCount = 3} initAgentServers testDB + bob <- getSMPAgentClient' 2 agentCfg initAgentServers testDB2 + (aliceId, bobId) <- withSmpServerStoreLogOn t testPort $ \_ -> runRight $ do + (aliceId, bobId) <- makeConnection alice bob + + 1 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "hello" + get alice ##> ("", bobId, SENT $ baseId + 1) + get bob =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False + ackMessage bob aliceId (baseId + 1) Nothing + + 2 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "hello too" + get bob ##> ("", aliceId, SENT $ baseId + 2) + get alice =##> \case ("", c, Msg "hello too") -> c == bobId; _ -> False + ackMessage alice bobId (baseId + 2) Nothing + + pure (aliceId, bobId) + + runRight_ $ do + ("", "", DOWN _ _) <- nGet alice + ("", "", DOWN _ _) <- nGet bob + 3 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "how are you?" + 4 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "message 1" + deleteConnectionsAsync alice True [bobId] + get alice =##> \case ("", cId, DEL_RCVQ _ _ (Just (BROKER _ e))) -> cId == bobId && (e == TIMEOUT || e == NETWORK); _ -> False + liftIO $ noMessages alice "nothing else should be delivered to alice" + liftIO $ noMessages bob "nothing else should be delivered to bob" + + withSmpServerStoreLogOn t testPort $ \_ -> runRight_ $ do + get alice ##> ("", bobId, SENT $ baseId + 3) + get alice ##> ("", bobId, SENT $ baseId + 4) + get alice =##> \case ("", cId, DEL_CONN) -> cId == bobId; _ -> False + + liftIO $ + getInAnyOrder + bob + [ \case ("", "", APC SAENone (UP _ [cId])) -> cId == aliceId; _ -> False, + \case ("", cId, APC SAEConn (Msg "how are you?")) -> cId == aliceId; _ -> False + ] + ackMessage bob aliceId (baseId + 3) Nothing + get bob =##> \case ("", c, Msg "message 1") -> c == aliceId; _ -> False + ackMessage bob aliceId (baseId + 4) Nothing + + -- queue wasn't deleted (DEL never reached server, see DEL_RCVQ with error), so bob can send message + 5 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "message 2" + get bob ##> ("", aliceId, SENT $ baseId + 5) + + liftIO $ noMessages alice "nothing else should be delivered to alice" + liftIO $ noMessages bob "nothing else should be delivered to bob" + + disconnectAgentClient alice + disconnectAgentClient bob + where + baseId = 3 + msgId = subtract baseId + +testDeleteConnectionAsyncWaitDeliveryAUTHErr :: ATransport -> IO () +testDeleteConnectionAsyncWaitDeliveryAUTHErr t = do + alice <- getSMPAgentClient' 1 agentCfg {initialCleanupDelay = 10000, cleanupInterval = 10000, deleteErrorCount = 3} initAgentServers testDB + bob <- getSMPAgentClient' 2 agentCfg initAgentServers testDB2 + (_aliceId, bobId) <- withSmpServerStoreLogOn t testPort $ \_ -> runRight $ do + (aliceId, bobId) <- makeConnection alice bob + + 1 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "hello" + get alice ##> ("", bobId, SENT $ baseId + 1) + get bob =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False + ackMessage bob aliceId (baseId + 1) Nothing + + 2 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "hello too" + get bob ##> ("", aliceId, SENT $ baseId + 2) + get alice =##> \case ("", c, Msg "hello too") -> c == bobId; _ -> False + ackMessage alice bobId (baseId + 2) Nothing + + deleteConnectionsAsync bob False [aliceId] + get bob =##> \case ("", cId, DEL_RCVQ _ _ Nothing) -> cId == aliceId; _ -> False + get bob =##> \case ("", cId, DEL_CONN) -> cId == aliceId; _ -> False + + pure (aliceId, bobId) + + runRight_ $ do + ("", "", DOWN _ _) <- nGet alice + 3 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "how are you?" + 4 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "message 1" + deleteConnectionsAsync alice True [bobId] + get alice =##> \case ("", cId, DEL_RCVQ _ _ (Just (BROKER _ e))) -> cId == bobId && (e == TIMEOUT || e == NETWORK); _ -> False + liftIO $ noMessages alice "nothing else should be delivered to alice" + liftIO $ noMessages bob "nothing else should be delivered to bob" + + withSmpServerStoreLogOn t testPort $ \_ -> do + get alice ##> ("", bobId, MERR (baseId + 3) (SMP AUTH)) + get alice ##> ("", bobId, MERR (baseId + 4) (SMP AUTH)) + get alice =##> \case ("", cId, DEL_CONN) -> cId == bobId; _ -> False + + liftIO $ noMessages alice "nothing else should be delivered to alice" + liftIO $ noMessages bob "nothing else should be delivered to bob" + + disconnectAgentClient alice + disconnectAgentClient bob + where + baseId = 3 + msgId = subtract baseId + +testDeleteConnectionAsyncWaitDeliveryTimeout :: ATransport -> IO () +testDeleteConnectionAsyncWaitDeliveryTimeout t = do + alice <- getSMPAgentClient' 1 agentCfg {connDeleteDeliveryTimeout = 1, initialCleanupDelay = 10000, cleanupInterval = 10000, deleteErrorCount = 3} initAgentServers testDB + bob <- getSMPAgentClient' 2 agentCfg initAgentServers testDB2 + (aliceId, bobId) <- withSmpServerStoreLogOn t testPort $ \_ -> runRight $ do + (aliceId, bobId) <- makeConnection alice bob + + 1 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "hello" + get alice ##> ("", bobId, SENT $ baseId + 1) + get bob =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False + ackMessage bob aliceId (baseId + 1) Nothing + + 2 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "hello too" + get bob ##> ("", aliceId, SENT $ baseId + 2) + get alice =##> \case ("", c, Msg "hello too") -> c == bobId; _ -> False + ackMessage alice bobId (baseId + 2) Nothing + + pure (aliceId, bobId) + + runRight_ $ do + ("", "", DOWN _ _) <- nGet alice + ("", "", DOWN _ _) <- nGet bob + 3 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "how are you?" + 4 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "message 1" + deleteConnectionsAsync alice True [bobId] + get alice =##> \case ("", cId, DEL_RCVQ _ _ (Just (BROKER _ e))) -> cId == bobId && (e == TIMEOUT || e == NETWORK); _ -> False + get alice =##> \case ("", cId, DEL_CONN) -> cId == bobId; _ -> False + liftIO $ noMessages alice "nothing else should be delivered to alice" + liftIO $ noMessages bob "nothing else should be delivered to bob" + + withSmpServerStoreLogOn t testPort $ \_ -> do + nGet bob =##> \case ("", "", UP _ [cId]) -> cId == aliceId; _ -> False + liftIO $ noMessages alice "nothing else should be delivered to alice" + liftIO $ noMessages bob "nothing else should be delivered to bob" + + disconnectAgentClient alice + disconnectAgentClient bob + where + baseId = 3 + msgId = subtract baseId + +testDeleteConnectionAsyncWaitDeliveryTimeout2 :: ATransport -> IO () +testDeleteConnectionAsyncWaitDeliveryTimeout2 t = do + alice <- getSMPAgentClient' 1 agentCfg {connDeleteDeliveryTimeout = 2, messageRetryInterval = fastMessageRetryInterval, initialCleanupDelay = 10000, cleanupInterval = 10000, deleteErrorCount = 3} initAgentServers testDB + bob <- getSMPAgentClient' 2 agentCfg initAgentServers testDB2 + (aliceId, bobId) <- withSmpServerStoreLogOn t testPort $ \_ -> runRight $ do + (aliceId, bobId) <- makeConnection alice bob + + 1 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "hello" + get alice ##> ("", bobId, SENT $ baseId + 1) + get bob =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False + ackMessage bob aliceId (baseId + 1) Nothing + + 2 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "hello too" + get bob ##> ("", aliceId, SENT $ baseId + 2) + get alice =##> \case ("", c, Msg "hello too") -> c == bobId; _ -> False + ackMessage alice bobId (baseId + 2) Nothing + + pure (aliceId, bobId) + + runRight_ $ do + ("", "", DOWN _ _) <- nGet alice + ("", "", DOWN _ _) <- nGet bob + 3 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "how are you?" + 4 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "message 1" + deleteConnectionsAsync alice True [bobId] + get alice =##> \case ("", cId, DEL_RCVQ _ _ (Just (BROKER _ e))) -> cId == bobId && (e == TIMEOUT || e == NETWORK); _ -> False + get alice =##> \case ("", cId, DEL_CONN) -> cId == bobId; _ -> False + liftIO $ noMessages alice "nothing else should be delivered to alice" + liftIO $ noMessages bob "nothing else should be delivered to bob" + + withSmpServerStoreLogOn t testPort $ \_ -> do + get alice ##> ("", bobId, SENT $ baseId + 3) + -- "message 1" not delivered + + liftIO $ + getInAnyOrder + bob + [ \case ("", "", APC SAENone (UP _ [cId])) -> cId == aliceId; _ -> False, + \case ("", cId, APC SAEConn (Msg "how are you?")) -> cId == aliceId; _ -> False + ] + liftIO $ noMessages alice "nothing else should be delivered to alice" + liftIO $ noMessages bob "nothing else should be delivered to bob" + + disconnectAgentClient alice + disconnectAgentClient bob + where + baseId = 3 + msgId = subtract baseId + testJoinConnectionAsyncReplyError :: HasCallStack => ATransport -> IO () testJoinConnectionAsyncReplyError t = do let initAgentServersSrv2 = initAgentServers {smp = userServers [noAuthSrv testSMPServer2]} @@ -1714,7 +1950,7 @@ testSwitchDelete servers = do stats <- switchConnectionAsync a "" bId liftIO $ rcvSwchStatuses' stats `shouldMatchList` [Just RSSwitchStarted] phaseRcv a bId SPStarted [Just RSSendingQADD, Nothing] - deleteConnectionAsync a bId + deleteConnectionAsync a False bId get a =##> \case ("", c, DEL_RCVQ _ _ Nothing) -> c == bId; _ -> False get a =##> \case ("", c, DEL_RCVQ _ _ Nothing) -> c == bId; _ -> False get a =##> \case ("", c, DEL_CONN) -> c == bId; _ -> False diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index 5799a0492..87b16a2a6 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -312,8 +312,8 @@ testDeleteRcvConn = Right (_, rq) <- createRcvConn db g cData1 rcvQueue1 SCMInvitation getConn db "conn1" `shouldReturn` Right (SomeConn SCRcv (RcvConnection cData1 rq)) - deleteConn db "conn1" - `shouldReturn` () + deleteConn db Nothing "conn1" + `shouldReturn` Just "conn1" getConn db "conn1" `shouldReturn` Left SEConnNotFound @@ -324,8 +324,8 @@ testDeleteSndConn = Right (_, sq) <- createSndConn db g cData1 sndQueue1 getConn db "conn1" `shouldReturn` Right (SomeConn SCSnd (SndConnection cData1 sq)) - deleteConn db "conn1" - `shouldReturn` () + deleteConn db Nothing "conn1" + `shouldReturn` Just "conn1" getConn db "conn1" `shouldReturn` Left SEConnNotFound @@ -337,8 +337,8 @@ testDeleteDuplexConn = Right sq <- upgradeRcvConnToDuplex db "conn1" sndQueue1 getConn db "conn1" `shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 [rq] [sq])) - deleteConn db "conn1" - `shouldReturn` () + deleteConn db Nothing "conn1" + `shouldReturn` Just "conn1" getConn db "conn1" `shouldReturn` Left SEConnNotFound From ce78646c7faabd30d24db05fd2854a64ca7a912b Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin Date: Sat, 2 Mar 2024 18:27:51 +0000 Subject: [PATCH 02/30] refactor creating connection record (#1021) --- src/Simplex/Messaging/Agent/Store/SQLite.hs | 29 ++++++++++----------- tests/AgentTests/SQLiteTests.hs | 10 ++++++- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index ab41d096e..647e8bd03 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -50,7 +50,6 @@ module Simplex.Messaging.Agent.Store.SQLite createNewConn, updateNewConnRcv, updateNewConnSnd, - createRcvConn, -- no longer used createSndConn, getConn, getDeletedConn, @@ -543,11 +542,8 @@ createConn_ gVar cData create = checkConstraint SEConnDuplicate $ case cData of ConnData {connId} -> Right . (connId,) <$> create connId createNewConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SConnectionMode c -> IO (Either StoreError ConnId) -createNewConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs} cMode = do - fst <$$> createConn_ gVar cData create - where - create connId = - DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, cMode, connAgentVersion, enableNtfs, True) +createNewConn db gVar cData cMode = do + fst <$$> createConn_ gVar cData (\connId -> createConnRecord db connId cData cMode) updateNewConnRcv :: DB.Connection -> ConnId -> NewRcvQueue -> IO (Either StoreError RcvQueue) updateNewConnRcv db connId rq = @@ -569,22 +565,25 @@ updateNewConnSnd db connId sq = updateConn :: IO (Either StoreError SndQueue) updateConn = Right <$> addConnSndQueue_ db connId sq -createRcvConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> NewRcvQueue -> SConnectionMode c -> IO (Either StoreError (ConnId, RcvQueue)) -createRcvConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs} q@RcvQueue {server} cMode = - createConn_ gVar cData $ \connId -> do - serverKeyHash_ <- createServer_ db server - DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, cMode, connAgentVersion, enableNtfs, True) - insertRcvQueue_ db connId q serverKeyHash_ - createSndConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> NewSndQueue -> IO (Either StoreError (ConnId, SndQueue)) -createSndConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs} q@SndQueue {server} = +createSndConn db gVar cData q@SndQueue {server} = -- check confirmed snd queue doesn't already exist, to prevent it being deleted by REPLACE in insertSndQueue_ ifM (liftIO $ checkConfirmedSndQueueExists_ db q) (pure $ Left SESndQueueExists) $ createConn_ gVar cData $ \connId -> do serverKeyHash_ <- createServer_ db server - DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, SCMInvitation, connAgentVersion, enableNtfs, True) + createConnRecord db connId cData SCMInvitation insertSndQueue_ db connId q serverKeyHash_ +createConnRecord :: DB.Connection -> ConnId -> ConnData -> SConnectionMode c -> IO () +createConnRecord db connId ConnData {userId, connAgentVersion, enableNtfs} cMode = + DB.execute + db + [sql| + INSERT INTO connections + (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?) + |] + (userId, connId, cMode, connAgentVersion, enableNtfs, True) + checkConfirmedSndQueueExists_ :: DB.Connection -> NewSndQueue -> IO Bool checkConfirmedSndQueueExists_ db SndQueue {server, sndId} = do fromMaybe False diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index 87b16a2a6..714b7e15e 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -15,6 +15,8 @@ import Control.Concurrent.Async (concurrently_) import Control.Concurrent.STM import Control.Exception (SomeException) import Control.Monad (replicateM_) +import Control.Monad.Trans.Except +import Crypto.Random (ChaChaDRG) import Data.ByteArray (ScrubbedBytes) import Data.ByteString.Char8 (ByteString) import Data.List (isInfixOf) @@ -91,7 +93,7 @@ storeTests = do testForeignKeysEnabled describe "db methods" $ do describe "Queue and Connection management" $ do - describe "createRcvConn" $ do + describe "create Rcv connection" $ do testCreateRcvConn testCreateRcvConnRandomId testCreateRcvConnDuplicate @@ -227,6 +229,12 @@ sndQueue1 = smpClientVersion = 1 } +createRcvConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> NewRcvQueue -> SConnectionMode c -> IO (Either StoreError (ConnId, RcvQueue)) +createRcvConn db g cData rq cMode = runExceptT $ do + connId <- ExceptT $ createNewConn db g cData cMode + rq' <- ExceptT $ updateNewConnRcv db connId rq + pure (connId, rq') + testCreateRcvConn :: SpecWith SQLiteStore testCreateRcvConn = it "should create RcvConnection and add SndQueue" . withStoreTransaction $ \db -> do From 246a0d10c22ebe02af2eb34773b77cce10247459 Mon Sep 17 00:00:00 2001 From: Alexander Bondarenko <486682+dpwiz@users.noreply.github.com> Date: Sat, 2 Mar 2024 20:46:05 +0200 Subject: [PATCH 03/30] xftp: raise internal upload limit to 5gb (#1020) * xftp: raise internal upload limit to 5gb * extract hard limit from agent --- src/Simplex/FileTransfer/Agent.hs | 2 +- src/Simplex/FileTransfer/Client/Main.hs | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/Simplex/FileTransfer/Agent.hs b/src/Simplex/FileTransfer/Agent.hs index 9a789a104..f45389462 100644 --- a/src/Simplex/FileTransfer/Agent.hs +++ b/src/Simplex/FileTransfer/Agent.hs @@ -393,7 +393,7 @@ runXFTPSndPrepareWorker c Worker {doWork} = do let CryptoFile {filePath} = srcFile fileName = takeFileName filePath fileSize <- liftIO $ fromInteger <$> CF.getFileContentsSize srcFile - when (fileSize > maxFileSize) $ throwError $ INTERNAL "max file size exceeded" + when (fileSize > maxFileSizeHard) $ throwError $ INTERNAL "max file size exceeded" let fileHdr = smpEncode FileHeader {fileName, fileExtra = Nothing} fileSize' = fromIntegral (B.length fileHdr) + fileSize chunkSizes = prepareChunkSizes $ fileSize' + fileSizeLen + authTagSize diff --git a/src/Simplex/FileTransfer/Client/Main.hs b/src/Simplex/FileTransfer/Client/Main.hs index e1ef9f0d8..c0277cd9f 100644 --- a/src/Simplex/FileTransfer/Client/Main.hs +++ b/src/Simplex/FileTransfer/Client/Main.hs @@ -19,6 +19,7 @@ module Simplex.FileTransfer.Client.Main prepareChunkSizes, prepareChunkSpecs, maxFileSize, + maxFileSizeHard, fileSizeLen, getChunkDigest, SentRecipientReplica (..), @@ -76,12 +77,17 @@ import UnliftIO.Directory xftpClientVersion :: String xftpClientVersion = "1.0.1" +-- | Soft limit for XFTP clients. Should be checked and reported to user. maxFileSize :: Int64 maxFileSize = gb 1 maxFileSizeStr :: String maxFileSizeStr = B.unpack . strEncode $ FileSize maxFileSize +-- | Hard internal limit for XFTP agent after which it refuses to prepare chunks. +maxFileSizeHard :: Int64 +maxFileSizeHard = gb 5 + fileSizeLen :: Int64 fileSizeLen = 8 From 30fd4065d9d5b8529520a5feb1f5ffab72323b66 Mon Sep 17 00:00:00 2001 From: spaced4ndy <8711996+spaced4ndy@users.noreply.github.com> Date: Sun, 3 Mar 2024 12:56:54 +0400 Subject: [PATCH 04/30] rename delete waiting delivery tests (#1022) --- tests/AgentTests/FunctionalAPITests.hs | 41 +++++++++++++------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index d4be8304d..5870266a7 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -271,18 +271,19 @@ functionalAPITests t = do withSmpServer t testAcceptContactAsync it "should delete connections using async command when server connection fails" $ testDeleteConnectionAsync t - it "delete waiting for delivery - should delete connection immediately if there are no pending messages" $ - testDeleteConnectionAsyncWaitDeliveryNoPending t - it "delete waiting for delivery - should delete connection after waiting for delivery to complete" $ - testDeleteConnectionAsyncWaitDelivery t - it "delete waiting for delivery - should delete connection if message can't be delivered due to AUTH error" $ - testDeleteConnectionAsyncWaitDeliveryAUTHErr t - it "delete waiting for delivery - should delete connection by timeout even if message wasn't delivered" $ - testDeleteConnectionAsyncWaitDeliveryTimeout t - it "delete waiting for delivery - should delete connection by timeout, message in progress can be delivered" $ - testDeleteConnectionAsyncWaitDeliveryTimeout2 t it "join connection when reply queue creation fails" $ testJoinConnectionAsyncReplyError t + describe "delete connection waiting for delivery" $ do + it "should delete connection immediately if there are no pending messages" $ + testWaitDeliveryNoPending t + it "should delete connection after waiting for delivery to complete" $ + testWaitDelivery t + it "should delete connection if message can't be delivered due to AUTH error" $ + testWaitDeliveryAUTHErr t + it "should delete connection by timeout even if message wasn't delivered" $ + testWaitDeliveryTimeout t + it "should delete connection by timeout, message in progress can be delivered" $ + testWaitDeliveryTimeout2 t describe "Users" $ do it "should create and delete user with connections" $ withSmpServer t testUsers @@ -1518,8 +1519,8 @@ testDeleteConnectionAsync t = do liftIO $ noMessages a "nothing else should be delivered to alice" disconnectAgentClient a -testDeleteConnectionAsyncWaitDeliveryNoPending :: ATransport -> IO () -testDeleteConnectionAsyncWaitDeliveryNoPending t = do +testWaitDeliveryNoPending :: ATransport -> IO () +testWaitDeliveryNoPending t = do alice <- getSMPAgentClient' 1 agentCfg initAgentServers testDB bob <- getSMPAgentClient' 2 agentCfg initAgentServers testDB2 withSmpServerStoreLogOn t testPort $ \_ -> runRight_ $ do @@ -1551,8 +1552,8 @@ testDeleteConnectionAsyncWaitDeliveryNoPending t = do baseId = 3 msgId = subtract baseId -testDeleteConnectionAsyncWaitDelivery :: ATransport -> IO () -testDeleteConnectionAsyncWaitDelivery t = do +testWaitDelivery :: ATransport -> IO () +testWaitDelivery t = do alice <- getSMPAgentClient' 1 agentCfg {initialCleanupDelay = 10000, cleanupInterval = 10000, deleteErrorCount = 3} initAgentServers testDB bob <- getSMPAgentClient' 2 agentCfg initAgentServers testDB2 (aliceId, bobId) <- withSmpServerStoreLogOn t testPort $ \_ -> runRight $ do @@ -1608,8 +1609,8 @@ testDeleteConnectionAsyncWaitDelivery t = do baseId = 3 msgId = subtract baseId -testDeleteConnectionAsyncWaitDeliveryAUTHErr :: ATransport -> IO () -testDeleteConnectionAsyncWaitDeliveryAUTHErr t = do +testWaitDeliveryAUTHErr :: ATransport -> IO () +testWaitDeliveryAUTHErr t = do alice <- getSMPAgentClient' 1 agentCfg {initialCleanupDelay = 10000, cleanupInterval = 10000, deleteErrorCount = 3} initAgentServers testDB bob <- getSMPAgentClient' 2 agentCfg initAgentServers testDB2 (_aliceId, bobId) <- withSmpServerStoreLogOn t testPort $ \_ -> runRight $ do @@ -1654,8 +1655,8 @@ testDeleteConnectionAsyncWaitDeliveryAUTHErr t = do baseId = 3 msgId = subtract baseId -testDeleteConnectionAsyncWaitDeliveryTimeout :: ATransport -> IO () -testDeleteConnectionAsyncWaitDeliveryTimeout t = do +testWaitDeliveryTimeout :: ATransport -> IO () +testWaitDeliveryTimeout t = do alice <- getSMPAgentClient' 1 agentCfg {connDeleteDeliveryTimeout = 1, initialCleanupDelay = 10000, cleanupInterval = 10000, deleteErrorCount = 3} initAgentServers testDB bob <- getSMPAgentClient' 2 agentCfg initAgentServers testDB2 (aliceId, bobId) <- withSmpServerStoreLogOn t testPort $ \_ -> runRight $ do @@ -1695,8 +1696,8 @@ testDeleteConnectionAsyncWaitDeliveryTimeout t = do baseId = 3 msgId = subtract baseId -testDeleteConnectionAsyncWaitDeliveryTimeout2 :: ATransport -> IO () -testDeleteConnectionAsyncWaitDeliveryTimeout2 t = do +testWaitDeliveryTimeout2 :: ATransport -> IO () +testWaitDeliveryTimeout2 t = do alice <- getSMPAgentClient' 1 agentCfg {connDeleteDeliveryTimeout = 2, messageRetryInterval = fastMessageRetryInterval, initialCleanupDelay = 10000, cleanupInterval = 10000, deleteErrorCount = 3} initAgentServers testDB bob <- getSMPAgentClient' 2 agentCfg initAgentServers testDB2 (aliceId, bobId) <- withSmpServerStoreLogOn t testPort $ \_ -> runRight $ do From e06e22328f39c652dabfbdb42cc77d29fcf32e80 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin Date: Sun, 3 Mar 2024 19:40:49 +0000 Subject: [PATCH 05/30] agent: quantum-resistant double ratchet encryption (#939) * doc * diff * ratchet header * types * ratchet step with PQ KEM, message header with KEM * comment * update types, remove Eq instances, store KEM keys to database * pqx3dh * PQ double ratchet test * pqdr tests pass * fix most tests * refactor * allow KEM proposals from both sides * test names * agent API parameters to use PQ KEM * initialize ratchet state for enabling KEM * fix/test KEM state machine to support disabling/enabling via messages * more tests * diff * diff2 * refactor * refactor * refactor * refactor * remove Maybe * rename * add PQ encryption status to CON, MID and MSG events and sendMessage API results * different PQ parameter when creating connection * rename/reorganize types for PQ encryption modes * rename * fix testWaitDeliveryTimeout * rename * rename2 * ghc8107 * rename * increase timeouts for concurrent send/receive test * enable all tests --------- Co-authored-by: spaced4ndy <8711996+spaced4ndy@users.noreply.github.com> --- rfcs/2023-12-29-pqdr.md | 36 + simplexmq.cabal | 3 + src/Simplex/FileTransfer/Client/Main.hs | 4 +- src/Simplex/FileTransfer/Protocol.hs | 2 +- src/Simplex/FileTransfer/Server/Store.hs | 1 - src/Simplex/FileTransfer/Types.hs | 18 +- src/Simplex/Messaging/Agent.hs | 341 +++++----- src/Simplex/Messaging/Agent/Protocol.hs | 115 ++-- src/Simplex/Messaging/Agent/Store.hs | 24 +- src/Simplex/Messaging/Agent/Store/SQLite.hs | 103 ++- .../Agent/Store/SQLite/Migrations.hs | 4 +- .../Migrations/M20240225_ratchet_kem.hs | 22 + .../Store/SQLite/Migrations/agent_schema.sql | 7 +- src/Simplex/Messaging/Crypto.hs | 42 +- src/Simplex/Messaging/Crypto/Ratchet.hs | 622 +++++++++++++++--- .../Messaging/Crypto/SNTRUP761/Bindings.hs | 57 +- src/Simplex/Messaging/Encoding/String.hs | 6 + src/Simplex/Messaging/Protocol.hs | 9 +- src/Simplex/Messaging/Server/QueueStore.hs | 4 +- tests/AgentTests.hs | 255 ++++--- tests/AgentTests/ConnectionRequestTests.hs | 27 +- tests/AgentTests/DoubleRatchetTests.hs | 482 ++++++++++++-- tests/AgentTests/EqInstances.hs | 25 + tests/AgentTests/FunctionalAPITests.hs | 164 +++-- tests/AgentTests/NotificationTests.hs | 30 +- tests/AgentTests/SQLiteTests.hs | 25 +- tests/CoreTests/CryptoTests.hs | 13 + tests/CoreTests/TRcvQueuesTests.hs | 1 + tests/SMPClient.hs | 5 +- tests/ServerTests.hs | 12 + tests/Util.hs | 6 + 31 files changed, 1776 insertions(+), 689 deletions(-) create mode 100644 rfcs/2023-12-29-pqdr.md create mode 100644 src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20240225_ratchet_kem.hs create mode 100644 tests/AgentTests/EqInstances.hs create mode 100644 tests/Util.hs diff --git a/rfcs/2023-12-29-pqdr.md b/rfcs/2023-12-29-pqdr.md new file mode 100644 index 000000000..7fd88ffbb --- /dev/null +++ b/rfcs/2023-12-29-pqdr.md @@ -0,0 +1,36 @@ +# Post-quantum double ratchet implementation + +See [the previous doc](https://github.com/simplex-chat/simplex-chat/blob/stable/docs/rfcs/2023-09-30-pq-double-ratchet.md). + +The main implementation consideration is that it should be both backwards and forwards compatible, to allow changing the connection DR to/from using PQ primitives (although client version downgrade may be impossible in this case), and also to decide whether to use PQ primitive on per-connection basis: +- use without links (in SMP confirmation or in SMP invitation via address or via member), don't use with links (as they would be too large). +- use in small groups, don't use in large groups. + +Also note that for DR to work we need to have 2 KEMs running in parallel. + +Possible combinations (assuming both clients support PQ): + +| Stage | No PQ kem | PQ key sent | PQ key + PQ ct sent | +|:------------:|:---------:|:-----------:|:-------------------:| +| inv | + | + | - | +| conf, in reply to:
no-pq inv
pq inv |  
+
+ |  
+
- |  
-
+ | +| 1st msg, in reply to:
no-pq conf
pq/pq+ct conf |  
+
+ |  
+
- |  
-
+ | +| Nth msg, in reply to:
no-pq msg
pq/pq+ct msg |  
+
+ |  
+
- |  
-
+ | + +These rules can be reduced to: +1. initial invitation optionally has PQ key, but must not have ciphertext. +2. all subsequent messages should be allowed without PQ key/ciphertext, but: + - if the previous message had PQ key or PQ key with ciphertext, they must either have no PQ key, or have PQ key with ciphertext (PQ key without ciphertext is an error). + - if the previous message had no PQ key, they must either have no PQ key, or have PQ key without ciphertext (PQ key with ciphertext is an error). + +The rules for calculating the shared secret for received/sent messages are (assuming received message is valid according to the above rules): + +| sent msg >
V received msg | no-pq | pq | pq+ct | +|:------------------------------:|:-----------:|:-------:|:---------------:| +| no-pq | DH / DH | DH / DH | err | +| pq (sent msg was NOT pq) | DH / DH | err | DH / DH+KEM | +| pq+ct (sent msg was NOT no-pq) | DH+KEM / DH | err | DH+KEM / DH+KEM | + +To summarize, the upgrade to DH+KEM secret happens in a sent message that has PQ key with ciphertext sent in reply to message with PQ key only (without ciphertext), and the downgrade to DH secret happens in the message that has no PQ key. + +The type for sending PQ key with optional ciphertext is `Maybe E2ERachetKEM` where `data E2ERachetKEM = E2ERachetKEM KEMPublicKey (Maybe KEMCiphertext)`, and for SMP invitation it will be simply `Maybe KEMPublicKey`. Possibly, there is a way to encode the rules above in the types, these types don't constrain possible transitions to valid ones. diff --git a/simplexmq.cabal b/simplexmq.cabal index f1d7c1bec..535e8fd2e 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -104,6 +104,7 @@ library Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240121_message_delivery_indexes Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240124_file_redirect Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240223_connections_wait_delivery + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240225_ratchet_kem Simplex.Messaging.Agent.TRcvQueues Simplex.Messaging.Client Simplex.Messaging.Client.Agent @@ -607,6 +608,7 @@ test-suite simplexmq-test AgentTests AgentTests.ConnectionRequestTests AgentTests.DoubleRatchetTests + AgentTests.EqInstances AgentTests.FunctionalAPITests AgentTests.MigrationTests AgentTests.NotificationTests @@ -629,6 +631,7 @@ test-suite simplexmq-test ServerTests SMPAgentClient SMPClient + Util XFTPAgent XFTPCLI XFTPClient diff --git a/src/Simplex/FileTransfer/Client/Main.hs b/src/Simplex/FileTransfer/Client/Main.hs index c0277cd9f..d0c867b27 100644 --- a/src/Simplex/FileTransfer/Client/Main.hs +++ b/src/Simplex/FileTransfer/Client/Main.hs @@ -220,13 +220,13 @@ data SentFileChunk = SentFileChunk digest :: FileDigest, replicas :: [SentFileChunkReplica] } - deriving (Eq, Show) + deriving (Show) data SentFileChunkReplica = SentFileChunkReplica { server :: XFTPServer, recipients :: [(ChunkReplicaId, C.APrivateAuthKey)] } - deriving (Eq, Show) + deriving (Show) data SentRecipientReplica = SentRecipientReplica { chunkNo :: Int, diff --git a/src/Simplex/FileTransfer/Protocol.hs b/src/Simplex/FileTransfer/Protocol.hs index a9de56ddb..e9988b56a 100644 --- a/src/Simplex/FileTransfer/Protocol.hs +++ b/src/Simplex/FileTransfer/Protocol.hs @@ -171,7 +171,7 @@ data FileInfo = FileInfo size :: Word32, digest :: ByteString } - deriving (Eq, Show) + deriving (Show) type XFTPFileId = ByteString diff --git a/src/Simplex/FileTransfer/Server/Store.hs b/src/Simplex/FileTransfer/Server/Store.hs index 031c46f5b..8c198690e 100644 --- a/src/Simplex/FileTransfer/Server/Store.hs +++ b/src/Simplex/FileTransfer/Server/Store.hs @@ -49,7 +49,6 @@ data FileRec = FileRec recipientIds :: TVar (Set RecipientId), createdAt :: SystemTime } - deriving (Eq) data FileRecipient = FileRecipient RecipientId RcvPublicAuthKey diff --git a/src/Simplex/FileTransfer/Types.hs b/src/Simplex/FileTransfer/Types.hs index 21967a3cd..ba306a6c6 100644 --- a/src/Simplex/FileTransfer/Types.hs +++ b/src/Simplex/FileTransfer/Types.hs @@ -55,7 +55,7 @@ data RcvFile = RcvFile status :: RcvFileStatus, deleted :: Bool } - deriving (Eq, Show) + deriving (Show) data RcvFileStatus = RFSReceiving @@ -96,7 +96,7 @@ data RcvFileChunk = RcvFileChunk fileTmpPath :: FilePath, chunkTmpPath :: Maybe FilePath } - deriving (Eq, Show) + deriving (Show) data RcvFileChunkReplica = RcvFileChunkReplica { rcvChunkReplicaId :: Int64, @@ -107,14 +107,14 @@ data RcvFileChunkReplica = RcvFileChunkReplica delay :: Maybe Int64, retries :: Int } - deriving (Eq, Show) + deriving (Show) data RcvFileRedirect = RcvFileRedirect { redirectDbId :: DBRcvFileId, redirectEntityId :: RcvFileId, redirectFileInfo :: RedirectFileInfo } - deriving (Eq, Show) + deriving (Show) -- Sending files @@ -135,7 +135,7 @@ data SndFile = SndFile deleted :: Bool, redirect :: Maybe RedirectFileInfo } - deriving (Eq, Show) + deriving (Show) sndFileEncPath :: FilePath -> FilePath sndFileEncPath prefixPath = prefixPath "xftp.encrypted" @@ -182,7 +182,7 @@ data SndFileChunk = SndFileChunk digest :: FileDigest, replicas :: [SndFileChunkReplica] } - deriving (Eq, Show) + deriving (Show) sndChunkSize :: SndFileChunk -> Word32 sndChunkSize SndFileChunk {chunkSpec = XFTPChunkSpec {chunkSize}} = chunkSize @@ -193,7 +193,7 @@ data NewSndChunkReplica = NewSndChunkReplica replicaKey :: C.APrivateAuthKey, rcvIdsKeys :: [(ChunkReplicaId, C.APrivateAuthKey)] } - deriving (Eq, Show) + deriving (Show) data SndFileChunkReplica = SndFileChunkReplica { sndChunkReplicaId :: Int64, @@ -205,7 +205,7 @@ data SndFileChunkReplica = SndFileChunkReplica delay :: Maybe Int64, retries :: Int } - deriving (Eq, Show) + deriving (Show) data SndFileReplicaStatus = SFRSCreated @@ -235,4 +235,4 @@ data DeletedSndChunkReplica = DeletedSndChunkReplica delay :: Maybe Int64, retries :: Int } - deriving (Eq, Show) + deriving (Show) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 4c2db7322..3f0ca12b4 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -218,20 +218,20 @@ deleteUser :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> m () deleteUser c = withAgentEnv c .: deleteUser' c -- | Create SMP agent connection (NEW command) asynchronously, synchronous response is new connection id -createConnectionAsync :: forall m c. (AgentErrorMonad m, ConnectionModeI c) => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> SubscriptionMode -> m ConnId -createConnectionAsync c userId aCorrId enableNtfs = withAgentEnv c .: newConnAsync c userId aCorrId enableNtfs +createConnectionAsync :: forall m c. (AgentErrorMonad m, ConnectionModeI c) => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> CR.InitialKeys -> SubscriptionMode -> m ConnId +createConnectionAsync c userId aCorrId enableNtfs = withAgentEnv c .:. newConnAsync c userId aCorrId enableNtfs -- | Join SMP agent connection (JOIN command) asynchronously, synchronous response is new connection id -joinConnectionAsync :: AgentErrorMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> m ConnId -joinConnectionAsync c userId aCorrId enableNtfs = withAgentEnv c .:. joinConnAsync c userId aCorrId enableNtfs +joinConnectionAsync :: AgentErrorMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> CR.PQEncryption -> SubscriptionMode -> m ConnId +joinConnectionAsync c userId aCorrId enableNtfs = withAgentEnv c .:: joinConnAsync c userId aCorrId enableNtfs -- | Allow connection to continue after CONF notification (LET command), no synchronous response allowConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> ConfirmationId -> ConnInfo -> m () allowConnectionAsync c = withAgentEnv c .:: allowConnectionAsync' c -- | Accept contact after REQ notification (ACPT command) asynchronously, synchronous response is new connection id -acceptContactAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> Bool -> ConfirmationId -> ConnInfo -> SubscriptionMode -> m ConnId -acceptContactAsync c aCorrId enableNtfs = withAgentEnv c .:. acceptContactAsync' c aCorrId enableNtfs +acceptContactAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> Bool -> ConfirmationId -> ConnInfo -> CR.PQEncryption -> SubscriptionMode -> m ConnId +acceptContactAsync c aCorrId enableNtfs = withAgentEnv c .:: acceptContactAsync' c aCorrId enableNtfs -- | Acknowledge message (ACK command) asynchronously, no synchronous response ackMessageAsync :: forall m. AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> AgentMsgId -> Maybe MsgReceiptInfo -> m () @@ -250,20 +250,20 @@ deleteConnectionsAsync :: AgentErrorMonad m => AgentClient -> Bool -> [ConnId] - deleteConnectionsAsync c waitDelivery = withAgentEnv c . deleteConnectionsAsync' c waitDelivery -- | Create SMP agent connection (NEW command) -createConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SubscriptionMode -> m (ConnId, ConnectionRequestUri c) -createConnection c userId enableNtfs = withAgentEnv c .:. newConn c userId "" enableNtfs +createConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> m (ConnId, ConnectionRequestUri c) +createConnection c userId enableNtfs = withAgentEnv c .:: newConn c userId "" enableNtfs -- | Join SMP agent connection (JOIN command) -joinConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> m ConnId -joinConnection c userId enableNtfs = withAgentEnv c .:. joinConn c userId "" enableNtfs +joinConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> CR.PQEncryption -> SubscriptionMode -> m ConnId +joinConnection c userId enableNtfs = withAgentEnv c .:: joinConn c userId "" enableNtfs -- | Allow connection to continue after CONF notification (LET command) allowConnection :: AgentErrorMonad m => AgentClient -> ConnId -> ConfirmationId -> ConnInfo -> m () allowConnection c = withAgentEnv c .:. allowConnection' c -- | Accept contact after REQ notification (ACPT command) -acceptContact :: AgentErrorMonad m => AgentClient -> Bool -> ConfirmationId -> ConnInfo -> SubscriptionMode -> m ConnId -acceptContact c enableNtfs = withAgentEnv c .:. acceptContact' c "" enableNtfs +acceptContact :: AgentErrorMonad m => AgentClient -> Bool -> ConfirmationId -> ConnInfo -> CR.PQEncryption -> SubscriptionMode -> m ConnId +acceptContact c enableNtfs = withAgentEnv c .:: acceptContact' c "" enableNtfs -- | Reject contact (RJCT command) rejectContact :: AgentErrorMonad m => AgentClient -> ConnId -> ConfirmationId -> m () @@ -292,17 +292,17 @@ resubscribeConnections :: AgentErrorMonad m => AgentClient -> [ConnId] -> m (Map resubscribeConnections c = withAgentEnv c . resubscribeConnections' c -- | Send message to the connection (SEND command) -sendMessage :: AgentErrorMonad m => AgentClient -> ConnId -> MsgFlags -> MsgBody -> m AgentMsgId -sendMessage c = withAgentEnv c .:. sendMessage' c +sendMessage :: AgentErrorMonad m => AgentClient -> ConnId -> CR.PQEncryption -> MsgFlags -> MsgBody -> m (AgentMsgId, CR.PQEncryption) +sendMessage c = withAgentEnv c .:: sendMessage' c type MsgReq = (ConnId, MsgFlags, MsgBody) -- | Send multiple messages to different connections (SEND command) -sendMessages :: MonadUnliftIO m => AgentClient -> [MsgReq] -> m [Either AgentErrorType AgentMsgId] -sendMessages c = withAgentEnv c . sendMessages' c +sendMessages :: MonadUnliftIO m => AgentClient -> CR.PQEncryption -> [MsgReq] -> m [Either AgentErrorType (AgentMsgId, CR.PQEncryption)] +sendMessages c = withAgentEnv c .: sendMessages' c -sendMessagesB :: (MonadUnliftIO m, Traversable t) => AgentClient -> t (Either AgentErrorType MsgReq) -> m (t (Either AgentErrorType AgentMsgId)) -sendMessagesB c = withAgentEnv c . sendMessagesB' c +sendMessagesB :: (MonadUnliftIO m, Traversable t) => AgentClient -> CR.PQEncryption -> t (Either AgentErrorType MsgReq) -> m (t (Either AgentErrorType (AgentMsgId, CR.PQEncryption))) +sendMessagesB c = withAgentEnv c .: sendMessagesB' c ackMessage :: AgentErrorMonad m => AgentClient -> ConnId -> AgentMsgId -> Maybe MsgReceiptInfo -> m () ackMessage c = withAgentEnv c .:. ackMessage' c @@ -316,8 +316,8 @@ abortConnectionSwitch :: AgentErrorMonad m => AgentClient -> ConnId -> m Connect abortConnectionSwitch c = withAgentEnv c . abortConnectionSwitch' c -- | Re-synchronize connection ratchet keys -synchronizeRatchet :: AgentErrorMonad m => AgentClient -> ConnId -> Bool -> m ConnectionStats -synchronizeRatchet c = withAgentEnv c .: synchronizeRatchet' c +synchronizeRatchet :: AgentErrorMonad m => AgentClient -> ConnId -> CR.PQEncryption -> Bool -> m ConnectionStats +synchronizeRatchet c = withAgentEnv c .:. synchronizeRatchet' c -- | Suspend SMP agent connection (OFF command) suspendConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m () @@ -514,13 +514,13 @@ client c@AgentClient {rcvQ, subQ} = forever $ do processCommand :: forall m. AgentMonad m => AgentClient -> (EntityId, APartyCmd 'Client) -> m (EntityId, APartyCmd 'Agent) processCommand c (connId, APC e cmd) = second (APC e) <$> case cmd of - NEW enableNtfs (ACM cMode) subMode -> second (INV . ACR cMode) <$> newConn c userId connId enableNtfs cMode Nothing subMode - JOIN enableNtfs (ACR _ cReq) subMode connInfo -> (,OK) <$> joinConn c userId connId enableNtfs cReq connInfo subMode + NEW enableNtfs (ACM cMode) pqIK subMode -> second (INV . ACR cMode) <$> newConn c userId connId enableNtfs cMode Nothing pqIK subMode + JOIN enableNtfs (ACR _ cReq) pqEnc subMode connInfo -> (,OK) <$> joinConn c userId connId enableNtfs cReq connInfo pqEnc subMode LET confId ownCInfo -> allowConnection' c connId confId ownCInfo $> (connId, OK) - ACPT invId ownCInfo -> (,OK) <$> acceptContact' c connId True invId ownCInfo SMSubscribe + ACPT invId pqEnc ownCInfo -> (,OK) <$> acceptContact' c connId True invId ownCInfo pqEnc SMSubscribe RJCT invId -> rejectContact' c connId invId $> (connId, OK) SUB -> subscribeConnection' c connId $> (connId, OK) - SEND msgFlags msgBody -> (connId,) . MID <$> sendMessage' c connId msgFlags msgBody + SEND pqEnc msgFlags msgBody -> (connId,) . uncurry MID <$> sendMessage' c connId pqEnc msgFlags msgBody ACK msgId rcptInfo_ -> ackMessage' c connId msgId rcptInfo_ $> (connId, OK) SWCH -> switchConnection' c connId $> (connId, OK) OFF -> suspendConnection' c connId $> (connId, OK) @@ -549,32 +549,32 @@ deleteUser' c userId delSMPQueues = do whenM (withStore' c (`deleteUserWithoutConns` userId)) . atomically $ writeTBQueue (subQ c) ("", "", APC SAENone $ DEL_USER userId) -newConnAsync :: forall m c. (AgentMonad m, ConnectionModeI c) => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> SubscriptionMode -> m ConnId -newConnAsync c userId corrId enableNtfs cMode subMode = do - connId <- newConnNoQueues c userId "" enableNtfs cMode - enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn $ NEW enableNtfs (ACM cMode) subMode +newConnAsync :: forall m c. (AgentMonad m, ConnectionModeI c) => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> CR.InitialKeys -> SubscriptionMode -> m ConnId +newConnAsync c userId corrId enableNtfs cMode pqInitKeys subMode = do + connId <- newConnNoQueues c userId "" enableNtfs cMode (CR.connPQEncryption pqInitKeys) + enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn $ NEW enableNtfs (ACM cMode) pqInitKeys subMode pure connId -newConnNoQueues :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> m ConnId -newConnNoQueues c userId connId enableNtfs cMode = do +newConnNoQueues :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> CR.PQEncryption -> m ConnId +newConnNoQueues c userId connId enableNtfs cMode pqEncryption = do g <- asks random connAgentVersion <- asks $ maxVersion . smpAgentVRange . config - let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk} + let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk, pqEncryption} withStore c $ \db -> createNewConn db g cData cMode -joinConnAsync :: AgentMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> m ConnId -joinConnAsync c userId corrId enableNtfs cReqUri@(CRInvitationUri ConnReqUriData {crAgentVRange} _) cInfo subMode = do +joinConnAsync :: AgentMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> CR.PQEncryption -> SubscriptionMode -> m ConnId +joinConnAsync c userId corrId enableNtfs cReqUri@(CRInvitationUri ConnReqUriData {crAgentVRange} _) cInfo pqEncryption subMode = do withInvLock c (strEncode cReqUri) "joinConnAsync" $ do aVRange <- asks $ smpAgentVRange . config case crAgentVRange `compatibleVersion` aVRange of Just (Compatible connAgentVersion) -> do g <- asks random - let cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk} + let cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk, pqEncryption} connId <- withStore c $ \db -> createNewConn db g cData SCMInvitation - enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn $ JOIN enableNtfs (ACR sConnectionMode cReqUri) subMode cInfo + enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn $ JOIN enableNtfs (ACR sConnectionMode cReqUri) pqEncryption subMode cInfo pure connId _ -> throwError $ AGENT A_VERSION -joinConnAsync _c _userId _corrId _enableNtfs (CRContactUri _) _subMode _cInfo = +joinConnAsync _c _userId _corrId _enableNtfs (CRContactUri _) _subMode _cInfo _pqEncryption = throwError $ CMD PROHIBITED allowConnectionAsync' :: AgentMonad m => AgentClient -> ACorrId -> ConnId -> ConfirmationId -> ConnInfo -> m () @@ -584,13 +584,13 @@ allowConnectionAsync' c corrId connId confId ownConnInfo = enqueueCommand c corrId connId (Just server) $ AClientCommand $ APC SAEConn $ LET confId ownConnInfo _ -> throwError $ CMD PROHIBITED -acceptContactAsync' :: AgentMonad m => AgentClient -> ACorrId -> Bool -> InvitationId -> ConnInfo -> SubscriptionMode -> m ConnId -acceptContactAsync' c corrId enableNtfs invId ownConnInfo subMode = do +acceptContactAsync' :: AgentMonad m => AgentClient -> ACorrId -> Bool -> InvitationId -> ConnInfo -> CR.PQEncryption -> SubscriptionMode -> m ConnId +acceptContactAsync' c corrId enableNtfs invId ownConnInfo pqEnc subMode = do Invitation {contactConnId, connReq} <- withStore c (`getInvitation` invId) withStore c (`getConn` contactConnId) >>= \case SomeConn _ (ContactConnection ConnData {userId} _) -> do withStore' c $ \db -> acceptInvitation db invId ownConnInfo - joinConnAsync c userId corrId enableNtfs connReq ownConnInfo subMode `catchAgentError` \err -> do + joinConnAsync c userId corrId enableNtfs connReq ownConnInfo pqEnc subMode `catchAgentError` \err -> do withStore' c (`unacceptInvitation` invId) throwError err _ -> throwError $ CMD PROHIBITED @@ -644,17 +644,20 @@ switchConnectionAsync' c corrId connId = pure . connectionStats $ DuplexConnection cData rqs' sqs _ -> throwError $ CMD PROHIBITED -newConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SubscriptionMode -> m (ConnId, ConnectionRequestUri c) -newConn c userId connId enableNtfs cMode clientData subMode = - getSMPServer c userId >>= newConnSrv c userId connId enableNtfs cMode clientData subMode +newConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> m (ConnId, ConnectionRequestUri c) +newConn c userId connId enableNtfs cMode clientData pqInitKeys subMode = + getSMPServer c userId >>= newConnSrv c userId connId enableNtfs cMode clientData pqInitKeys subMode -newConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SubscriptionMode -> SMPServerWithAuth -> m (ConnId, ConnectionRequestUri c) -newConnSrv c userId connId enableNtfs cMode clientData subMode srv = do - connId' <- newConnNoQueues c userId connId enableNtfs cMode - newRcvConnSrv c userId connId' enableNtfs cMode clientData subMode srv +newConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> SMPServerWithAuth -> m (ConnId, ConnectionRequestUri c) +newConnSrv c userId connId enableNtfs cMode clientData pqInitKeys subMode srv = do + connId' <- newConnNoQueues c userId connId enableNtfs cMode (CR.connPQEncryption pqInitKeys) + newRcvConnSrv c userId connId' enableNtfs cMode clientData pqInitKeys subMode srv -newRcvConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SubscriptionMode -> SMPServerWithAuth -> m (ConnId, ConnectionRequestUri c) -newRcvConnSrv c userId connId enableNtfs cMode clientData subMode srv = do +newRcvConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> SMPServerWithAuth -> m (ConnId, ConnectionRequestUri c) +newRcvConnSrv c userId connId enableNtfs cMode clientData pqInitKeys subMode srv = do + case (cMode, pqInitKeys) of + (SCMContact, CR.IKUsePQ) -> throwError $ CMD PROHIBITED + _ -> pure () AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config (rq, qUri) <- newRcvQueue c userId connId srv smpClientVRange subMode `catchAgentError` \e -> liftIO (print e) >> throwError e rq' <- withStore c $ \db -> updateNewConnRcv db connId rq @@ -669,71 +672,73 @@ newRcvConnSrv c userId connId enableNtfs cMode clientData subMode srv = do SCMContact -> pure (connId, CRContactUri crData) SCMInvitation -> do g <- asks random - (pk1, pk2, e2eRcvParams) <- atomically . CR.generateE2EParams g $ maxVersion e2eEncryptVRange - withStore' c $ \db -> createRatchetX3dhKeys db connId pk1 pk2 + (pk1, pk2, pKem, e2eRcvParams) <- liftIO $ CR.generateRcvE2EParams g (maxVersion e2eEncryptVRange) (CR.initialPQEncryption pqInitKeys) + withStore' c $ \db -> createRatchetX3dhKeys db connId pk1 pk2 pKem pure (connId, CRInvitationUri crData $ toVersionRangeT e2eRcvParams e2eEncryptVRange) -joinConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> m ConnId -joinConn c userId connId enableNtfs cReq cInfo subMode = do +joinConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> CR.PQEncryption -> SubscriptionMode -> m ConnId +joinConn c userId connId enableNtfs cReq cInfo pqEnc subMode = do srv <- case cReq of CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _ -> getNextServer c userId [qServer q] _ -> getSMPServer c userId - joinConnSrv c userId connId enableNtfs cReq cInfo subMode srv + joinConnSrv c userId connId enableNtfs cReq cInfo pqEnc subMode srv -startJoinInvitation :: AgentMonad m => UserId -> ConnId -> Bool -> ConnectionRequestUri 'CMInvitation -> m (Compatible Version, ConnData, NewSndQueue, CR.Ratchet 'C.X448, CR.E2ERatchetParams 'C.X448) -startJoinInvitation userId connId enableNtfs (CRInvitationUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)} e2eRcvParamsUri) = do +startJoinInvitation :: AgentMonad m => UserId -> ConnId -> Bool -> ConnectionRequestUri 'CMInvitation -> CR.PQEncryption -> m (Compatible Version, ConnData, NewSndQueue, CR.Ratchet 'C.X448, CR.SndE2ERatchetParams 'C.X448) +startJoinInvitation userId connId enableNtfs (CRInvitationUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)} e2eRcvParamsUri) pqEncryption = do AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config case ( qUri `compatibleVersion` smpClientVRange, e2eRcvParamsUri `compatibleVersion` e2eEncryptVRange, crAgentVRange `compatibleVersion` smpAgentVRange ) of - (Just qInfo, Just (Compatible e2eRcvParams@(CR.E2ERatchetParams _ _ rcDHRr)), Just aVersion@(Compatible connAgentVersion)) -> do + (Just qInfo, Just (Compatible e2eRcvParams@(CR.E2ERatchetParams v _ rcDHRr kem_)), Just aVersion@(Compatible connAgentVersion)) -> do g <- asks random - (pk1, pk2, e2eSndParams) <- atomically . CR.generateE2EParams g $ version e2eRcvParams + (pk1, pk2, pKem, e2eSndParams) <- liftIO $ CR.generateSndE2EParams g v (CR.replyKEM_ pqEncryption kem_) (_, rcDHRs) <- atomically $ C.generateKeyPair g - let rc = CR.initSndRatchet e2eEncryptVRange rcDHRr rcDHRs $ CR.x3dhSnd pk1 pk2 e2eRcvParams + -- TODO PQ generate KEM keypair if needed - is it done? + rcParams <- liftEitherWith cryptoError $ CR.pqX3dhSnd pk1 pk2 pKem e2eRcvParams + let rc = CR.initSndRatchet e2eEncryptVRange rcDHRr rcDHRs rcParams q <- newSndQueue userId "" qInfo - let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk} + let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk, pqEncryption} pure (aVersion, cData, q, rc, e2eSndParams) _ -> throwError $ AGENT A_VERSION -joinConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> SMPServerWithAuth -> m ConnId -joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo subMode srv = +joinConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> CR.PQEncryption -> SubscriptionMode -> SMPServerWithAuth -> m ConnId +joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqEnc subMode srv = withInvLock c (strEncode inv) "joinConnSrv" $ do - (aVersion, cData, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv + (aVersion, cData, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv pqEnc g <- asks random (connId', sq) <- withStore c $ \db -> runExceptT $ do r@(connId', _) <- ExceptT $ createSndConn db g cData q liftIO $ createRatchet db connId' rc pure r let cData' = (cData :: ConnData) {connId = connId'} - tryError (confirmQueue aVersion c cData' sq srv cInfo (Just e2eSndParams) subMode) >>= \case + tryError (confirmQueue aVersion c cData' sq srv cInfo (Just e2eSndParams) (Just pqEnc) subMode) >>= \case Right _ -> pure connId' Left e -> do -- possible improvement: recovery for failure on network timeout, see rfcs/2022-04-20-smp-conf-timeout-recovery.md void $ withStore' c $ \db -> deleteConn db Nothing connId' throwError e -joinConnSrv c userId connId enableNtfs (CRContactUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)}) cInfo subMode srv = do +joinConnSrv c userId connId enableNtfs (CRContactUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)}) cInfo pqEnc subMode srv = do aVRange <- asks $ smpAgentVRange . config clientVRange <- asks $ smpClientVRange . config case ( qUri `compatibleVersion` clientVRange, crAgentVRange `compatibleVersion` aVRange ) of (Just qInfo, Just vrsn) -> do - (connId', cReq) <- newConnSrv c userId connId enableNtfs SCMInvitation Nothing subMode srv + (connId', cReq) <- newConnSrv c userId connId enableNtfs SCMInvitation Nothing (CR.joinContactInitialKeys pqEnc) subMode srv sendInvitation c userId qInfo vrsn cReq cInfo pure connId' _ -> throwError $ AGENT A_VERSION -joinConnSrvAsync :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> SMPServerWithAuth -> m () -joinConnSrvAsync c userId connId enableNtfs inv@CRInvitationUri {} cInfo subMode srv = do - (_aVersion, cData, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv +joinConnSrvAsync :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> CR.PQEncryption -> SubscriptionMode -> SMPServerWithAuth -> m () +joinConnSrvAsync c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqEnc subMode srv = do + (_aVersion, cData, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv pqEnc q' <- withStore c $ \db -> runExceptT $ do liftIO $ createRatchet db connId rc ExceptT $ updateNewConnSnd db connId q - confirmQueueAsync c cData q' srv cInfo (Just e2eSndParams) subMode -joinConnSrvAsync _c _userId _connId _enableNtfs (CRContactUri _) _cInfo _subMode _srv = do + confirmQueueAsync c cData q' srv cInfo (Just e2eSndParams) pqEnc subMode +joinConnSrvAsync _c _userId _connId _enableNtfs (CRContactUri _) _cInfo _subMode _pqEnc _srv = do throwError $ CMD PROHIBITED createReplyQueue :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> SubscriptionMode -> SMPServerWithAuth -> m SMPQueueInfo @@ -764,13 +769,13 @@ allowConnection' c connId confId ownConnInfo = withConnLock c connId "allowConne _ -> throwError $ CMD PROHIBITED -- | Accept contact (ACPT command) in Reader monad -acceptContact' :: AgentMonad m => AgentClient -> ConnId -> Bool -> InvitationId -> ConnInfo -> SubscriptionMode -> m ConnId -acceptContact' c connId enableNtfs invId ownConnInfo subMode = withConnLock c connId "acceptContact" $ do +acceptContact' :: AgentMonad m => AgentClient -> ConnId -> Bool -> InvitationId -> ConnInfo -> CR.PQEncryption -> SubscriptionMode -> m ConnId +acceptContact' c connId enableNtfs invId ownConnInfo pqEnc subMode = withConnLock c connId "acceptContact" $ do Invitation {contactConnId, connReq} <- withStore c (`getInvitation` invId) withStore c (`getConn` contactConnId) >>= \case SomeConn _ (ContactConnection ConnData {userId} _) -> do withStore' c $ \db -> acceptInvitation db invId ownConnInfo - joinConn c userId connId enableNtfs connReq ownConnInfo subMode `catchAgentError` \err -> do + joinConn c userId connId enableNtfs connReq ownConnInfo pqEnc subMode `catchAgentError` \err -> do withStore' c (`unacceptInvitation` invId) throwError err _ -> throwError $ CMD PROHIBITED @@ -905,18 +910,18 @@ getNotificationMessage' c nonce encNtfInfo = do Nothing -> SMP.notification msgFlags -- | Send message to the connection (SEND command) in Reader monad -sendMessage' :: forall m. AgentMonad m => AgentClient -> ConnId -> MsgFlags -> MsgBody -> m AgentMsgId -sendMessage' c connId msgFlags msg = liftEither . runIdentity =<< sendMessagesB' c (Identity (Right (connId, msgFlags, msg))) +sendMessage' :: forall m. AgentMonad m => AgentClient -> ConnId -> CR.PQEncryption -> MsgFlags -> MsgBody -> m (AgentMsgId, CR.PQEncryption) +sendMessage' c connId pqEnc msgFlags msg = liftEither . runIdentity =<< sendMessagesB' c pqEnc (Identity (Right (connId, msgFlags, msg))) -- | Send multiple messages to different connections (SEND command) in Reader monad -sendMessages' :: forall m. AgentMonad' m => AgentClient -> [MsgReq] -> m [Either AgentErrorType AgentMsgId] -sendMessages' c = sendMessagesB' c . map Right +sendMessages' :: forall m. AgentMonad' m => AgentClient -> CR.PQEncryption -> [MsgReq] -> m [Either AgentErrorType (AgentMsgId, CR.PQEncryption)] +sendMessages' c pqEnc = sendMessagesB' c pqEnc . map Right -sendMessagesB' :: forall m t. (AgentMonad' m, Traversable t) => AgentClient -> t (Either AgentErrorType MsgReq) -> m (t (Either AgentErrorType AgentMsgId)) -sendMessagesB' c reqs = withConnLocks c connIds "sendMessages" $ do +sendMessagesB' :: forall m t. (AgentMonad' m, Traversable t) => AgentClient -> CR.PQEncryption -> t (Either AgentErrorType MsgReq) -> m (t (Either AgentErrorType (AgentMsgId, CR.PQEncryption))) +sendMessagesB' c pqEnc reqs = withConnLocks c connIds "sendMessages" $ do reqs' <- withStoreBatch c (\db -> fmap (bindRight $ \req@(connId, _, _) -> bimap storeError (req,) <$> getConn db connId) reqs) let reqs'' = fmap (>>= prepareConn) reqs' - enqueueMessagesB c reqs'' + enqueueMessagesB c (Just pqEnc) reqs'' where prepareConn :: (MsgReq, SomeConn) -> Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage) prepareConn ((_, msgFlags, msg), SomeConn _ conn) = case conn of @@ -965,16 +970,16 @@ runCommandProcessing c@AgentClient {subQ} server_ Worker {doWork} = do processCmd :: RetryInterval -> PendingCommand -> m () processCmd ri PendingCommand {cmdId, corrId, userId, connId, command} = case command of AClientCommand (APC _ cmd) -> case cmd of - NEW enableNtfs (ACM cMode) subMode -> noServer $ do + NEW enableNtfs (ACM cMode) pqEnc subMode -> noServer $ do usedSrvs <- newTVarIO ([] :: [SMPServer]) tryCommand . withNextSrv c userId usedSrvs [] $ \srv -> do - (_, cReq) <- newRcvConnSrv c userId connId enableNtfs cMode Nothing subMode srv + (_, cReq) <- newRcvConnSrv c userId connId enableNtfs cMode Nothing pqEnc subMode srv notify $ INV (ACR cMode cReq) - JOIN enableNtfs (ACR _ cReq@(CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _)) subMode connInfo -> noServer $ do + JOIN enableNtfs (ACR _ cReq@(CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _)) pqEnc subMode connInfo -> noServer $ do let initUsed = [qServer q] usedSrvs <- newTVarIO initUsed tryCommand . withNextSrv c userId usedSrvs initUsed $ \srv -> do - joinConnSrvAsync c userId connId enableNtfs cReq connInfo subMode srv + joinConnSrvAsync c userId connId enableNtfs cReq connInfo pqEnc subMode srv notify OK LET confId ownCInfo -> withServer' . tryCommand $ allowConnection' c connId confId ownCInfo >> notify OK ACK msgId rcptInfo_ -> withServer' . tryCommand $ ackMessage' c connId msgId rcptInfo_ >> notify OK @@ -999,7 +1004,7 @@ runCommandProcessing c@AgentClient {subQ} server_ Worker {doWork} = do _ -> throwError $ INTERNAL $ "incorrect connection type " <> show (internalCmdTag cmd) ICDuplexSecure _rId senderKey -> withServer' . tryWithLock "ICDuplexSecure" . withDuplexConn $ \(DuplexConnection cData (rq :| _) (sq :| _)) -> do secure rq senderKey - void $ enqueueMessage c cData sq SMP.MsgFlags {notification = True} HELLO + void $ enqueueMessage c cData sq Nothing SMP.MsgFlags {notification = True} HELLO -- ICDeleteConn is no longer used, but it can be present in old client databases ICDeleteConn -> withStore' c (`deleteCommand` cmdId) ICDeleteRcvQueue rId -> withServer $ \srv -> tryWithLock "ICDeleteRcvQueue" $ do @@ -1014,7 +1019,7 @@ runCommandProcessing c@AgentClient {subQ} server_ Worker {doWork} = do Just rq1 -> when (status == Confirmed) $ do secureQueue c rq' senderKey withStore' c $ \db -> setRcvQueueStatus db rq' Secured - void . enqueueMessages c cData sqs SMP.noMsgFlags $ QUSE [((server, sndId), True)] + void . enqueueMessages c cData sqs Nothing SMP.noMsgFlags $ QUSE [((server, sndId), True)] rq1' <- withStore' c $ \db -> setRcvSwitchStatus db rq1 $ Just RSSendingQUSE let rqs' = updatedQs rq1' rqs conn' = DuplexConnection cData rqs' sqs @@ -1077,39 +1082,39 @@ runCommandProcessing c@AgentClient {subQ} server_ Worker {doWork} = do notify cmd = atomically $ writeTBQueue subQ (corrId, connId, APC (sAEntity @e) cmd) -- ^ ^ ^ async command processing / -enqueueMessages :: AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> MsgFlags -> AMessage -> m AgentMsgId -enqueueMessages c cData sqs msgFlags aMessage = do +enqueueMessages :: AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> Maybe CR.PQEncryption -> MsgFlags -> AMessage -> m (AgentMsgId, CR.PQEncryption) +enqueueMessages c cData sqs pqEnc_ msgFlags aMessage = do when (ratchetSyncSendProhibited cData) $ throwError $ INTERNAL "enqueueMessages: ratchet is not synchronized" - enqueueMessages' c cData sqs msgFlags aMessage + enqueueMessages' c cData sqs pqEnc_ msgFlags aMessage -enqueueMessages' :: AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> MsgFlags -> AMessage -> m AgentMsgId -enqueueMessages' c cData sqs msgFlags aMessage = - liftEither . runIdentity =<< enqueueMessagesB c (Identity (Right (cData, sqs, msgFlags, aMessage))) +enqueueMessages' :: AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> Maybe CR.PQEncryption -> MsgFlags -> AMessage -> m (AgentMsgId, CR.PQEncryption) +enqueueMessages' c cData sqs pqEnc_ msgFlags aMessage = + liftEither . runIdentity =<< enqueueMessagesB c pqEnc_ (Identity (Right (cData, sqs, msgFlags, aMessage))) -enqueueMessagesB :: (AgentMonad' m, Traversable t) => AgentClient -> t (Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage)) -> m (t (Either AgentErrorType AgentMsgId)) -enqueueMessagesB c reqs = do - reqs' <- enqueueMessageB c reqs +enqueueMessagesB :: (AgentMonad' m, Traversable t) => AgentClient -> Maybe CR.PQEncryption -> t (Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage)) -> m (t (Either AgentErrorType (AgentMsgId, CR.PQEncryption))) +enqueueMessagesB c pqEnc_ reqs = do + reqs' <- enqueueMessageB c pqEnc_ reqs enqueueSavedMessageB c $ mapMaybe snd $ rights $ toList reqs' pure $ fst <$$> reqs' isActiveSndQ :: SndQueue -> Bool isActiveSndQ SndQueue {status} = status == Secured || status == Active -enqueueMessage :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> MsgFlags -> AMessage -> m AgentMsgId -enqueueMessage c cData sq msgFlags aMessage = - liftEither . fmap fst . runIdentity =<< enqueueMessageB c (Identity (Right (cData, [sq], msgFlags, aMessage))) +enqueueMessage :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> Maybe CR.PQEncryption -> MsgFlags -> AMessage -> m (AgentMsgId, CR.PQEncryption) +enqueueMessage c cData sq pqEnc_ msgFlags aMessage = + liftEither . fmap fst . runIdentity =<< enqueueMessageB c pqEnc_ (Identity (Right (cData, [sq], msgFlags, aMessage))) -- this function is used only for sending messages in batch, it returns the list of successes to enqueue additional deliveries -enqueueMessageB :: forall m t. (AgentMonad' m, Traversable t) => AgentClient -> t (Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage)) -> m (t (Either AgentErrorType (AgentMsgId, Maybe (ConnData, [SndQueue], AgentMsgId)))) -enqueueMessageB c reqs = do +enqueueMessageB :: forall m t. (AgentMonad' m, Traversable t) => AgentClient -> Maybe CR.PQEncryption -> t (Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage)) -> m (t (Either AgentErrorType ((AgentMsgId, CR.PQEncryption), Maybe (ConnData, [SndQueue], AgentMsgId)))) +enqueueMessageB c pqEnc_ reqs = do aVRange <- asks $ maxVersion . smpAgentVRange . config reqMids <- withStoreBatch c $ \db -> fmap (bindRight $ storeSentMsg db aVRange) reqs - forME reqMids $ \((cData, sq :| sqs, _, _), InternalId msgId) -> do + forME reqMids $ \((cData, sq :| sqs, _, _), InternalId msgId, pqSecr) -> do submitPendingMsg c cData sq let sqs' = filter isActiveSndQ sqs - pure $ Right (msgId, if null sqs' then Nothing else Just (cData, sqs', msgId)) + pure $ Right ((msgId, pqSecr), if null sqs' then Nothing else Just (cData, sqs', msgId)) where - storeSentMsg :: DB.Connection -> Version -> (ConnData, NonEmpty SndQueue, MsgFlags, AMessage) -> IO (Either AgentErrorType ((ConnData, NonEmpty SndQueue, MsgFlags, AMessage), InternalId)) + storeSentMsg :: DB.Connection -> Version -> (ConnData, NonEmpty SndQueue, MsgFlags, AMessage) -> IO (Either AgentErrorType ((ConnData, NonEmpty SndQueue, MsgFlags, AMessage), InternalId, CR.PQEncryption)) storeSentMsg db agentVersion req@(ConnData {connId}, sq :| _, msgFlags, aMessage) = fmap (first storeError) $ runExceptT $ do internalTs <- liftIO getCurrentTime (internalId, internalSndId, prevMsgHash) <- liftIO $ updateSndIds db connId @@ -1117,13 +1122,13 @@ enqueueMessageB c reqs = do agentMsg = AgentMessage privHeader aMessage agentMsgStr = smpEncode agentMsg internalHash = C.sha256Hash agentMsgStr - encAgentMessage <- agentRatchetEncrypt db connId agentMsgStr e2eEncUserMsgLength + (encAgentMessage, pqEncryption) <- agentRatchetEncrypt db connId agentMsgStr e2eEncUserMsgLength pqEnc_ let msgBody = smpEncode $ AgentMsgEnvelope {agentVersion, encAgentMessage} msgType = agentMessageType agentMsg - msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgFlags, msgBody, internalHash, prevMsgHash} + msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgFlags, msgBody, pqEncryption, internalHash, prevMsgHash} liftIO $ createSndMsg db connId msgData liftIO $ createSndMsgDelivery db connId sq internalId - pure (req, internalId) + pure (req, internalId, pqEncryption) enqueueSavedMessage :: AgentMonad' m => AgentClient -> ConnData -> AgentMsgId -> SndQueue -> m () enqueueSavedMessage c cData msgId sq = enqueueSavedMessageB c $ Identity (cData, [sq], msgId) @@ -1166,7 +1171,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} ConnData {connId} sq (Worker {doWork atomically $ throwWhenNoDelivery c sq atomically $ beginAgentOperation c AOSndNetwork withWork c doWork (\db -> getPendingQueueMsg db connId sq) $ - \(rq_, PendingMsgData {msgId, msgType, msgBody, msgFlags, msgRetryState, internalTs}) -> do + \(rq_, PendingMsgData {msgId, msgType, msgBody, pqEncryption, msgFlags, msgRetryState, internalTs}) -> do atomically $ endAgentOperation c AOMsgDelivery -- this operation begins in submitPendingMsg let mId = unId msgId ri' = maybe id updateRetryInterval2 msgRetryState ri @@ -1236,7 +1241,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} ConnData {connId} sq (Worker {doWork -- it would lead to the non-deterministic internal ID of the first sent message, at to some other race conditions, -- because it can be sent before HELLO is received -- With `status == Active` condition, CON is sent here only by the accepting party, that previously received HELLO - when (status == Active) $ notify CON + when (status == Active) $ notify $ CON pqEncryption -- this branch should never be reached as receive queue is created before the confirmation, _ -> logError "HELLO sent without receive queue" AM_A_MSG_ -> notify $ SENT mId @@ -1335,7 +1340,7 @@ ackMessage' c connId msgId rcptInfo_ = withConnLock c connId "ackMessage" $ do when (connAgentVersion >= deliveryRcptsSMPAgentVersion) $ do let RcvMsg {msgMeta = MsgMeta {sndMsgId}, internalHash} = msg rcpt = A_RCVD [AMessageReceipt {agentMsgId = sndMsgId, msgHash = internalHash, rcptInfo}] - void $ enqueueMessages c cData sqs SMP.MsgFlags {notification = False} rcpt + void $ enqueueMessages c cData sqs Nothing SMP.MsgFlags {notification = False} rcpt Nothing -> case (msgType, msgReceipt) of -- only remove sent message if receipt hash was Ok, both to debug and for future redundancy (AM_A_RCVD_, Just MsgReceipt {agentMsgId = sndMsgId, msgRcptStatus = MROk}) -> @@ -1365,7 +1370,7 @@ switchDuplexConnection c (DuplexConnection cData@ConnData {connId, userId} rqs s let rq' = (q :: NewRcvQueue) {primary = True, dbReplaceQueueId = Just dbQueueId} rq'' <- withStore c $ \db -> addConnRcvQueue db connId rq' addSubscription c rq'' - void . enqueueMessages c cData sqs SMP.noMsgFlags $ QADD [(qUri, Just (server, sndId))] + void . enqueueMessages c cData sqs Nothing SMP.noMsgFlags $ QADD [(qUri, Just (server, sndId))] rq1 <- withStore' c $ \db -> setRcvSwitchStatus db rq $ Just RSSendingQADD let rqs' = updatedQs rq1 rqs <> [rq''] pure . connectionStats $ DuplexConnection cData rqs' sqs @@ -1394,19 +1399,19 @@ abortConnectionSwitch' c connId = _ -> throwError $ CMD PROHIBITED _ -> throwError $ CMD PROHIBITED -synchronizeRatchet' :: AgentMonad m => AgentClient -> ConnId -> Bool -> m ConnectionStats -synchronizeRatchet' c connId force = withConnLock c connId "synchronizeRatchet" $ do +synchronizeRatchet' :: AgentMonad m => AgentClient -> ConnId -> CR.PQEncryption -> Bool -> m ConnectionStats +synchronizeRatchet' c connId pqEnc force = withConnLock c connId "synchronizeRatchet" $ do withStore c (`getConn` connId) >>= \case SomeConn _ (DuplexConnection cData rqs sqs) | ratchetSyncAllowed cData || force -> do -- check queues are not switching? AgentConfig {e2eEncryptVRange} <- asks config g <- asks random - (pk1, pk2, e2eParams) <- atomically . CR.generateE2EParams g $ maxVersion e2eEncryptVRange + (pk1, pk2, pKem, e2eParams) <- liftIO $ CR.generateRcvE2EParams g (maxVersion e2eEncryptVRange) pqEnc enqueueRatchetKeyMsgs c cData sqs e2eParams withStore' c $ \db -> do setConnRatchetSync db connId RSStarted - setRatchetX3dhKeys db connId pk1 pk2 + setRatchetX3dhKeys db connId pk1 pk2 pKem let cData' = cData {ratchetSyncState = RSStarted} :: ConnData conn' = DuplexConnection cData' rqs sqs pure $ connectionStats conn' @@ -1938,7 +1943,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, notify (MSGNTF $ SMP.rcvMessageMeta srvMsgId msg') where queueDrained = case conn of - DuplexConnection _ _ sqs -> void $ enqueueMessages c cData sqs SMP.noMsgFlags $ QCONT (sndAddress rq) + DuplexConnection _ _ sqs -> void $ enqueueMessages c cData sqs Nothing SMP.noMsgFlags $ QCONT (sndAddress rq) _ -> pure () processClientMsg srvTs msgFlags msgBody = do clientMsg@SMP.ClientMsgEnvelope {cmHeader = SMP.PubHeader phVer e2ePubKey_} <- @@ -1981,7 +1986,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, Right (Just (msgId, msgMeta, aMessage, rcPrev)) -> do conn'' <- resetRatchetSync case aMessage of - HELLO -> helloMsg srvMsgId conn'' >> ackDel msgId + HELLO -> helloMsg srvMsgId msgMeta conn'' >> ackDel msgId -- note that there is no ACK sent for A_MSG, it is sent with agent's user ACK command A_MSG body -> do logServer "<--" c srv rId $ "MSG :" <> logSecret srvMsgId @@ -2041,7 +2046,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, agentClientMsg :: TVar ChaChaDRG -> ByteString -> m (Maybe (InternalId, MsgMeta, AMessage, CR.RatchetX448)) agentClientMsg g encryptedMsgHash = withStore c $ \db -> runExceptT $ do rc <- ExceptT $ getRatchet db connId -- ratchet state pre-decryption - required for processing EREADY - agentMsgBody <- agentRatchetDecrypt' g db connId rc encAgentMessage + (agentMsgBody, pqEncryption) <- agentRatchetDecrypt' g db connId rc encAgentMessage liftEither (parse smpP (SEAgentError $ AGENT A_MESSAGE) agentMsgBody) >>= \case agentMsg@(AgentMessage APrivHeader {sndMsgId, prevMsgHash} aMessage) -> do let msgType = agentMessageType agentMsg @@ -2051,7 +2056,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, let integrity = checkMsgIntegrity prevExtSndId sndMsgId prevRcvMsgHash prevMsgHash recipient = (unId internalId, internalTs) broker = (srvMsgId, systemToUTCTime srvTs) - msgMeta = MsgMeta {integrity, recipient, broker, sndMsgId} + msgMeta = MsgMeta {integrity, recipient, broker, sndMsgId, pqEncryption} rcvMsg = RcvMsgData {msgMeta, msgType, msgFlags, msgBody = agentMsgBody, internalRcvId, internalHash, externalPrevSndHash = prevMsgHash, encryptedMsgHash} liftIO $ createRcvMsg db connId rq rcvMsg pure $ Just (internalId, msgMeta, aMessage, rc) @@ -2121,7 +2126,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, parseMessage :: Encoding a => ByteString -> m a parseMessage = liftEither . parse smpP (AGENT A_MESSAGE) - smpConfirmation :: SMP.MsgId -> Connection c -> C.APublicAuthKey -> C.PublicKeyX25519 -> Maybe (CR.E2ERatchetParams 'C.X448) -> ByteString -> Version -> Version -> m () + smpConfirmation :: SMP.MsgId -> Connection c -> C.APublicAuthKey -> C.PublicKeyX25519 -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> ByteString -> Version -> Version -> m () smpConfirmation srvMsgId conn' senderKey e2ePubKey e2eEncryption encConnInfo smpClientVersion agentVersion = do logServer "<--" c srv rId $ "MSG :" <> logSecret srvMsgId AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config @@ -2131,10 +2136,11 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, case status of New -> case (conn', e2eEncryption) of -- party initiating connection - (RcvConnection {}, Just e2eSndParams@(CR.E2ERatchetParams e2eVersion _ _)) -> do + (RcvConnection ConnData {pqEncryption} _, Just (CR.AE2ERatchetParams _ e2eSndParams@(CR.E2ERatchetParams e2eVersion _ _ _))) -> do unless (e2eVersion `isCompatible` e2eEncryptVRange) (throwError $ AGENT A_VERSION) - (pk1, rcDHRs) <- withStore c (`getRatchetX3dhKeys` connId) - let rc = CR.initRcvRatchet e2eEncryptVRange rcDHRs $ CR.x3dhRcv pk1 rcDHRs e2eSndParams + (pk1, rcDHRs, pKem) <- withStore c (`getRatchetX3dhKeys` connId) + rcParams <- liftError cryptoError $ CR.pqX3dhRcv pk1 rcDHRs pKem e2eSndParams + let rc = CR.initRcvRatchet e2eEncryptVRange rcDHRs rcParams pqEncryption g <- asks random (agentMsgBody_, rc', skipped) <- liftError cryptoError $ CR.rcDecrypt g rc M.empty encConnInfo case (agentMsgBody_, skipped) of @@ -2155,7 +2161,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, -- party accepting connection (DuplexConnection _ (RcvQueue {smpClientVersion = v'} :| _) _, Nothing) -> do g <- asks random - withStore c (\db -> runExceptT $ agentRatchetDecrypt g db connId encConnInfo) >>= parseMessage >>= \case + withStore c (\db -> runExceptT $ agentRatchetDecrypt g db connId encConnInfo) >>= parseMessage . fst >>= \case AgentConnInfo connInfo -> do notify $ INFO connInfo let dhSecret = C.dh' e2ePubKey e2ePrivKey @@ -2165,8 +2171,8 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, _ -> prohibited _ -> prohibited - helloMsg :: SMP.MsgId -> Connection c -> m () - helloMsg srvMsgId conn' = do + helloMsg :: SMP.MsgId -> MsgMeta -> Connection c -> m () + helloMsg srvMsgId MsgMeta {pqEncryption} conn' = do logServer "<--" c srv rId $ "MSG :" <> logSecret srvMsgId case status of Active -> prohibited @@ -2176,14 +2182,16 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, -- `sndStatus == Active` when HELLO was previously sent, and this is the reply HELLO -- this branch is executed by the accepting party in duplexHandshake mode (v2) -- (was executed by initiating party in v1 that is no longer supported) - | sndStatus == Active -> notify CON + -- + -- TODO PQ encryption mode + | sndStatus == Active -> notify $ CON pqEncryption | otherwise -> enqueueDuplexHello sq _ -> pure () where enqueueDuplexHello :: SndQueue -> m () enqueueDuplexHello sq = do let cData' = toConnData conn' - void $ enqueueMessage c cData' sq SMP.MsgFlags {notification = True} HELLO + void $ enqueueMessage c cData' sq Nothing SMP.MsgFlags {notification = True} HELLO continueSending :: SMP.MsgId -> (SMPServer, SMP.SenderId) -> Connection 'CDuplex -> m () continueSending srvMsgId addr (DuplexConnection _ _ sqs) = @@ -2240,7 +2248,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, (Just sndPubKey, Just dhPublicKey) -> do logServer "<--" c srv rId $ "MSG :" <> logSecret srvMsgId <> " " <> logSecret (senderId queueAddress) let sqInfo' = (sqInfo :: SMPQueueInfo) {queueAddress = queueAddress {dhPublicKey}} - void . enqueueMessages c cData' sqs SMP.noMsgFlags $ QKEY [(sqInfo', sndPubKey)] + void . enqueueMessages c cData' sqs Nothing SMP.noMsgFlags $ QKEY [(sqInfo', sndPubKey)] sq1 <- withStore' c $ \db -> setSndSwitchStatus db sq $ Just SSSendingQKEY let sqs'' = updatedQs sq1 sqs' <> [sq2] conn' = DuplexConnection cData' rqs sqs'' @@ -2285,7 +2293,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, withStore' c $ \db -> setSndQueueStatus db sq' Secured let sq'' = (sq' :: SndQueue) {status = Secured} -- sending QTEST to the new queue only, the old one will be removed if sent successfully - void $ enqueueMessages c cData' [sq''] SMP.noMsgFlags $ QTEST [addr] + void $ enqueueMessages c cData' [sq''] Nothing SMP.noMsgFlags $ QTEST [addr] sq1' <- withStore' c $ \db -> setSndSwitchStatus db sq1 $ Just SSSendingQTEST let sqs' = updatedQs sq1' sqs conn' = DuplexConnection cData' rqs sqs' @@ -2301,7 +2309,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, let CR.Ratchet {rcSnd} = rcPrev -- if ratchet was initialized as receiving, it means EREADY wasn't sent on key negotiation when (isNothing rcSnd) . void $ - enqueueMessages' c cData' sqs SMP.MsgFlags {notification = True} (EREADY lastExternalSndId) + enqueueMessages' c cData' sqs Nothing SMP.MsgFlags {notification = True} (EREADY lastExternalSndId) smpInvitation :: SMP.MsgId -> Connection c -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m () smpInvitation srvMsgId conn' connReq@(CRInvitationUri crData _) cInfo = do @@ -2320,8 +2328,8 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, DuplexConnection {} -> action conn' _ -> qError $ name <> ": message must be sent to duplex connection" - newRatchetKey :: CR.E2ERatchetParams 'C.X448 -> Connection 'CDuplex -> m () - newRatchetKey e2eOtherPartyParams@(CR.E2ERatchetParams e2eVersion k1Rcv k2Rcv) conn'@(DuplexConnection cData'@ConnData {lastExternalSndId} _ sqs) = + newRatchetKey :: CR.RcvE2ERatchetParams 'C.X448 -> Connection 'CDuplex -> m () + newRatchetKey e2eOtherPartyParams@(CR.E2ERatchetParams e2eVersion k1Rcv k2Rcv kem_) conn'@(DuplexConnection cData'@ConnData {lastExternalSndId} _ sqs) = unlessM ratchetExists $ do AgentConfig {e2eEncryptVRange} <- asks config unless (e2eVersion `isCompatible` e2eEncryptVRange) (throwError $ AGENT A_VERSION) @@ -2336,7 +2344,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, exists <- checkRatchetKeyHashExists db connId rkHashRcv unless exists $ addProcessedRatchetKeyHash db connId rkHashRcv pure exists - getSendRatchetKeys :: m (C.PrivateKeyX448, C.PrivateKeyX448) + getSendRatchetKeys :: m (C.PrivateKeyX448, C.PrivateKeyX448, Maybe CR.RcvPrivRKEMParams) getSendRatchetKeys = case rss of RSOk -> sendReplyKey -- receiving client RSAllowed -> sendReplyKey @@ -2352,9 +2360,10 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, where sendReplyKey = do g <- asks random - (pk1, pk2, e2eParams) <- atomically . CR.generateE2EParams g $ version e2eOtherPartyParams + -- TODO PQ the decision to use KEM should depend on connection + (pk1, pk2, pKem, e2eParams) <- liftIO $ CR.generateRcvE2EParams g e2eVersion CR.PQEncOn enqueueRatchetKeyMsgs c cData' sqs e2eParams - pure (pk1, pk2) + pure (pk1, pk2, pKem) notifyRatchetSyncError = do let cData'' = cData' {ratchetSyncState = RSRequired} :: ConnData conn'' = updateConnection cData'' conn' @@ -2371,14 +2380,17 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, createRatchet db connId rc -- compare public keys `k1` in AgentRatchetKey messages sent by self and other party -- to determine ratchet initilization ordering - initRatchet :: VersionRange -> (C.PrivateKeyX448, C.PrivateKeyX448) -> m () - initRatchet e2eEncryptVRange (pk1, pk2) + initRatchet :: VersionRange -> (C.PrivateKeyX448, C.PrivateKeyX448, Maybe CR.RcvPrivRKEMParams) -> m () + initRatchet e2eEncryptVRange (pk1, pk2, pKem) | rkHash (C.publicKey pk1) (C.publicKey pk2) <= rkHashRcv = do - recreateRatchet $ CR.initRcvRatchet e2eEncryptVRange pk2 $ CR.x3dhRcv pk1 pk2 e2eOtherPartyParams + rcParams <- liftError cryptoError $ CR.pqX3dhRcv pk1 pk2 pKem e2eOtherPartyParams + -- TODO PQ the decision to use KEM should either depend on the global setting or on whether it was enabled in connection before + recreateRatchet $ CR.initRcvRatchet e2eEncryptVRange pk2 rcParams $ CR.PQEncryption (isJust kem_) | otherwise = do (_, rcDHRs) <- atomically . C.generateKeyPair =<< asks random - recreateRatchet $ CR.initSndRatchet e2eEncryptVRange k2Rcv rcDHRs $ CR.x3dhSnd pk1 pk2 e2eOtherPartyParams - void . enqueueMessages' c cData' sqs SMP.MsgFlags {notification = True} $ EREADY lastExternalSndId + rcParams <- liftEitherWith cryptoError $ CR.pqX3dhSnd pk1 pk2 (CR.APRKP CR.SRKSProposed <$> pKem) e2eOtherPartyParams + recreateRatchet $ CR.initSndRatchet e2eEncryptVRange k2Rcv rcDHRs rcParams + void . enqueueMessages' c cData' sqs Nothing SMP.MsgFlags {notification = True} $ EREADY lastExternalSndId checkMsgIntegrity :: PrevExternalSndId -> ExternalSndId -> PrevRcvMsgHash -> ByteString -> MsgIntegrity checkMsgIntegrity prevExtSndId extSndId internalPrevMsgHash receivedPrevMsgHash @@ -2412,15 +2424,15 @@ connectReplyQueues c cData@ConnData {userId, connId} ownConnInfo (qInfo :| _) = Just qInfo' -> do sq <- newSndQueue userId connId qInfo' sq' <- withStore c $ \db -> upgradeRcvConnToDuplex db connId sq - enqueueConfirmation c cData sq' ownConnInfo Nothing + enqueueConfirmation c cData sq' ownConnInfo Nothing Nothing -confirmQueueAsync :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> SubscriptionMode -> m () -confirmQueueAsync c cData sq srv connInfo e2eEncryption_ subMode = do - storeConfirmation c cData sq e2eEncryption_ =<< mkAgentConfirmation c cData sq srv connInfo subMode +confirmQueueAsync :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> CR.PQEncryption -> SubscriptionMode -> m () +confirmQueueAsync c cData sq srv connInfo e2eEncryption_ pqEnc subMode = do + storeConfirmation c cData sq e2eEncryption_ (Just pqEnc) =<< mkAgentConfirmation c cData sq srv connInfo subMode submitPendingMsg c cData sq -confirmQueue :: forall m. AgentMonad m => Compatible Version -> AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> SubscriptionMode -> m () -confirmQueue (Compatible agentVersion) c cData@ConnData {connId} sq srv connInfo e2eEncryption_ subMode = do +confirmQueue :: forall m. AgentMonad m => Compatible Version -> AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> Maybe CR.PQEncryption -> SubscriptionMode -> m () +confirmQueue (Compatible agentVersion) c cData@ConnData {connId} sq srv connInfo e2eEncryption_ pqEnc_ subMode = do msg <- mkConfirmation =<< mkAgentConfirmation c cData sq srv connInfo subMode sendConfirmation c sq msg withStore' c $ \db -> setSndQueueStatus db sq Confirmed @@ -2428,7 +2440,7 @@ confirmQueue (Compatible agentVersion) c cData@ConnData {connId} sq srv connInfo mkConfirmation :: AgentMessage -> m MsgBody mkConfirmation aMessage = withStore c $ \db -> runExceptT $ do void . liftIO $ updateSndIds db connId - encConnInfo <- agentRatchetEncrypt db connId (smpEncode aMessage) e2eEncConnInfoLength + (encConnInfo, _) <- agentRatchetEncrypt db connId (smpEncode aMessage) e2eEncConnInfoLength pqEnc_ pure . smpEncode $ AgentConfirmation {agentVersion, e2eEncryption_, encConnInfo} mkAgentConfirmation :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> SubscriptionMode -> m AgentMessage @@ -2436,30 +2448,30 @@ mkAgentConfirmation c cData sq srv connInfo subMode = do qInfo <- createReplyQueue c cData sq subMode srv pure $ AgentConnInfoReply (qInfo :| []) connInfo -enqueueConfirmation :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> m () -enqueueConfirmation c cData sq connInfo e2eEncryption_ = do - storeConfirmation c cData sq e2eEncryption_ $ AgentConnInfo connInfo +enqueueConfirmation :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> Maybe CR.PQEncryption -> m () +enqueueConfirmation c cData sq connInfo e2eEncryption_ pqEnc_ = do + storeConfirmation c cData sq e2eEncryption_ pqEnc_ $ AgentConnInfo connInfo submitPendingMsg c cData sq -storeConfirmation :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> Maybe (CR.E2ERatchetParams 'C.X448) -> AgentMessage -> m () -storeConfirmation c ConnData {connId, connAgentVersion} sq e2eEncryption_ agentMsg = withStore c $ \db -> runExceptT $ do +storeConfirmation :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> Maybe CR.PQEncryption -> AgentMessage -> m () +storeConfirmation c ConnData {connId, connAgentVersion} sq e2eEncryption_ pqEnc_ agentMsg = withStore c $ \db -> runExceptT $ do internalTs <- liftIO getCurrentTime (internalId, internalSndId, prevMsgHash) <- liftIO $ updateSndIds db connId let agentMsgStr = smpEncode agentMsg internalHash = C.sha256Hash agentMsgStr - encConnInfo <- agentRatchetEncrypt db connId agentMsgStr e2eEncConnInfoLength + (encConnInfo, pqEncryption) <- agentRatchetEncrypt db connId agentMsgStr e2eEncConnInfoLength pqEnc_ let msgBody = smpEncode $ AgentConfirmation {agentVersion = connAgentVersion, e2eEncryption_, encConnInfo} msgType = agentMessageType agentMsg - msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgBody, msgFlags = SMP.MsgFlags {notification = True}, internalHash, prevMsgHash} + msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgBody, pqEncryption, msgFlags = SMP.MsgFlags {notification = True}, internalHash, prevMsgHash} liftIO $ createSndMsg db connId msgData liftIO $ createSndMsgDelivery db connId sq internalId -enqueueRatchetKeyMsgs :: forall m. AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> CR.E2ERatchetParams 'C.X448 -> m () +enqueueRatchetKeyMsgs :: forall m. AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> CR.RcvE2ERatchetParams 'C.X448 -> m () enqueueRatchetKeyMsgs c cData (sq :| sqs) e2eEncryption = do msgId <- enqueueRatchetKey c cData sq e2eEncryption mapM_ (enqueueSavedMessage c cData msgId) $ filter isActiveSndQ sqs -enqueueRatchetKey :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> CR.E2ERatchetParams 'C.X448 -> m AgentMsgId +enqueueRatchetKey :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> CR.RcvE2ERatchetParams 'C.X448 -> m AgentMsgId enqueueRatchetKey c cData@ConnData {connId} sq e2eEncryption = do aVRange <- asks $ smpAgentVRange . config msgId <- storeRatchetKey $ maxVersion aVRange @@ -2475,31 +2487,32 @@ enqueueRatchetKey c cData@ConnData {connId} sq e2eEncryption = do internalHash = C.sha256Hash agentMsgStr let msgBody = smpEncode $ AgentRatchetKey {agentVersion, e2eEncryption, info = agentMsgStr} msgType = agentMessageType agentMsg - msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgBody, msgFlags = SMP.MsgFlags {notification = True}, internalHash, prevMsgHash} + -- TODO PQ set pqEncryption based on connection mode + msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgBody, pqEncryption = CR.PQEncOff, msgFlags = SMP.MsgFlags {notification = True}, internalHash, prevMsgHash} liftIO $ createSndMsg db connId msgData liftIO $ createSndMsgDelivery db connId sq internalId pure internalId -- encoded AgentMessage -> encoded EncAgentMessage -agentRatchetEncrypt :: DB.Connection -> ConnId -> ByteString -> Int -> ExceptT StoreError IO ByteString -agentRatchetEncrypt db connId msg paddedLen = do +agentRatchetEncrypt :: DB.Connection -> ConnId -> ByteString -> Int -> Maybe CR.PQEncryption -> ExceptT StoreError IO (ByteString, CR.PQEncryption) +agentRatchetEncrypt db connId msg paddedLen pqEnc_ = do rc <- ExceptT $ getRatchet db connId - (encMsg, rc') <- liftE (SEAgentError . cryptoError) $ CR.rcEncrypt rc paddedLen msg + (encMsg, rc') <- liftE (SEAgentError . cryptoError) $ CR.rcEncrypt rc paddedLen msg pqEnc_ liftIO $ updateRatchet db connId rc' CR.SMDNoChange - pure encMsg + pure (encMsg, CR.rcSndKEM rc') -- encoded EncAgentMessage -> encoded AgentMessage -agentRatchetDecrypt :: TVar ChaChaDRG -> DB.Connection -> ConnId -> ByteString -> ExceptT StoreError IO ByteString +agentRatchetDecrypt :: TVar ChaChaDRG -> DB.Connection -> ConnId -> ByteString -> ExceptT StoreError IO (ByteString, CR.PQEncryption) agentRatchetDecrypt g db connId encAgentMsg = do rc <- ExceptT $ getRatchet db connId agentRatchetDecrypt' g db connId rc encAgentMsg -agentRatchetDecrypt' :: TVar ChaChaDRG -> DB.Connection -> ConnId -> CR.RatchetX448 -> ByteString -> ExceptT StoreError IO ByteString +agentRatchetDecrypt' :: TVar ChaChaDRG -> DB.Connection -> ConnId -> CR.RatchetX448 -> ByteString -> ExceptT StoreError IO (ByteString, CR.PQEncryption) agentRatchetDecrypt' g db connId rc encAgentMsg = do skipped <- liftIO $ getSkippedMsgKeys db connId (agentMsgBody_, rc', skippedDiff) <- liftE (SEAgentError . cryptoError) $ CR.rcDecrypt g rc skipped encAgentMsg liftIO $ updateRatchet db connId rc' skippedDiff - liftEither $ first (SEAgentError . cryptoError) agentMsgBody_ + liftEither $ bimap (SEAgentError . cryptoError) (,CR.rcRcvKEM rc') agentMsgBody_ newSndQueue :: (MonadUnliftIO m, MonadReader Env m) => UserId -> ConnId -> Compatible SMPQueueInfo -> m NewSndQueue newSndQueue userId connId (Compatible (SMPQueueInfo smpClientVersion SMPQueueAddress {smpServer, senderId, dhPublicKey = rcvE2ePubDhKey})) = do diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index 6129b8503..c17ecf4cf 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -166,7 +166,7 @@ import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L import Data.Map (Map) import qualified Data.Map as M -import Data.Maybe (fromMaybe, isJust) +import Data.Maybe (fromMaybe) import Data.Text (Text) import qualified Data.Text as T import Data.Text.Encoding (decodeLatin1, encodeUtf8) @@ -182,7 +182,7 @@ import Simplex.FileTransfer.Description import Simplex.FileTransfer.Protocol (FileParty (..), XFTPErrorType) import Simplex.Messaging.Agent.QueryString import qualified Simplex.Messaging.Crypto as C -import Simplex.Messaging.Crypto.Ratchet (E2ERatchetParams, E2ERatchetParamsUri) +import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), PQEncryption (..), pattern PQEncOff, RcvE2ERatchetParams, RcvE2ERatchetParamsUri, SndE2ERatchetParams) import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers @@ -243,11 +243,15 @@ supportedSMPAgentVRange = mkVersionRange duplexHandshakeSMPAgentVersion currentS -- it is shorter to allow all handshake headers, -- including E2E (double-ratchet) parameters and -- signing key of the sender for the server +-- TODO PQ this should be version-dependent +-- previously it was 14848, reduced by 3700 (roughly the increase of message ratchet header size + key and ciphertext in reply link) e2eEncConnInfoLength :: Int -e2eEncConnInfoLength = 14848 +e2eEncConnInfoLength = 11148 +-- TODO PQ this should be version-dependent +-- previously it was 15856, reduced by 2200 (roughly the increase of message ratchet header size) e2eEncUserMsgLength :: Int -e2eEncUserMsgLength = 15856 +e2eEncUserMsgLength = 13656 -- | Raw (unparsed) SMP agent protocol transmission. type ARawTransmission = (ByteString, ByteString, ByteString) @@ -273,8 +277,6 @@ data SAParty :: AParty -> Type where deriving instance Show (SAParty p) -deriving instance Eq (SAParty p) - instance TestEquality SAParty where testEquality SAgent SAgent = Just Refl testEquality SClient SClient = Just Refl @@ -297,8 +299,6 @@ data SAEntity :: AEntity -> Type where deriving instance Show (SAEntity e) -deriving instance Eq (SAEntity e) - instance TestEquality SAEntity where testEquality SAEConn SAEConn = Just Refl testEquality SAERcvFile SAERcvFile = Just Refl @@ -322,27 +322,22 @@ deriving instance Show ACmd data APartyCmd p = forall e. AEntityI e => APC (SAEntity e) (ACommand p e) -instance Eq (APartyCmd p) where - APC e cmd == APC e' cmd' = case testEquality e e' of - Just Refl -> cmd == cmd' - Nothing -> False - deriving instance Show (APartyCmd p) type ConnInfo = ByteString -- | Parameterized type for SMP agent protocol commands and responses from all participants. data ACommand (p :: AParty) (e :: AEntity) where - NEW :: Bool -> AConnectionMode -> SubscriptionMode -> ACommand Client AEConn -- response INV + NEW :: Bool -> AConnectionMode -> InitialKeys -> SubscriptionMode -> ACommand Client AEConn -- response INV INV :: AConnectionRequestUri -> ACommand Agent AEConn - JOIN :: Bool -> AConnectionRequestUri -> SubscriptionMode -> ConnInfo -> ACommand Client AEConn -- response OK + JOIN :: Bool -> AConnectionRequestUri -> PQEncryption -> SubscriptionMode -> ConnInfo -> ACommand Client AEConn -- response OK CONF :: ConfirmationId -> [SMPServer] -> ConnInfo -> ACommand Agent AEConn -- ConnInfo is from sender, [SMPServer] will be empty only in v1 handshake LET :: ConfirmationId -> ConnInfo -> ACommand Client AEConn -- ConnInfo is from client REQ :: InvitationId -> NonEmpty SMPServer -> ConnInfo -> ACommand Agent AEConn -- ConnInfo is from sender - ACPT :: InvitationId -> ConnInfo -> ACommand Client AEConn -- ConnInfo is from client + ACPT :: InvitationId -> PQEncryption -> ConnInfo -> ACommand Client AEConn -- ConnInfo is from client RJCT :: InvitationId -> ACommand Client AEConn INFO :: ConnInfo -> ACommand Agent AEConn - CON :: ACommand Agent AEConn -- notification that connection is established + CON :: PQEncryption -> ACommand Agent AEConn -- notification that connection is established SUB :: ACommand Client AEConn END :: ACommand Agent AEConn CONNECT :: AProtocolType -> TransportHost -> ACommand Agent AENone @@ -351,8 +346,8 @@ data ACommand (p :: AParty) (e :: AEntity) where UP :: SMPServer -> [ConnId] -> ACommand Agent AENone SWITCH :: QueueDirection -> SwitchPhase -> ConnectionStats -> ACommand Agent AEConn RSYNC :: RatchetSyncState -> Maybe AgentCryptoError -> ConnectionStats -> ACommand Agent AEConn - SEND :: MsgFlags -> MsgBody -> ACommand Client AEConn - MID :: AgentMsgId -> ACommand Agent AEConn + SEND :: PQEncryption -> MsgFlags -> MsgBody -> ACommand Client AEConn + MID :: AgentMsgId -> PQEncryption -> ACommand Agent AEConn SENT :: AgentMsgId -> ACommand Agent AEConn MERR :: AgentMsgId -> AgentErrorType -> ACommand Agent AEConn MERRS :: NonEmpty AgentMsgId -> AgentErrorType -> ACommand Agent AEConn @@ -379,19 +374,12 @@ data ACommand (p :: AParty) (e :: AEntity) where SFDONE :: ValidFileDescription 'FSender -> [ValidFileDescription 'FRecipient] -> ACommand Agent AESndFile SFERR :: AgentErrorType -> ACommand Agent AESndFile -deriving instance Eq (ACommand p e) - deriving instance Show (ACommand p e) data ACmdTag = forall p e. (APartyI p, AEntityI e) => ACmdTag (SAParty p) (SAEntity e) (ACommandTag p e) data APartyCmdTag p = forall e. AEntityI e => APCT (SAEntity e) (ACommandTag p e) -instance Eq (APartyCmdTag p) where - APCT e cmd == APCT e' cmd' = case testEquality e e' of - Just Refl -> cmd == cmd' - Nothing -> False - deriving instance Show (APartyCmdTag p) data ACommandTag (p :: AParty) (e :: AEntity) where @@ -441,8 +429,6 @@ data ACommandTag (p :: AParty) (e :: AEntity) where SFDONE_ :: ACommandTag Agent AESndFile SFERR_ :: ACommandTag Agent AESndFile -deriving instance Eq (ACommandTag p e) - deriving instance Show (ACommandTag p e) aPartyCmdTag :: APartyCmd p -> APartyCmdTag p @@ -458,8 +444,8 @@ aCommandTag = \case REQ {} -> REQ_ ACPT {} -> ACPT_ RJCT _ -> RJCT_ - INFO _ -> INFO_ - CON -> CON_ + INFO {} -> INFO_ + CON _ -> CON_ SUB -> SUB_ END -> END_ CONNECT {} -> CONNECT_ @@ -469,7 +455,7 @@ aCommandTag = \case SWITCH {} -> SWITCH_ RSYNC {} -> RSYNC_ SEND {} -> SEND_ - MID _ -> MID_ + MID {} -> MID_ SENT _ -> SENT_ MERR {} -> MERR_ MERRS {} -> MERRS_ @@ -726,8 +712,6 @@ data SConnectionMode (m :: ConnectionMode) where SCMInvitation :: SConnectionMode CMInvitation SCMContact :: SConnectionMode CMContact -deriving instance Eq (SConnectionMode m) - deriving instance Show (SConnectionMode m) instance TestEquality SConnectionMode where @@ -737,9 +721,6 @@ instance TestEquality SConnectionMode where data AConnectionMode = forall m. ConnectionModeI m => ACM (SConnectionMode m) -instance Eq AConnectionMode where - ACM m == ACM m' = isJust $ testEquality m m' - cmInvitation :: AConnectionMode cmInvitation = ACM SCMInvitation @@ -769,17 +750,19 @@ data MsgMeta = MsgMeta { integrity :: MsgIntegrity, recipient :: (AgentMsgId, UTCTime), broker :: (MsgId, UTCTime), - sndMsgId :: AgentMsgId + sndMsgId :: AgentMsgId, + pqEncryption :: PQEncryption } deriving (Eq, Show) instance StrEncoding MsgMeta where - strEncode MsgMeta {integrity, recipient = (rmId, rTs), broker = (bmId, bTs), sndMsgId} = + strEncode MsgMeta {integrity, recipient = (rmId, rTs), broker = (bmId, bTs), sndMsgId, pqEncryption} = B.unwords [ strEncode integrity, "R=" <> bshow rmId <> "," <> showTs rTs, "B=" <> encode bmId <> "," <> showTs bTs, - "S=" <> bshow sndMsgId + "S=" <> bshow sndMsgId, + "PQ=" <> strEncode pqEncryption ] where showTs = B.pack . formatISO8601Millis @@ -788,7 +771,8 @@ instance StrEncoding MsgMeta where recipient <- " R=" *> partyMeta A.decimal broker <- " B=" *> partyMeta base64P sndMsgId <- " S=" *> A.decimal - pure MsgMeta {integrity, recipient, broker, sndMsgId} + pqEncryption <- " PQ=" *> strP + pure MsgMeta {integrity, recipient, broker, sndMsgId, pqEncryption} where partyMeta idParser = (,) <$> idParser <* A.char ',' <*> tsISO8601P @@ -809,7 +793,7 @@ data SMPConfirmation = SMPConfirmation data AgentMsgEnvelope = AgentConfirmation { agentVersion :: Version, - e2eEncryption_ :: Maybe (E2ERatchetParams 'C.X448), + e2eEncryption_ :: Maybe (SndE2ERatchetParams 'C.X448), encConnInfo :: ByteString } | AgentMsgEnvelope @@ -823,7 +807,7 @@ data AgentMsgEnvelope } | AgentRatchetKey { agentVersion :: Version, - e2eEncryption :: E2ERatchetParams 'C.X448, + e2eEncryption :: RcvE2ERatchetParams 'C.X448, info :: ByteString } deriving (Show) @@ -1115,7 +1099,7 @@ instance forall m. ConnectionModeI m => StrEncoding (ConnectionRequestUri m) whe CRInvitationUri crData e2eParams -> crEncode "invitation" crData (Just e2eParams) CRContactUri crData -> crEncode "contact" crData Nothing where - crEncode :: ByteString -> ConnReqUriData -> Maybe (E2ERatchetParamsUri 'C.X448) -> ByteString + crEncode :: ByteString -> ConnReqUriData -> Maybe (RcvE2ERatchetParamsUri 'C.X448) -> ByteString crEncode crMode ConnReqUriData {crScheme, crAgentVRange, crSmpQueues, crClientData} e2eParams = strEncode crScheme <> "/" <> crMode <> "#/?" <> queryStr where @@ -1324,22 +1308,15 @@ instance Encoding SMPQueueUri where pure $ SMPQueueUri clientVRange SMPQueueAddress {smpServer, senderId, dhPublicKey} data ConnectionRequestUri (m :: ConnectionMode) where - CRInvitationUri :: ConnReqUriData -> E2ERatchetParamsUri 'C.X448 -> ConnectionRequestUri CMInvitation - -- contact connection request does NOT contain E2E encryption parameters - + CRInvitationUri :: ConnReqUriData -> RcvE2ERatchetParamsUri 'C.X448 -> ConnectionRequestUri CMInvitation + -- contact connection request does NOT contain E2E encryption parameters for double ratchet - -- they are passed in AgentInvitation message CRContactUri :: ConnReqUriData -> ConnectionRequestUri CMContact -deriving instance Eq (ConnectionRequestUri m) - deriving instance Show (ConnectionRequestUri m) data AConnectionRequestUri = forall m. ConnectionModeI m => ACR (SConnectionMode m) (ConnectionRequestUri m) -instance Eq AConnectionRequestUri where - ACR m cr == ACR m' cr' = case testEquality m m' of - Just Refl -> cr == cr' - _ -> False - deriving instance Show AConnectionRequestUri data ConnReqUriData = ConnReqUriData @@ -1713,13 +1690,13 @@ commandP binaryP = >>= \case ACmdTag SClient e cmd -> ACmd SClient e <$> case cmd of - NEW_ -> s (NEW <$> strP_ <*> strP_ <*> (strP <|> pure SMP.SMSubscribe)) - JOIN_ -> s (JOIN <$> strP_ <*> strP_ <*> (strP_ <|> pure SMP.SMSubscribe) <*> binaryP) + NEW_ -> s (NEW <$> strP_ <*> strP_ <*> pqIKP <*> (strP <|> pure SMP.SMSubscribe)) + JOIN_ -> s (JOIN <$> strP_ <*> strP_ <*> pqEncP <*> (strP_ <|> pure SMP.SMSubscribe) <*> binaryP) LET_ -> s (LET <$> A.takeTill (== ' ') <* A.space <*> binaryP) - ACPT_ -> s (ACPT <$> A.takeTill (== ' ') <* A.space <*> binaryP) + ACPT_ -> s (ACPT <$> A.takeTill (== ' ') <* A.space <*> pqEncP <*> binaryP) RJCT_ -> s (RJCT <$> A.takeByteString) SUB_ -> pure SUB - SEND_ -> s (SEND <$> smpP <* A.space <*> binaryP) + SEND_ -> s (SEND <$> pqEncP <*> smpP <* A.space <*> binaryP) ACK_ -> s (ACK <$> A.decimal <*> optional (A.space *> binaryP)) SWCH_ -> pure SWCH OFF_ -> pure OFF @@ -1731,7 +1708,7 @@ commandP binaryP = CONF_ -> s (CONF <$> A.takeTill (== ' ') <* A.space <*> strListP <* A.space <*> binaryP) REQ_ -> s (REQ <$> A.takeTill (== ' ') <* A.space <*> strP_ <*> binaryP) INFO_ -> s (INFO <$> binaryP) - CON_ -> pure CON + CON_ -> s (CON <$> strP) END_ -> pure END CONNECT_ -> s (CONNECT <$> strP_ <*> strP) DISCONNECT_ -> s (DISCONNECT <$> strP_ <*> strP) @@ -1739,7 +1716,7 @@ commandP binaryP = UP_ -> s (UP <$> strP_ <*> connections) SWITCH_ -> s (SWITCH <$> strP_ <*> strP_ <*> strP) RSYNC_ -> s (RSYNC <$> strP_ <*> strP <*> strP) - MID_ -> s (MID <$> A.decimal) + MID_ -> s (MID <$> A.decimal <*> _strP) SENT_ -> s (SENT <$> A.decimal) MERR_ -> s (MERR <$> A.decimal <* A.space <*> strP) MERRS_ -> s (MERRS <$> strP_ <*> strP) @@ -1762,6 +1739,10 @@ commandP binaryP = where s :: Parser a -> Parser a s p = A.space *> p + pqIKP :: Parser InitialKeys + pqIKP = strP_ <|> pure (IKNoPQ PQEncOff) + pqEncP :: Parser PQEncryption + pqEncP = strP_ <|> pure PQEncOff connections :: Parser [ConnId] connections = strP `A.sepBy'` A.char ',' sfDone :: Text -> Either String (ACommand 'Agent 'AESndFile) @@ -1777,13 +1758,13 @@ parseCommand = parse (commandP A.takeByteString) $ CMD SYNTAX -- | Serialize SMP agent command. serializeCommand :: ACommand p e -> ByteString serializeCommand = \case - NEW ntfs cMode subMode -> s (NEW_, ntfs, cMode, subMode) + NEW ntfs cMode pqIK subMode -> s (NEW_, ntfs, cMode, pqIK, subMode) INV cReq -> s (INV_, cReq) - JOIN ntfs cReq subMode cInfo -> s (JOIN_, ntfs, cReq, subMode, Str $ serializeBinary cInfo) + JOIN ntfs cReq pqEnc subMode cInfo -> s (JOIN_, ntfs, cReq, pqEnc, subMode, Str $ serializeBinary cInfo) CONF confId srvs cInfo -> B.unwords [s CONF_, confId, strEncodeList srvs, serializeBinary cInfo] LET confId cInfo -> B.unwords [s LET_, confId, serializeBinary cInfo] REQ invId srvs cInfo -> B.unwords [s REQ_, invId, s srvs, serializeBinary cInfo] - ACPT invId cInfo -> B.unwords [s ACPT_, invId, serializeBinary cInfo] + ACPT invId pqEnc cInfo -> B.unwords [s ACPT_, invId, s pqEnc, serializeBinary cInfo] RJCT invId -> B.unwords [s RJCT_, invId] INFO cInfo -> B.unwords [s INFO_, serializeBinary cInfo] SUB -> s SUB_ @@ -1794,8 +1775,8 @@ serializeCommand = \case UP srv conns -> B.unwords [s UP_, s srv, connections conns] SWITCH dir phase srvs -> s (SWITCH_, dir, phase, srvs) RSYNC rrState cryptoErr cstats -> s (RSYNC_, rrState, cryptoErr, cstats) - SEND msgFlags msgBody -> B.unwords [s SEND_, smpEncode msgFlags, serializeBinary msgBody] - MID mId -> s (MID_, mId) + SEND pqEnc msgFlags msgBody -> B.unwords [s SEND_, s pqEnc, smpEncode msgFlags, serializeBinary msgBody] + MID mId pqEnc -> s (MID_, mId, pqEnc) SENT mId -> s (SENT_, mId) MERR mId e -> s (MERR_, mId, e) MERRS mIds e -> s (MERRS_, mIds, e) @@ -1811,7 +1792,7 @@ serializeCommand = \case DEL_USER userId -> s (DEL_USER_, userId) CHK -> s CHK_ STAT srvs -> s (STAT_, srvs) - CON -> s CON_ + CON pqEnc -> s (CON_, pqEnc) ERR e -> s (ERR_, e) OK -> s OK_ SUSPENDED -> s SUSPENDED_ @@ -1884,13 +1865,13 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody cmdWithMsgBody :: APartyCmd p -> m (Either AgentErrorType (APartyCmd p)) cmdWithMsgBody (APC e cmd) = APC e <$$> case cmd of - SEND msgFlags body -> SEND msgFlags <$$> getBody body + SEND kem msgFlags body -> SEND kem msgFlags <$$> getBody body MSG msgMeta msgFlags body -> MSG msgMeta msgFlags <$$> getBody body - JOIN ntfs qUri subMode cInfo -> JOIN ntfs qUri subMode <$$> getBody cInfo + JOIN ntfs qUri kem subMode cInfo -> JOIN ntfs qUri kem subMode <$$> getBody cInfo CONF confId srvs cInfo -> CONF confId srvs <$$> getBody cInfo LET confId cInfo -> LET confId <$$> getBody cInfo REQ invId srvs cInfo -> REQ invId srvs <$$> getBody cInfo - ACPT invId cInfo -> ACPT invId <$$> getBody cInfo + ACPT invId kem cInfo -> ACPT invId kem <$$> getBody cInfo INFO cInfo -> INFO <$$> getBody cInfo _ -> pure $ Right cmd diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index 8f67c74c2..07112a836 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -30,7 +30,7 @@ import Data.Type.Equality import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Agent.RetryInterval (RI2State) import qualified Simplex.Messaging.Crypto as C -import Simplex.Messaging.Crypto.Ratchet (RatchetX448) +import Simplex.Messaging.Crypto.Ratchet (RatchetX448, PQEncryption) import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol ( MsgBody, @@ -61,8 +61,6 @@ data DBQueueId (q :: QueueStored) where DBQueueId :: Int64 -> DBQueueId 'QSStored DBNewQueue :: DBQueueId 'QSNew -deriving instance Eq (DBQueueId q) - deriving instance Show (DBQueueId q) type RcvQueue = StoredRcvQueue 'QSStored @@ -101,7 +99,7 @@ data StoredRcvQueue (q :: QueueStored) = RcvQueue clientNtfCreds :: Maybe ClientNtfCreds, deleteErrors :: Int } - deriving (Eq, Show) + deriving (Show) rcvQueueInfo :: RcvQueue -> RcvQueueInfo rcvQueueInfo rq@RcvQueue {server, rcvSwchStatus} = @@ -128,7 +126,7 @@ data ClientNtfCreds = ClientNtfCreds -- | shared DH secret used to encrypt/decrypt notification metadata (NMsgMeta) from server to recipient rcvNtfDhSecret :: RcvNtfDhSecret } - deriving (Eq, Show) + deriving (Show) type SndQueue = StoredSndQueue 'QSStored @@ -161,7 +159,7 @@ data StoredSndQueue (q :: QueueStored) = SndQueue -- | SMP client version smpClientVersion :: Version } - deriving (Eq, Show) + deriving (Show) sndQueueInfo :: SndQueue -> SndQueueInfo sndQueueInfo SndQueue {server, sndSwchStatus} = @@ -256,8 +254,6 @@ data Connection (d :: ConnType) where DuplexConnection :: ConnData -> NonEmpty RcvQueue -> NonEmpty SndQueue -> Connection CDuplex ContactConnection :: ConnData -> RcvQueue -> Connection CContact -deriving instance Eq (Connection d) - deriving instance Show (Connection d) toConnData :: Connection d -> ConnData @@ -290,8 +286,6 @@ connType SCSnd = CSnd connType SCDuplex = CDuplex connType SCContact = CContact -deriving instance Eq (SConnType d) - deriving instance Show (SConnType d) instance TestEquality SConnType where @@ -305,11 +299,6 @@ instance TestEquality SConnType where -- Used to refer to an arbitrary connection when retrieving from store. data SomeConn = forall d. SomeConn (SConnType d) (Connection d) -instance Eq SomeConn where - SomeConn d c == SomeConn d' c' = case testEquality d d' of - Just Refl -> c == c' - _ -> False - deriving instance Show SomeConn data ConnData = ConnData @@ -319,7 +308,8 @@ data ConnData = ConnData enableNtfs :: Bool, lastExternalSndId :: PrevExternalSndId, deleted :: Bool, - ratchetSyncState :: RatchetSyncState + ratchetSyncState :: RatchetSyncState, + pqEncryption :: PQEncryption } deriving (Eq, Show) @@ -534,6 +524,7 @@ data SndMsgData = SndMsgData msgType :: AgentMessageType, msgFlags :: MsgFlags, msgBody :: MsgBody, + pqEncryption :: PQEncryption, internalHash :: MsgHash, prevMsgHash :: MsgHash } @@ -551,6 +542,7 @@ data PendingMsgData = PendingMsgData msgType :: AgentMessageType, msgFlags :: MsgFlags, msgBody :: MsgBody, + pqEncryption :: PQEncryption, msgRetryState :: Maybe RI2State, internalTs :: InternalTs } diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 647e8bd03..63ac4a280 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -269,6 +269,7 @@ import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.File (CryptoFile (..), CryptoFileArgs (..)) import Simplex.Messaging.Crypto.Ratchet (RatchetX448, SkippedMsgDiff (..), SkippedMsgKeys) +import qualified Simplex.Messaging.Crypto.Ratchet as CR import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Notifications.Protocol (DeviceToken (..), NtfSubscriptionId, NtfTknStatus (..), NtfTokenId, SMPQueueNtf (..)) @@ -575,14 +576,14 @@ createSndConn db gVar cData q@SndQueue {server} = insertSndQueue_ db connId q serverKeyHash_ createConnRecord :: DB.Connection -> ConnId -> ConnData -> SConnectionMode c -> IO () -createConnRecord db connId ConnData {userId, connAgentVersion, enableNtfs} cMode = +createConnRecord db connId ConnData {userId, connAgentVersion, enableNtfs, pqEncryption} cMode = DB.execute db [sql| INSERT INTO connections - (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?) + (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, pq_encryption, duplex_handshake) VALUES (?,?,?,?,?,?,?) |] - (userId, connId, cMode, connAgentVersion, enableNtfs, True) + (userId, connId, cMode, connAgentVersion, enableNtfs, pqEncryption, True) checkConfirmedSndQueueExists_ :: DB.Connection -> NewSndQueue -> IO Bool checkConfirmedSndQueueExists_ db SndQueue {server, sndId} = do @@ -1028,18 +1029,18 @@ getPendingQueueMsg db connId SndQueue {dbQueueId} = DB.query db [sql| - SELECT m.msg_type, m.msg_flags, m.msg_body, m.internal_ts, s.retry_int_slow, s.retry_int_fast + SELECT m.msg_type, m.msg_flags, m.msg_body, m.pq_encryption, m.internal_ts, s.retry_int_slow, s.retry_int_fast FROM messages m JOIN snd_messages s ON s.conn_id = m.conn_id AND s.internal_id = m.internal_id WHERE m.conn_id = ? AND m.internal_id = ? |] (connId, msgId) err = SEInternal $ "msg delivery " <> bshow msgId <> " returned []" - pendingMsgData :: (AgentMessageType, Maybe MsgFlags, MsgBody, InternalTs, Maybe Int64, Maybe Int64) -> PendingMsgData - pendingMsgData (msgType, msgFlags_, msgBody, internalTs, riSlow_, riFast_) = + pendingMsgData :: (AgentMessageType, Maybe MsgFlags, MsgBody, CR.PQEncryption, InternalTs, Maybe Int64, Maybe Int64) -> PendingMsgData + pendingMsgData (msgType, msgFlags_, msgBody, pqEncryption, internalTs, riSlow_, riFast_) = let msgFlags = fromMaybe SMP.noMsgFlags msgFlags_ msgRetryState = RI2State <$> riSlow_ <*> riFast_ - in PendingMsgData {msgId, msgType, msgFlags, msgBody, msgRetryState, internalTs} + in PendingMsgData {msgId, msgType, msgFlags, msgBody, pqEncryption, msgRetryState, internalTs} markMsgFailed msgId = DB.execute db "UPDATE snd_message_deliveries SET failed = 1 WHERE conn_id = ? AND internal_id = ?" (connId, msgId) getWorkItem :: Show i => ByteString -> IO (Maybe i) -> (i -> IO (Either StoreError a)) -> (i -> IO ()) -> IO (Either StoreError (Maybe a)) @@ -1108,7 +1109,7 @@ getRcvMsg db connId agentMsgId = [sql| SELECT r.internal_id, m.internal_ts, r.broker_id, r.broker_ts, r.external_snd_id, r.integrity, r.internal_hash, - m.msg_type, m.msg_body, s.internal_id, s.rcpt_status, r.user_ack + m.msg_type, m.msg_body, m.pq_encryption, s.internal_id, s.rcpt_status, r.user_ack FROM rcv_messages r JOIN messages m ON r.conn_id = m.conn_id AND r.internal_id = m.internal_id LEFT JOIN snd_messages s ON s.conn_id = r.conn_id AND s.rcpt_internal_id = r.internal_id @@ -1124,7 +1125,7 @@ getLastMsg db connId msgId = [sql| SELECT r.internal_id, m.internal_ts, r.broker_id, r.broker_ts, r.external_snd_id, r.integrity, r.internal_hash, - m.msg_type, m.msg_body, s.internal_id, s.rcpt_status, r.user_ack + m.msg_type, m.msg_body, m.pq_encryption, s.internal_id, s.rcpt_status, r.user_ack FROM rcv_messages r JOIN messages m ON r.conn_id = m.conn_id AND r.internal_id = m.internal_id JOIN connections c ON r.conn_id = c.conn_id AND c.last_internal_msg_id = r.internal_id @@ -1133,9 +1134,9 @@ getLastMsg db connId msgId = |] (connId, msgId) -toRcvMsg :: (Int64, InternalTs, BrokerId, BrokerTs, AgentMsgId, MsgIntegrity, MsgHash, AgentMessageType, MsgBody, Maybe AgentMsgId, Maybe MsgReceiptStatus, Bool) -> RcvMsg -toRcvMsg (agentMsgId, internalTs, brokerId, brokerTs, sndMsgId, integrity, internalHash, msgType, msgBody, rcptInternalId_, rcptStatus_, userAck) = - let msgMeta = MsgMeta {recipient = (agentMsgId, internalTs), broker = (brokerId, brokerTs), sndMsgId, integrity} +toRcvMsg :: (Int64, InternalTs, BrokerId, BrokerTs) :. (AgentMsgId, MsgIntegrity, MsgHash, AgentMessageType, MsgBody, CR.PQEncryption, Maybe AgentMsgId, Maybe MsgReceiptStatus, Bool) -> RcvMsg +toRcvMsg ((agentMsgId, internalTs, brokerId, brokerTs) :. (sndMsgId, integrity, internalHash, msgType, msgBody, pqEncryption, rcptInternalId_, rcptStatus_, userAck)) = + let msgMeta = MsgMeta {recipient = (agentMsgId, internalTs), broker = (brokerId, brokerTs), sndMsgId, integrity, pqEncryption} msgReceipt = MsgReceipt <$> rcptInternalId_ <*> rcptStatus_ in RcvMsg {internalId = InternalId agentMsgId, msgMeta, msgType, msgBody, internalHash, msgReceipt, userAck} @@ -1195,34 +1196,34 @@ deleteSndMsgsExpired db ttl = do "DELETE FROM messages WHERE internal_ts < ? AND internal_snd_id IS NOT NULL" (Only cutoffTs) -createRatchetX3dhKeys :: DB.Connection -> ConnId -> C.PrivateKeyX448 -> C.PrivateKeyX448 -> IO () -createRatchetX3dhKeys db connId x3dhPrivKey1 x3dhPrivKey2 = - DB.execute db "INSERT INTO ratchets (conn_id, x3dh_priv_key_1, x3dh_priv_key_2) VALUES (?, ?, ?)" (connId, x3dhPrivKey1, x3dhPrivKey2) +createRatchetX3dhKeys :: DB.Connection -> ConnId -> C.PrivateKeyX448 -> C.PrivateKeyX448 -> Maybe CR.RcvPrivRKEMParams -> IO () +createRatchetX3dhKeys db connId x3dhPrivKey1 x3dhPrivKey2 pqPrivKem = + DB.execute db "INSERT INTO ratchets (conn_id, x3dh_priv_key_1, x3dh_priv_key_2, pq_priv_kem) VALUES (?, ?, ?, ?)" (connId, x3dhPrivKey1, x3dhPrivKey2, pqPrivKem) -getRatchetX3dhKeys :: DB.Connection -> ConnId -> IO (Either StoreError (C.PrivateKeyX448, C.PrivateKeyX448)) +getRatchetX3dhKeys :: DB.Connection -> ConnId -> IO (Either StoreError (C.PrivateKeyX448, C.PrivateKeyX448, Maybe CR.RcvPrivRKEMParams)) getRatchetX3dhKeys db connId = - fmap hasKeys $ - firstRow id SEX3dhKeysNotFound $ - DB.query db "SELECT x3dh_priv_key_1, x3dh_priv_key_2 FROM ratchets WHERE conn_id = ?" (Only connId) + firstRow' keys SEX3dhKeysNotFound $ + DB.query db "SELECT x3dh_priv_key_1, x3dh_priv_key_2, pq_priv_kem FROM ratchets WHERE conn_id = ?" (Only connId) where - hasKeys = \case - Right (Just k1, Just k2) -> Right (k1, k2) + keys = \case + (Just k1, Just k2, pKem) -> Right (k1, k2, pKem) _ -> Left SEX3dhKeysNotFound -- used to remember new keys when starting ratchet re-synchronization -- TODO remove the columns for public keys in v5.7. -- Currently, the keys are not used but still stored to support app downgrade to the previous version. -setRatchetX3dhKeys :: DB.Connection -> ConnId -> C.PrivateKeyX448 -> C.PrivateKeyX448 -> IO () -setRatchetX3dhKeys db connId x3dhPrivKey1 x3dhPrivKey2 = +setRatchetX3dhKeys :: DB.Connection -> ConnId -> C.PrivateKeyX448 -> C.PrivateKeyX448 -> Maybe CR.RcvPrivRKEMParams -> IO () +setRatchetX3dhKeys db connId x3dhPrivKey1 x3dhPrivKey2 pqPrivKem = DB.execute db [sql| UPDATE ratchets - SET x3dh_priv_key_1 = ?, x3dh_priv_key_2 = ?, x3dh_pub_key_1 = ?, x3dh_pub_key_2 = ? + SET x3dh_priv_key_1 = ?, x3dh_priv_key_2 = ?, x3dh_pub_key_1 = ?, x3dh_pub_key_2 = ?, pq_priv_kem = ? WHERE conn_id = ? |] - (x3dhPrivKey1, x3dhPrivKey2, C.publicKey x3dhPrivKey1, C.publicKey x3dhPrivKey2, connId) + (x3dhPrivKey1, x3dhPrivKey2, C.publicKey x3dhPrivKey1, C.publicKey x3dhPrivKey2, pqPrivKem, connId) +-- TODO remove the columns for public keys in v5.7. createRatchet :: DB.Connection -> ConnId -> RatchetX448 -> IO () createRatchet db connId rc = DB.executeNamed @@ -1233,7 +1234,10 @@ createRatchet db connId rc = ON CONFLICT (conn_id) DO UPDATE SET ratchet_state = :ratchet_state, x3dh_priv_key_1 = NULL, - x3dh_priv_key_2 = NULL + x3dh_priv_key_2 = NULL, + x3dh_pub_key_1 = NULL, + x3dh_pub_key_2 = NULL, + pq_priv_kem = NULL |] [":conn_id" := connId, ":ratchet_state" := rc] @@ -1772,6 +1776,10 @@ instance ToField MsgReceiptStatus where toField = toField . decodeLatin1 . strEn instance FromField MsgReceiptStatus where fromField = fromTextField_ $ eitherToMaybe . strDecode . encodeUtf8 +instance ToField CR.PQEncryption where toField (CR.PQEncryption pqEnc) = toField pqEnc + +instance FromField CR.PQEncryption where fromField f = CR.PQEncryption <$> fromField f + listToEither :: e -> [a] -> Either e a listToEither _ (x : _) = Right x listToEither e _ = Left e @@ -1923,14 +1931,14 @@ getConnData db connId' = [sql| SELECT user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, - last_external_snd_msg_id, deleted, ratchet_sync_state + last_external_snd_msg_id, deleted, ratchet_sync_state, pq_encryption FROM connections WHERE conn_id = ? |] (Only connId') where - cData (userId, connId, cMode, connAgentVersion, enableNtfs_, lastExternalSndId, deleted, ratchetSyncState) = - (ConnData {userId, connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, lastExternalSndId, deleted, ratchetSyncState}, cMode) + cData (userId, connId, cMode, connAgentVersion, enableNtfs_, lastExternalSndId, deleted, ratchetSyncState, pqEncryption) = + (ConnData {userId, connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, lastExternalSndId, deleted, ratchetSyncState, pqEncryption}, cMode) setConnDeleted :: DB.Connection -> Bool -> ConnId -> IO () setConnDeleted db waitDelivery connId @@ -2089,23 +2097,15 @@ updateLastIdsRcv_ dbConn connId newInternalId newInternalRcvId = insertRcvMsgBase_ :: DB.Connection -> ConnId -> RcvMsgData -> IO () insertRcvMsgBase_ dbConn connId RcvMsgData {msgMeta, msgType, msgFlags, msgBody, internalRcvId} = do - let MsgMeta {recipient = (internalId, internalTs)} = msgMeta - DB.executeNamed + let MsgMeta {recipient = (internalId, internalTs), pqEncryption} = msgMeta + DB.execute dbConn [sql| INSERT INTO messages - ( conn_id, internal_id, internal_ts, internal_rcv_id, internal_snd_id, msg_type, msg_flags, msg_body) - VALUES - (:conn_id,:internal_id,:internal_ts,:internal_rcv_id, NULL,:msg_type,:msg_flags,:msg_body); + (conn_id, internal_id, internal_ts, internal_rcv_id, internal_snd_id, msg_type, msg_flags, msg_body, pq_encryption) + VALUES (?,?,?,?,?,?,?,?,?); |] - [ ":conn_id" := connId, - ":internal_id" := internalId, - ":internal_ts" := internalTs, - ":internal_rcv_id" := internalRcvId, - ":msg_type" := msgType, - ":msg_flags" := msgFlags, - ":msg_body" := msgBody - ] + (connId, internalId, internalTs, internalRcvId, Nothing :: Maybe Int64, msgType, msgFlags, msgBody, pqEncryption) insertRcvMsgDetails_ :: DB.Connection -> ConnId -> RcvQueue -> RcvMsgData -> IO () insertRcvMsgDetails_ db connId RcvQueue {dbQueueId} RcvMsgData {msgMeta, internalRcvId, internalHash, externalPrevSndHash, encryptedMsgHash} = do @@ -2186,23 +2186,16 @@ updateLastIdsSnd_ dbConn connId newInternalId newInternalSndId = -- * createSndMsg helpers insertSndMsgBase_ :: DB.Connection -> ConnId -> SndMsgData -> IO () -insertSndMsgBase_ dbConn connId SndMsgData {..} = do - DB.executeNamed - dbConn +insertSndMsgBase_ db connId SndMsgData {internalId, internalTs, internalSndId, msgType, msgFlags, msgBody, pqEncryption} = do + DB.execute + db [sql| INSERT INTO messages - ( conn_id, internal_id, internal_ts, internal_rcv_id, internal_snd_id, msg_type, msg_flags, msg_body) + (conn_id, internal_id, internal_ts, internal_rcv_id, internal_snd_id, msg_type, msg_flags, msg_body, pq_encryption) VALUES - (:conn_id,:internal_id,:internal_ts, NULL,:internal_snd_id,:msg_type,:msg_flags,:msg_body); + (?,?,?,?,?,?,?,?,?); |] - [ ":conn_id" := connId, - ":internal_id" := internalId, - ":internal_ts" := internalTs, - ":internal_snd_id" := internalSndId, - ":msg_type" := msgType, - ":msg_flags" := msgFlags, - ":msg_body" := msgBody - ] + (connId, internalId, internalTs, Nothing :: Maybe Int64, internalSndId, msgType, msgFlags, msgBody, pqEncryption) insertSndMsgDetails_ :: DB.Connection -> ConnId -> SndMsgData -> IO () insertSndMsgDetails_ dbConn connId SndMsgData {..} = diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs index 2ed79afa3..344a3f9ce 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs @@ -70,6 +70,7 @@ import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20231225_failed_work_ite import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240121_message_delivery_indexes import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240124_file_redirect import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240223_connections_wait_delivery +import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240225_ratchet_kem import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (dropPrefix, sumTypeJSON) import Simplex.Messaging.Transport.Client (TransportHost) @@ -108,7 +109,8 @@ schemaMigrations = ("m20231225_failed_work_items", m20231225_failed_work_items, Just down_m20231225_failed_work_items), ("m20240121_message_delivery_indexes", m20240121_message_delivery_indexes, Just down_m20240121_message_delivery_indexes), ("m20240124_file_redirect", m20240124_file_redirect, Just down_m20240124_file_redirect), - ("m20240223_connections_wait_delivery", m20240223_connections_wait_delivery, Just down_m20240223_connections_wait_delivery) + ("m20240223_connections_wait_delivery", m20240223_connections_wait_delivery, Just down_m20240223_connections_wait_delivery), + ("m20240225_ratchet_kem", m20240225_ratchet_kem, Just down_m20240225_ratchet_kem) ] -- | The list of migrations in ascending order by date diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20240225_ratchet_kem.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20240225_ratchet_kem.hs new file mode 100644 index 000000000..07ba0f135 --- /dev/null +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20240225_ratchet_kem.hs @@ -0,0 +1,22 @@ +{-# LANGUAGE QuasiQuotes #-} + +module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240225_ratchet_kem where + +import Database.SQLite.Simple (Query) +import Database.SQLite.Simple.QQ (sql) + +m20240225_ratchet_kem :: Query +m20240225_ratchet_kem = + [sql| +ALTER TABLE ratchets ADD COLUMN pq_priv_kem BLOB; +ALTER TABLE connections ADD COLUMN pq_encryption INTEGER NOT NULL DEFAULT 0; +ALTER TABLE messages ADD COLUMN pq_encryption INTEGER NOT NULL DEFAULT 0; +|] + +down_m20240225_ratchet_kem :: Query +down_m20240225_ratchet_kem = + [sql| +ALTER TABLE ratchets DROP COLUMN pq_priv_kem; +ALTER TABLE connections DROP COLUMN pq_encryption; +ALTER TABLE messages DROP COLUMN pq_encryption; +|] diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql index 35459042d..850199cbb 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql @@ -27,7 +27,8 @@ CREATE TABLE connections( user_id INTEGER CHECK(user_id NOT NULL) REFERENCES users ON DELETE CASCADE, ratchet_sync_state TEXT NOT NULL DEFAULT 'ok', - deleted_at_wait_delivery TEXT + deleted_at_wait_delivery TEXT, + pq_encryption INTEGER NOT NULL DEFAULT 0 ) WITHOUT ROWID; CREATE TABLE rcv_queues( host TEXT NOT NULL, @@ -90,6 +91,7 @@ CREATE TABLE messages( msg_type BLOB NOT NULL, --(H)ELLO,(R)EPLY,(D)ELETE. Should SMP confirmation be saved too? msg_body BLOB NOT NULL DEFAULT x'', msg_flags TEXT NULL, + pq_encryption INTEGER NOT NULL DEFAULT 0, PRIMARY KEY(conn_id, internal_id), FOREIGN KEY(conn_id, internal_rcv_id) REFERENCES rcv_messages ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED, @@ -160,7 +162,8 @@ CREATE TABLE ratchets( e2e_version INTEGER NOT NULL DEFAULT 1 , x3dh_pub_key_1 BLOB, - x3dh_pub_key_2 BLOB + x3dh_pub_key_2 BLOB, + pq_priv_kem BLOB ) WITHOUT ROWID; CREATE TABLE skipped_messages( skipped_message_id INTEGER PRIMARY KEY, diff --git a/src/Simplex/Messaging/Crypto.hs b/src/Simplex/Messaging/Crypto.hs index 9a775faa3..28183a1fc 100644 --- a/src/Simplex/Messaging/Crypto.hs +++ b/src/Simplex/Messaging/Crypto.hs @@ -101,6 +101,7 @@ module Simplex.Messaging.Crypto verify, verify', validSignatureSize, + checkAlgorithm, -- * crypto_box authenticator, as discussed in https://groups.google.com/g/sci.crypt/c/73yb5a9pz2Y/m/LNgRO7IYXOwJ CbAuthenticator (..), @@ -243,8 +244,6 @@ data SAlgorithm :: Algorithm -> Type where SX25519 :: SAlgorithm X25519 SX448 :: SAlgorithm X448 -deriving instance Eq (SAlgorithm a) - deriving instance Show (SAlgorithm a) data Alg = forall a. AlgorithmI a => Alg (SAlgorithm a) @@ -297,11 +296,6 @@ data APublicKey AlgorithmI a => APublicKey (SAlgorithm a) (PublicKey a) -instance Eq APublicKey where - APublicKey a k == APublicKey a' k' = case testEquality a a' of - Just Refl -> k == k' - Nothing -> False - instance Encoding APublicKey where smpEncode = smpEncode . encodePubKey {-# INLINE smpEncode #-} @@ -342,11 +336,6 @@ data APrivateKey AlgorithmI a => APrivateKey (SAlgorithm a) (PrivateKey a) -instance Eq APrivateKey where - APrivateKey a k == APrivateKey a' k' = case testEquality a a' of - Just Refl -> k == k' - Nothing -> False - deriving instance Show APrivateKey type PrivateKeyEd25519 = PrivateKey Ed25519 @@ -372,11 +361,6 @@ data APrivateSignKey (AlgorithmI a, SignatureAlgorithm a) => APrivateSignKey (SAlgorithm a) (PrivateKey a) -instance Eq APrivateSignKey where - APrivateSignKey a k == APrivateSignKey a' k' = case testEquality a a' of - Just Refl -> k == k' - Nothing -> False - deriving instance Show APrivateSignKey instance Encoding APrivateSignKey where @@ -396,11 +380,6 @@ data APublicVerifyKey (AlgorithmI a, SignatureAlgorithm a) => APublicVerifyKey (SAlgorithm a) (PublicKey a) -instance Eq APublicVerifyKey where - APublicVerifyKey a k == APublicVerifyKey a' k' = case testEquality a a' of - Just Refl -> k == k' - Nothing -> False - deriving instance Show APublicVerifyKey data APrivateDhKey @@ -408,11 +387,6 @@ data APrivateDhKey (AlgorithmI a, DhAlgorithm a) => APrivateDhKey (SAlgorithm a) (PrivateKey a) -instance Eq APrivateDhKey where - APrivateDhKey a k == APrivateDhKey a' k' = case testEquality a a' of - Just Refl -> k == k' - Nothing -> False - deriving instance Show APrivateDhKey data APublicDhKey @@ -420,11 +394,6 @@ data APublicDhKey (AlgorithmI a, DhAlgorithm a) => APublicDhKey (SAlgorithm a) (PublicKey a) -instance Eq APublicDhKey where - APublicDhKey a k == APublicDhKey a' k' = case testEquality a a' of - Just Refl -> k == k' - Nothing -> False - deriving instance Show APublicDhKey data DhSecret (a :: Algorithm) where @@ -787,8 +756,6 @@ data Signature (a :: Algorithm) where SignatureEd25519 :: Ed25519.Signature -> Signature Ed25519 SignatureEd448 :: Ed448.Signature -> Signature Ed448 -deriving instance Eq (Signature a) - deriving instance Show (Signature a) data ASignature @@ -796,11 +763,6 @@ data ASignature (AlgorithmI a, SignatureAlgorithm a) => ASignature (SAlgorithm a) (Signature a) -instance Eq ASignature where - ASignature a s == ASignature a' s' = case testEquality a a' of - Just Refl -> s == s' - _ -> False - deriving instance Show ASignature class CryptoSignature s where @@ -885,6 +847,8 @@ data CryptoError CryptoHeaderError String | -- | no sending chain key in ratchet state CERatchetState + | -- | no decapsulation key in ratchet state + CERatchetKEMState | -- | header decryption error (could indicate that another key should be tried) CERatchetHeader | -- | too many skipped messages diff --git a/src/Simplex/Messaging/Crypto/Ratchet.hs b/src/Simplex/Messaging/Crypto/Ratchet.hs index 0afa06db3..345119fca 100644 --- a/src/Simplex/Messaging/Crypto/Ratchet.hs +++ b/src/Simplex/Messaging/Crypto/Ratchet.hs @@ -5,16 +5,23 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StrictData #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} +-- {-# LANGUAGE TypeOperators #-} {-# OPTIONS_GHC -fno-warn-redundant-constraints #-} module Simplex.Messaging.Crypto.Ratchet where +import Control.Applicative ((<|>)) import Control.Monad.Except +import Control.Monad.IO.Class (liftIO) import Control.Monad.Trans.Except import Crypto.Cipher.AES (AES256) import Crypto.Hash (SHA512) @@ -23,22 +30,30 @@ import Crypto.Random (ChaChaDRG) import Data.Aeson (FromJSON (..), ToJSON (..)) import qualified Data.Aeson as J import qualified Data.Aeson.TH as JQ +import Data.Attoparsec.ByteString (Parser) +import qualified Data.Attoparsec.ByteString.Char8 as A +import qualified Data.ByteArray as BA import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import qualified Data.ByteString.Lazy as LB +import Data.Composition ((.:), (.:.)) +import Data.Functor (($>)) import qualified Data.List.NonEmpty as L import Data.Map.Strict (Map) import qualified Data.Map.Strict as M -import Data.Maybe (fromMaybe) +import Data.Maybe (fromMaybe, isJust) +import Data.Type.Equality import Data.Typeable (Typeable) import Data.Word (Word32) import Database.SQLite.Simple.FromField (FromField (..)) import Database.SQLite.Simple.ToField (ToField (..)) import Simplex.Messaging.Agent.QueryString import Simplex.Messaging.Crypto +import Simplex.Messaging.Crypto.SNTRUP761.Bindings import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (blobFieldDecoder, defaultJSON, parseE, parseE') +import Simplex.Messaging.Util ((<$?>), ($>>=)) import Simplex.Messaging.Version import UnliftIO.STM @@ -49,74 +64,319 @@ import UnliftIO.STM kdfX3DHE2EEncryptVersion :: Version kdfX3DHE2EEncryptVersion = 2 +pqRatchetVersion :: Version +pqRatchetVersion = 3 + currentE2EEncryptVersion :: Version -currentE2EEncryptVersion = 2 +currentE2EEncryptVersion = 3 supportedE2EEncryptVRange :: VersionRange supportedE2EEncryptVRange = mkVersionRange kdfX3DHE2EEncryptVersion currentE2EEncryptVersion -data E2ERatchetParams (a :: Algorithm) - = E2ERatchetParams Version (PublicKey a) (PublicKey a) - deriving (Eq, Show) +data RatchetKEMState + = RKSProposed -- only KEM encapsulation key + | RKSAccepted -- KEM ciphertext and the next encapsulation key -instance AlgorithmI a => Encoding (E2ERatchetParams a) where - smpEncode (E2ERatchetParams v k1 k2) = smpEncode (v, k1, k2) - smpP = E2ERatchetParams <$> smpP <*> smpP <*> smpP +data SRatchetKEMState (s :: RatchetKEMState) where + SRKSProposed :: SRatchetKEMState 'RKSProposed + SRKSAccepted :: SRatchetKEMState 'RKSAccepted -instance VersionI (E2ERatchetParams a) where - type VersionRangeT (E2ERatchetParams a) = E2ERatchetParamsUri a - version (E2ERatchetParams v _ _) = v - toVersionRangeT (E2ERatchetParams _ k1 k2) vr = E2ERatchetParamsUri vr k1 k2 +deriving instance Show (SRatchetKEMState s) -instance VersionRangeI (E2ERatchetParamsUri a) where - type VersionT (E2ERatchetParamsUri a) = (E2ERatchetParams a) - versionRange (E2ERatchetParamsUri vr _ _) = vr - toVersionT (E2ERatchetParamsUri _ k1 k2) v = E2ERatchetParams v k1 k2 +instance TestEquality SRatchetKEMState where + testEquality SRKSProposed SRKSProposed = Just Refl + testEquality SRKSAccepted SRKSAccepted = Just Refl + testEquality _ _ = Nothing -data E2ERatchetParamsUri (a :: Algorithm) - = E2ERatchetParamsUri VersionRange (PublicKey a) (PublicKey a) - deriving (Eq, Show) +class RatchetKEMStateI (s :: RatchetKEMState) where sRatchetKEMState :: SRatchetKEMState s -instance AlgorithmI a => StrEncoding (E2ERatchetParamsUri a) where - strEncode (E2ERatchetParamsUri vs key1 key2) = - strEncode $ - QSP QNoEscaping [("v", strEncode vs), ("x3dh", strEncodeList [key1, key2])] +instance RatchetKEMStateI RKSProposed where sRatchetKEMState = SRKSProposed + +instance RatchetKEMStateI RKSAccepted where sRatchetKEMState = SRKSAccepted + +checkRatchetKEMState :: forall t s s' a. (RatchetKEMStateI s, RatchetKEMStateI s') => t s' a -> Either String (t s a) +checkRatchetKEMState x = case testEquality (sRatchetKEMState @s) (sRatchetKEMState @s') of + Just Refl -> Right x + Nothing -> Left "bad ratchet KEM state" + +checkRatchetKEMState' :: forall t s s'. (RatchetKEMStateI s, RatchetKEMStateI s') => t s' -> Either String (t s) +checkRatchetKEMState' x = case testEquality (sRatchetKEMState @s) (sRatchetKEMState @s') of + Just Refl -> Right x + Nothing -> Left "bad ratchet KEM state" + +data RKEMParams (s :: RatchetKEMState) where + RKParamsProposed :: KEMPublicKey -> RKEMParams 'RKSProposed + RKParamsAccepted :: KEMCiphertext -> KEMPublicKey -> RKEMParams 'RKSAccepted + +deriving instance Show (RKEMParams s) + +data ARKEMParams = forall s. RatchetKEMStateI s => ARKP (SRatchetKEMState s) (RKEMParams s) + +deriving instance Show ARKEMParams + +instance RatchetKEMStateI s => Encoding (RKEMParams s) where + smpEncode = \case + RKParamsProposed k -> smpEncode ('P', k) + RKParamsAccepted ct k -> smpEncode ('A', ct, k) + smpP = (\(ARKP _ ps) -> checkRatchetKEMState' ps) <$?> smpP + +instance Encoding (ARKEMParams) where + smpEncode (ARKP _ ps) = smpEncode ps + smpP = + smpP >>= \case + 'P' -> ARKP SRKSProposed . RKParamsProposed <$> smpP + 'A' -> ARKP SRKSAccepted .: RKParamsAccepted <$> smpP <*> smpP + _ -> fail "bad ratchet KEM params" + +data E2ERatchetParams (s :: RatchetKEMState) (a :: Algorithm) + = E2ERatchetParams Version (PublicKey a) (PublicKey a) (Maybe (RKEMParams s)) + deriving (Show) + +data AE2ERatchetParams (a :: Algorithm) + = forall s. + RatchetKEMStateI s => + AE2ERatchetParams (SRatchetKEMState s) (E2ERatchetParams s a) + +deriving instance Show (AE2ERatchetParams a) + +data AnyE2ERatchetParams + = forall s a. + (RatchetKEMStateI s, DhAlgorithm a, AlgorithmI a) => + AnyE2ERatchetParams (SRatchetKEMState s) (SAlgorithm a) (E2ERatchetParams s a) + +deriving instance Show AnyE2ERatchetParams + +instance (RatchetKEMStateI s, AlgorithmI a) => Encoding (E2ERatchetParams s a) where + smpEncode (E2ERatchetParams v k1 k2 kem_) + | v >= pqRatchetVersion = smpEncode (v, k1, k2, kem_) + | otherwise = smpEncode (v, k1, k2) + smpP = toParams <$?> smpP + where + toParams :: AE2ERatchetParams a -> Either String (E2ERatchetParams s a) + toParams = \case + AE2ERatchetParams _ (E2ERatchetParams v k1 k2 Nothing) -> Right $ E2ERatchetParams v k1 k2 Nothing + AE2ERatchetParams _ ps -> checkRatchetKEMState ps + +instance AlgorithmI a => Encoding (AE2ERatchetParams a) where + smpEncode (AE2ERatchetParams _ ps) = smpEncode ps + smpP = (\(AnyE2ERatchetParams s _ ps) -> (AE2ERatchetParams s) <$> checkAlgorithm ps) <$?> smpP + +instance Encoding AnyE2ERatchetParams where + smpEncode (AnyE2ERatchetParams _ _ ps) = smpEncode ps + smpP = do + v :: Version <- smpP + APublicDhKey a k1 <- smpP + APublicDhKey a' k2 <- smpP + case testEquality a a' of + Nothing -> fail "bad e2e params: different key algorithms" + Just Refl -> + kemP v >>= \case + Just (ARKP s kem) -> pure $ AnyE2ERatchetParams s a $ E2ERatchetParams v k1 k2 (Just kem) + Nothing -> pure $ AnyE2ERatchetParams SRKSProposed a $ E2ERatchetParams v k1 k2 Nothing + where + kemP :: Version -> Parser (Maybe (ARKEMParams)) + kemP v + | v >= pqRatchetVersion = smpP + | otherwise = pure Nothing + +instance VersionI (E2ERatchetParams s a) where + type VersionRangeT (E2ERatchetParams s a) = E2ERatchetParamsUri s a + version (E2ERatchetParams v _ _ _) = v + toVersionRangeT (E2ERatchetParams _ k1 k2 kem_) vr = E2ERatchetParamsUri vr k1 k2 kem_ + +instance VersionRangeI (E2ERatchetParamsUri s a) where + type VersionT (E2ERatchetParamsUri s a) = (E2ERatchetParams s a) + versionRange (E2ERatchetParamsUri vr _ _ _) = vr + toVersionT (E2ERatchetParamsUri _ k1 k2 kem_) v = E2ERatchetParams v k1 k2 kem_ + +type RcvE2ERatchetParamsUri a = E2ERatchetParamsUri 'RKSProposed a + +data E2ERatchetParamsUri (s :: RatchetKEMState) (a :: Algorithm) + = E2ERatchetParamsUri VersionRange (PublicKey a) (PublicKey a) (Maybe (RKEMParams s)) + deriving (Show) + +data AE2ERatchetParamsUri (a :: Algorithm) + = forall s. + RatchetKEMStateI s => + AE2ERatchetParamsUri (SRatchetKEMState s) (E2ERatchetParamsUri s a) + +deriving instance Show (AE2ERatchetParamsUri a) + +data AnyE2ERatchetParamsUri + = forall s a. + (RatchetKEMStateI s, DhAlgorithm a, AlgorithmI a) => + AnyE2ERatchetParamsUri (SRatchetKEMState s) (SAlgorithm a) (E2ERatchetParamsUri s a) + +deriving instance Show AnyE2ERatchetParamsUri + +instance (RatchetKEMStateI s, AlgorithmI a) => StrEncoding (E2ERatchetParamsUri s a) where + strEncode (E2ERatchetParamsUri vs key1 key2 kem_) = + strEncode . QSP QNoEscaping $ + [("v", strEncode vs), ("x3dh", strEncodeList [key1, key2])] + <> maybe [] encodeKem kem_ + where + encodeKem kem + | maxVersion vs < pqRatchetVersion = [] + | otherwise = case kem of + RKParamsProposed k -> [("kem_key", strEncode k)] + RKParamsAccepted ct k -> [("kem_ct", strEncode ct), ("kem_key", strEncode k)] + strP = toParamsURI <$?> strP + where + toParamsURI = \case + AE2ERatchetParamsUri _ (E2ERatchetParamsUri vr k1 k2 Nothing) -> Right $ E2ERatchetParamsUri vr k1 k2 Nothing + AE2ERatchetParamsUri _ ps -> checkRatchetKEMState ps + +instance AlgorithmI a => StrEncoding (AE2ERatchetParamsUri a) where + strEncode (AE2ERatchetParamsUri _ ps) = strEncode ps + strP = (\(AnyE2ERatchetParamsUri s _ ps) -> (AE2ERatchetParamsUri s) <$> checkAlgorithm ps) <$?> strP + +instance StrEncoding AnyE2ERatchetParamsUri where + strEncode (AnyE2ERatchetParamsUri _ _ ps) = strEncode ps strP = do query <- strP - vs <- queryParam "v" query + vr :: VersionRange <- queryParam "v" query keys <- L.toList <$> queryParam "x3dh" query case keys of - [key1, key2] -> pure $ E2ERatchetParamsUri vs key1 key2 + [APublicDhKey a k1, APublicDhKey a' k2] -> case testEquality a a' of + Nothing -> fail "bad e2e params: different key algorithms" + Just Refl -> + kemP vr query >>= \case + Just (ARKP s kem) -> pure $ AnyE2ERatchetParamsUri s a $ E2ERatchetParamsUri vr k1 k2 (Just kem) + Nothing -> pure $ AnyE2ERatchetParamsUri SRKSProposed a $ E2ERatchetParamsUri vr k1 k2 Nothing _ -> fail "bad e2e params" + where + kemP vr query + | maxVersion vr >= pqRatchetVersion = + queryParam_ "kem_key" query + $>>= \k -> (Just . kemParams k <$> queryParam_ "kem_ct" query) + | otherwise = pure Nothing + kemParams k = \case + Nothing -> ARKP SRKSProposed $ RKParamsProposed k + Just ct -> ARKP SRKSAccepted $ RKParamsAccepted ct k -generateE2EParams :: (AlgorithmI a, DhAlgorithm a) => TVar ChaChaDRG -> Version -> STM (PrivateKey a, PrivateKey a, E2ERatchetParams a) -generateE2EParams g v = do - (k1, pk1) <- generateKeyPair g - (k2, pk2) <- generateKeyPair g - pure (pk1, pk2, E2ERatchetParams v k1 k2) +type RcvE2ERatchetParams a = E2ERatchetParams 'RKSProposed a + +type SndE2ERatchetParams a = AE2ERatchetParams a + +data PrivRKEMParams (s :: RatchetKEMState) where + PrivateRKParamsProposed :: KEMKeyPair -> PrivRKEMParams 'RKSProposed + PrivateRKParamsAccepted :: KEMCiphertext -> KEMSharedKey -> KEMKeyPair -> PrivRKEMParams 'RKSAccepted + +data APrivRKEMParams = forall s. RatchetKEMStateI s => APRKP (SRatchetKEMState s) (PrivRKEMParams s) + +type RcvPrivRKEMParams = PrivRKEMParams 'RKSProposed + +instance RatchetKEMStateI s => Encoding (PrivRKEMParams s) where + smpEncode = \case + PrivateRKParamsProposed k -> smpEncode ('P', k) + PrivateRKParamsAccepted ct shared k -> smpEncode ('A', ct, shared, k) + smpP = (\(APRKP _ ps) -> checkRatchetKEMState' ps) <$?> smpP + +instance Encoding (APrivRKEMParams) where + smpEncode (APRKP _ ps) = smpEncode ps + smpP = + smpP >>= \case + 'P' -> APRKP SRKSProposed . PrivateRKParamsProposed <$> smpP + 'A' -> APRKP SRKSAccepted .:. PrivateRKParamsAccepted <$> smpP <*> smpP <*> smpP + _ -> fail "bad APrivRKEMParams" + +instance RatchetKEMStateI s => ToField (PrivRKEMParams s) where toField = toField . smpEncode + +instance (Typeable s, RatchetKEMStateI s) => FromField (PrivRKEMParams s) where fromField = blobFieldDecoder smpDecode + +data UseKEM (s :: RatchetKEMState) where + ProposeKEM :: UseKEM 'RKSProposed + AcceptKEM :: KEMPublicKey -> UseKEM 'RKSAccepted + +data AUseKEM = forall s. RatchetKEMStateI s => AUseKEM (SRatchetKEMState s) (UseKEM s) + +generateE2EParams :: forall s a. (AlgorithmI a, DhAlgorithm a) => TVar ChaChaDRG -> Version -> Maybe (UseKEM s) -> IO (PrivateKey a, PrivateKey a, Maybe (PrivRKEMParams s), E2ERatchetParams s a) +generateE2EParams g v useKEM_ = do + (k1, pk1) <- atomically $ generateKeyPair g + (k2, pk2) <- atomically $ generateKeyPair g + kems <- kemParams + pure (pk1, pk2, snd <$> kems, E2ERatchetParams v k1 k2 (fst <$> kems)) + where + kemParams :: IO (Maybe (RKEMParams s, PrivRKEMParams s)) + kemParams = case useKEM_ of + Just useKem | v >= pqRatchetVersion -> Just <$> do + ks@(k, _) <- sntrup761Keypair g + case useKem of + ProposeKEM -> pure (RKParamsProposed k, PrivateRKParamsProposed ks) + AcceptKEM k' -> do + (ct, shared) <- sntrup761Enc g k' + pure (RKParamsAccepted ct k, PrivateRKParamsAccepted ct shared ks) + _ -> pure Nothing + +-- used by party initiating connection, Bob in double-ratchet spec +generateRcvE2EParams :: (AlgorithmI a, DhAlgorithm a) => TVar ChaChaDRG -> Version -> PQEncryption -> IO (PrivateKey a, PrivateKey a, Maybe (PrivRKEMParams 'RKSProposed), E2ERatchetParams 'RKSProposed a) +generateRcvE2EParams g v = generateE2EParams g v . proposeKEM_ + where + proposeKEM_ :: PQEncryption -> Maybe (UseKEM 'RKSProposed) + proposeKEM_ = \case + PQEncOn -> Just ProposeKEM + PQEncOff -> Nothing + + +-- used by party accepting connection, Alice in double-ratchet spec +generateSndE2EParams :: forall a. (AlgorithmI a, DhAlgorithm a) => TVar ChaChaDRG -> Version -> Maybe AUseKEM -> IO (PrivateKey a, PrivateKey a, Maybe APrivRKEMParams, AE2ERatchetParams a) +generateSndE2EParams g v = \case + Nothing -> do + (pk1, pk2, _, e2eParams) <- generateE2EParams g v Nothing + pure (pk1, pk2, Nothing, AE2ERatchetParams SRKSProposed e2eParams) + Just (AUseKEM s useKEM) -> do + (pk1, pk2, pKem, e2eParams) <- generateE2EParams g v (Just useKEM) + pure (pk1, pk2, APRKP s <$> pKem, AE2ERatchetParams s e2eParams) data RatchetInitParams = RatchetInitParams { assocData :: Str, ratchetKey :: RatchetKey, sndHK :: HeaderKey, - rcvNextHK :: HeaderKey + rcvNextHK :: HeaderKey, + kemAccepted :: Maybe RatchetKEMAccepted } - deriving (Eq, Show) + deriving (Show) -x3dhSnd :: DhAlgorithm a => PrivateKey a -> PrivateKey a -> E2ERatchetParams a -> RatchetInitParams -x3dhSnd spk1 spk2 (E2ERatchetParams _ rk1 rk2) = - x3dh (publicKey spk1, rk1) (dh' rk1 spk2) (dh' rk2 spk1) (dh' rk2 spk2) +-- this is used by the peer joining the connection +pqX3dhSnd :: DhAlgorithm a => PrivateKey a -> PrivateKey a -> Maybe APrivRKEMParams -> E2ERatchetParams 'RKSProposed a -> Either CryptoError (RatchetInitParams, Maybe KEMKeyPair) +-- 3. replied 2. received +pqX3dhSnd spk1 spk2 spKem_ (E2ERatchetParams v rk1 rk2 rKem_) = do + (ks_, kem_) <- sndPq + let initParams = pqX3dh (publicKey spk1, rk1) (dh' rk1 spk2) (dh' rk2 spk1) (dh' rk2 spk2) kem_ + pure (initParams, ks_) + where + sndPq :: Either CryptoError (Maybe KEMKeyPair, Maybe RatchetKEMAccepted) + sndPq = case spKem_ of + Just (APRKP _ ps) | v >= pqRatchetVersion -> case (ps, rKem_) of + (PrivateRKParamsAccepted ct shared ks, Just (RKParamsProposed k)) -> Right (Just ks, Just $ RatchetKEMAccepted k shared ct) + (PrivateRKParamsProposed ks, _) -> Right (Just ks, Nothing) -- both parties can send "proposal" in case of ratchet renegotiation + _ -> Left CERatchetKEMState + _ -> Right (Nothing, Nothing) -x3dhRcv :: DhAlgorithm a => PrivateKey a -> PrivateKey a -> E2ERatchetParams a -> RatchetInitParams -x3dhRcv rpk1 rpk2 (E2ERatchetParams _ sk1 sk2) = - x3dh (sk1, publicKey rpk1) (dh' sk2 rpk1) (dh' sk1 rpk2) (dh' sk2 rpk2) +-- this is used by the peer that created new connection, after receiving the reply +pqX3dhRcv :: forall s a. (RatchetKEMStateI s, DhAlgorithm a) => PrivateKey a -> PrivateKey a -> Maybe (PrivRKEMParams 'RKSProposed) -> E2ERatchetParams s a -> ExceptT CryptoError IO (RatchetInitParams, Maybe KEMKeyPair) +-- 1. sent 4. received in reply +pqX3dhRcv rpk1 rpk2 rpKem_ (E2ERatchetParams v sk1 sk2 sKem_) = do + kem_ <- rcvPq + let initParams = pqX3dh (sk1, publicKey rpk1) (dh' sk2 rpk1) (dh' sk1 rpk2) (dh' sk2 rpk2) (snd <$> kem_) + pure (initParams, fst <$> kem_) + where + rcvPq :: ExceptT CryptoError IO (Maybe (KEMKeyPair, RatchetKEMAccepted)) + rcvPq = case sKem_ of + Just (RKParamsAccepted ct k') | v >= pqRatchetVersion -> case rpKem_ of + Just (PrivateRKParamsProposed ks@(_, pk)) -> do + shared <- liftIO $ sntrup761Dec ct pk + pure $ Just (ks, RatchetKEMAccepted k' shared ct) + Nothing -> throwError CERatchetKEMState + _ -> pure Nothing -- both parties can send "proposal" in case of ratchet renegotiation -x3dh :: DhAlgorithm a => (PublicKey a, PublicKey a) -> DhSecret a -> DhSecret a -> DhSecret a -> RatchetInitParams -x3dh (sk1, rk1) dh1 dh2 dh3 = - RatchetInitParams {assocData, ratchetKey = RatchetKey sk, sndHK = Key hk, rcvNextHK = Key nhk} +pqX3dh :: DhAlgorithm a => (PublicKey a, PublicKey a) -> DhSecret a -> DhSecret a -> DhSecret a -> Maybe RatchetKEMAccepted -> RatchetInitParams +pqX3dh (sk1, rk1) dh1 dh2 dh3 kemAccepted = + RatchetInitParams {assocData, ratchetKey = RatchetKey sk, sndHK = Key hk, rcvNextHK = Key nhk, kemAccepted} where assocData = Str $ pubKeyBytes sk1 <> pubKeyBytes rk1 - dhs = dhBytes' dh1 <> dhBytes' dh2 <> dhBytes' dh3 + dhs = dhBytes' dh1 <> dhBytes' dh2 <> dhBytes' dh3 <> pq + pq = maybe "" (\RatchetKEMAccepted {rcPQRss = KEMSharedKey ss} -> BA.convert ss) kemAccepted (hk, nhk, sk) = let salt = B.replicate 64 '\0' in hkdf3 salt dhs "SimpleXX3DH" @@ -129,6 +389,11 @@ data Ratchet a = Ratchet -- associated data - must be the same in both parties ratchets rcAD :: Str, rcDHRs :: PrivateKey a, + rcKEM :: Maybe RatchetKEM, + -- TODO PQ make them optional via JSON parser for PQEncryption + rcEnableKEM :: PQEncryption, -- will enable KEM on the next ratchet step + rcSndKEM :: PQEncryption, -- used KEM hybrid secret for sending ratchet + rcRcvKEM :: PQEncryption, -- used KEM hybrid secret for receiving ratchet rcRK :: RatchetKey, rcSnd :: Maybe (SndRatchet a), rcRcv :: Maybe RcvRatchet, @@ -138,20 +403,33 @@ data Ratchet a = Ratchet rcNHKs :: HeaderKey, rcNHKr :: HeaderKey } - deriving (Eq, Show) + deriving (Show) data SndRatchet a = SndRatchet { rcDHRr :: PublicKey a, rcCKs :: RatchetKey, rcHKs :: HeaderKey } - deriving (Eq, Show) + deriving (Show) data RcvRatchet = RcvRatchet { rcCKr :: RatchetKey, rcHKr :: HeaderKey } - deriving (Eq, Show) + deriving (Show) + +data RatchetKEM = RatchetKEM + { rcPQRs :: KEMKeyPair, + rcKEMs :: Maybe RatchetKEMAccepted + } + deriving (Show) + +data RatchetKEMAccepted = RatchetKEMAccepted + { rcPQRr :: KEMPublicKey, -- received key + rcPQRss :: KEMSharedKey, -- computed shared secret + rcPQRct :: KEMCiphertext -- sent encaps(rcPQRr, rcPQRss) + } + deriving (Show) type SkippedMsgKeys = Map HeaderKey SkippedHdrMsgKeys @@ -189,7 +467,7 @@ instance Encoding MessageKey where -- | Input key material for double ratchet HKDF functions newtype RatchetKey = RatchetKey ByteString - deriving (Eq, Show) + deriving (Show) instance ToJSON RatchetKey where toJSON (RatchetKey k) = strToJSON k @@ -202,19 +480,32 @@ instance ToField MessageKey where toField = toField . smpEncode instance FromField MessageKey where fromField = blobFieldDecoder smpDecode --- | Sending ratchet initialization, equivalent to RatchetInitAliceHE in double ratchet spec +-- | Sending ratchet initialization -- -- Please note that sPKey is not stored, and its public part together with random salt -- is sent to the recipient. +-- @ +-- RatchetInitAlicePQ2HE(state, SK, bob_dh_public_key, shared_hka, shared_nhkb, bob_pq_kem_encapsulation_key) +-- // below added for post-quantum KEM +-- state.PQRs = GENERATE_PQKEM() +-- state.PQRr = bob_pq_kem_encapsulation_key +-- state.PQRss = random // shared secret for KEM +-- state.PQRct = PQKEM-ENC(state.PQRr, state.PQRss) // encapsulated additional shared secret +-- // above added for KEM +-- @ initSndRatchet :: - forall a. (AlgorithmI a, DhAlgorithm a) => VersionRange -> PublicKey a -> PrivateKey a -> RatchetInitParams -> Ratchet a -initSndRatchet rcVersion rcDHRr rcDHRs RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK} = do - -- state.RK, state.CKs, state.NHKs = KDF_RK_HE(SK, DH(state.DHRs, state.DHRr)) - let (rcRK, rcCKs, rcNHKs) = rootKdf ratchetKey rcDHRr rcDHRs + forall a. (AlgorithmI a, DhAlgorithm a) => VersionRange -> PublicKey a -> PrivateKey a -> (RatchetInitParams, Maybe KEMKeyPair) -> Ratchet a +initSndRatchet rcVersion rcDHRr rcDHRs (RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted}, rcPQRs_) = do + -- state.RK, state.CKs, state.NHKs = KDF_RK_HE(SK, DH(state.DHRs, state.DHRr) || state.PQRss) + let (rcRK, rcCKs, rcNHKs) = rootKdf ratchetKey rcDHRr rcDHRs (rcPQRss <$> kemAccepted) in Ratchet { rcVersion, rcAD = assocData, rcDHRs, + rcKEM = (`RatchetKEM` kemAccepted) <$> rcPQRs_, + rcEnableKEM = PQEncryption $ isJust rcPQRs_, + rcSndKEM = PQEncryption $ isJust kemAccepted, + rcRcvKEM = PQEncOff, rcRK, rcSnd = Just SndRatchet {rcDHRr, rcCKs, rcHKs = sndHK}, rcRcv = Nothing, @@ -225,17 +516,28 @@ initSndRatchet rcVersion rcDHRr rcDHRs RatchetInitParams {assocData, ratchetKey, rcNHKr = rcvNextHK } --- | Receiving ratchet initialization, equivalent to RatchetInitBobHE in double ratchet spec +-- | Receiving ratchet initialization, equivalent to RatchetInitBobPQ2HE in double ratchet spec +-- +-- def RatchetInitBobPQ2HE(state, SK, bob_dh_key_pair, shared_hka, shared_nhkb, bob_pq_kem_key_pair) -- -- Please note that the public part of rcDHRs was sent to the sender -- as part of the connection request and random salt was received from the sender. initRcvRatchet :: - forall a. (AlgorithmI a, DhAlgorithm a) => VersionRange -> PrivateKey a -> RatchetInitParams -> Ratchet a -initRcvRatchet rcVersion rcDHRs RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK} = + forall a. (AlgorithmI a, DhAlgorithm a) => VersionRange -> PrivateKey a -> (RatchetInitParams, Maybe KEMKeyPair) -> PQEncryption -> Ratchet a +initRcvRatchet rcVersion rcDHRs (RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted}, rcPQRs_) rcEnableKEM = Ratchet { rcVersion, rcAD = assocData, rcDHRs, + -- rcKEM: + -- state.PQRs = bob_pq_kem_key_pair + -- state.PQRr = None + -- state.PQRss = None + -- state.PQRct = None + rcKEM = (`RatchetKEM` kemAccepted) <$> rcPQRs_, + rcEnableKEM, + rcSndKEM = PQEncOff, + rcRcvKEM = PQEncOff, rcRK = ratchetKey, rcSnd = Nothing, rcRcv = Nothing, @@ -246,14 +548,17 @@ initRcvRatchet rcVersion rcDHRs RatchetInitParams {assocData, ratchetKey, sndHK, rcNHKr = sndHK } +-- encaps = state.PQRs.encaps, // added for KEM #2 +-- ct = state.PQRct // added for KEM #1 data MsgHeader a = MsgHeader { -- | max supported ratchet version msgMaxVersion :: Version, msgDHRs :: PublicKey a, + msgKEM :: Maybe ARKEMParams, msgPN :: Word32, msgNs :: Word32 } - deriving (Eq, Show) + deriving (Show) data AMsgHeader = forall a. @@ -262,8 +567,10 @@ data AMsgHeader -- to allow extension without increasing the size, the actual header length is: -- 69 = 2 (original size) + 2 + 1+56 (Curve448) + 4 + 4 +-- TODO PQ this must be version-dependent +-- TODO this is the exact size, some reserve should be added paddedHeaderLen :: Int -paddedHeaderLen = 88 +paddedHeaderLen = 2284 -- only used in tests to validate correct padding -- (2 bytes - version size, 1 byte - header size, not to have it fixed or version-dependent) @@ -271,14 +578,16 @@ fullHeaderLen :: Int fullHeaderLen = 2 + 1 + paddedHeaderLen + authTagSize + ivSize @AES256 instance AlgorithmI a => Encoding (MsgHeader a) where - smpEncode MsgHeader {msgMaxVersion, msgDHRs, msgPN, msgNs} = - smpEncode (msgMaxVersion, msgDHRs, msgPN, msgNs) + smpEncode MsgHeader {msgMaxVersion, msgDHRs, msgKEM, msgPN, msgNs} + | msgMaxVersion >= pqRatchetVersion = smpEncode (msgMaxVersion, msgDHRs, msgKEM, msgPN, msgNs) + | otherwise = smpEncode (msgMaxVersion, msgDHRs, msgPN, msgNs) smpP = do msgMaxVersion <- smpP msgDHRs <- smpP + msgKEM <- if msgMaxVersion >= pqRatchetVersion then smpP else pure Nothing msgPN <- smpP msgNs <- smpP - pure MsgHeader {msgMaxVersion, msgDHRs, msgPN, msgNs} + pure MsgHeader {msgMaxVersion, msgDHRs, msgKEM, msgPN, msgNs} data EncMessageHeader = EncMessageHeader { ehVersion :: Version, @@ -288,10 +597,12 @@ data EncMessageHeader = EncMessageHeader } instance Encoding EncMessageHeader where - smpEncode EncMessageHeader {ehVersion, ehIV, ehAuthTag, ehBody} = - smpEncode (ehVersion, ehIV, ehAuthTag, ehBody) + smpEncode EncMessageHeader {ehVersion, ehIV, ehAuthTag, ehBody} + | ehVersion >= pqRatchetVersion = smpEncode (ehVersion, ehIV, ehAuthTag, Large ehBody) + | otherwise = smpEncode (ehVersion, ehIV, ehAuthTag, ehBody) smpP = do - (ehVersion, ehIV, ehAuthTag, ehBody) <- smpP + (ehVersion, ehIV, ehAuthTag) <- smpP + ehBody <- if ehVersion >= pqRatchetVersion then unLarge <$> smpP else smpP pure EncMessageHeader {ehVersion, ehIV, ehAuthTag, ehBody} data EncRatchetMessage = EncRatchetMessage @@ -300,37 +611,123 @@ data EncRatchetMessage = EncRatchetMessage emBody :: ByteString } -instance Encoding EncRatchetMessage where - smpEncode EncRatchetMessage {emHeader, emBody, emAuthTag} = - smpEncode (emHeader, emAuthTag, Tail emBody) - smpP = do - (emHeader, emAuthTag, Tail emBody) <- smpP - pure EncRatchetMessage {emHeader, emBody, emAuthTag} +encodeEncRatchetMessage :: Version -> EncRatchetMessage -> ByteString +encodeEncRatchetMessage v EncRatchetMessage {emHeader, emBody, emAuthTag} + | v >= pqRatchetVersion = smpEncode (Large emHeader, emAuthTag, Tail emBody) + | otherwise = smpEncode (emHeader, emAuthTag, Tail emBody) -rcEncrypt :: AlgorithmI a => Ratchet a -> Int -> ByteString -> ExceptT CryptoError IO (ByteString, Ratchet a) -rcEncrypt Ratchet {rcSnd = Nothing} _ _ = throwE CERatchetState -rcEncrypt rc@Ratchet {rcSnd = Just sr@SndRatchet {rcCKs, rcHKs}, rcDHRs, rcNs, rcPN, rcAD = Str rcAD, rcVersion} paddedMsgLen msg = do +encRatchetMessageP :: Version -> Parser EncRatchetMessage +encRatchetMessageP v = do + emHeader <- if v >= pqRatchetVersion then unLarge <$> smpP else smpP + (emAuthTag, Tail emBody) <- smpP + pure EncRatchetMessage {emHeader, emBody, emAuthTag} + +newtype PQEncryption = PQEncryption {enablePQ :: Bool} + deriving (Eq, Show) + +pattern PQEncOn :: PQEncryption +pattern PQEncOn = PQEncryption True + +pattern PQEncOff :: PQEncryption +pattern PQEncOff = PQEncryption False + +{-# COMPLETE PQEncOn, PQEncOff #-} + +instance ToJSON PQEncryption where + toEncoding (PQEncryption pq) = toEncoding pq + toJSON (PQEncryption pq) = toJSON pq + +instance FromJSON PQEncryption where + parseJSON v = PQEncryption <$> parseJSON v + +replyKEM_ :: PQEncryption -> Maybe (RKEMParams 'RKSProposed) -> Maybe AUseKEM +replyKEM_ pqEnc kem_ = case pqEnc of + PQEncOn -> Just $ case kem_ of + Just (RKParamsProposed k) -> AUseKEM SRKSAccepted $ AcceptKEM k + Nothing -> AUseKEM SRKSProposed ProposeKEM + PQEncOff -> Nothing + +instance StrEncoding PQEncryption where + strEncode pqMode + | enablePQ pqMode = "pq=enable" + | otherwise = "pq=disable" + strP = + A.takeTill (== ' ') >>= \case + "pq=enable" -> pq True + "pq=disable" -> pq False + _ -> fail "bad PQEncryption" + where + pq = pure . PQEncryption + +data InitialKeys = IKUsePQ | IKNoPQ PQEncryption + deriving (Eq, Show) + +instance StrEncoding InitialKeys where + strEncode = \case + IKUsePQ -> "pq=invitation" + IKNoPQ pq -> strEncode pq + strP = IKNoPQ <$> strP <|> "pq=invitation" $> IKUsePQ + +-- determines whether PQ key should be included in invitation link +initialPQEncryption :: InitialKeys -> PQEncryption +initialPQEncryption = \case + IKUsePQ -> PQEncOn + IKNoPQ _ -> PQEncOff -- default + +-- determines whether PQ encryption should be used in connection +connPQEncryption :: InitialKeys -> PQEncryption +connPQEncryption = \case + IKUsePQ -> PQEncOn + IKNoPQ pq -> pq -- default for creating connection is IKNoPQ PQEncOn + +-- determines whether PQ key should be included in invitation link sent to contact address +joinContactInitialKeys :: PQEncryption -> InitialKeys +joinContactInitialKeys = \case + PQEncOn -> IKUsePQ -- default + PQEncOff -> IKNoPQ PQEncOff + +rcEncrypt :: AlgorithmI a => Ratchet a -> Int -> ByteString -> Maybe PQEncryption -> ExceptT CryptoError IO (ByteString, Ratchet a) +rcEncrypt Ratchet {rcSnd = Nothing} _ _ _ = throwE CERatchetState +rcEncrypt rc@Ratchet {rcSnd = Just sr@SndRatchet {rcCKs, rcHKs}, rcDHRs, rcKEM, rcNs, rcPN, rcAD = Str rcAD, rcVersion} paddedMsgLen msg pqMode_ = do -- state.CKs, mk = KDF_CK(state.CKs) let (ck', mk, iv, ehIV) = chainKdf rcCKs -- enc_header = HENCRYPT(state.HKs, header) (ehAuthTag, ehBody) <- encryptAEAD rcHKs ehIV paddedHeaderLen rcAD msgHeader -- return enc_header, ENCRYPT(mk, plaintext, CONCAT(AD, enc_header)) - let emHeader = smpEncode EncMessageHeader {ehVersion = minVersion rcVersion, ehBody, ehAuthTag, ehIV} + -- TODO PQ versioning in Ratchet should change somehow + let emHeader = smpEncode EncMessageHeader {ehVersion = maxVersion rcVersion, ehBody, ehAuthTag, ehIV} (emAuthTag, emBody) <- encryptAEAD mk iv paddedMsgLen (rcAD <> emHeader) msg - let msg' = smpEncode EncRatchetMessage {emHeader, emBody, emAuthTag} + let msg' = encodeEncRatchetMessage (maxVersion rcVersion) EncRatchetMessage {emHeader, emBody, emAuthTag} -- state.Ns += 1 rc' = rc {rcSnd = Just sr {rcCKs = ck'}, rcNs = rcNs + 1} - pure (msg', rc') + rc'' = case pqMode_ of + Nothing -> rc' + Just rcEnableKEM + | enablePQ rcEnableKEM -> rc' {rcEnableKEM} + | otherwise -> + let rcKEM' = (\rck -> rck {rcKEMs = Nothing}) <$> rcKEM + in rc' {rcEnableKEM, rcKEM = rcKEM'} + pure (msg', rc'') where - -- header = HEADER(state.DHRs, state.PN, state.Ns) + -- header = HEADER_PQ2( + -- dh = state.DHRs.public, + -- kem = state.PQRs.public, // added for KEM #2 + -- ct = state.PQRct, // added for KEM #1 + -- pn = state.PN, + -- n = state.Ns + -- ) msgHeader = smpEncode MsgHeader { msgMaxVersion = maxVersion rcVersion, msgDHRs = publicKey rcDHRs, + msgKEM = msgKEMParams <$> rcKEM, msgPN = rcPN, msgNs = rcNs } + msgKEMParams RatchetKEM {rcPQRs = (k, _), rcKEMs} = case rcKEMs of + Nothing -> ARKP SRKSProposed $ RKParamsProposed k + Just RatchetKEMAccepted {rcPQRct} -> ARKP SRKSAccepted $ RKParamsAccepted rcPQRct k data SkippedMessage a = SMMessage (DecryptResult a) @@ -338,7 +735,7 @@ data SkippedMessage a | SMNone data RatchetStep = AdvanceRatchet | SameRatchet - deriving (Eq) + deriving (Eq, Show) type DecryptResult a = (Either CryptoError ByteString, Ratchet a, SkippedMsgDiff) @@ -353,8 +750,9 @@ rcDecrypt :: SkippedMsgKeys -> ByteString -> ExceptT CryptoError IO (DecryptResult a) -rcDecrypt g rc@Ratchet {rcRcv, rcAD = Str rcAD} rcMKSkipped msg' = do - encMsg@EncRatchetMessage {emHeader} <- parseE CryptoHeaderError smpP msg' +rcDecrypt g rc@Ratchet {rcRcv, rcAD = Str rcAD, rcVersion} rcMKSkipped msg' = do + -- TODO PQ versioning should change + encMsg@EncRatchetMessage {emHeader} <- parseE CryptoHeaderError (encRatchetMessageP $ maxVersion rcVersion) msg' encHdr <- parseE CryptoHeaderError smpP emHeader -- plaintext = TrySkippedMessageKeysHE(state, enc_header, cipher-text, AD) decryptSkipped encHdr encMsg >>= \case @@ -368,7 +766,7 @@ rcDecrypt g rc@Ratchet {rcRcv, rcAD = Str rcAD} rcMKSkipped msg' = do SMMessage r -> pure r where decryptRcMessage :: RatchetStep -> MsgHeader a -> EncRatchetMessage -> ExceptT CryptoError IO (DecryptResult a) - decryptRcMessage rcStep MsgHeader {msgDHRs, msgPN, msgNs} encMsg = do + decryptRcMessage rcStep MsgHeader {msgDHRs, msgKEM, msgPN, msgNs} encMsg = do -- if dh_ratchet: (rc', smks1) <- ratchetStep rcStep case skipMessageKeys msgNs rc' of @@ -392,15 +790,23 @@ rcDecrypt g rc@Ratchet {rcRcv, rcAD = Str rcAD} rcMKSkipped msg' = do case skipMessageKeys msgPN rc of Left e -> throwE e Right (rc'@Ratchet {rcDHRs, rcRK, rcNHKs, rcNHKr}, hmks) -> do - -- DHRatchetHE(state, header) + -- DHRatchetPQ2HE(state, header) + (kemSS, kemSS', rcKEM') <- pqRatchetStep rc' msgKEM + -- state.DHRs = GENERATE_DH() (_, rcDHRs') <- atomically $ generateKeyPair @a g - -- state.RK, state.CKr, state.NHKr = KDF_RK_HE(state.RK, DH(state.DHRs, state.DHRr)) - let (rcRK', rcCKr', rcNHKr') = rootKdf rcRK msgDHRs rcDHRs - -- state.RK, state.CKs, state.NHKs = KDF_RK_HE(state.RK, DH(state.DHRs, state.DHRr)) - (rcRK'', rcCKs', rcNHKs') = rootKdf rcRK' msgDHRs rcDHRs' + -- state.RK, state.CKr, state.NHKr = KDF_RK_HE(state.RK, DH(state.DHRs, state.DHRr) || ss) + let (rcRK', rcCKr', rcNHKr') = rootKdf rcRK msgDHRs rcDHRs kemSS + -- state.RK, state.CKs, state.NHKs = KDF_RK_HE(state.RK, DH(state.DHRs, state.DHRr) || state.PQRss) + (rcRK'', rcCKs', rcNHKs') = rootKdf rcRK' msgDHRs rcDHRs' kemSS' + sndKEM = isJust kemSS' + rcvKEM = isJust kemSS rc'' = rc' { rcDHRs = rcDHRs', + rcKEM = rcKEM', + rcEnableKEM = PQEncryption $ sndKEM || rcvKEM, + rcSndKEM = PQEncryption sndKEM, + rcRcvKEM = PQEncryption rcvKEM, rcRK = rcRK'', rcSnd = Just SndRatchet {rcDHRr = msgDHRs, rcCKs = rcCKs', rcHKs = rcNHKs}, rcRcv = Just RcvRatchet {rcCKr = rcCKr', rcHKr = rcNHKr}, @@ -411,6 +817,39 @@ rcDecrypt g rc@Ratchet {rcRcv, rcAD = Str rcAD} rcMKSkipped msg' = do rcNHKr = rcNHKr' } pure (rc'', hmks) + pqRatchetStep :: Ratchet a -> Maybe ARKEMParams -> ExceptT CryptoError IO (Maybe KEMSharedKey, Maybe KEMSharedKey, Maybe RatchetKEM) + pqRatchetStep Ratchet {rcKEM, rcEnableKEM = PQEncryption pqEnc} = \case + -- received message does not have KEM in header, + -- but the user enabled KEM when sending previous message + Nothing -> case rcKEM of + Nothing | pqEnc -> do + rcPQRs <- liftIO $ sntrup761Keypair g + pure (Nothing, Nothing, Just RatchetKEM {rcPQRs, rcKEMs = Nothing}) + _ -> pure (Nothing, Nothing, Nothing) + -- received message has KEM in header. + Just (ARKP _ ps) + | pqEnc -> do + -- state.PQRr = header.kem + (ss, rcPQRr) <- sharedSecret + -- state.PQRct = PQKEM-ENC(state.PQRr, state.PQRss) // encapsulated additional shared secret KEM #1 + (rcPQRct, rcPQRss) <- liftIO $ sntrup761Enc g rcPQRr + -- state.PQRs = GENERATE_PQKEM() + rcPQRs <- liftIO $ sntrup761Keypair g + let kem' = RatchetKEM {rcPQRs, rcKEMs = Just RatchetKEMAccepted {rcPQRr, rcPQRss, rcPQRct}} + pure (ss, Just rcPQRss, Just kem') + | otherwise -> do + -- state.PQRr = header.kem + (ss, _) <- sharedSecret + pure (ss, Nothing, Nothing) + where + sharedSecret = case ps of + RKParamsProposed k -> pure (Nothing, k) + RKParamsAccepted ct k -> case rcKEM of + Nothing -> throwE CERatchetKEMState + -- ss = PQKEM-DEC(state.PQRs.private, header.ct) + Just RatchetKEM {rcPQRs} -> do + ss <- liftIO $ sntrup761Dec ct (snd rcPQRs) + pure (Just ss, k) skipMessageKeys :: Word32 -> Ratchet a -> Either CryptoError (Ratchet a, SkippedMsgKeys) skipMessageKeys _ r@Ratchet {rcRcv = Nothing} = Right (r, M.empty) skipMessageKeys untilN r@Ratchet {rcRcv = Just rr@RcvRatchet {rcCKr, rcHKr}, rcNr} @@ -465,10 +904,13 @@ rcDecrypt g rc@Ratchet {rcRcv, rcAD = Str rcAD} rcMKSkipped msg' = do -- DECRYPT(mk, cipher-text, CONCAT(AD, enc_header)) tryE $ decryptAEAD mk iv (rcAD <> emHeader) emBody emAuthTag -rootKdf :: (AlgorithmI a, DhAlgorithm a) => RatchetKey -> PublicKey a -> PrivateKey a -> (RatchetKey, RatchetKey, Key) -rootKdf (RatchetKey rk) k pk = - let dhOut = dhBytes' $ dh' k pk - (rk', ck, nhk) = hkdf3 rk dhOut "SimpleXRootRatchet" +rootKdf :: (AlgorithmI a, DhAlgorithm a) => RatchetKey -> PublicKey a -> PrivateKey a -> Maybe KEMSharedKey -> (RatchetKey, RatchetKey, Key) +rootKdf (RatchetKey rk) k pk kemSecret_ = + let dhOut = dhBytes' (dh' k pk) + ss = case kemSecret_ of + Just (KEMSharedKey s) -> dhOut <> BA.convert s + Nothing -> dhOut + (rk', ck, nhk) = hkdf3 rk ss "SimpleXRootRatchet" in (RatchetKey rk', RatchetKey ck, Key nhk) chainKdf :: RatchetKey -> (RatchetKey, Key, IV, IV) @@ -487,6 +929,10 @@ hkdf3 salt ikm info = (s1, s2, s3) $(JQ.deriveJSON defaultJSON ''RcvRatchet) +$(JQ.deriveJSON defaultJSON ''RatchetKEMAccepted) + +$(JQ.deriveJSON defaultJSON ''RatchetKEM) + instance AlgorithmI a => ToJSON (SndRatchet a) where toEncoding = $(JQ.mkToEncoding defaultJSON ''SndRatchet) toJSON = $(JQ.mkToJSON defaultJSON ''SndRatchet) diff --git a/src/Simplex/Messaging/Crypto/SNTRUP761/Bindings.hs b/src/Simplex/Messaging/Crypto/SNTRUP761/Bindings.hs index 0940c53ba..3b2238086 100644 --- a/src/Simplex/Messaging/Crypto/SNTRUP761/Bindings.hs +++ b/src/Simplex/Messaging/Crypto/SNTRUP761/Bindings.hs @@ -19,16 +19,20 @@ import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String newtype KEMPublicKey = KEMPublicKey ByteString - deriving (Show) + deriving (Eq, Show) newtype KEMSecretKey = KEMSecretKey ScrubbedBytes - deriving (Show) + deriving (Eq, Show) newtype KEMCiphertext = KEMCiphertext ByteString - deriving (Show) + deriving (Eq, Show) newtype KEMSharedKey = KEMSharedKey ScrubbedBytes - deriving (Show) + deriving (Eq, Show) + +unsafeRevealKEMSharedKey :: KEMSharedKey -> String +unsafeRevealKEMSharedKey (KEMSharedKey scrubbed) = show (BA.convert scrubbed :: ByteString) +{-# DEPRECATED unsafeRevealKEMSharedKey "unsafeRevealKEMSharedKey left in code" #-} type KEMKeyPair = (KEMPublicKey, KEMSecretKey) @@ -60,6 +64,18 @@ sntrup761Dec (KEMCiphertext c) (KEMSecretKey sk) = KEMSharedKey <$> BA.alloc c_SNTRUP761_SIZE (\kPtr -> c_sntrup761_dec kPtr cPtr skPtr) +instance Encoding KEMSecretKey where + smpEncode (KEMSecretKey c) = smpEncode . Large $ BA.convert c + smpP = KEMSecretKey . BA.convert . unLarge <$> smpP + +instance StrEncoding KEMSecretKey where + strEncode (KEMSecretKey pk) = strEncode (BA.convert pk :: ByteString) + strP = KEMSecretKey . BA.convert <$> strP @ByteString + +instance Encoding KEMPublicKey where + smpEncode (KEMPublicKey pk) = smpEncode . Large $ BA.convert pk + smpP = KEMPublicKey . BA.convert . unLarge <$> smpP + instance StrEncoding KEMPublicKey where strEncode (KEMPublicKey pk) = strEncode (BA.convert pk :: ByteString) strP = KEMPublicKey . BA.convert <$> strP @ByteString @@ -68,6 +84,25 @@ instance Encoding KEMCiphertext where smpEncode (KEMCiphertext c) = smpEncode . Large $ BA.convert c smpP = KEMCiphertext . BA.convert . unLarge <$> smpP +instance Encoding KEMSharedKey where + smpEncode (KEMSharedKey c) = smpEncode (BA.convert c :: ByteString) + smpP = KEMSharedKey . BA.convert <$> smpP @ByteString + +instance StrEncoding KEMCiphertext where + strEncode (KEMCiphertext pk) = strEncode (BA.convert pk :: ByteString) + strP = KEMCiphertext . BA.convert <$> strP @ByteString + +instance StrEncoding KEMSharedKey where + strEncode (KEMSharedKey pk) = strEncode (BA.convert pk :: ByteString) + strP = KEMSharedKey . BA.convert <$> strP @ByteString + +instance ToJSON KEMSecretKey where + toJSON = strToJSON + toEncoding = strToJEncoding + +instance FromJSON KEMSecretKey where + parseJSON = strParseJSON "KEMSecretKey" + instance ToJSON KEMPublicKey where toJSON = strToJSON toEncoding = strToJEncoding @@ -75,8 +110,22 @@ instance ToJSON KEMPublicKey where instance FromJSON KEMPublicKey where parseJSON = strParseJSON "KEMPublicKey" +instance ToJSON KEMCiphertext where + toJSON = strToJSON + toEncoding = strToJEncoding + +instance FromJSON KEMCiphertext where + parseJSON = strParseJSON "KEMCiphertext" + instance ToField KEMSharedKey where toField (KEMSharedKey k) = toField (BA.convert k :: ByteString) instance FromField KEMSharedKey where fromField f = KEMSharedKey . BA.convert @ByteString <$> fromField f + +instance ToJSON KEMSharedKey where + toJSON = strToJSON + toEncoding = strToJEncoding + +instance FromJSON KEMSharedKey where + parseJSON = strParseJSON "KEMSharedKey" diff --git a/src/Simplex/Messaging/Encoding/String.hs b/src/Simplex/Messaging/Encoding/String.hs index e81b0da89..fcefdc73d 100644 --- a/src/Simplex/Messaging/Encoding/String.hs +++ b/src/Simplex/Messaging/Encoding/String.hs @@ -179,6 +179,12 @@ instance (StrEncoding a, StrEncoding b, StrEncoding c, StrEncoding d, StrEncodin strP = (,,,,) <$> strP_ <*> strP_ <*> strP_ <*> strP_ <*> strP {-# INLINE strP #-} +instance (StrEncoding a, StrEncoding b, StrEncoding c, StrEncoding d, StrEncoding e, StrEncoding f) => StrEncoding (a, b, c, d, e, f) where + strEncode (a, b, c, d, e, f) = B.unwords [strEncode a, strEncode b, strEncode c, strEncode d, strEncode e, strEncode f] + {-# INLINE strEncode #-} + strP = (,,,,,) <$> strP_ <*> strP_ <*> strP_ <*> strP_ <*> strP_ <*> strP + {-# INLINE strP #-} + strP_ :: StrEncoding a => Parser a strP_ = strP <* A.space diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 315a4e5a3..2c7685ab6 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -175,7 +175,7 @@ import Data.Functor (($>)) import Data.Kind import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L -import Data.Maybe (isJust, isNothing) +import Data.Maybe (isNothing) import Data.String import Data.Time.Clock.System (SystemTime (..)) import Data.Type.Equality @@ -272,7 +272,7 @@ data RawTransmission = RawTransmission data TransmissionAuth = TASignature C.ASignature | TAAuthenticator C.CbAuthenticator - deriving (Eq, Show) + deriving (Show) -- this encoding is backwards compatible with v6 that used Maybe C.ASignature instead of TAuthorization tAuthBytes :: Maybe TransmissionAuth -> ByteString @@ -338,8 +338,6 @@ data Command (p :: Party) where deriving instance Show (Command p) -deriving instance Eq (Command p) - data SubscriptionMode = SMSubscribe | SMOnlyCreate deriving (Eq, Show) @@ -746,9 +744,6 @@ data AProtocolType = forall p. ProtocolTypeI p => AProtocolType (SProtocolType p deriving instance Show AProtocolType -instance Eq AProtocolType where - AProtocolType p == AProtocolType p' = isJust $ testEquality p p' - instance TestEquality SProtocolType where testEquality SPSMP SPSMP = Just Refl testEquality SPNTF SPNTF = Just Refl diff --git a/src/Simplex/Messaging/Server/QueueStore.hs b/src/Simplex/Messaging/Server/QueueStore.hs index 56ce9b679..cd1b94215 100644 --- a/src/Simplex/Messaging/Server/QueueStore.hs +++ b/src/Simplex/Messaging/Server/QueueStore.hs @@ -17,14 +17,14 @@ data QueueRec = QueueRec notifier :: !(Maybe NtfCreds), status :: !ServerQueueStatus } - deriving (Eq, Show) + deriving (Show) data NtfCreds = NtfCreds { notifierId :: !NotifierId, notifierKey :: !NtfPublicAuthKey, rcvNtfDhSecret :: !RcvNtfDhSecret } - deriving (Eq, Show) + deriving (Show) instance StrEncoding NtfCreds where strEncode NtfCreds {notifierId, notifierKey, rcvNtfDhSecret} = strEncode (notifierId, notifierKey, rcvNtfDhSecret) diff --git a/tests/AgentTests.hs b/tests/AgentTests.hs index f0078ae24..09e9e5002 100644 --- a/tests/AgentTests.hs +++ b/tests/AgentTests.hs @@ -4,6 +4,7 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PostfixOperators #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-} @@ -12,12 +13,12 @@ module AgentTests (agentTests) where import AgentTests.ConnectionRequestTests import AgentTests.DoubleRatchetTests (doubleRatchetTests) -import AgentTests.FunctionalAPITests (functionalAPITests) +import AgentTests.FunctionalAPITests (functionalAPITests, pattern Msg, pattern Msg') import AgentTests.MigrationTests (migrationTests) import AgentTests.NotificationTests (notificationTests) import AgentTests.SQLiteTests (storeTests) import Control.Concurrent -import Control.Monad (forM_) +import Control.Monad (forM_, when) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Maybe (fromJust) @@ -26,15 +27,18 @@ import GHC.Stack (withFrozenCallStack) import Network.HTTP.Types (urlEncode) import SMPAgentClient import SMPClient (testKeyHash, testPort, testPort2, testStoreLogFile, withSmpServer, withSmpServerStoreLogOn) -import Simplex.Messaging.Agent.Protocol +import Simplex.Messaging.Agent.Protocol hiding (MID) import qualified Simplex.Messaging.Agent.Protocol as A +import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), PQEncryption (..), pattern PQEncOn, pattern PQEncOff) +import qualified Simplex.Messaging.Crypto.Ratchet as CR import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Protocol (ErrorType (..), MsgBody) +import Simplex.Messaging.Protocol (ErrorType (..)) import Simplex.Messaging.Transport (ATransport (..), TProxy (..), Transport (..)) import Simplex.Messaging.Util (bshow) import System.Directory (removeFile) import System.Timeout import Test.Hspec +import Util agentTests :: ATransport -> Spec agentTests (ATransport t) = do @@ -46,24 +50,25 @@ agentTests (ATransport t) = do describe "Migration tests" migrationTests describe "SMP agent protocol syntax" $ syntaxTests t describe "Establishing duplex connection (via agent protocol)" $ do - -- These tests are disabled because the agent does not work correctly with multiple connected TCP clients - xit "should connect via one server and one agent" $ do - smpAgentTest2_1_1 $ testDuplexConnection t - xit "should connect via one server and one agent (random IDs)" $ do - smpAgentTest2_1_1 $ testDuplexConnRandomIds t + skip "These tests are disabled because the agent does not work correctly with multiple connected TCP clients" $ + describe "one agent" $ do + it "should connect via one server and one agent" $ do + smpAgentTest2_1_1 $ testDuplexConnection t + it "should connect via one server and one agent (random IDs)" $ do + smpAgentTest2_1_1 $ testDuplexConnRandomIds t it "should connect via one server and 2 agents" $ do smpAgentTest2_2_1 $ testDuplexConnection t it "should connect via one server and 2 agents (random IDs)" $ do smpAgentTest2_2_1 $ testDuplexConnRandomIds t - it "should connect via 2 servers and 2 agents" $ do - smpAgentTest2_2_2 $ testDuplexConnection t - it "should connect via 2 servers and 2 agents (random IDs)" $ do - smpAgentTest2_2_2 $ testDuplexConnRandomIds t + describe "should connect via 2 servers and 2 agents" $ do + pqMatrix2 t smpAgentTest2_2_2 testDuplexConnection' + describe "should connect via 2 servers and 2 agents (random IDs)" $ do + pqMatrix2 t smpAgentTest2_2_2 testDuplexConnRandomIds' describe "Establishing connections via `contact connection`" $ do - it "should connect via contact connection with one server and 3 agents" $ do - smpAgentTest3 $ testContactConnection t - it "should connect via contact connection with one server and 2 agents (random IDs)" $ do - smpAgentTest2_2_1 $ testContactConnRandomIds t + describe "should connect via contact connection with one server and 3 agents" $ do + pqMatrix3 t smpAgentTest3 testContactConnection + describe "should connect via contact connection with one server and 2 agents (random IDs)" $ do + pqMatrix2NoInv t smpAgentTest2_2_1 testContactConnRandomIds it "should support rejecting contact request" $ do smpAgentTest2_2_1 $ testRejectContactRequest t describe "Connection subscriptions" $ do @@ -72,8 +77,8 @@ agentTests (ATransport t) = do it "should send notifications to client when server disconnects" $ do smpAgentServerTest $ testSubscrNotification t describe "Message delivery and server reconnection" $ do - it "should deliver messages after losing server connection and re-connecting" $ do - smpAgentTest2_2_2_needs_server $ testMsgDeliveryServerRestart t + describe "should deliver messages after losing server connection and re-connecting" $ + pqMatrix2 t smpAgentTest2_2_2_needs_server testMsgDeliveryServerRestart it "should connect to the server when server goes up if it initially was down" $ do smpAgentTestN [] $ testServerConnectionAfterError t it "should deliver pending messages after agent restarting" $ do @@ -133,6 +138,9 @@ action #> (corrId, connId, cmd) = withFrozenCallStack $ action `shouldReturn` (c (=#>) :: IO (AEntityTransmissionOrError 'Agent 'AEConn) -> (AEntityTransmission 'Agent 'AEConn -> Bool) -> Expectation action =#> p = withFrozenCallStack $ action >>= (`shouldSatisfy` p . correctTransmission) +pattern MID :: AgentMsgId -> ACommand 'Agent 'AEConn +pattern MID msgId = A.MID msgId PQEncOn + correctTransmission :: (ACorrId, ConnId, Either AgentErrorType cmd) -> (ACorrId, ConnId, cmd) correctTransmission (corrId, connId, cmdOrErr) = case cmdOrErr of Right cmd -> (corrId, connId, cmd) @@ -161,130 +169,175 @@ h #:# err = tryGet `shouldReturn` () Just _ -> error err _ -> return () -pattern Msg :: MsgBody -> ACommand 'Agent e -pattern Msg msgBody <- MSG MsgMeta {integrity = MsgOk} _ msgBody +type PQMatrix2 c = + HasCallStack => + TProxy c -> + (HasCallStack => (c -> c -> IO ()) -> Expectation) -> + (HasCallStack => (c, InitialKeys) -> (c, PQEncryption) -> IO ()) -> + Spec -pattern Msg' :: AgentMsgId -> MsgBody -> ACommand 'Agent e -pattern Msg' aMsgId msgBody <- MSG MsgMeta {integrity = MsgOk, recipient = (aMsgId, _)} _ msgBody +pqMatrix2 :: PQMatrix2 c +pqMatrix2 = pqMatrix2_ True + +pqMatrix2NoInv :: PQMatrix2 c +pqMatrix2NoInv = pqMatrix2_ False + +pqMatrix2_ :: Bool -> PQMatrix2 c +pqMatrix2_ pqInv _ smpTest test = do + it "dh/dh handshake" $ smpTest $ \a b -> test (a, ikPQOff) (b, PQEncOff) + it "dh/pq handshake" $ smpTest $ \a b -> test (a, ikPQOff) (b, PQEncOn) + it "pq/dh handshake" $ smpTest $ \a b -> test (a, ikPQOn) (b, PQEncOff) + it "pq/pq handshake" $ smpTest $ \a b -> test (a, ikPQOn) (b, PQEncOn) + when pqInv $ do + it "pq-inv/dh handshake" $ smpTest $ \a b -> test (a, IKUsePQ) (b, PQEncOff) + it "pq-inv/pq handshake" $ smpTest $ \a b -> test (a, IKUsePQ) (b, PQEncOn) + +pqMatrix3 :: + HasCallStack => + TProxy c -> + (HasCallStack => (c -> c -> c -> IO ()) -> Expectation) -> + (HasCallStack => (c, InitialKeys) -> (c, PQEncryption) -> (c, PQEncryption) -> IO ()) -> + Spec +pqMatrix3 _ smpTest test = do + it "dh" $ smpTest $ \a b c -> test (a, ikPQOff) (b, PQEncOff) (c, PQEncOff) + it "dh/dh/pq" $ smpTest $ \a b c -> test (a, ikPQOff) (b, PQEncOff) (c, PQEncOn) + it "dh/pq/dh" $ smpTest $ \a b c -> test (a, ikPQOff) (b, PQEncOn) (c, PQEncOff) + it "dh/pq/pq" $ smpTest $ \a b c -> test (a, ikPQOff) (b, PQEncOn) (c, PQEncOn) + it "pq/dh/dh" $ smpTest $ \a b c -> test (a, ikPQOn) (b, PQEncOff) (c, PQEncOff) + it "pq/dh/pq" $ smpTest $ \a b c -> test (a, ikPQOn) (b, PQEncOff) (c, PQEncOn) + it "pq/pq/dh" $ smpTest $ \a b c -> test (a, ikPQOn) (b, PQEncOn) (c, PQEncOff) + it "pq" $ smpTest $ \a b c -> test (a, ikPQOn) (b, PQEncOn) (c, PQEncOn) testDuplexConnection :: (HasCallStack, Transport c) => TProxy c -> c -> c -> IO () -testDuplexConnection _ alice bob = do - ("1", "bob", Right (INV cReq)) <- alice #: ("1", "bob", "NEW T INV subscribe") +testDuplexConnection _ alice bob = testDuplexConnection' (alice, ikPQOn) (bob, PQEncOn) + +testDuplexConnection' :: (HasCallStack, Transport c) => (c, InitialKeys) -> (c, PQEncryption) -> IO () +testDuplexConnection' (alice, aPQ) (bob, bPQ) = do + let pq = pqConnectionMode aPQ bPQ + ("1", "bob", Right (INV cReq)) <- alice #: ("1", "bob", "NEW T INV" <> pqConnModeStr aPQ <> " subscribe") let cReq' = strEncode cReq - bob #: ("11", "alice", "JOIN T " <> cReq' <> " subscribe 14\nbob's connInfo") #> ("11", "alice", OK) + bob #: ("11", "alice", "JOIN T " <> cReq' <> enableKEMStr bPQ <> " subscribe 14\nbob's connInfo") #> ("11", "alice", OK) ("", "bob", Right (CONF confId _ "bob's connInfo")) <- (alice <#:) alice #: ("2", "bob", "LET " <> confId <> " 16\nalice's connInfo") #> ("2", "bob", OK) bob <# ("", "alice", INFO "alice's connInfo") - bob <# ("", "alice", CON) - alice <# ("", "bob", CON) + bob <# ("", "alice", CON pq) + alice <# ("", "bob", CON pq) -- message IDs 1 to 3 get assigned to control messages, so first MSG is assigned ID 4 - alice #: ("3", "bob", "SEND F :hello") #> ("3", "bob", MID 4) + alice #: ("3", "bob", "SEND F :hello") #> ("3", "bob", A.MID 4 pq) alice <# ("", "bob", SENT 4) - bob <#= \case ("", "alice", Msg' 4 "hello") -> True; _ -> False + bob <#= \case ("", "alice", Msg' 4 pq' "hello") -> pq == pq'; _ -> False bob #: ("12", "alice", "ACK 4") #> ("12", "alice", OK) - alice #: ("4", "bob", "SEND F :how are you?") #> ("4", "bob", MID 5) + alice #: ("4", "bob", "SEND F :how are you?") #> ("4", "bob", A.MID 5 pq) alice <# ("", "bob", SENT 5) - bob <#= \case ("", "alice", Msg' 5 "how are you?") -> True; _ -> False + bob <#= \case ("", "alice", Msg' 5 pq' "how are you?") -> pq == pq'; _ -> False bob #: ("13", "alice", "ACK 5") #> ("13", "alice", OK) - bob #: ("14", "alice", "SEND F 9\nhello too") #> ("14", "alice", MID 6) + bob #: ("14", "alice", "SEND F 9\nhello too") #> ("14", "alice", A.MID 6 pq) bob <# ("", "alice", SENT 6) - alice <#= \case ("", "bob", Msg' 6 "hello too") -> True; _ -> False + alice <#= \case ("", "bob", Msg' 6 pq' "hello too") -> pq == pq'; _ -> False alice #: ("3a", "bob", "ACK 6") #> ("3a", "bob", OK) - bob #: ("15", "alice", "SEND F 9\nmessage 1") #> ("15", "alice", MID 7) + bob #: ("15", "alice", "SEND F 9\nmessage 1") #> ("15", "alice", A.MID 7 pq) bob <# ("", "alice", SENT 7) - alice <#= \case ("", "bob", Msg' 7 "message 1") -> True; _ -> False + alice <#= \case ("", "bob", Msg' 7 pq' "message 1") -> pq == pq'; _ -> False alice #: ("4a", "bob", "ACK 7") #> ("4a", "bob", OK) alice #: ("5", "bob", "OFF") #> ("5", "bob", OK) - bob #: ("17", "alice", "SEND F 9\nmessage 3") #> ("17", "alice", MID 8) + bob #: ("17", "alice", "SEND F 9\nmessage 3") #> ("17", "alice", A.MID 8 pq) bob <# ("", "alice", MERR 8 (SMP AUTH)) alice #: ("6", "bob", "DEL") #> ("6", "bob", OK) alice #:# "nothing else should be delivered to alice" -testDuplexConnRandomIds :: Transport c => TProxy c -> c -> c -> IO () -testDuplexConnRandomIds _ alice bob = do - ("1", bobConn, Right (INV cReq)) <- alice #: ("1", "", "NEW T INV subscribe") +testDuplexConnRandomIds :: (HasCallStack, Transport c) => TProxy c -> c -> c -> IO () +testDuplexConnRandomIds _ alice bob = testDuplexConnRandomIds' (alice, ikPQOn) (bob, PQEncOn) + +testDuplexConnRandomIds' :: (HasCallStack, Transport c) => (c, InitialKeys) -> (c, PQEncryption) -> IO () +testDuplexConnRandomIds' (alice, aPQ) (bob, bPQ) = do + let pq = pqConnectionMode aPQ bPQ + ("1", bobConn, Right (INV cReq)) <- alice #: ("1", "", "NEW T INV" <> pqConnModeStr aPQ <> " subscribe") let cReq' = strEncode cReq - ("11", aliceConn, Right OK) <- bob #: ("11", "", "JOIN T " <> cReq' <> " subscribe 14\nbob's connInfo") + ("11", aliceConn, Right OK) <- bob #: ("11", "", "JOIN T " <> cReq' <> enableKEMStr bPQ <> " subscribe 14\nbob's connInfo") ("", bobConn', Right (CONF confId _ "bob's connInfo")) <- (alice <#:) bobConn' `shouldBe` bobConn alice #: ("2", bobConn, "LET " <> confId <> " 16\nalice's connInfo") =#> \case ("2", c, OK) -> c == bobConn; _ -> False bob <# ("", aliceConn, INFO "alice's connInfo") - bob <# ("", aliceConn, CON) - alice <# ("", bobConn, CON) - alice #: ("2", bobConn, "SEND F :hello") #> ("2", bobConn, MID 4) + bob <# ("", aliceConn, CON pq) + alice <# ("", bobConn, CON pq) + alice #: ("2", bobConn, "SEND F :hello") #> ("2", bobConn, A.MID 4 pq) alice <# ("", bobConn, SENT 4) - bob <#= \case ("", c, Msg "hello") -> c == aliceConn; _ -> False + bob <#= \case ("", c, Msg' 4 pq' "hello") -> c == aliceConn && pq == pq'; _ -> False bob #: ("12", aliceConn, "ACK 4") #> ("12", aliceConn, OK) - alice #: ("3", bobConn, "SEND F :how are you?") #> ("3", bobConn, MID 5) + alice #: ("3", bobConn, "SEND F :how are you?") #> ("3", bobConn, A.MID 5 pq) alice <# ("", bobConn, SENT 5) - bob <#= \case ("", c, Msg "how are you?") -> c == aliceConn; _ -> False + bob <#= \case ("", c, Msg' 5 pq' "how are you?") -> c == aliceConn && pq == pq'; _ -> False bob #: ("13", aliceConn, "ACK 5") #> ("13", aliceConn, OK) - bob #: ("14", aliceConn, "SEND F 9\nhello too") #> ("14", aliceConn, MID 6) + bob #: ("14", aliceConn, "SEND F 9\nhello too") #> ("14", aliceConn, A.MID 6 pq) bob <# ("", aliceConn, SENT 6) - alice <#= \case ("", c, Msg "hello too") -> c == bobConn; _ -> False + alice <#= \case ("", c, Msg' 6 pq' "hello too") -> c == bobConn && pq == pq'; _ -> False alice #: ("3a", bobConn, "ACK 6") #> ("3a", bobConn, OK) - bob #: ("15", aliceConn, "SEND F 9\nmessage 1") #> ("15", aliceConn, MID 7) + bob #: ("15", aliceConn, "SEND F 9\nmessage 1") #> ("15", aliceConn, A.MID 7 pq) bob <# ("", aliceConn, SENT 7) - alice <#= \case ("", c, Msg "message 1") -> c == bobConn; _ -> False + alice <#= \case ("", c, Msg' 7 pq' "message 1") -> c == bobConn && pq == pq'; _ -> False alice #: ("4a", bobConn, "ACK 7") #> ("4a", bobConn, OK) alice #: ("5", bobConn, "OFF") #> ("5", bobConn, OK) - bob #: ("17", aliceConn, "SEND F 9\nmessage 3") #> ("17", aliceConn, MID 8) + bob #: ("17", aliceConn, "SEND F 9\nmessage 3") #> ("17", aliceConn, A.MID 8 pq) bob <# ("", aliceConn, MERR 8 (SMP AUTH)) alice #: ("6", bobConn, "DEL") #> ("6", bobConn, OK) alice #:# "nothing else should be delivered to alice" -testContactConnection :: Transport c => TProxy c -> c -> c -> c -> IO () -testContactConnection _ alice bob tom = do - ("1", "alice_contact", Right (INV cReq)) <- alice #: ("1", "alice_contact", "NEW T CON subscribe") +testContactConnection :: Transport c => (c, InitialKeys) -> (c, PQEncryption) -> (c, PQEncryption) -> IO () +testContactConnection (alice, aPQ) (bob, bPQ) (tom, tPQ) = do + ("1", "alice_contact", Right (INV cReq)) <- alice #: ("1", "alice_contact", "NEW T CON" <> pqConnModeStr aPQ <> " subscribe") let cReq' = strEncode cReq + abPQ = pqConnectionMode aPQ bPQ + aPQMode = CR.connPQEncryption aPQ - bob #: ("11", "alice", "JOIN T " <> cReq' <> " subscribe 14\nbob's connInfo") #> ("11", "alice", OK) + bob #: ("11", "alice", "JOIN T " <> cReq' <> enableKEMStr bPQ <> " subscribe 14\nbob's connInfo") #> ("11", "alice", OK) ("", "alice_contact", Right (REQ aInvId _ "bob's connInfo")) <- (alice <#:) - alice #: ("2", "bob", "ACPT " <> aInvId <> " 16\nalice's connInfo") #> ("2", "bob", OK) + alice #: ("2", "bob", "ACPT " <> aInvId <> enableKEMStr aPQMode <> " 16\nalice's connInfo") #> ("2", "bob", OK) ("", "alice", Right (CONF bConfId _ "alice's connInfo")) <- (bob <#:) bob #: ("12", "alice", "LET " <> bConfId <> " 16\nbob's connInfo 2") #> ("12", "alice", OK) alice <# ("", "bob", INFO "bob's connInfo 2") - alice <# ("", "bob", CON) - bob <# ("", "alice", CON) - alice #: ("3", "bob", "SEND F :hi") #> ("3", "bob", MID 4) + alice <# ("", "bob", CON abPQ) + bob <# ("", "alice", CON abPQ) + alice #: ("3", "bob", "SEND F :hi") #> ("3", "bob", A.MID 4 abPQ) alice <# ("", "bob", SENT 4) - bob <#= \case ("", "alice", Msg "hi") -> True; _ -> False + bob <#= \case ("", "alice", Msg' 4 pq' "hi") -> pq' == abPQ; _ -> False bob #: ("13", "alice", "ACK 4") #> ("13", "alice", OK) - tom #: ("21", "alice", "JOIN T " <> cReq' <> " subscribe 14\ntom's connInfo") #> ("21", "alice", OK) + let atPQ = pqConnectionMode aPQ tPQ + tom #: ("21", "alice", "JOIN T " <> cReq' <> enableKEMStr tPQ <> " subscribe 14\ntom's connInfo") #> ("21", "alice", OK) ("", "alice_contact", Right (REQ aInvId' _ "tom's connInfo")) <- (alice <#:) - alice #: ("4", "tom", "ACPT " <> aInvId' <> " 16\nalice's connInfo") #> ("4", "tom", OK) + alice #: ("4", "tom", "ACPT " <> aInvId' <> enableKEMStr aPQMode <> " 16\nalice's connInfo") #> ("4", "tom", OK) ("", "alice", Right (CONF tConfId _ "alice's connInfo")) <- (tom <#:) tom #: ("22", "alice", "LET " <> tConfId <> " 16\ntom's connInfo 2") #> ("22", "alice", OK) alice <# ("", "tom", INFO "tom's connInfo 2") - alice <# ("", "tom", CON) - tom <# ("", "alice", CON) - alice #: ("5", "tom", "SEND F :hi there") #> ("5", "tom", MID 4) + alice <# ("", "tom", CON atPQ) + tom <# ("", "alice", CON atPQ) + alice #: ("5", "tom", "SEND F :hi there") #> ("5", "tom", A.MID 4 atPQ) alice <# ("", "tom", SENT 4) - tom <#= \case ("", "alice", Msg "hi there") -> True; _ -> False + tom <#= \case ("", "alice", Msg' 4 pq' "hi there") -> pq' == atPQ; _ -> False tom #: ("23", "alice", "ACK 4") #> ("23", "alice", OK) -testContactConnRandomIds :: Transport c => TProxy c -> c -> c -> IO () -testContactConnRandomIds _ alice bob = do - ("1", aliceContact, Right (INV cReq)) <- alice #: ("1", "", "NEW T CON subscribe") +testContactConnRandomIds :: Transport c => (c, InitialKeys) -> (c, PQEncryption) -> IO () +testContactConnRandomIds (alice, aPQ) (bob, bPQ) = do + let pq = pqConnectionMode aPQ bPQ + ("1", aliceContact, Right (INV cReq)) <- alice #: ("1", "", "NEW T CON" <> pqConnModeStr aPQ <> " subscribe") let cReq' = strEncode cReq - ("11", aliceConn, Right OK) <- bob #: ("11", "", "JOIN T " <> cReq' <> " subscribe 14\nbob's connInfo") + ("11", aliceConn, Right OK) <- bob #: ("11", "", "JOIN T " <> cReq' <> enableKEMStr bPQ <> " subscribe 14\nbob's connInfo") ("", aliceContact', Right (REQ aInvId _ "bob's connInfo")) <- (alice <#:) aliceContact' `shouldBe` aliceContact - ("2", bobConn, Right OK) <- alice #: ("2", "", "ACPT " <> aInvId <> " 16\nalice's connInfo") + ("2", bobConn, Right OK) <- alice #: ("2", "", "ACPT " <> aInvId <> enableKEMStr (CR.connPQEncryption aPQ) <> " 16\nalice's connInfo") ("", aliceConn', Right (CONF bConfId _ "alice's connInfo")) <- (bob <#:) aliceConn' `shouldBe` aliceConn bob #: ("12", aliceConn, "LET " <> bConfId <> " 16\nbob's connInfo 2") #> ("12", aliceConn, OK) alice <# ("", bobConn, INFO "bob's connInfo 2") - alice <# ("", bobConn, CON) - bob <# ("", aliceConn, CON) + alice <# ("", bobConn, CON pq) + bob <# ("", aliceConn, CON pq) - alice #: ("3", bobConn, "SEND F :hi") #> ("3", bobConn, MID 4) + alice #: ("3", bobConn, "SEND F :hi") #> ("3", bobConn, A.MID 4 pq) alice <# ("", bobConn, SENT 4) - bob <#= \case ("", c, Msg "hi") -> c == aliceConn; _ -> False + bob <#= \case ("", c, Msg' 4 pq' "hi") -> c == aliceConn && pq == pq'; _ -> False bob #: ("13", aliceConn, "ACK 4") #> ("13", aliceConn, OK) testRejectContactRequest :: Transport c => TProxy c -> c -> c -> IO () @@ -327,31 +380,32 @@ testSubscrNotification t (server, _) client = do withSmpServer (ATransport t) $ client <# ("", "conn1", ERR (SMP AUTH)) -- this new server does not have the queue -testMsgDeliveryServerRestart :: Transport c => TProxy c -> c -> c -> IO () -testMsgDeliveryServerRestart t alice bob = do +testMsgDeliveryServerRestart :: forall c. Transport c => (c, InitialKeys) -> (c, PQEncryption) -> IO () +testMsgDeliveryServerRestart (alice, aPQ) (bob, bPQ) = do + let pq = pqConnectionMode aPQ bPQ withServer $ do - connect (alice, "alice") (bob, "bob") - bob #: ("1", "alice", "SEND F 2\nhi") #> ("1", "alice", MID 4) + connect' (alice, "alice", aPQ) (bob, "bob", bPQ) + bob #: ("1", "alice", "SEND F 2\nhi") #> ("1", "alice", A.MID 4 pq) bob <# ("", "alice", SENT 4) - alice <#= \case ("", "bob", Msg "hi") -> True; _ -> False + alice <#= \case ("", "bob", Msg' _ pq' "hi") -> pq == pq'; _ -> False alice #: ("11", "bob", "ACK 4") #> ("11", "bob", OK) alice #:# "nothing else delivered before the server is killed" let server = SMPServer "localhost" testPort2 testKeyHash alice <#. ("", "", DOWN server ["bob"]) - bob #: ("2", "alice", "SEND F 11\nhello again") #> ("2", "alice", MID 5) + bob #: ("2", "alice", "SEND F 11\nhello again") #> ("2", "alice", A.MID 5 pq) bob #:# "nothing else delivered before the server is restarted" alice #:# "nothing else delivered before the server is restarted" withServer $ do bob <# ("", "alice", SENT 5) alice <#. ("", "", UP server ["bob"]) - alice <#= \case ("", "bob", Msg "hello again") -> True; _ -> False + alice <#= \case ("", "bob", Msg' _ pq' "hello again") -> pq == pq'; _ -> False alice #: ("12", "bob", "ACK 5") #> ("12", "bob", OK) removeFile testStoreLogFile where - withServer test' = withSmpServerStoreLogOn (ATransport t) testPort2 (const test') `shouldReturn` () + withServer test' = withSmpServerStoreLogOn (transport @c) testPort2 (const test') `shouldReturn` () testServerConnectionAfterError :: forall c. Transport c => TProxy c -> [c] -> IO () testServerConnectionAfterError t _ = do @@ -492,16 +546,37 @@ testResumeDeliveryQuotaExceeded _ alice bob = do -- message 8 is skipped because of alice agent sending "QCONT" message bob #: ("5", "alice", "ACK 9") #> ("5", "alice", OK) -connect :: forall c. Transport c => (c, ByteString) -> (c, ByteString) -> IO () -connect (h1, name1) (h2, name2) = do - ("c1", _, Right (INV cReq)) <- h1 #: ("c1", name2, "NEW T INV subscribe") +ikPQOn :: InitialKeys +ikPQOn = IKNoPQ PQEncOn + +ikPQOff :: InitialKeys +ikPQOff = IKNoPQ PQEncOff + +connect :: Transport c => (c, ByteString) -> (c, ByteString) -> IO () +connect (h1, name1) (h2, name2) = connect' (h1, name1, ikPQOn) (h2, name2, PQEncOn) + +connect' :: forall c. Transport c => (c, ByteString, InitialKeys) -> (c, ByteString, PQEncryption) -> IO () +connect' (h1, name1, pqMode1) (h2, name2, pqMode2) = do + ("c1", _, Right (INV cReq)) <- h1 #: ("c1", name2, "NEW T INV" <> pqConnModeStr pqMode1 <> " subscribe") let cReq' = strEncode cReq - h2 #: ("c2", name1, "JOIN T " <> cReq' <> " subscribe 5\ninfo2") #> ("c2", name1, OK) + h2 #: ("c2", name1, "JOIN T " <> cReq' <> enableKEMStr pqMode2 <> " subscribe 5\ninfo2") #> ("c2", name1, OK) ("", _, Right (CONF connId _ "info2")) <- (h1 <#:) h1 #: ("c3", name2, "LET " <> connId <> " 5\ninfo1") #> ("c3", name2, OK) h2 <# ("", name1, INFO "info1") - h2 <# ("", name1, CON) - h1 <# ("", name2, CON) + let pq = pqConnectionMode pqMode1 pqMode2 + h2 <# ("", name1, CON pq) + h1 <# ("", name2, CON pq) + +pqConnectionMode :: InitialKeys -> PQEncryption -> PQEncryption +pqConnectionMode pqMode1 pqMode2 = PQEncryption $ enablePQ (CR.connPQEncryption pqMode1) && enablePQ pqMode2 + +enableKEMStr :: PQEncryption -> ByteString +enableKEMStr PQEncOn = " " <> strEncode PQEncOn +enableKEMStr _ = "" + +pqConnModeStr :: InitialKeys -> ByteString +pqConnModeStr (IKNoPQ PQEncOff) = "" +pqConnModeStr pq = " " <> strEncode pq sendMessage :: Transport c => (c, ConnId) -> (c, ConnId) -> ByteString -> IO () sendMessage (h1, name1) (h2, name2) msg = do diff --git a/tests/AgentTests/ConnectionRequestTests.hs b/tests/AgentTests/ConnectionRequestTests.hs index 83548182a..eae87651e 100644 --- a/tests/AgentTests/ConnectionRequestTests.hs +++ b/tests/AgentTests/ConnectionRequestTests.hs @@ -1,12 +1,16 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# OPTIONS_GHC -Wno-orphans #-} {-# OPTIONS_GHC -fno-warn-ambiguous-fields #-} module AgentTests.ConnectionRequestTests where import Data.ByteString (ByteString) +import Data.Type.Equality import Network.HTTP.Types (urlEncode) import Simplex.Messaging.Agent.Protocol import qualified Simplex.Messaging.Crypto as C @@ -17,6 +21,17 @@ import Simplex.Messaging.ServiceScheme (ServiceScheme (..)) import Simplex.Messaging.Version import Test.Hspec +deriving instance Eq (ConnectionRequestUri m) + +deriving instance Eq (E2ERatchetParamsUri s a) + +deriving instance Eq (RKEMParams s) + +instance Eq AConnectionRequestUri where + ACR m cr == ACR m' cr' = case testEquality m m' of + Just Refl -> cr == cr' + _ -> False + uri :: String uri = "smp.simplex.im" @@ -61,11 +76,11 @@ connReqData = testDhPubKey :: C.PublicKeyX448 testDhPubKey = "MEIwBQYDK2VvAzkAmKuSYeQ/m0SixPDS8Wq8VBaTS1cW+Lp0n0h4Diu+kUpR+qXx4SDJ32YGEFoGFGSbGPry5Ychr6U=" -testE2ERatchetParams :: E2ERatchetParamsUri 'C.X448 -testE2ERatchetParams = E2ERatchetParamsUri (mkVersionRange 1 1) testDhPubKey testDhPubKey +testE2ERatchetParams :: RcvE2ERatchetParamsUri 'C.X448 +testE2ERatchetParams = E2ERatchetParamsUri (mkVersionRange 1 1) testDhPubKey testDhPubKey Nothing -testE2ERatchetParams12 :: E2ERatchetParamsUri 'C.X448 -testE2ERatchetParams12 = E2ERatchetParamsUri supportedE2EEncryptVRange testDhPubKey testDhPubKey +testE2ERatchetParams12 :: RcvE2ERatchetParamsUri 'C.X448 +testE2ERatchetParams12 = E2ERatchetParamsUri supportedE2EEncryptVRange testDhPubKey testDhPubKey Nothing connectionRequest :: AConnectionRequestUri connectionRequest = @@ -123,7 +138,7 @@ connectionRequestTests = <> urlEncode True testDhKeyStrUri <> "%2Csmp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D" <> urlEncode True testDhKeyStrUri - <> "&e2e=v%3D2%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D" + <> "&e2e=v%3D2-3%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D" strEncode connectionRequestClientDataEmpty `shouldBe` "simplex:/invitation#/?v=2&smp=smp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D" <> urlEncode True testDhKeyStrUri @@ -167,7 +182,7 @@ connectionRequestTests = <> testDhKeyStrUri <> "%2Csmp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D" <> testDhKeyStrUri - <> "&e2e=extra_key%3Dnew%26v%3D2%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D" + <> "&e2e=extra_key%3Dnew%26v%3D2-3%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D" <> "&some_new_param=abc" <> "&v=2-4" ) diff --git a/tests/AgentTests/DoubleRatchetTests.hs b/tests/AgentTests/DoubleRatchetTests.hs index 95e23b333..a0d2deb5f 100644 --- a/tests/AgentTests/DoubleRatchetTests.hs +++ b/tests/AgentTests/DoubleRatchetTests.hs @@ -1,75 +1,175 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} +{-# OPTIONS_GHC -Wno-orphans #-} {-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-} module AgentTests.DoubleRatchetTests where import Control.Concurrent.STM +import Control.Monad (when) import Control.Monad.Except +import Control.Monad.IO.Class import Crypto.Random (ChaChaDRG) import Data.Aeson (FromJSON, ToJSON) import qualified Data.Aeson as J import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import qualified Data.Map.Strict as M +import Data.Type.Equality import Simplex.Messaging.Crypto (Algorithm (..), AlgorithmI, CryptoError, DhAlgorithm) import qualified Simplex.Messaging.Crypto as C +import Simplex.Messaging.Crypto.SNTRUP761.Bindings import Simplex.Messaging.Crypto.Ratchet import Simplex.Messaging.Encoding import Simplex.Messaging.Parsers (parseAll) import Simplex.Messaging.Util ((<$$>)) +import Simplex.Messaging.Version import Test.Hspec doubleRatchetTests :: Spec doubleRatchetTests = do describe "double-ratchet encryption/decryption" $ do - it "should serialize and parse message header" testMessageHeader - it "should encrypt and decrypt messages" $ do - withRatchets @X25519 testEncryptDecrypt - withRatchets @X448 testEncryptDecrypt - it "should encrypt and decrypt skipped messages" $ do - withRatchets @X25519 testSkippedMessages - withRatchets @X448 testSkippedMessages - it "should encrypt and decrypt many messages" $ do - withRatchets @X25519 testManyMessages - it "should allow skipped after ratchet advance" $ do - withRatchets @X25519 testSkippedAfterRatchetAdvance + it "should serialize and parse message header" $ do + testAlgs $ testMessageHeader kdfX3DHE2EEncryptVersion + testAlgs $ testMessageHeader $ max pqRatchetVersion currentE2EEncryptVersion + describe "message tests" $ runMessageTests initRatchets False it "should encode/decode ratchet as JSON" $ do - testKeyJSON C.SX25519 - testKeyJSON C.SX448 - testRatchetJSON C.SX25519 - testRatchetJSON C.SX448 - it "should agree the same ratchet parameters" $ do - testX3dh C.SX25519 - testX3dh C.SX448 - it "should agree the same ratchet parameters with version 1" $ do - testX3dhV1 C.SX25519 - testX3dhV1 C.SX448 + testAlgs testKeyJSON + testAlgs testRatchetJSON + it "should agree the same ratchet parameters" $ testAlgs testX3dh + it "should agree the same ratchet parameters with version 1" $ testAlgs testX3dhV1 + describe "post-quantum hybrid KEM double-ratchet algorithm" $ do + describe "hybrid KEM key agreement" $ do + it "should propose KEM during agreement, but no shared secret" $ testAlgs testPqX3dhProposeInReply + it "should agree shared secret using KEM" $ testAlgs testPqX3dhProposeAccept + it "should reject proposed KEM in reply" $ testAlgs testPqX3dhProposeReject + it "should allow second proposal in reply" $ testAlgs testPqX3dhProposeAgain + describe "hybrid KEM key agreement errors" $ do + it "should fail if reply contains acceptance without proposal" $ testAlgs testPqX3dhAcceptWithoutProposalError + describe "ratchet encryption/decryption" $ do + it "should serialize and parse public KEM params" testKEMParams + it "should serialize and parse message header" $ testAlgs testMessageHeaderKEM + describe "message tests, KEM proposed" $ runMessageTests initRatchetsKEMProposed True + describe "message tests, KEM accepted" $ runMessageTests initRatchetsKEMAccepted False + describe "message tests, KEM proposed again in reply" $ runMessageTests initRatchetsKEMProposedAgain True + it "should disable and re-enable KEM" $ withRatchets_ @X25519 initRatchetsKEMAccepted testDisableEnableKEM + it "should disable and re-enable KEM (always set PQEncryption)" $ withRatchets_ @X25519 initRatchetsKEMAccepted testDisableEnableKEMStrict + it "should enable KEM when it was not enabled in handshake" $ withRatchets_ @X25519 initRatchets testEnableKEM + it "should enable KEM when it was not enabled in handshake (always set PQEncryption)" $ withRatchets_ @X25519 initRatchets testEnableKEMStrict + +runMessageTests :: + (forall a. (AlgorithmI a, DhAlgorithm a) => IO (Ratchet a, Ratchet a, Encrypt a, Decrypt a, EncryptDecryptSpec a)) -> + Bool -> + Spec +runMessageTests initRatchets_ agreeRatchetKEMs = do + it "should encrypt and decrypt messages" $ run $ testEncryptDecrypt agreeRatchetKEMs + it "should encrypt and decrypt skipped messages" $ run $ testSkippedMessages agreeRatchetKEMs + it "should encrypt and decrypt many messages" $ run $ testManyMessages agreeRatchetKEMs + it "should allow skipped after ratchet advance" $ run $ testSkippedAfterRatchetAdvance agreeRatchetKEMs + where + run :: (forall a. (AlgorithmI a, DhAlgorithm a) => TestRatchets a) -> IO () + run test = do + withRatchets_ @X25519 initRatchets_ test + withRatchets_ @X448 initRatchets_ test + + +testAlgs :: (forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO ()) -> IO () +testAlgs test = test C.SX25519 >> test C.SX448 paddedMsgLen :: Int paddedMsgLen = 100 -fullMsgLen :: Int -fullMsgLen = 1 + fullHeaderLen + C.authTagSize + paddedMsgLen +fullMsgLen :: Version -> Int +fullMsgLen v = headerLenLength + fullHeaderLen + C.authTagSize + paddedMsgLen + where + headerLenLength = if v < pqRatchetVersion then 1 else 3 -- two bytes are added because of two Large used in new encoding -testMessageHeader :: Expectation -testMessageHeader = do - (k, _) <- atomically . C.generateKeyPair @X25519 =<< C.newRandom - let hdr = MsgHeader {msgMaxVersion = currentE2EEncryptVersion, msgDHRs = k, msgPN = 0, msgNs = 0} - parseAll (smpP @(MsgHeader 'X25519)) (smpEncode hdr) `shouldBe` Right hdr +testMessageHeader :: forall a. AlgorithmI a => Version -> C.SAlgorithm a -> Expectation +testMessageHeader v _ = do + (k, _) <- atomically . C.generateKeyPair @a =<< C.newRandom + let hdr = MsgHeader {msgMaxVersion = v, msgDHRs = k, msgKEM = Nothing, msgPN = 0, msgNs = 0} + parseAll (smpP @(MsgHeader a)) (smpEncode hdr) `shouldBe` Right hdr + +testKEMParams :: Expectation +testKEMParams = do + g <- C.newRandom + (kem, _) <- sntrup761Keypair g + let kemParams = ARKP SRKSProposed $ RKParamsProposed kem + parseAll (smpP @ARKEMParams) (smpEncode kemParams) `shouldBe` Right kemParams + (kem', _) <- sntrup761Keypair g + (ct, _) <- sntrup761Enc g kem + let kemParams' = ARKP SRKSAccepted $ RKParamsAccepted ct kem' + parseAll (smpP @ARKEMParams) (smpEncode kemParams') `shouldBe` Right kemParams' + +testMessageHeaderKEM :: forall a. AlgorithmI a => C.SAlgorithm a -> Expectation +testMessageHeaderKEM _ = do + g <- C.newRandom + (k, _) <- atomically $ C.generateKeyPair @a g + (kem, _) <- sntrup761Keypair g + let msgMaxVersion = max pqRatchetVersion currentE2EEncryptVersion + msgKEM = Just . ARKP SRKSProposed $ RKParamsProposed kem + hdr = MsgHeader {msgMaxVersion, msgDHRs = k, msgKEM, msgPN = 0, msgNs = 0} + parseAll (smpP @(MsgHeader a)) (smpEncode hdr) `shouldBe` Right hdr + (kem', _) <- sntrup761Keypair g + (ct, _) <- sntrup761Enc g kem + let msgKEM' = Just . ARKP SRKSAccepted $ RKParamsAccepted ct kem' + hdr' = MsgHeader {msgMaxVersion, msgDHRs = k, msgKEM = msgKEM', msgPN = 0, msgNs = 0} + parseAll (smpP @(MsgHeader a)) (smpEncode hdr') `shouldBe` Right hdr' pattern Decrypted :: ByteString -> Either CryptoError (Either CryptoError ByteString) pattern Decrypted msg <- Right (Right msg) -type TestRatchets a = (AlgorithmI a, DhAlgorithm a) => TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> IO () +type Encrypt a = TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError ByteString) -testEncryptDecrypt :: TestRatchets a -testEncryptDecrypt alice bob = do +type Decrypt a = TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError (Either CryptoError ByteString)) + +type EncryptDecryptSpec a = (TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys), ByteString) -> TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> Expectation + +type TestRatchets a = + TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> + TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> + Encrypt a -> + Decrypt a -> + EncryptDecryptSpec a -> + IO () + +deriving instance Eq (Ratchet a) + +deriving instance Eq (SndRatchet a) + +deriving instance Eq RcvRatchet + +deriving instance Eq RatchetKEM + +deriving instance Eq RatchetKEMAccepted + +deriving instance Eq RatchetInitParams + +deriving instance Eq RatchetKey + +deriving instance Eq (RKEMParams s) + +instance Eq ARKEMParams where + (ARKP s ps) == (ARKP s' ps') = case testEquality s s' of + Just Refl -> ps == ps' + Nothing -> False + +deriving instance Eq (MsgHeader a) + +initRatchetKEM :: (AlgorithmI a, DhAlgorithm a) => TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> IO () +initRatchetKEM s r = encryptDecrypt (Just $ PQEncOn) (const ()) (const ()) (s, "initialising ratchet") r + +testEncryptDecrypt :: (AlgorithmI a, DhAlgorithm a) => Bool -> TestRatchets a +testEncryptDecrypt agreeRatchetKEMs alice bob encrypt decrypt (#>) = do + when agreeRatchetKEMs $ initRatchetKEM bob alice >> initRatchetKEM alice bob (bob, "hello alice") #> alice (alice, "hello bob") #> bob Right b1 <- encrypt bob "how are you, alice?" @@ -88,8 +188,9 @@ testEncryptDecrypt alice bob = do (alice, "I'm here too, same") #> bob pure () -testSkippedMessages :: TestRatchets a -testSkippedMessages alice bob = do +testSkippedMessages :: (AlgorithmI a, DhAlgorithm a) => Bool -> TestRatchets a +testSkippedMessages agreeRatchetKEMs alice bob encrypt decrypt _ = do + when agreeRatchetKEMs $ initRatchetKEM bob alice >> initRatchetKEM alice bob Right msg1 <- encrypt bob "hello alice" Right msg2 <- encrypt bob "hello there again" Right msg3 <- encrypt bob "are you there?" @@ -99,8 +200,9 @@ testSkippedMessages alice bob = do Decrypted "hello alice" <- decrypt alice msg1 pure () -testManyMessages :: TestRatchets a -testManyMessages alice bob = do +testManyMessages :: (AlgorithmI a, DhAlgorithm a) => Bool -> TestRatchets a +testManyMessages agreeRatchetKEMs alice bob _ _ (#>) = do + when agreeRatchetKEMs $ initRatchetKEM bob alice >> initRatchetKEM alice bob (bob, "b1") #> alice (bob, "b2") #> alice (bob, "b3") #> alice @@ -117,8 +219,9 @@ testManyMessages alice bob = do (bob, "b15") #> alice (bob, "b16") #> alice -testSkippedAfterRatchetAdvance :: TestRatchets a -testSkippedAfterRatchetAdvance alice bob = do +testSkippedAfterRatchetAdvance :: (AlgorithmI a, DhAlgorithm a) => Bool -> TestRatchets a +testSkippedAfterRatchetAdvance agreeRatchetKEMs alice bob encrypt decrypt (#>) = do + when agreeRatchetKEMs $ initRatchetKEM bob alice >> initRatchetKEM alice bob (bob, "b1") #> alice Right b2 <- encrypt bob "b2" Right b3 <- encrypt bob "b3" @@ -152,6 +255,74 @@ testSkippedAfterRatchetAdvance alice bob = do Decrypted "b11" <- decrypt alice b11 pure () +testDisableEnableKEM :: forall a. (AlgorithmI a, DhAlgorithm a) => TestRatchets a +testDisableEnableKEM alice bob _ _ _ = do + (bob, "hello alice") !#> alice + (alice, "hello bob") !#> bob + (bob, "disabling KEM") !#>\ alice + (alice, "still disabling KEM") !#> bob + (bob, "now KEM is disabled") \#> alice + (alice, "KEM is disabled for both sides") \#> bob + (bob, "trying to enable KEM") \#>! alice + (alice, "but unless alice enables it too it won't enable") \#> bob + (bob, "KEM is disabled") \#> alice + (alice, "KEM is disabled for both sides") \#> bob + (bob, "enabling KEM again") \#>! alice + (alice, "and alice accepts it this time") \#>! bob + (bob, "still enabling KEM") \#>! alice + (alice, "now KEM is enabled") !#> bob + (bob, "KEM is enabled for both sides") !#> alice + +testDisableEnableKEMStrict :: forall a. (AlgorithmI a, DhAlgorithm a) => TestRatchets a +testDisableEnableKEMStrict alice bob _ _ _ = do + (bob, "hello alice") !#>! alice + (alice, "hello bob") !#>! bob + (bob, "disabling KEM") !#>\ alice + (alice, "still disabling KEM") !#>! bob + (bob, "now KEM is disabled") \#>\ alice + (alice, "KEM is disabled for both sides") \#>\ bob + (bob, "trying to enable KEM") \#>! alice + (alice, "but unless alice enables it too it won't enable") \#>\ bob + (bob, "KEM is disabled") \#>! alice + (alice, "KEM is disabled for both sides") \#>\ bob + (bob, "enabling KEM again") \#>! alice + (alice, "and alice accepts it this time") \#>! bob + (bob, "still enabling KEM") \#>! alice + (alice, "now KEM is enabled") !#>! bob + (bob, "KEM is enabled for both sides") !#>! alice + +testEnableKEM :: forall a. (AlgorithmI a, DhAlgorithm a) => TestRatchets a +testEnableKEM alice bob _ _ _ = do + (bob, "hello alice") \#> alice + (alice, "hello bob") \#> bob + (bob, "enabling KEM") \#>! alice + (bob, "KEM not enabled yet") \#>! alice + (alice, "accepting KEM") \#>! bob + (alice, "KEM not enabled yet here too") \#>! bob + (bob, "KEM is still not enabled") \#>! alice + (alice, "now KEM is enabled") !#> bob + (bob, "now KEM is enabled for both sides") !#> alice + (alice, "disabling KEM") !#>\ bob + (bob, "KEM not disabled yet") !#> alice + (alice, "KEM disabled") \#> bob + (bob, "KEM disabled on both sides") \#> alice + +testEnableKEMStrict :: forall a. (AlgorithmI a, DhAlgorithm a) => TestRatchets a +testEnableKEMStrict alice bob _ _ _ = do + (bob, "hello alice") \#>\ alice + (alice, "hello bob") \#>\ bob + (bob, "enabling KEM") \#>! alice + (bob, "KEM not enabled yet") \#>! alice + (alice, "accepting KEM") \#>! bob + (alice, "KEM not enabled yet here too") \#>! bob + (bob, "KEM is still not enabled") \#>! alice + (alice, "now KEM is enabled") !#>! bob + (bob, "now KEM is enabled for both sides") !#>! alice + (alice, "disabling KEM") !#>\ bob + (bob, "KEM not disabled yet") !#>! alice + (alice, "KEM disabled") \#>\ bob + (bob, "KEM disabled on both sides") \#>! alice + testKeyJSON :: forall a. AlgorithmI a => C.SAlgorithm a -> IO () testKeyJSON _ = do (k, pk) <- atomically . C.generateKeyPair @a =<< C.newRandom @@ -160,7 +331,7 @@ testKeyJSON _ = do testRatchetJSON :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO () testRatchetJSON _ = do - (alice, bob) <- initRatchets @a + (alice, bob, _, _, _) <- initRatchets @a testEncodeDecode alice testEncodeDecode bob @@ -173,77 +344,246 @@ testEncodeDecode x = do testX3dh :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO () testX3dh _ = do g <- C.newRandom - (pkBob1, pkBob2, e2eBob) <- atomically $ generateE2EParams @a g currentE2EEncryptVersion - (pkAlice1, pkAlice2, e2eAlice) <- atomically $ generateE2EParams @a g currentE2EEncryptVersion - let paramsBob = x3dhSnd pkBob1 pkBob2 e2eAlice - paramsAlice = x3dhRcv pkAlice1 pkAlice2 e2eBob + let v = max pqRatchetVersion currentE2EEncryptVersion + (pkBob1, pkBob2, Nothing, AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g v Nothing + (pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQEncOff + let paramsBob = pqX3dhSnd pkBob1 pkBob2 Nothing e2eAlice + paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 Nothing e2eBob paramsAlice `shouldBe` paramsBob testX3dhV1 :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO () testX3dhV1 _ = do g <- C.newRandom - (pkBob1, pkBob2, e2eBob) <- atomically $ generateE2EParams @a g 1 - (pkAlice1, pkAlice2, e2eAlice) <- atomically $ generateE2EParams @a g 1 - let paramsBob = x3dhSnd pkBob1 pkBob2 e2eAlice - paramsAlice = x3dhRcv pkAlice1 pkAlice2 e2eBob + (pkBob1, pkBob2, Nothing, AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g 1 Nothing + (pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams @a g 1 PQEncOff + let paramsBob = pqX3dhSnd pkBob1 pkBob2 Nothing e2eAlice + paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 Nothing e2eBob paramsAlice `shouldBe` paramsBob -(#>) :: (AlgorithmI a, DhAlgorithm a) => (TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys), ByteString) -> TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> Expectation -(alice, msg) #> bob = do - Right msg' <- encrypt alice msg - Decrypted msg'' <- decrypt bob msg' +testPqX3dhProposeInReply :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO () +testPqX3dhProposeInReply _ = do + g <- C.newRandom + let v = max pqRatchetVersion currentE2EEncryptVersion + -- initiate (no KEM) + (pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQEncOff + -- propose KEM in reply + (pkBob1, pkBob2, pKemBob_@(Just _), AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g v (Just $ AUseKEM SRKSProposed ProposeKEM) + Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 pKemBob_ e2eAlice + Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 Nothing e2eBob + paramsAlice `compatibleRatchets` paramsBob + +testPqX3dhProposeAccept :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO () +testPqX3dhProposeAccept _ = do + g <- C.newRandom + let v = max pqRatchetVersion currentE2EEncryptVersion + -- initiate (propose KEM) + (pkAlice1, pkAlice2, pKemAlice_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQEncOn + E2ERatchetParams _ _ _ (Just (RKParamsProposed aliceKem)) <- pure e2eAlice + -- accept KEM + (pkBob1, pkBob2, pKemBob_@(Just _), AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g v (Just $ AUseKEM SRKSAccepted $ AcceptKEM aliceKem) + Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 pKemBob_ e2eAlice + Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 pKemAlice_ e2eBob + paramsAlice `compatibleRatchets` paramsBob + +testPqX3dhProposeReject :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO () +testPqX3dhProposeReject _ = do + g <- C.newRandom + let v = max pqRatchetVersion currentE2EEncryptVersion + -- initiate (propose KEM) + (pkAlice1, pkAlice2, pKemAlice_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQEncOn + E2ERatchetParams _ _ _ (Just (RKParamsProposed _)) <- pure e2eAlice + -- reject KEM + (pkBob1, pkBob2, Nothing, AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g v Nothing + Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 Nothing e2eAlice + Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 pKemAlice_ e2eBob + paramsAlice `compatibleRatchets` paramsBob + +testPqX3dhAcceptWithoutProposalError :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO () +testPqX3dhAcceptWithoutProposalError _ = do + g <- C.newRandom + let v = max pqRatchetVersion currentE2EEncryptVersion + -- initiate (no KEM) + (pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQEncOff + E2ERatchetParams _ _ _ Nothing <- pure e2eAlice + -- incorrectly accept KEM + -- we don't have key in proposal, so we just generate it + (k, _) <- sntrup761Keypair g + (pkBob1, pkBob2, pKemBob_@(Just _), AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g v (Just $ AUseKEM SRKSAccepted $ AcceptKEM k) + pqX3dhSnd pkBob1 pkBob2 pKemBob_ e2eAlice `shouldBe` Left C.CERatchetKEMState + runExceptT (pqX3dhRcv pkAlice1 pkAlice2 Nothing e2eBob) `shouldReturn` Left C.CERatchetKEMState + +testPqX3dhProposeAgain :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO () +testPqX3dhProposeAgain _ = do + g <- C.newRandom + let v = max pqRatchetVersion currentE2EEncryptVersion + -- initiate (propose KEM) + (pkAlice1, pkAlice2, pKemAlice_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQEncOn + E2ERatchetParams _ _ _ (Just (RKParamsProposed _)) <- pure e2eAlice + -- propose KEM again in reply - this is not an error + (pkBob1, pkBob2, pKemBob_@(Just _), AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g v (Just $ AUseKEM SRKSProposed ProposeKEM) + Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 pKemBob_ e2eAlice + Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 pKemAlice_ e2eBob + paramsAlice `compatibleRatchets` paramsBob + +compatibleRatchets :: (RatchetInitParams, x) -> (RatchetInitParams, x) -> Expectation +compatibleRatchets + (RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted}, _) + (RatchetInitParams {assocData = ad, ratchetKey = rk, sndHK = shk, rcvNextHK = rnhk, kemAccepted = ka}, _) = do + assocData == ad && ratchetKey == rk && sndHK == shk && rcvNextHK == rnhk `shouldBe` True + case (kemAccepted, ka) of + (Just RatchetKEMAccepted {rcPQRr, rcPQRss, rcPQRct}, Just RatchetKEMAccepted {rcPQRr = pqk, rcPQRss = pqss, rcPQRct = pqct}) -> + pqk /= rcPQRr && pqss == rcPQRss && pqct == rcPQRct `shouldBe` True + (Nothing, Nothing) -> pure () + _ -> expectationFailure "RatchetInitParams params are not compatible" + +encryptDecrypt :: (AlgorithmI a, DhAlgorithm a) => Maybe PQEncryption -> (Ratchet a -> ()) -> (Ratchet a -> ()) -> EncryptDecryptSpec a +encryptDecrypt pqEnc invalidSnd invalidRcv (alice, msg) bob = do + Right msg' <- withTVar (encrypt_ pqEnc) invalidSnd alice msg + Decrypted msg'' <- decrypt' invalidRcv bob msg' msg'' `shouldBe` msg -withRatchets :: forall a. (AlgorithmI a, DhAlgorithm a) => (TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> IO ()) -> Expectation -withRatchets test = do +-- enable KEM (currently disabled) +(\#>!) :: (AlgorithmI a, DhAlgorithm a) => EncryptDecryptSpec a +(s, msg) \#>! r = encryptDecrypt (Just PQEncOn) noSndKEM noRcvKEM (s, msg) r + +-- enable KEM (currently enabled) +(!#>!) :: (AlgorithmI a, DhAlgorithm a) => EncryptDecryptSpec a +(s, msg) !#>! r = encryptDecrypt (Just PQEncOn) hasSndKEM hasRcvKEM (s, msg) r + +-- KEM enabled (no user preference) +(!#>) :: (AlgorithmI a, DhAlgorithm a) => EncryptDecryptSpec a +(s, msg) !#> r = encryptDecrypt Nothing hasSndKEM hasRcvKEM (s, msg) r + +-- disable KEM (currently enabled) +(!#>\) :: (AlgorithmI a, DhAlgorithm a) => EncryptDecryptSpec a +(s, msg) !#>\ r = encryptDecrypt (Just PQEncOff) hasSndKEM hasRcvKEM (s, msg) r + +-- disable KEM (currently disabled) +(\#>\) :: (AlgorithmI a, DhAlgorithm a) => EncryptDecryptSpec a +(s, msg) \#>\ r = encryptDecrypt (Just PQEncOff) noSndKEM noSndKEM (s, msg) r + +-- KEM disabled (no user preference) +(\#>) :: (AlgorithmI a, DhAlgorithm a) => EncryptDecryptSpec a +(s, msg) \#> r = encryptDecrypt Nothing noSndKEM noSndKEM (s, msg) r + +withRatchets_ :: IO (Ratchet a, Ratchet a, Encrypt a, Decrypt a, EncryptDecryptSpec a) -> TestRatchets a -> Expectation +withRatchets_ initRatchets_ test = do ga <- C.newRandom gb <- C.newRandom - (a, b) <- initRatchets @a + (a, b, encrypt, decrypt, (#>)) <- initRatchets_ alice <- newTVarIO (ga, a, M.empty) bob <- newTVarIO (gb, b, M.empty) - test alice bob `shouldReturn` () + test alice bob encrypt decrypt (#>) `shouldReturn` () -initRatchets :: (AlgorithmI a, DhAlgorithm a) => IO (Ratchet a, Ratchet a) +initRatchets :: (AlgorithmI a, DhAlgorithm a) => IO (Ratchet a, Ratchet a, Encrypt a, Decrypt a, EncryptDecryptSpec a) initRatchets = do g <- C.newRandom - (pkBob1, pkBob2, e2eBob) <- atomically $ generateE2EParams g currentE2EEncryptVersion - (pkAlice1, pkAlice2, e2eAlice) <- atomically $ generateE2EParams g currentE2EEncryptVersion - let paramsBob = x3dhSnd pkBob1 pkBob2 e2eAlice - paramsAlice = x3dhRcv pkAlice1 pkAlice2 e2eBob + let v = max pqRatchetVersion currentE2EEncryptVersion + (pkBob1, pkBob2, _pKemParams@Nothing, AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams g v Nothing + (pkAlice1, pkAlice2, _pKem@Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams g v PQEncOff + Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 Nothing e2eAlice + Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 Nothing e2eBob (_, pkBob3) <- atomically $ C.generateKeyPair g let bob = initSndRatchet supportedE2EEncryptVRange (C.publicKey pkAlice2) pkBob3 paramsBob - alice = initRcvRatchet supportedE2EEncryptVRange pkAlice2 paramsAlice - pure (alice, bob) + alice = initRcvRatchet supportedE2EEncryptVRange pkAlice2 paramsAlice PQEncOff + pure (alice, bob, encrypt' noSndKEM, decrypt' noRcvKEM, (\#>)) -encrypt_ :: AlgorithmI a => (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError (ByteString, Ratchet a, SkippedMsgDiff)) -encrypt_ (_, rc, _) msg = - runExceptT (rcEncrypt rc paddedMsgLen msg) +initRatchetsKEMProposed :: forall a. (AlgorithmI a, DhAlgorithm a) => IO (Ratchet a, Ratchet a, Encrypt a, Decrypt a, EncryptDecryptSpec a) +initRatchetsKEMProposed = do + g <- C.newRandom + let v = max pqRatchetVersion currentE2EEncryptVersion + -- initiate (no KEM) + (pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams g v PQEncOff + -- propose KEM in reply + let useKem = AUseKEM SRKSProposed ProposeKEM + (pkBob1, pkBob2, pKemParams_@(Just _), AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams g v (Just useKem) + Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 pKemParams_ e2eAlice + Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 Nothing e2eBob + (_, pkBob3) <- atomically $ C.generateKeyPair g + let bob = initSndRatchet supportedE2EEncryptVRange (C.publicKey pkAlice2) pkBob3 paramsBob + alice = initRcvRatchet supportedE2EEncryptVRange pkAlice2 paramsAlice PQEncOn + pure (alice, bob, encrypt' hasSndKEM, decrypt' hasRcvKEM, (!#>)) + +initRatchetsKEMAccepted :: forall a. (AlgorithmI a, DhAlgorithm a) => IO (Ratchet a, Ratchet a, Encrypt a, Decrypt a, EncryptDecryptSpec a) +initRatchetsKEMAccepted = do + g <- C.newRandom + let v = max pqRatchetVersion currentE2EEncryptVersion + -- initiate (propose) + (pkAlice1, pkAlice2, pKem_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams g v PQEncOn + E2ERatchetParams _ _ _ (Just (RKParamsProposed aliceKem)) <- pure e2eAlice + -- accept + let useKem = AUseKEM SRKSAccepted (AcceptKEM aliceKem) + (pkBob1, pkBob2, pKemParams_@(Just _), AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams g v (Just useKem) + Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 pKemParams_ e2eAlice + Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 pKem_ e2eBob + (_, pkBob3) <- atomically $ C.generateKeyPair g + let bob = initSndRatchet supportedE2EEncryptVRange (C.publicKey pkAlice2) pkBob3 paramsBob + alice = initRcvRatchet supportedE2EEncryptVRange pkAlice2 paramsAlice PQEncOn + pure (alice, bob, encrypt' hasSndKEM, decrypt' hasRcvKEM, (!#>)) + +initRatchetsKEMProposedAgain :: forall a. (AlgorithmI a, DhAlgorithm a) => IO (Ratchet a, Ratchet a, Encrypt a, Decrypt a, EncryptDecryptSpec a) +initRatchetsKEMProposedAgain = do + g <- C.newRandom + let v = max pqRatchetVersion currentE2EEncryptVersion + -- initiate (propose KEM) + (pkAlice1, pkAlice2, pKem_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams g v PQEncOn + -- propose KEM again in reply + let useKem = AUseKEM SRKSProposed ProposeKEM + (pkBob1, pkBob2, pKemParams_@(Just _), AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams g v (Just useKem) + Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 pKemParams_ e2eAlice + Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 pKem_ e2eBob + (_, pkBob3) <- atomically $ C.generateKeyPair g + let bob = initSndRatchet supportedE2EEncryptVRange (C.publicKey pkAlice2) pkBob3 paramsBob + alice = initRcvRatchet supportedE2EEncryptVRange pkAlice2 paramsAlice PQEncOn + pure (alice, bob, encrypt' hasSndKEM, decrypt' hasRcvKEM, (!#>)) + +encrypt_ :: AlgorithmI a => Maybe PQEncryption -> (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError (ByteString, Ratchet a, SkippedMsgDiff)) +encrypt_ enableKem (_, rc, _) msg = + -- print msg >> + runExceptT (rcEncrypt rc paddedMsgLen msg enableKem) >>= either (pure . Left) checkLength where checkLength (msg', rc') = do - B.length msg' `shouldBe` fullMsgLen + B.length msg' `shouldBe` fullMsgLen (maxVersion $ rcVersion rc) pure $ Right (msg', rc', SMDNoChange) decrypt_ :: (AlgorithmI a, DhAlgorithm a) => (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError (Either CryptoError ByteString, Ratchet a, SkippedMsgDiff)) decrypt_ (g, rc, smks) msg = runExceptT $ rcDecrypt g rc smks msg -encrypt :: AlgorithmI a => TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError ByteString) -encrypt = withTVar encrypt_ +encrypt' :: AlgorithmI a => (Ratchet a -> ()) -> Encrypt a +encrypt' = withTVar $ encrypt_ Nothing -decrypt :: (AlgorithmI a, DhAlgorithm a) => TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError (Either CryptoError ByteString)) -decrypt = withTVar decrypt_ +decrypt' :: (AlgorithmI a, DhAlgorithm a) => (Ratchet a -> ()) -> Decrypt a +decrypt' = withTVar decrypt_ + +noSndKEM :: Ratchet a -> () +noSndKEM Ratchet {rcSndKEM = PQEncOn} = error "snd ratchet has KEM" +noSndKEM _ = () + +noRcvKEM :: Ratchet a -> () +noRcvKEM Ratchet {rcRcvKEM = PQEncOn} = error "rcv ratchet has KEM" +noRcvKEM _ = () + +hasSndKEM :: Ratchet a -> () +hasSndKEM Ratchet {rcSndKEM = PQEncOn} = () +hasSndKEM _ = error "snd ratchet has no KEM" + +hasRcvKEM :: Ratchet a -> () +hasRcvKEM Ratchet {rcRcvKEM = PQEncOn} = () +hasRcvKEM _ = error "rcv ratchet has no KEM" withTVar :: AlgorithmI a => ((TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either e (r, Ratchet a, SkippedMsgDiff))) -> + (Ratchet a -> ()) -> TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either e r) -withTVar op rcVar msg = do +withTVar op valid rcVar msg = do (g, rc, smks) <- readTVarIO rcVar applyDiff smks <$$> (testEncodeDecode rc >> op (g, rc, smks) msg) >>= \case - Right (res, rc', smks') -> atomically (writeTVar rcVar (g, rc', smks')) >> pure (Right res) + Right (res, rc', smks') -> valid rc' `seq` atomically (writeTVar rcVar (g, rc', smks')) >> pure (Right res) Left e -> pure $ Left e where applyDiff smks (res, rc', smDiff) = (res, rc', applySMDiff smks smDiff) diff --git a/tests/AgentTests/EqInstances.hs b/tests/AgentTests/EqInstances.hs new file mode 100644 index 000000000..aaaa2de51 --- /dev/null +++ b/tests/AgentTests/EqInstances.hs @@ -0,0 +1,25 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# OPTIONS_GHC -Wno-orphans #-} + +module AgentTests.EqInstances where + +import Data.Type.Equality +import Simplex.Messaging.Agent.Store + +instance Eq SomeConn where + SomeConn d c == SomeConn d' c' = case testEquality d d' of + Just Refl -> c == c' + _ -> False + +deriving instance Eq (Connection d) + +deriving instance Eq (SConnType d) + +deriving instance Eq (StoredRcvQueue q) + +deriving instance Eq (StoredSndQueue q) + +deriving instance Eq (DBQueueId q) + +deriving instance Eq ClientNtfCreds diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 5870266a7..a9a261711 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -1,5 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} @@ -9,7 +10,9 @@ {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} +{-# OPTIONS_GHC -Wno-orphans #-} {-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-} module AgentTests.FunctionalAPITests @@ -20,6 +23,9 @@ module AgentTests.FunctionalAPITests makeConnection, exchangeGreetingsMsgId, switchComplete, + createConnection, + joinConnection, + sendMessage, runRight, runRight_, get, @@ -29,7 +35,9 @@ module AgentTests.FunctionalAPITests nGet, (##>), (=##>), + pattern CON, pattern Msg, + pattern Msg', agentCfgV7, ) where @@ -45,7 +53,7 @@ import Data.Either (isRight) import Data.Int (Int64) import Data.List (nub) import qualified Data.Map as M -import Data.Maybe (isNothing) +import Data.Maybe (isJust, isNothing) import qualified Data.Set as S import Data.Time.Clock (diffUTCTime, getCurrentTime) import Data.Time.Clock.System (SystemTime (..), getSystemTime) @@ -53,17 +61,21 @@ import Data.Type.Equality import qualified Database.SQLite.Simple as SQL import SMPAgentClient import SMPClient (cfg, testPort, testPort2, testStoreLogFile2, withSmpServer, withSmpServerConfigOn, withSmpServerOn, withSmpServerStoreLogOn, withSmpServerStoreMsgLogOn, withSmpServerV7) -import Simplex.Messaging.Agent +import Simplex.Messaging.Agent hiding (createConnection, joinConnection, sendMessage) +import qualified Simplex.Messaging.Agent as Agent import Simplex.Messaging.Agent.Client (ProtocolTestFailure (..), ProtocolTestStep (..)) import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..), InitialAgentServers (..), createAgentStore) -import Simplex.Messaging.Agent.Protocol as Agent +import Simplex.Messaging.Agent.Protocol hiding (CON) +import qualified Simplex.Messaging.Agent.Protocol as Agent import Simplex.Messaging.Agent.Store.SQLite (MigrationConfirmation (..), SQLiteStore (dbNew)) import Simplex.Messaging.Agent.Store.SQLite.Common (withTransaction') import Simplex.Messaging.Client (NetworkConfig (..), ProtocolClientConfig (..), TransportSessionMode (TSMEntity, TSMUser), defaultSMPClientConfig) import qualified Simplex.Messaging.Crypto as C +import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), PQEncryption (..), pattern PQEncOn, pattern PQEncOff) +import qualified Simplex.Messaging.Crypto.Ratchet as CR import Simplex.Messaging.Encoding.String import Simplex.Messaging.Notifications.Transport (authBatchCmdsNTFVersion) -import Simplex.Messaging.Protocol (BasicAuth, ErrorType (..), MsgBody, ProtocolServer (..), SubscriptionMode (..), supportedSMPClientVRange) +import Simplex.Messaging.Protocol (AProtocolType (..), BasicAuth, ErrorType (..), MsgBody, ProtocolServer (..), SubscriptionMode (..), supportedSMPClientVRange) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Server.Env.STM (ServerConfig (..)) import Simplex.Messaging.Server.Expiration @@ -72,10 +84,21 @@ import Simplex.Messaging.Version import System.Directory (copyFile, renameFile) import Test.Hspec import UnliftIO +import Util import XFTPClient (testXFTPServer) type AEntityTransmission e = (ACorrId, ConnId, ACommand 'Agent e) +deriving instance Eq (ACommand p e) + +instance Eq AConnectionMode where + ACM m == ACM m' = isJust $ testEquality m m' + +instance Eq AProtocolType where + AProtocolType p == AProtocolType p' = isJust $ testEquality p p' + +-- deriving instance Eq (ValidFileDescription p) + (##>) :: (HasCallStack, MonadUnliftIO m) => m (AEntityTransmission e) -> AEntityTransmission e -> m () a ##> t = withTimeout a (`shouldBe` t) @@ -118,13 +141,19 @@ pGet c = do DISCONNECT {} -> pGet c _ -> pure t -pattern Msg :: MsgBody -> ACommand 'Agent e -pattern Msg msgBody <- MSG MsgMeta {integrity = MsgOk} _ msgBody +pattern CON :: ACommand 'Agent 'AEConn +pattern CON = Agent.CON PQEncOn -pattern MsgErr :: AgentMsgId -> MsgErrorType -> MsgBody -> ACommand 'Agent e +pattern Msg :: MsgBody -> ACommand 'Agent e +pattern Msg msgBody <- MSG MsgMeta {integrity = MsgOk, pqEncryption = PQEncOn} _ msgBody + +pattern Msg' :: AgentMsgId -> PQEncryption -> MsgBody -> ACommand 'Agent e +pattern Msg' aMsgId pqEncryption msgBody <- MSG MsgMeta {integrity = MsgOk, recipient = (aMsgId, _), pqEncryption} _ msgBody + +pattern MsgErr :: AgentMsgId -> MsgErrorType -> MsgBody -> ACommand 'Agent 'AEConn pattern MsgErr msgId err msgBody <- MSG MsgMeta {recipient = (msgId, _), integrity = MsgError err} _ msgBody -pattern Rcvd :: AgentMsgId -> ACommand 'Agent e +pattern Rcvd :: AgentMsgId -> ACommand 'Agent 'AEConn pattern Rcvd agentMsgId <- RCVD MsgMeta {integrity = MsgOk} [MsgReceipt {agentMsgId, msgRcptStatus = MROk}] smpCfgVPrev :: ProtocolClientConfig @@ -184,6 +213,18 @@ inAnyOrder g rs = do expected :: a -> (a -> Bool) -> Bool expected r rp = rp r +createConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SubscriptionMode -> m (ConnId, ConnectionRequestUri c) +createConnection c userId enableNtfs cMode clientData = Agent.createConnection c userId enableNtfs cMode clientData (IKNoPQ PQEncOn) + +joinConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> m ConnId +joinConnection c userId enableNtfs cReq connInfo = Agent.joinConnection c userId enableNtfs cReq connInfo PQEncOn + +sendMessage :: AgentErrorMonad m => AgentClient -> ConnId -> SMP.MsgFlags -> MsgBody -> m AgentMsgId +sendMessage c connId msgFlags msgBody = do + (msgId, pqEnc) <- Agent.sendMessage c connId PQEncOn msgFlags msgBody + liftIO $ pqEnc `shouldBe` PQEncOn + pure msgId + functionalAPITests :: ATransport -> Spec functionalAPITests t = do describe "Establishing duplex connection" $ do @@ -259,9 +300,9 @@ functionalAPITests t = do describe "Batching SMP commands" $ do it "should subscribe to multiple (200) subscriptions with batching" $ testBatchedSubscriptions 200 10 t - -- 200 subscriptions gets very slow with test coverage, use below test instead - xit "should subscribe to multiple (6) subscriptions with batching" $ - testBatchedSubscriptions 6 3 t + skip "faster version of the previous test (200 subscriptions gets very slow with test coverage)" $ + it "should subscribe to multiple (6) subscriptions with batching" $ + testBatchedSubscriptions 6 3 t describe "Async agent commands" $ do it "should connect using async agent commands" $ withSmpServer t testAsyncCommands @@ -381,20 +422,22 @@ testMatrix2 t runTest = do it "current to v7" $ withSmpServerV7 t $ runTestCfg2 agentCfg agentCfgV7 3 runTest it "current with v7 server" $ withSmpServerV7 t $ runTestCfg2 agentCfg agentCfg 3 runTest it "current" $ withSmpServer t $ runTestCfg2 agentCfg agentCfg 3 runTest - it "prev" $ withSmpServer t $ runTestCfg2 agentCfgVPrev agentCfgVPrev 3 runTest - it "prev to current" $ withSmpServer t $ runTestCfg2 agentCfgVPrev agentCfg 3 runTest - it "current to prev" $ withSmpServer t $ runTestCfg2 agentCfg agentCfgVPrev 3 runTest + skip "TODO PQ versioning" $ describe "TODO fails with previous version" $ do + it "prev" $ withSmpServer t $ runTestCfg2 agentCfgVPrev agentCfgVPrev 3 runTest + it "prev to current" $ withSmpServer t $ runTestCfg2 agentCfgVPrev agentCfg 3 runTest + it "current to prev" $ withSmpServer t $ runTestCfg2 agentCfg agentCfgVPrev 3 runTest testRatchetMatrix2 :: ATransport -> (AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec testRatchetMatrix2 t runTest = do it "ratchet current" $ withSmpServer t $ runTestCfg2 agentCfg agentCfg 3 runTest - pendingV "ratchet prev" $ withSmpServer t $ runTestCfg2 agentCfgRatchetVPrev agentCfgRatchetVPrev 3 runTest - pendingV "ratchets prev to current" $ withSmpServer t $ runTestCfg2 agentCfgRatchetVPrev agentCfg 3 runTest - pendingV "ratchets current to prev" $ withSmpServer t $ runTestCfg2 agentCfg agentCfgRatchetVPrev 3 runTest + skip "TODO PQ versioning" $ describe "TODO fails with previous version" $ do + pendingV "ratchet prev" $ withSmpServer t $ runTestCfg2 agentCfgRatchetVPrev agentCfgRatchetVPrev 3 runTest + pendingV "ratchets prev to current" $ withSmpServer t $ runTestCfg2 agentCfgRatchetVPrev agentCfg 3 runTest + pendingV "ratchets current to prev" $ withSmpServer t $ runTestCfg2 agentCfg agentCfgRatchetVPrev 3 runTest where - pendingV = + pendingV d = let vr = e2eEncryptVRange agentCfg - in if minVersion vr == maxVersion vr then xit else it + in if minVersion vr == maxVersion vr then skip "previous version is not supported" . it d else it d testServerMatrix2 :: ATransport -> (InitialAgentServers -> IO ()) -> Spec testServerMatrix2 t runTest = do @@ -483,7 +526,7 @@ runAgentClientContactTest alice bob baseId = (_, qInfo) <- createConnection alice 1 True SCMContact Nothing SMSubscribe aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe ("", _, REQ invId _ "bob's connInfo") <- get alice - bobId <- acceptContact alice True invId "alice's connInfo" SMSubscribe + bobId <- acceptContact alice True invId "alice's connInfo" PQEncOn SMSubscribe ("", _, CONF confId _ "alice's connInfo") <- get bob allowConnection bob aliceId confId "bob's connInfo" get alice ##> ("", bobId, INFO "bob's connInfo") @@ -979,7 +1022,7 @@ testRatchetSync t = withAgentClients2 $ \alice bob -> withSmpServerStoreMsgLogOn t testPort $ \_ -> do (aliceId, bobId, bob2) <- setupDesynchronizedRatchet alice bob runRight $ do - ConnectionStats {ratchetSyncState} <- synchronizeRatchet bob2 aliceId False + ConnectionStats {ratchetSyncState} <- synchronizeRatchet bob2 aliceId PQEncOn False liftIO $ ratchetSyncState `shouldBe` RSStarted get alice =##> ratchetSyncP bobId RSAgreed get bob2 =##> ratchetSyncP aliceId RSAgreed @@ -1023,7 +1066,7 @@ setupDesynchronizedRatchet alice bob = do runRight_ $ do subscribeConnection bob2 aliceId - Left Agent.CMD {cmdErr = PROHIBITED} <- runExceptT $ synchronizeRatchet bob2 aliceId False + Left Agent.CMD {cmdErr = PROHIBITED} <- runExceptT $ synchronizeRatchet bob2 aliceId PQEncOn False 8 <- sendMessage alice bobId SMP.noMsgFlags "hello 5" get alice ##> ("", bobId, SENT 8) @@ -1054,7 +1097,7 @@ testRatchetSyncServerOffline t = withAgentClients2 $ \alice bob -> do ("", "", DOWN _ _) <- nGet alice ("", "", DOWN _ _) <- nGet bob2 - ConnectionStats {ratchetSyncState} <- runRight $ synchronizeRatchet bob2 aliceId False + ConnectionStats {ratchetSyncState} <- runRight $ synchronizeRatchet bob2 aliceId PQEncOn False liftIO $ ratchetSyncState `shouldBe` RSStarted withSmpServerStoreMsgLogOn t testPort $ \_ -> do @@ -1084,7 +1127,7 @@ testRatchetSyncClientRestart t = do setupDesynchronizedRatchet alice bob ("", "", DOWN _ _) <- nGet alice ("", "", DOWN _ _) <- nGet bob2 - ConnectionStats {ratchetSyncState} <- runRight $ synchronizeRatchet bob2 aliceId False + ConnectionStats {ratchetSyncState} <- runRight $ synchronizeRatchet bob2 aliceId PQEncOn False liftIO $ ratchetSyncState `shouldBe` RSStarted disconnectAgentClient bob2 bob3 <- getSMPAgentClient' 3 agentCfg initAgentServers testDB2 @@ -1111,7 +1154,7 @@ testRatchetSyncSuspendForeground t = do ("", "", DOWN _ _) <- nGet alice ("", "", DOWN _ _) <- nGet bob2 - ConnectionStats {ratchetSyncState} <- runRight $ synchronizeRatchet bob2 aliceId False + ConnectionStats {ratchetSyncState} <- runRight $ synchronizeRatchet bob2 aliceId PQEncOn False liftIO $ ratchetSyncState `shouldBe` RSStarted suspendAgent bob2 0 @@ -1145,10 +1188,10 @@ testRatchetSyncSimultaneous t = do ("", "", DOWN _ _) <- nGet alice ("", "", DOWN _ _) <- nGet bob2 - ConnectionStats {ratchetSyncState = bRSS} <- runRight $ synchronizeRatchet bob2 aliceId False + ConnectionStats {ratchetSyncState = bRSS} <- runRight $ synchronizeRatchet bob2 aliceId PQEncOn False liftIO $ bRSS `shouldBe` RSStarted - ConnectionStats {ratchetSyncState = aRSS} <- runRight $ synchronizeRatchet alice bobId True + ConnectionStats {ratchetSyncState = aRSS} <- runRight $ synchronizeRatchet alice bobId PQEncOn True liftIO $ aRSS `shouldBe` RSStarted withSmpServerStoreMsgLogOn t testPort $ \_ -> do @@ -1203,17 +1246,23 @@ testOnlyCreatePull = withAgentClients2 $ \alice bob -> runRight_ $ do pure r makeConnection :: AgentClient -> AgentClient -> ExceptT AgentErrorType IO (ConnId, ConnId) -makeConnection alice bob = makeConnectionForUsers alice 1 bob 1 +makeConnection = makeConnection_ PQEncOn + +makeConnection_ :: PQEncryption -> AgentClient -> AgentClient -> ExceptT AgentErrorType IO (ConnId, ConnId) +makeConnection_ pqEnc alice bob = makeConnectionForUsers_ pqEnc alice 1 bob 1 makeConnectionForUsers :: AgentClient -> UserId -> AgentClient -> UserId -> ExceptT AgentErrorType IO (ConnId, ConnId) -makeConnectionForUsers alice aliceUserId bob bobUserId = do - (bobId, qInfo) <- createConnection alice aliceUserId True SCMInvitation Nothing SMSubscribe - aliceId <- joinConnection bob bobUserId True qInfo "bob's connInfo" SMSubscribe +makeConnectionForUsers = makeConnectionForUsers_ PQEncOn + +makeConnectionForUsers_ :: PQEncryption -> AgentClient -> UserId -> AgentClient -> UserId -> ExceptT AgentErrorType IO (ConnId, ConnId) +makeConnectionForUsers_ pqEnc alice aliceUserId bob bobUserId = do + (bobId, qInfo) <- Agent.createConnection alice aliceUserId True SCMInvitation Nothing (CR.IKNoPQ pqEnc) SMSubscribe + aliceId <- Agent.joinConnection bob bobUserId True qInfo "bob's connInfo" pqEnc SMSubscribe ("", _, CONF confId _ "bob's connInfo") <- get alice allowConnection alice bobId confId "alice's connInfo" - get alice ##> ("", bobId, CON) + get alice ##> ("", bobId, Agent.CON pqEnc) get bob ##> ("", aliceId, INFO "alice's connInfo") - get bob ##> ("", aliceId, CON) + get bob ##> ("", aliceId, Agent.CON pqEnc) pure (aliceId, bobId) testInactiveNoSubs :: ATransport -> IO () @@ -1336,8 +1385,8 @@ testBatchedSubscriptions nCreate nDel t = do a <- getSMPAgentClient' 1 agentCfg initAgentServers2 testDB b <- getSMPAgentClient' 2 agentCfg initAgentServers2 testDB2 conns <- runServers $ do - conns <- replicateM (nCreate :: Int) $ makeConnection a b - forM_ conns $ \(aId, bId) -> exchangeGreetings a bId b aId + conns <- replicateM (nCreate :: Int) $ makeConnection_ PQEncOff a b + forM_ conns $ \(aId, bId) -> exchangeGreetings_ PQEncOff a bId b aId let (aIds', bIds') = unzip $ take nDel conns delete a bIds' delete b aIds' @@ -1358,10 +1407,10 @@ testBatchedSubscriptions nCreate nDel t = do (aIds', bIds') = unzip conns' subscribe a bIds subscribe b aIds - forM_ conns' $ \(aId, bId) -> exchangeGreetingsMsgId 6 a bId b aId + forM_ conns' $ \(aId, bId) -> exchangeGreetingsMsgId_ PQEncOff 6 a bId b aId void $ resubscribeConnections a bIds void $ resubscribeConnections b aIds - forM_ conns' $ \(aId, bId) -> exchangeGreetingsMsgId 8 a bId b aId + forM_ conns' $ \(aId, bId) -> exchangeGreetingsMsgId_ PQEncOff 8 a bId b aId delete a bIds' delete b aIds' deleteFail a bIds' @@ -1400,10 +1449,10 @@ testBatchedSubscriptions nCreate nDel t = do testAsyncCommands :: IO () testAsyncCommands = withAgentClients2 $ \alice bob -> runRight_ $ do - bobId <- createConnectionAsync alice 1 "1" True SCMInvitation SMSubscribe + bobId <- createConnectionAsync alice 1 "1" True SCMInvitation (IKNoPQ PQEncOn) SMSubscribe ("1", bobId', INV (ACR _ qInfo)) <- get alice liftIO $ bobId' `shouldBe` bobId - aliceId <- joinConnectionAsync bob 1 "2" True qInfo "bob's connInfo" SMSubscribe + aliceId <- joinConnectionAsync bob 1 "2" True qInfo "bob's connInfo" PQEncOn SMSubscribe ("2", aliceId', OK) <- get bob liftIO $ aliceId' `shouldBe` aliceId ("", _, CONF confId _ "bob's connInfo") <- get alice @@ -1450,7 +1499,7 @@ testAsyncCommands = testAsyncCommandsRestore :: ATransport -> IO () testAsyncCommandsRestore t = do alice <- getSMPAgentClient' 1 agentCfg initAgentServers testDB - bobId <- runRight $ createConnectionAsync alice 1 "1" True SCMInvitation SMSubscribe + bobId <- runRight $ createConnectionAsync alice 1 "1" True SCMInvitation (IKNoPQ PQEncOn) SMSubscribe liftIO $ noMessages alice "alice doesn't receive INV because server is down" disconnectAgentClient alice alice' <- liftIO $ getSMPAgentClient' 2 agentCfg initAgentServers testDB @@ -1467,7 +1516,7 @@ testAcceptContactAsync = (_, qInfo) <- createConnection alice 1 True SCMContact Nothing SMSubscribe aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe ("", _, REQ invId _ "bob's connInfo") <- get alice - bobId <- acceptContactAsync alice "1" True invId "alice's connInfo" SMSubscribe + bobId <- acceptContactAsync alice "1" True invId "alice's connInfo" PQEncOn SMSubscribe get alice =##> \case ("1", c, OK) -> c == bobId; _ -> False ("", _, CONF confId _ "alice's connInfo") <- get bob allowConnection bob aliceId confId "bob's connInfo" @@ -1685,6 +1734,8 @@ testWaitDeliveryTimeout t = do liftIO $ noMessages alice "nothing else should be delivered to alice" liftIO $ noMessages bob "nothing else should be delivered to bob" + liftIO $ threadDelay 100000 + withSmpServerStoreLogOn t testPort $ \_ -> do nGet bob =##> \case ("", "", UP _ [cId]) -> cId == aliceId; _ -> False liftIO $ noMessages alice "nothing else should be delivered to alice" @@ -1751,10 +1802,10 @@ testJoinConnectionAsyncReplyError t = do a <- getSMPAgentClient' 1 agentCfg initAgentServers testDB b <- getSMPAgentClient' 2 agentCfg initAgentServersSrv2 testDB2 (aId, bId) <- withSmpServerStoreLogOn t testPort $ \_ -> runRight $ do - bId <- createConnectionAsync a 1 "1" True SCMInvitation SMSubscribe + bId <- createConnectionAsync a 1 "1" True SCMInvitation (IKNoPQ PQEncOn) SMSubscribe ("1", bId', INV (ACR _ qInfo)) <- get a liftIO $ bId' `shouldBe` bId - aId <- joinConnectionAsync b 1 "2" True qInfo "bob's connInfo" SMSubscribe + aId <- joinConnectionAsync b 1 "2" True qInfo "bob's connInfo" PQEncOn SMSubscribe liftIO $ threadDelay 500000 ConnectionStats {rcvQueuesInfo = [], sndQueuesInfo = [SndQueueInfo {}]} <- getConnectionServers b aId pure (aId, bId) @@ -2344,7 +2395,7 @@ testDeliveryReceiptsConcurrent t = t1 <- liftIO getCurrentTime concurrently_ (runClient "a" a bId) (runClient "b" b aId) t2 <- liftIO getCurrentTime - diffUTCTime t2 t1 `shouldSatisfy` (< 15) + diffUTCTime t2 t1 `shouldSatisfy` (< 60) liftIO $ noMessages a "nothing else should be delivered to alice" liftIO $ noMessages b "nothing else should be delivered to bob" where @@ -2355,7 +2406,6 @@ testDeliveryReceiptsConcurrent t = numMsgs = 100 send = runRight_ $ replicateM_ numMsgs $ do - -- liftIO $ print $ cName <> ": sendMessage" void $ sendMessage client connId SMP.noMsgFlags "hello" receive = runRight_ $ @@ -2383,7 +2433,7 @@ testDeliveryReceiptsConcurrent t = receiveLoop (n - 1) getWithTimeout :: ExceptT AgentErrorType IO (AEntityTransmission 'AEConn) getWithTimeout = do - 1000000 `timeout` get client >>= \case + 3000000 `timeout` get client >>= \case Just r -> pure r _ -> error "timeout" @@ -2493,20 +2543,26 @@ testServerMultipleIdentities = testE2ERatchetParams12 exchangeGreetings :: HasCallStack => AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO () -exchangeGreetings = exchangeGreetingsMsgId 4 +exchangeGreetings = exchangeGreetings_ PQEncOn + +exchangeGreetings_ :: HasCallStack => PQEncryption -> AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO () +exchangeGreetings_ pqEnc = exchangeGreetingsMsgId_ pqEnc 4 exchangeGreetingsMsgId :: HasCallStack => Int64 -> AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO () -exchangeGreetingsMsgId msgId alice bobId bob aliceId = do - msgId1 <- sendMessage alice bobId SMP.noMsgFlags "hello" - liftIO $ msgId1 `shouldBe` msgId +exchangeGreetingsMsgId = exchangeGreetingsMsgId_ PQEncOn + +exchangeGreetingsMsgId_ :: HasCallStack => PQEncryption -> Int64 -> AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO () +exchangeGreetingsMsgId_ pqEnc msgId alice bobId bob aliceId = do + msgId1 <- Agent.sendMessage alice bobId pqEnc SMP.noMsgFlags "hello" + liftIO $ msgId1 `shouldBe` (msgId, pqEnc) get alice ##> ("", bobId, SENT msgId) - get bob =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False + get bob =##> \case ("", c, Msg' mId pq "hello") -> c == aliceId && mId == msgId && pq == pqEnc; _ -> False ackMessage bob aliceId msgId Nothing - msgId2 <- sendMessage bob aliceId SMP.noMsgFlags "hello too" + msgId2 <- Agent.sendMessage bob aliceId pqEnc SMP.noMsgFlags "hello too" let msgId' = msgId + 1 - liftIO $ msgId2 `shouldBe` msgId' + liftIO $ msgId2 `shouldBe` (msgId', pqEnc) get bob ##> ("", aliceId, SENT msgId') - get alice =##> \case ("", c, Msg "hello too") -> c == bobId; _ -> False + get alice =##> \case ("", c, Msg' mId pq "hello too") -> c == bobId && mId == msgId' && pq == pqEnc; _ -> False ackMessage alice bobId msgId' Nothing exchangeGreetingsMsgIds :: HasCallStack => AgentClient -> ConnId -> Int64 -> AgentClient -> ConnId -> Int64 -> ExceptT AgentErrorType IO () diff --git a/tests/AgentTests/NotificationTests.hs b/tests/AgentTests/NotificationTests.hs index bb1e687b3..f815fb808 100644 --- a/tests/AgentTests/NotificationTests.hs +++ b/tests/AgentTests/NotificationTests.hs @@ -12,7 +12,26 @@ module AgentTests.NotificationTests where -- import Control.Logger.Simple (LogConfig (..), LogLevel (..), setLogLevel, withGlobalLogging) -import AgentTests.FunctionalAPITests (agentCfgV7, exchangeGreetingsMsgId, get, getSMPAgentClient', makeConnection, nGet, runRight, runRight_, switchComplete, testServerMatrix2, withAgentClientsCfg2, (##>), (=##>), pattern Msg) +import AgentTests.FunctionalAPITests + ( agentCfgV7, + createConnection, + exchangeGreetingsMsgId, + get, + getSMPAgentClient', + joinConnection, + makeConnection, + nGet, + runRight, + runRight_, + sendMessage, + switchComplete, + testServerMatrix2, + withAgentClientsCfg2, + (##>), + (=##>), + pattern CON, + pattern Msg, + ) import Control.Concurrent (ThreadId, killThread, threadDelay) import Control.Monad import Control.Monad.Except @@ -28,10 +47,10 @@ import Data.Text.Encoding (encodeUtf8) import NtfClient import SMPAgentClient (agentCfg, initAgentServers, initAgentServers2, testDB, testDB2, testDB3, testNtfServer, testNtfServer2) import SMPClient (cfg, cfgV7, testPort, testPort2, testStoreLogFile2, withSmpServer, withSmpServerConfigOn, withSmpServerStoreLogOn) -import Simplex.Messaging.Agent +import Simplex.Messaging.Agent hiding (createConnection, joinConnection, sendMessage) import Simplex.Messaging.Agent.Client (ProtocolTestFailure (..), ProtocolTestStep (..), withStore') import Simplex.Messaging.Agent.Env.SQLite (AgentConfig, Env (..), InitialAgentServers) -import Simplex.Messaging.Agent.Protocol +import Simplex.Messaging.Agent.Protocol hiding (CON) import Simplex.Messaging.Agent.Store.SQLite (getSavedNtfToken) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding.String @@ -46,6 +65,7 @@ import Simplex.Messaging.Transport (ATransport) import System.Directory (doesFileExist, removeFile) import Test.Hspec import UnliftIO +import Util removeFileIfExists :: FilePath -> IO () removeFileIfExists filePath = do @@ -125,8 +145,8 @@ testNtfMatrix t runTest = do it "next servers: SMP v7, NTF v2; next clients: v7/v2" $ runNtfTestCfg t cfgV7 ntfServerCfgV2 agentCfgV7 agentCfgV7 runTest it "next servers: SMP v7, NTF v2; curr clients: v6/v1" $ runNtfTestCfg t cfgV7 ntfServerCfgV2 agentCfg agentCfg runTest it "curr servers: SMP v6, NTF v1; curr clients: v6/v1" $ runNtfTestCfg t cfg ntfServerCfg agentCfg agentCfg runTest - -- this case will cannot be supported - see RFC - xit "servers: SMP v6, NTF v1; clients: v7/v2 (not supported)" $ runNtfTestCfg t cfg ntfServerCfg agentCfgV7 agentCfgV7 runTest + skip "this case cannot be supported - see RFC" $ + it "servers: SMP v6, NTF v1; clients: v7/v2 (not supported)" $ runNtfTestCfg t cfg ntfServerCfg agentCfgV7 agentCfgV7 runTest -- servers can be migrated in any order it "servers: next SMP v7, curr NTF v1; curr clients: v6/v1" $ runNtfTestCfg t cfgV7 ntfServerCfg agentCfg agentCfg runTest it "servers: curr SMP v6, next NTF v2; curr clients: v6/v1" $ runNtfTestCfg t cfg ntfServerCfgV2 agentCfg agentCfg runTest diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index 714b7e15e..9665e3833 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -1,16 +1,21 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# OPTIONS_GHC -Wno-orphans #-} {-# OPTIONS_GHC -fno-warn-ambiguous-fields #-} module AgentTests.SQLiteTests (storeTests) where +import AgentTests.EqInstances () import Control.Concurrent.Async (concurrently_) import Control.Concurrent.STM import Control.Exception (SomeException) @@ -40,6 +45,8 @@ import Simplex.Messaging.Agent.Store.SQLite.Common (withTransaction') import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations import qualified Simplex.Messaging.Crypto as C +import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), pattern PQEncOn) +import qualified Simplex.Messaging.Crypto.Ratchet as CR import Simplex.Messaging.Crypto.File (CryptoFile (..)) import Simplex.Messaging.Encoding.String (StrEncoding (..)) import Simplex.Messaging.Protocol (SubscriptionMode (..)) @@ -174,7 +181,17 @@ testForeignKeysEnabled = `shouldThrow` (\e -> SQL.sqlError e == SQL.ErrorConstraint) cData1 :: ConnData -cData1 = ConnData {userId = 1, connId = "conn1", connAgentVersion = 1, enableNtfs = True, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk} +cData1 = + ConnData + { userId = 1, + connId = "conn1", + connAgentVersion = 1, + enableNtfs = True, + lastExternalSndId = 0, + deleted = False, + ratchetSyncState = RSOk, + pqEncryption = CR.PQEncOn + } testPrivateAuthKey :: C.APrivateAuthKey testPrivateAuthKey = C.APrivateAuthKey C.SEd25519 "MC4CAQAwBQYDK2VwBCIEIDfEfevydXXfKajz3sRkcQ7RPvfWUPoq6pu1TYHV1DEe" @@ -467,7 +484,8 @@ mkRcvMsgData internalId internalRcvId externalSndId brokerId internalHash = { integrity = MsgOk, recipient = (unId internalId, ts), sndMsgId = externalSndId, - broker = (brokerId, ts) + broker = (brokerId, ts), + pqEncryption = CR.PQEncOn }, msgType = AM_A_MSG_, msgFlags = SMP.noMsgFlags, @@ -505,6 +523,7 @@ mkSndMsgData internalId internalSndId internalHash = msgType = AM_A_MSG_, msgFlags = SMP.noMsgFlags, msgBody = hw, + pqEncryption = CR.PQEncOn, internalHash, prevMsgHash = internalHash } @@ -643,7 +662,7 @@ testGetPendingServerCommand st = do Right (Just PendingCommand {corrId = corrId'}) <- getPendingServerCommand db (Just smpServer1) corrId' `shouldBe` "4" where - command = AClientCommand $ APC SAEConn $ NEW True (ACM SCMInvitation) SMSubscribe + command = AClientCommand $ APC SAEConn $ NEW True (ACM SCMInvitation) (IKNoPQ PQEncOn) SMSubscribe corruptCmd :: DB.Connection -> ByteString -> ConnId -> IO () corruptCmd db corrId connId = DB.execute db "UPDATE commands SET command = cast('bad' as blob) WHERE conn_id = ? AND corr_id = ?" (connId, corrId) diff --git a/tests/CoreTests/CryptoTests.hs b/tests/CoreTests/CryptoTests.hs index 39bc17c4b..35e82d6d2 100644 --- a/tests/CoreTests/CryptoTests.hs +++ b/tests/CoreTests/CryptoTests.hs @@ -1,5 +1,7 @@ +{-# LANGUAGE GADTs #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# OPTIONS_GHC -Wno-orphans #-} module CoreTests.CryptoTests (cryptoTests) where @@ -13,6 +15,7 @@ import qualified Data.Text as T import Data.Text.Encoding (encodeUtf8) import qualified Data.Text.Lazy as LT import qualified Data.Text.Lazy.Encoding as LE +import Data.Type.Equality import qualified Simplex.Messaging.Crypto as C import qualified Simplex.Messaging.Crypto.Lazy as LC import Simplex.Messaging.Crypto.SNTRUP761.Bindings @@ -91,6 +94,16 @@ cryptoTests = do describe "sntrup761" $ it "should enc/dec key" testSNTRUP761 +instance Eq C.APublicKey where + C.APublicKey a k == C.APublicKey a' k' = case testEquality a a' of + Just Refl -> k == k' + Nothing -> False + +instance Eq C.APrivateKey where + C.APrivateKey a k == C.APrivateKey a' k' = case testEquality a a' of + Just Refl -> k == k' + Nothing -> False + testPadUnpadFile :: IO () testPadUnpadFile = do let f = "tests/tmp/testpad" diff --git a/tests/CoreTests/TRcvQueuesTests.hs b/tests/CoreTests/TRcvQueuesTests.hs index 70f2d93ab..64181a179 100644 --- a/tests/CoreTests/TRcvQueuesTests.hs +++ b/tests/CoreTests/TRcvQueuesTests.hs @@ -4,6 +4,7 @@ module CoreTests.TRcvQueuesTests where +import AgentTests.EqInstances () import qualified Data.List.NonEmpty as L import qualified Data.Map as M import qualified Data.Set as S diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index f1ed84d68..88de2c8b8 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -34,6 +34,7 @@ import UnliftIO.Concurrent import qualified UnliftIO.Exception as E import UnliftIO.STM (TMVar, atomically, newEmptyTMVarIO, takeTMVar) import UnliftIO.Timeout (timeout) +import Util testHost :: NonEmpty TransportHost testHost = "localhost" @@ -60,12 +61,12 @@ testServerStatsBackupFile :: FilePath testServerStatsBackupFile = "tests/tmp/smp-server-stats.log" xit' :: (HasCallStack, Example a) => String -> a -> SpecWith (Arg a) -xit' = if os == "linux" then xit else it +xit' d = if os == "linux" then skip "skipped on Linux" . it d else it d xit'' :: (HasCallStack, Example a) => String -> a -> SpecWith (Arg a) xit'' d t = do ci <- runIO $ lookupEnv "CI" - (if ci == Just "true" then xit else it) d t + (if ci == Just "true" then skip "skipped on CI" . it d else it d) t testSMPClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandle c -> m a) -> m a testSMPClient = testSMPClientVR supportedClientSMPRelayVRange diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index d6938fa0f..4065c7e19 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -8,7 +8,9 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} +{-# OPTIONS_GHC -Wno-orphans #-} module ServerTests where @@ -23,6 +25,7 @@ import Data.ByteString.Base64 import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import qualified Data.Set as S +import Data.Type.Equality import GHC.Stack (withFrozenCallStack) import SMPClient import qualified Simplex.Messaging.Crypto as C @@ -919,6 +922,15 @@ sampleSig = Just $ TASignature "e8JK+8V3fq6kOLqco/SaKlpNaQ7i1gfOrXoqekEl42u4mF8B noAuth :: (Char, Maybe BasicAuth) noAuth = ('A', Nothing) +deriving instance Eq TransmissionAuth + +instance Eq C.ASignature where + C.ASignature a s == C.ASignature a' s' = case testEquality a a' of + Just Refl -> s == s' + _ -> False + +deriving instance Eq (C.Signature a) + syntaxTests :: ATransport -> Spec syntaxTests (ATransport t) = do it "unknown command" $ ("", "abcd", "1234", ('H', 'E', 'L', 'L', 'O')) >#> ("", "abcd", "1234", ERR $ CMD UNKNOWN) diff --git a/tests/Util.hs b/tests/Util.hs new file mode 100644 index 000000000..a52fee32c --- /dev/null +++ b/tests/Util.hs @@ -0,0 +1,6 @@ +module Util where + +import Test.Hspec + +skip :: String -> SpecWith a -> SpecWith a +skip = before_ . pendingWith From dd2bd11584d98360cadb599ab235566ee9e818f5 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin Date: Mon, 4 Mar 2024 19:06:51 +0000 Subject: [PATCH 06/30] parameterize version scopes with phantom types (#1026) * parameterize version scopes with phantom types * move Version to another module * parens --- simplexmq.cabal | 1 + src/Simplex/FileTransfer/Agent.hs | 2 +- src/Simplex/FileTransfer/Client.hs | 2 +- src/Simplex/FileTransfer/Protocol.hs | 111 +++------------- src/Simplex/FileTransfer/Server.hs | 2 +- src/Simplex/FileTransfer/Server/Store.hs | 3 +- src/Simplex/FileTransfer/Transport.hs | 122 +++++++++++++++++- src/Simplex/Messaging/Agent.hs | 20 +-- src/Simplex/Messaging/Agent/Client.hs | 62 ++++----- src/Simplex/Messaging/Agent/Env/SQLite.hs | 17 +-- src/Simplex/Messaging/Agent/Protocol.hs | 75 +++++++---- src/Simplex/Messaging/Agent/Store.hs | 8 +- src/Simplex/Messaging/Agent/Store/SQLite.hs | 22 ++-- src/Simplex/Messaging/Client.hs | 72 +++++------ src/Simplex/Messaging/Client/Agent.hs | 4 +- src/Simplex/Messaging/Crypto/Ratchet.hs | 121 ++++++++++++----- src/Simplex/Messaging/Notifications/Client.hs | 6 +- .../Messaging/Notifications/Protocol.hs | 14 +- src/Simplex/Messaging/Notifications/Server.hs | 6 +- .../Messaging/Notifications/Server/Env.hs | 8 +- .../Messaging/Notifications/Transport.hs | 63 +++++---- src/Simplex/Messaging/Protocol.hs | 89 ++++++++----- src/Simplex/Messaging/Server.hs | 8 +- src/Simplex/Messaging/Server/Env/STM.hs | 9 +- src/Simplex/Messaging/Transport.hs | 85 +++++++----- src/Simplex/Messaging/Version.hs | 71 +++++----- src/Simplex/Messaging/Version/Internal.hs | 25 ++++ src/Simplex/RemoteControl/Client.hs | 12 +- src/Simplex/RemoteControl/Invitation.hs | 4 +- src/Simplex/RemoteControl/Types.hs | 28 +++- tests/AgentTests/ConnectionRequestTests.hs | 11 +- tests/AgentTests/DoubleRatchetTests.hs | 22 +++- tests/AgentTests/FunctionalAPITests.hs | 47 ++++--- tests/AgentTests/SQLiteTests.hs | 13 +- tests/CoreTests/BatchingTests.hs | 33 +++-- tests/CoreTests/ProtocolErrorTests.hs | 2 +- tests/CoreTests/TRcvQueuesTests.hs | 5 +- tests/CoreTests/VersionRangeTests.hs | 36 +++--- tests/NtfClient.hs | 15 ++- tests/NtfServerTests.hs | 21 +-- tests/SMPClient.hs | 36 +++--- tests/ServerTests.hs | 41 +++--- tests/XFTPAgent.hs | 3 +- tests/XFTPServerTests.hs | 4 +- 44 files changed, 807 insertions(+), 554 deletions(-) create mode 100644 src/Simplex/Messaging/Version/Internal.hs diff --git a/simplexmq.cabal b/simplexmq.cabal index f1d7c1bec..5c45e42c9 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -159,6 +159,7 @@ library Simplex.Messaging.Transport.WebSockets Simplex.Messaging.Util Simplex.Messaging.Version + Simplex.Messaging.Version.Internal Simplex.RemoteControl.Client Simplex.RemoteControl.Discovery Simplex.RemoteControl.Discovery.Multicast diff --git a/src/Simplex/FileTransfer/Agent.hs b/src/Simplex/FileTransfer/Agent.hs index f45389462..5666b63ff 100644 --- a/src/Simplex/FileTransfer/Agent.hs +++ b/src/Simplex/FileTransfer/Agent.hs @@ -52,8 +52,8 @@ import Simplex.FileTransfer.Client.Main import Simplex.FileTransfer.Crypto import Simplex.FileTransfer.Description import Simplex.FileTransfer.Protocol (FileParty (..), SFileParty (..)) -import qualified Simplex.FileTransfer.Protocol as XFTP import Simplex.FileTransfer.Transport (XFTPRcvChunkSpec (..)) +import qualified Simplex.FileTransfer.Transport as XFTP import Simplex.FileTransfer.Types import Simplex.FileTransfer.Util (removePath, uniqueCombine) import Simplex.Messaging.Agent.Client diff --git a/src/Simplex/FileTransfer/Client.hs b/src/Simplex/FileTransfer/Client.hs index 84c99eb48..d9c4c058a 100644 --- a/src/Simplex/FileTransfer/Client.hs +++ b/src/Simplex/FileTransfer/Client.hs @@ -57,7 +57,7 @@ import UnliftIO.Directory data XFTPClient = XFTPClient { http2Client :: HTTP2Client, transportSession :: TransportSession FileResponse, - thParams :: THandleParams, + thParams :: THandleParams XFTPVersion, config :: XFTPClientConfig } diff --git a/src/Simplex/FileTransfer/Protocol.hs b/src/Simplex/FileTransfer/Protocol.hs index a9de56ddb..3094aae9e 100644 --- a/src/Simplex/FileTransfer/Protocol.hs +++ b/src/Simplex/FileTransfer/Protocol.hs @@ -1,9 +1,11 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TemplateHaskell #-} @@ -14,9 +16,7 @@ module Simplex.FileTransfer.Protocol where -import Control.Applicative ((<|>)) import qualified Data.Aeson.TH as J -import qualified Data.Attoparsec.ByteString.Char8 as A import Data.Bifunctor (first) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B @@ -25,11 +25,11 @@ import Data.List.NonEmpty (NonEmpty (..)) import Data.Maybe (isNothing) import Data.Type.Equality import Data.Word (Word32) +import Simplex.FileTransfer.Transport (VersionXFTP, XFTPErrorType (..), XFTPVersion, pattern VersionXFTP, xftpClientHandshake) import Simplex.Messaging.Client (authTransmission) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Notifications.Transport (ntfClientHandshake) import Simplex.Messaging.Parsers import Simplex.Messaging.Protocol ( BasicAuth, @@ -56,11 +56,10 @@ import Simplex.Messaging.Protocol tParse, ) import Simplex.Messaging.Transport (THandleParams (..), TransportError (..)) -import Simplex.Messaging.Util (bshow, (<$?>)) -import Simplex.Messaging.Version +import Simplex.Messaging.Util ((<$?>)) -currentXFTPVersion :: Version -currentXFTPVersion = 1 +currentXFTPVersion :: VersionXFTP +currentXFTPVersion = VersionXFTP 1 xftpBlockSize :: Int xftpBlockSize = 16384 @@ -142,10 +141,10 @@ instance ProtocolMsgTag FileCmdTag where instance FilePartyI p => ProtocolMsgTag (FileCommandTag p) where decodeTag s = decodeTag s >>= (\(FCT _ t) -> checkParty' t) -instance Protocol XFTPErrorType FileResponse where +instance Protocol XFTPVersion XFTPErrorType FileResponse where type ProtoCommand FileResponse = FileCmd type ProtoType FileResponse = 'PXFTP - protocolClientHandshake = ntfClientHandshake + protocolClientHandshake = xftpClientHandshake protocolPing = FileCmd SFRecipient PING protocolError = \case FRErr e -> Just e @@ -175,7 +174,7 @@ data FileInfo = FileInfo type XFTPFileId = ByteString -instance FilePartyI p => ProtocolEncoding XFTPErrorType (FileCommand p) where +instance FilePartyI p => ProtocolEncoding XFTPVersion XFTPErrorType (FileCommand p) where type Tag (FileCommand p) = FileCommandTag p encodeProtocol _v = \case FNEW file rKeys auth_ -> e (FNEW_, ' ', file, rKeys, auth_) @@ -191,7 +190,7 @@ instance FilePartyI p => ProtocolEncoding XFTPErrorType (FileCommand p) where protocolP v tag = (\(FileCmd _ c) -> checkParty c) <$?> protocolP v (FCT (sFileParty @p) tag) - fromProtocolError = fromProtocolError @XFTPErrorType @FileResponse + fromProtocolError = fromProtocolError @XFTPVersion @XFTPErrorType @FileResponse {-# INLINE fromProtocolError #-} checkCredentials (auth, _, fileId, _) cmd = case cmd of @@ -208,7 +207,7 @@ instance FilePartyI p => ProtocolEncoding XFTPErrorType (FileCommand p) where | isNothing auth || B.null fileId -> Left $ CMD NO_AUTH | otherwise -> Right cmd -instance ProtocolEncoding XFTPErrorType FileCmd where +instance ProtocolEncoding XFTPVersion XFTPErrorType FileCmd where type Tag FileCmd = FileCmdTag encodeProtocol _v (FileCmd _ c) = encodeProtocol _v c @@ -225,7 +224,7 @@ instance ProtocolEncoding XFTPErrorType FileCmd where FACK_ -> pure FACK PING_ -> pure PING - fromProtocolError = fromProtocolError @XFTPErrorType @FileResponse + fromProtocolError = fromProtocolError @XFTPVersion @XFTPErrorType @FileResponse {-# INLINE fromProtocolError #-} checkCredentials t (FileCmd p c) = FileCmd p <$> checkCredentials t c @@ -276,7 +275,7 @@ data FileResponse | FRPong deriving (Show) -instance ProtocolEncoding XFTPErrorType FileResponse where +instance ProtocolEncoding XFTPVersion XFTPErrorType FileResponse where type Tag FileResponse = FileResponseTag encodeProtocol _v = \case FRSndIds fId rIds -> e (FRSndIds_, ' ', fId, rIds) @@ -319,82 +318,6 @@ instance ProtocolEncoding XFTPErrorType FileResponse where | B.null entId = Right cmd | otherwise = Left $ CMD HAS_AUTH -data XFTPErrorType - = -- | incorrect block format, encoding or signature size - BLOCK - | -- | incorrect SMP session ID (TLS Finished message / tls-unique binding RFC5929) - SESSION - | -- | SMP command is unknown or has invalid syntax - CMD {cmdErr :: CommandError} - | -- | command authorization error - bad signature or non-existing SMP queue - AUTH - | -- | incorrent file size - SIZE - | -- | storage quota exceeded - QUOTA - | -- | incorrent file digest - DIGEST - | -- | file encryption/decryption failed - CRYPTO - | -- | no expected file body in request/response or no file on the server - NO_FILE - | -- | unexpected file body - HAS_FILE - | -- | file IO error - FILE_IO - | -- | bad redirect data - REDIRECT {redirectError :: String} - | -- | internal server error - INTERNAL - | -- | used internally, never returned by the server (to be removed) - DUPLICATE_ -- not part of SMP protocol, used internally - deriving (Eq, Read, Show) - -instance StrEncoding XFTPErrorType where - strEncode = \case - CMD e -> "CMD " <> bshow e - REDIRECT e -> "REDIRECT " <> bshow e - e -> bshow e - strP = - "CMD " *> (CMD <$> parseRead1) - <|> "REDIRECT " *> (REDIRECT <$> parseRead A.takeByteString) - <|> parseRead1 - -instance Encoding XFTPErrorType where - smpEncode = \case - BLOCK -> "BLOCK" - SESSION -> "SESSION" - CMD err -> "CMD " <> smpEncode err - AUTH -> "AUTH" - SIZE -> "SIZE" - QUOTA -> "QUOTA" - DIGEST -> "DIGEST" - CRYPTO -> "CRYPTO" - NO_FILE -> "NO_FILE" - HAS_FILE -> "HAS_FILE" - FILE_IO -> "FILE_IO" - REDIRECT err -> "REDIRECT " <> smpEncode err - INTERNAL -> "INTERNAL" - DUPLICATE_ -> "DUPLICATE_" - - smpP = - A.takeTill (== ' ') >>= \case - "BLOCK" -> pure BLOCK - "SESSION" -> pure SESSION - "CMD" -> CMD <$> _smpP - "AUTH" -> pure AUTH - "SIZE" -> pure SIZE - "QUOTA" -> pure QUOTA - "DIGEST" -> pure DIGEST - "CRYPTO" -> pure CRYPTO - "NO_FILE" -> pure NO_FILE - "HAS_FILE" -> pure HAS_FILE - "FILE_IO" -> pure FILE_IO - "REDIRECT" -> REDIRECT <$> _smpP - "INTERNAL" -> pure INTERNAL - "DUPLICATE_" -> pure DUPLICATE_ - _ -> fail "bad error type" - checkParty :: forall t p p'. (FilePartyI p, FilePartyI p') => t p' -> Either String (t p) checkParty c = case testEquality (sFileParty @p) (sFileParty @p') of Just Refl -> Right c @@ -405,12 +328,12 @@ checkParty' c = case testEquality (sFileParty @p) (sFileParty @p') of Just Refl -> Just c _ -> Nothing -xftpEncodeAuthTransmission :: ProtocolEncoding e c => THandleParams -> C.APrivateAuthKey -> Transmission c -> Either TransportError ByteString +xftpEncodeAuthTransmission :: ProtocolEncoding XFTPVersion e c => THandleParams XFTPVersion -> C.APrivateAuthKey -> Transmission c -> Either TransportError ByteString xftpEncodeAuthTransmission thParams pKey (corrId, fId, msg) = do let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (corrId, fId, msg) xftpEncodeBatch1 . (,tToSend) =<< authTransmission Nothing (Just pKey) corrId tForAuth -xftpEncodeTransmission :: ProtocolEncoding e c => THandleParams -> Transmission c -> Either TransportError ByteString +xftpEncodeTransmission :: ProtocolEncoding XFTPVersion e c => THandleParams XFTPVersion -> Transmission c -> Either TransportError ByteString xftpEncodeTransmission thParams (corrId, fId, msg) = do let t = encodeTransmission thParams (corrId, fId, msg) xftpEncodeBatch1 (Nothing, t) @@ -419,7 +342,7 @@ xftpEncodeTransmission thParams (corrId, fId, msg) = do xftpEncodeBatch1 :: SentRawTransmission -> Either TransportError ByteString xftpEncodeBatch1 t = first (const TELargeMsg) $ C.pad (tEncodeBatch1 t) xftpBlockSize -xftpDecodeTransmission :: ProtocolEncoding e c => THandleParams -> ByteString -> Either XFTPErrorType (SignedTransmission e c) +xftpDecodeTransmission :: ProtocolEncoding XFTPVersion e c => THandleParams XFTPVersion -> ByteString -> Either XFTPErrorType (SignedTransmission e c) xftpDecodeTransmission thParams t = do t' <- first (const BLOCK) $ C.unPad t case tParse thParams t' of @@ -427,5 +350,3 @@ xftpDecodeTransmission thParams t = do _ -> Left BLOCK $(J.deriveJSON (enumJSON $ dropPrefix "F") ''FileParty) - -$(J.deriveJSON (sumTypeJSON id) ''XFTPErrorType) diff --git a/src/Simplex/FileTransfer/Server.hs b/src/Simplex/FileTransfer/Server.hs index 158429d79..ae202c2b0 100644 --- a/src/Simplex/FileTransfer/Server.hs +++ b/src/Simplex/FileTransfer/Server.hs @@ -69,7 +69,7 @@ type M a = ReaderT XFTPEnv IO a data XFTPTransportRequest = XFTPTransportRequest - { thParams :: THandleParams, + { thParams :: THandleParams XFTPVersion, reqBody :: HTTP2Body, request :: H.Request, sendResponse :: H.Response -> IO () diff --git a/src/Simplex/FileTransfer/Server/Store.hs b/src/Simplex/FileTransfer/Server/Store.hs index 031c46f5b..985fcb9f4 100644 --- a/src/Simplex/FileTransfer/Server/Store.hs +++ b/src/Simplex/FileTransfer/Server/Store.hs @@ -28,7 +28,8 @@ import Data.Int (Int64) import Data.Set (Set) import qualified Data.Set as S import Data.Time.Clock.System (SystemTime (..)) -import Simplex.FileTransfer.Protocol (FileInfo (..), SFileParty (..), XFTPErrorType (..), XFTPFileId) +import Simplex.FileTransfer.Protocol (FileInfo (..), SFileParty (..), XFTPFileId) +import Simplex.FileTransfer.Transport (XFTPErrorType (..)) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol (RcvPublicAuthKey, RecipientId, SenderId) diff --git a/src/Simplex/FileTransfer/Transport.hs b/src/Simplex/FileTransfer/Transport.hs index 90b1a8a44..464a75ac8 100644 --- a/src/Simplex/FileTransfer/Transport.hs +++ b/src/Simplex/FileTransfer/Transport.hs @@ -1,10 +1,19 @@ {-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} module Simplex.FileTransfer.Transport ( supportedFileServerVRange, + xftpClientHandshake, -- stub + XFTPVersion, + VersionXFTP, + pattern VersionXFTP, + XFTPErrorType (..), XFTPRcvChunkSpec (..), ReceiveFileError (..), receiveFile, @@ -14,22 +23,31 @@ module Simplex.FileTransfer.Transport ) where +import Control.Applicative ((<|>)) import qualified Control.Exception as E import Control.Monad import Control.Monad.Except import Control.Monad.IO.Class +import qualified Data.Aeson.TH as J +import qualified Data.Attoparsec.ByteString.Char8 as A import Data.Bifunctor (first) import qualified Data.ByteArray as BA import Data.ByteString.Builder (Builder, byteString) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import qualified Data.ByteString.Lazy.Char8 as LB -import Data.Word (Word32) -import Simplex.FileTransfer.Protocol (XFTPErrorType (..)) +import Data.Word (Word16, Word32) import qualified Simplex.Messaging.Crypto as C import qualified Simplex.Messaging.Crypto.Lazy as LC +import Simplex.Messaging.Encoding +import Simplex.Messaging.Encoding.String +import Simplex.Messaging.Parsers +import Simplex.Messaging.Protocol (CommandError) +import Simplex.Messaging.Transport (HandshakeError (..), THandle, TransportError (..)) import Simplex.Messaging.Transport.HTTP2.File +import Simplex.Messaging.Util (bshow) import Simplex.Messaging.Version +import Simplex.Messaging.Version.Internal import System.IO (Handle, IOMode (..), withFile) data XFTPRcvChunkSpec = XFTPRcvChunkSpec @@ -39,8 +57,26 @@ data XFTPRcvChunkSpec = XFTPRcvChunkSpec } deriving (Show) -supportedFileServerVRange :: VersionRange -supportedFileServerVRange = mkVersionRange 1 1 +data XFTPVersion + +instance VersionScope XFTPVersion + +type VersionXFTP = Version XFTPVersion + +type VersionRangeXFTP = VersionRange XFTPVersion + +pattern VersionXFTP :: Word16 -> VersionXFTP +pattern VersionXFTP v = Version v + +initialXFTPVersion :: VersionXFTP +initialXFTPVersion = VersionXFTP 1 + +supportedFileServerVRange :: VersionRangeXFTP +supportedFileServerVRange = mkVersionRange initialXFTPVersion initialXFTPVersion + +-- XFTP protocol does not support handshake +xftpClientHandshake :: c -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeXFTP -> ExceptT TransportError IO (THandle XFTPVersion c) +xftpClientHandshake _c _ks _keyHash _xftpVRange = throwError $ TEHandshake VERSION sendEncFile :: Handle -> (Builder -> IO ()) -> LC.SbState -> Word32 -> IO () sendEncFile h send = go @@ -97,3 +133,81 @@ receiveFile_ receive XFTPRcvChunkSpec {filePath, chunkSize, chunkDigest} = do ExceptT $ withFile filePath WriteMode (`receive` chunkSize) digest' <- liftIO $ LC.sha256Hash <$> LB.readFile filePath when (digest' /= chunkDigest) $ throwError DIGEST + +data XFTPErrorType + = -- | incorrect block format, encoding or signature size + BLOCK + | -- | incorrect SMP session ID (TLS Finished message / tls-unique binding RFC5929) + SESSION + | -- | SMP command is unknown or has invalid syntax + CMD {cmdErr :: CommandError} + | -- | command authorization error - bad signature or non-existing SMP queue + AUTH + | -- | incorrent file size + SIZE + | -- | storage quota exceeded + QUOTA + | -- | incorrent file digest + DIGEST + | -- | file encryption/decryption failed + CRYPTO + | -- | no expected file body in request/response or no file on the server + NO_FILE + | -- | unexpected file body + HAS_FILE + | -- | file IO error + FILE_IO + | -- | bad redirect data + REDIRECT {redirectError :: String} + | -- | internal server error + INTERNAL + | -- | used internally, never returned by the server (to be removed) + DUPLICATE_ -- not part of SMP protocol, used internally + deriving (Eq, Read, Show) + +instance StrEncoding XFTPErrorType where + strEncode = \case + CMD e -> "CMD " <> bshow e + REDIRECT e -> "REDIRECT " <> bshow e + e -> bshow e + strP = + "CMD " *> (CMD <$> parseRead1) + <|> "REDIRECT " *> (REDIRECT <$> parseRead A.takeByteString) + <|> parseRead1 + +instance Encoding XFTPErrorType where + smpEncode = \case + BLOCK -> "BLOCK" + SESSION -> "SESSION" + CMD err -> "CMD " <> smpEncode err + AUTH -> "AUTH" + SIZE -> "SIZE" + QUOTA -> "QUOTA" + DIGEST -> "DIGEST" + CRYPTO -> "CRYPTO" + NO_FILE -> "NO_FILE" + HAS_FILE -> "HAS_FILE" + FILE_IO -> "FILE_IO" + REDIRECT err -> "REDIRECT " <> smpEncode err + INTERNAL -> "INTERNAL" + DUPLICATE_ -> "DUPLICATE_" + + smpP = + A.takeTill (== ' ') >>= \case + "BLOCK" -> pure BLOCK + "SESSION" -> pure SESSION + "CMD" -> CMD <$> _smpP + "AUTH" -> pure AUTH + "SIZE" -> pure SIZE + "QUOTA" -> pure QUOTA + "DIGEST" -> pure DIGEST + "CRYPTO" -> pure CRYPTO + "NO_FILE" -> pure NO_FILE + "HAS_FILE" -> pure HAS_FILE + "FILE_IO" -> pure FILE_IO + "REDIRECT" -> REDIRECT <$> _smpP + "INTERNAL" -> pure INTERNAL + "DUPLICATE_" -> pure DUPLICATE_ + _ -> fail "bad error type" + +$(J.deriveJSON (sumTypeJSON id) ''XFTPErrorType) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 4c2db7322..7c751337c 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -165,11 +165,11 @@ import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfRegCode (NtfReg import Simplex.Messaging.Notifications.Server.Push.APNS (PNMessageData (..)) import Simplex.Messaging.Notifications.Types import Simplex.Messaging.Parsers (parse) -import Simplex.Messaging.Protocol (BrokerMsg, EntityId, ErrorType (AUTH), MsgBody, MsgFlags (..), NtfServer, ProtoServerWithAuth, ProtocolTypeI (..), SMPMsgMeta, SProtocolType (..), SndPublicAuthKey, SubscriptionMode (..), UserProtocol, XFTPServerWithAuth) +import Simplex.Messaging.Protocol (BrokerMsg, EntityId, ErrorType (AUTH), MsgBody, MsgFlags (..), NtfServer, ProtoServerWithAuth, ProtocolTypeI (..), SMPMsgMeta, SProtocolType (..), SndPublicAuthKey, SubscriptionMode (..), UserProtocol, VersionSMPC, XFTPServerWithAuth) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.ServiceScheme (ServiceScheme (..)) import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Transport (THandleParams (sessionId)) +import Simplex.Messaging.Transport (SMPVersion, THandleParams (sessionId)) import Simplex.Messaging.Util import Simplex.Messaging.Version import Simplex.RemoteControl.Client @@ -681,7 +681,7 @@ joinConn c userId connId enableNtfs cReq cInfo subMode = do _ -> getSMPServer c userId joinConnSrv c userId connId enableNtfs cReq cInfo subMode srv -startJoinInvitation :: AgentMonad m => UserId -> ConnId -> Bool -> ConnectionRequestUri 'CMInvitation -> m (Compatible Version, ConnData, NewSndQueue, CR.Ratchet 'C.X448, CR.E2ERatchetParams 'C.X448) +startJoinInvitation :: AgentMonad m => UserId -> ConnId -> Bool -> ConnectionRequestUri 'CMInvitation -> m (Compatible VersionSMPA, ConnData, NewSndQueue, CR.Ratchet 'C.X448, CR.E2ERatchetParams 'C.X448) startJoinInvitation userId connId enableNtfs (CRInvitationUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)} e2eRcvParamsUri) = do AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config case ( qUri `compatibleVersion` smpClientVRange, @@ -1109,7 +1109,7 @@ enqueueMessageB c reqs = do let sqs' = filter isActiveSndQ sqs pure $ Right (msgId, if null sqs' then Nothing else Just (cData, sqs', msgId)) where - storeSentMsg :: DB.Connection -> Version -> (ConnData, NonEmpty SndQueue, MsgFlags, AMessage) -> IO (Either AgentErrorType ((ConnData, NonEmpty SndQueue, MsgFlags, AMessage), InternalId)) + storeSentMsg :: DB.Connection -> VersionSMPA -> (ConnData, NonEmpty SndQueue, MsgFlags, AMessage) -> IO (Either AgentErrorType ((ConnData, NonEmpty SndQueue, MsgFlags, AMessage), InternalId)) storeSentMsg db agentVersion req@(ConnData {connId}, sq :| _, msgFlags, aMessage) = fmap (first storeError) $ runExceptT $ do internalTs <- liftIO getCurrentTime (internalId, internalSndId, prevMsgHash) <- liftIO $ updateSndIds db connId @@ -1917,7 +1917,7 @@ cleanupManager c@AgentClient {subQ} = do -- | make sure to ACK or throw in each message processing branch -- it cannot be finally, unfortunately, as sometimes it needs to be ACK+DEL -processSMPTransmission :: forall m. AgentMonad m => AgentClient -> ServerTransmission BrokerMsg -> m () +processSMPTransmission :: forall m. AgentMonad m => AgentClient -> ServerTransmission SMPVersion BrokerMsg -> m () processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, sessId, rId, cmd) = do (rq, SomeConn _ conn) <- withStore c (\db -> getRcvConn db srv rId) processSMP rq conn $ toConnData conn @@ -2058,7 +2058,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, _ -> pure Nothing _ -> prohibited >> ack _ -> prohibited >> ack - updateConnVersion :: Connection c -> ConnData -> Version -> m (Connection c) + updateConnVersion :: Connection c -> ConnData -> VersionSMPA -> m (Connection c) updateConnVersion conn' cData' msgAgentVersion = do aVRange <- asks $ smpAgentVRange . config let msgAVRange = fromMaybe (versionToRange msgAgentVersion) $ safeVersionRange (minVersion aVRange) msgAgentVersion @@ -2121,7 +2121,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, parseMessage :: Encoding a => ByteString -> m a parseMessage = liftEither . parse smpP (AGENT A_MESSAGE) - smpConfirmation :: SMP.MsgId -> Connection c -> C.APublicAuthKey -> C.PublicKeyX25519 -> Maybe (CR.E2ERatchetParams 'C.X448) -> ByteString -> Version -> Version -> m () + smpConfirmation :: SMP.MsgId -> Connection c -> C.APublicAuthKey -> C.PublicKeyX25519 -> Maybe (CR.E2ERatchetParams 'C.X448) -> ByteString -> VersionSMPC -> VersionSMPA -> m () smpConfirmation srvMsgId conn' senderKey e2ePubKey e2eEncryption encConnInfo smpClientVersion agentVersion = do logServer "<--" c srv rId $ "MSG :" <> logSecret srvMsgId AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config @@ -2371,7 +2371,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, createRatchet db connId rc -- compare public keys `k1` in AgentRatchetKey messages sent by self and other party -- to determine ratchet initilization ordering - initRatchet :: VersionRange -> (C.PrivateKeyX448, C.PrivateKeyX448) -> m () + initRatchet :: CR.VersionRangeE2E -> (C.PrivateKeyX448, C.PrivateKeyX448) -> m () initRatchet e2eEncryptVRange (pk1, pk2) | rkHash (C.publicKey pk1) (C.publicKey pk2) <= rkHashRcv = do recreateRatchet $ CR.initRcvRatchet e2eEncryptVRange pk2 $ CR.x3dhRcv pk1 pk2 e2eOtherPartyParams @@ -2419,7 +2419,7 @@ confirmQueueAsync c cData sq srv connInfo e2eEncryption_ subMode = do storeConfirmation c cData sq e2eEncryption_ =<< mkAgentConfirmation c cData sq srv connInfo subMode submitPendingMsg c cData sq -confirmQueue :: forall m. AgentMonad m => Compatible Version -> AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> SubscriptionMode -> m () +confirmQueue :: forall m. AgentMonad m => Compatible VersionSMPA -> AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> SubscriptionMode -> m () confirmQueue (Compatible agentVersion) c cData@ConnData {connId} sq srv connInfo e2eEncryption_ subMode = do msg <- mkConfirmation =<< mkAgentConfirmation c cData sq srv connInfo subMode sendConfirmation c sq msg @@ -2466,7 +2466,7 @@ enqueueRatchetKey c cData@ConnData {connId} sq e2eEncryption = do submitPendingMsg c cData sq pure $ unId msgId where - storeRatchetKey :: Version -> m InternalId + storeRatchetKey :: VersionSMPA -> m InternalId storeRatchetKey agentVersion = withStore c $ \db -> runExceptT $ do internalTs <- liftIO getCurrentTime (internalId, internalSndId, prevMsgHash) <- liftIO $ updateSndIds db connId diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 23caa2254..f60ddea26 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -167,8 +167,8 @@ import Network.Socket (HostName) import Simplex.FileTransfer.Client (XFTPChunkSpec (..), XFTPClient, XFTPClientConfig (..), XFTPClientError) import qualified Simplex.FileTransfer.Client as X import Simplex.FileTransfer.Description (ChunkReplicaId (..), FileDigest (..), kb) -import Simplex.FileTransfer.Protocol (FileInfo (..), FileResponse, XFTPErrorType (DIGEST)) -import Simplex.FileTransfer.Transport (XFTPRcvChunkSpec (..)) +import Simplex.FileTransfer.Protocol (FileInfo (..), FileResponse) +import Simplex.FileTransfer.Transport (XFTPRcvChunkSpec (..), XFTPErrorType (DIGEST), XFTPVersion) import Simplex.FileTransfer.Types (DeletedSndChunkReplica (..), NewSndChunkReplica (..), RcvFileChunkReplica (..), SndFileChunk (..), SndFileChunkReplica (..)) import Simplex.FileTransfer.Util (uniqueCombine) import Simplex.Messaging.Agent.Env.SQLite @@ -187,6 +187,7 @@ import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Notifications.Client import Simplex.Messaging.Notifications.Protocol +import Simplex.Messaging.Notifications.Transport (NTFVersion) import Simplex.Messaging.Notifications.Types import Simplex.Messaging.Parsers (defaultJSON, dropPrefix, enumJSON, parse) import Simplex.Messaging.Protocol @@ -215,9 +216,12 @@ import Simplex.Messaging.Protocol UserProtocol, XFTPServer, XFTPServerWithAuth, + VersionSMPC, + VersionRangeSMPC, sameSrvAddr', ) import qualified Simplex.Messaging.Protocol as SMP +import Simplex.Messaging.Transport (SMPVersion) import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport.Client (TransportHost) @@ -253,7 +257,7 @@ data AgentClient = AgentClient { active :: TVar Bool, rcvQ :: TBQueue (ATransmission 'Client), subQ :: TBQueue (ATransmission 'Agent), - msgQ :: TBQueue (ServerTransmission BrokerMsg), + msgQ :: TBQueue (ServerTransmission SMPVersion BrokerMsg), smpServers :: TMap UserId (NonEmpty SMPServerWithAuth), smpClients :: TMap SMPTransportSession SMPClientVar, ntfServers :: TVar [NtfServer], @@ -467,7 +471,7 @@ agentClientStore AgentClient {agentEnv = Env {store}} = store agentDRG :: AgentClient -> TVar ChaChaDRG agentDRG AgentClient {agentEnv = Env {random}} = random -class (Encoding err, Show err) => ProtocolServerClient err msg | msg -> err where +class (Encoding err, Show err) => ProtocolServerClient v err msg | msg -> v, msg -> err where type Client msg = c | c -> msg getProtocolServerClient :: AgentMonad m => AgentClient -> TransportSession msg -> m (Client msg) clientProtocolError :: err -> AgentErrorType @@ -476,8 +480,8 @@ class (Encoding err, Show err) => ProtocolServerClient err msg | msg -> err wher clientTransportHost :: Client msg -> TransportHost clientSessionTs :: Client msg -> UTCTime -instance ProtocolServerClient ErrorType BrokerMsg where - type Client BrokerMsg = ProtocolClient ErrorType BrokerMsg +instance ProtocolServerClient SMPVersion ErrorType BrokerMsg where + type Client BrokerMsg = ProtocolClient SMPVersion ErrorType BrokerMsg getProtocolServerClient = getSMPServerClient clientProtocolError = SMP closeProtocolServerClient = closeProtocolClient @@ -485,8 +489,8 @@ instance ProtocolServerClient ErrorType BrokerMsg where clientTransportHost = transportHost' clientSessionTs = sessionTs -instance ProtocolServerClient ErrorType NtfResponse where - type Client NtfResponse = ProtocolClient ErrorType NtfResponse +instance ProtocolServerClient NTFVersion ErrorType NtfResponse where + type Client NtfResponse = ProtocolClient NTFVersion ErrorType NtfResponse getProtocolServerClient = getNtfServerClient clientProtocolError = NTF closeProtocolServerClient = closeProtocolClient @@ -494,7 +498,7 @@ instance ProtocolServerClient ErrorType NtfResponse where clientTransportHost = transportHost' clientSessionTs = sessionTs -instance ProtocolServerClient XFTPErrorType FileResponse where +instance ProtocolServerClient XFTPVersion XFTPErrorType FileResponse where type Client FileResponse = XFTPClient getProtocolServerClient = getXFTPServerClient clientProtocolError = XFTP @@ -683,8 +687,8 @@ waitForProtocolClient c (_, srv, _) v = do -- clientConnected arg is only passed for SMP server newProtocolClient :: - forall err msg m. - (AgentMonad m, ProtocolTypeI (ProtoType msg), ProtocolServerClient err msg) => + forall v err msg m. + (AgentMonad m, ProtocolTypeI (ProtoType msg), ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> TMap (TransportSession msg) (ClientVar msg) -> @@ -706,10 +710,10 @@ newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient v = putTMVar (sessionVar v) (Left e) throwError e -- signal error to caller -hostEvent :: forall err msg. (ProtocolTypeI (ProtoType msg), ProtocolServerClient err msg) => (AProtocolType -> TransportHost -> ACommand 'Agent 'AENone) -> Client msg -> ACommand 'Agent 'AENone +hostEvent :: forall v err msg. (ProtocolTypeI (ProtoType msg), ProtocolServerClient v err msg) => (AProtocolType -> TransportHost -> ACommand 'Agent 'AENone) -> Client msg -> ACommand 'Agent 'AENone hostEvent event = event (AProtocolType $ protocolTypeI @(ProtoType msg)) . clientTransportHost -getClientConfig :: AgentMonad' m => AgentClient -> (AgentConfig -> ProtocolClientConfig) -> m ProtocolClientConfig +getClientConfig :: AgentMonad' m => AgentClient -> (AgentConfig -> ProtocolClientConfig v) -> m (ProtocolClientConfig v) getClientConfig AgentClient {useNetworkConfig} cfgSel = do cfg <- asks $ cfgSel . config networkConfig <- readTVarIO useNetworkConfig @@ -754,19 +758,19 @@ throwWhenNoDelivery c sq = unlessM (TM.member (qAddress sq) $ smpDeliveryWorkers c) $ throwSTM ThreadKilled -closeProtocolServerClients :: ProtocolServerClient err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO () +closeProtocolServerClients :: ProtocolServerClient v err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO () closeProtocolServerClients c clientsSel = atomically (clientsSel c `swapTVar` M.empty) >>= mapM_ (forkIO . closeClient_ c) -reconnectServerClients :: ProtocolServerClient err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO () +reconnectServerClients :: ProtocolServerClient v err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO () reconnectServerClients c clientsSel = readTVarIO (clientsSel c) >>= mapM_ (forkIO . closeClient_ c) -closeClient :: ProtocolServerClient err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> TransportSession msg -> IO () +closeClient :: ProtocolServerClient v err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> TransportSession msg -> IO () closeClient c clientSel tSess = atomically (TM.lookupDelete tSess $ clientSel c) >>= mapM_ (closeClient_ c) -closeClient_ :: ProtocolServerClient err msg => AgentClient -> ClientVar msg -> IO () +closeClient_ :: ProtocolServerClient v err msg => AgentClient -> ClientVar msg -> IO () closeClient_ c v = do NetworkConfig {tcpConnectTimeout} <- readTVarIO $ useNetworkConfig c tcpConnectTimeout `timeout` atomically (readTMVar $ sessionVar v) >>= \case @@ -798,7 +802,7 @@ getMapLock locks key = TM.lookup key locks >>= maybe newLock pure where newLock = createLock >>= \l -> TM.insert key l locks $> l -withClient_ :: forall a m err msg. (AgentMonad m, ProtocolServerClient err msg) => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> m a) -> m a +withClient_ :: forall a m v err msg. (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> m a) -> m a withClient_ c tSess@(userId, srv, _) statCmd action = do cl <- getProtocolServerClient c tSess (action cl <* stat cl "OK") `catchAgentError` logServerError cl @@ -810,18 +814,18 @@ withClient_ c tSess@(userId, srv, _) statCmd action = do stat cl $ strEncode e throwError e -withLogClient_ :: (AgentMonad m, ProtocolServerClient err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> m a) -> m a +withLogClient_ :: (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> m a) -> m a withLogClient_ c tSess@(_, srv, _) entId cmdStr action = do logServer "-->" c srv entId cmdStr res <- withClient_ c tSess cmdStr action logServer "<--" c srv entId "OK" return res -withClient :: forall m err msg a. (AgentMonad m, ProtocolServerClient err msg) => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> m a -withClient c tSess statKey action = withClient_ c tSess statKey $ \client -> liftClient (clientProtocolError @err @msg) (clientServer client) $ action client +withClient :: forall m v err msg a. (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> m a +withClient c tSess statKey action = withClient_ c tSess statKey $ \client -> liftClient (clientProtocolError @v @err @msg) (clientServer client) $ action client -withLogClient :: forall m err msg a. (AgentMonad m, ProtocolServerClient err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> m a -withLogClient c tSess entId cmdStr action = withLogClient_ c tSess entId cmdStr $ \client -> liftClient (clientProtocolError @err @msg) (clientServer client) $ action client +withLogClient :: forall m v err msg a. (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> m a +withLogClient c tSess entId cmdStr action = withLogClient_ c tSess entId cmdStr $ \client -> liftClient (clientProtocolError @v @err @msg) (clientServer client) $ action client withSMPClient :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> m a withSMPClient c q cmdStr action = do @@ -837,7 +841,7 @@ withNtfClient :: forall m a. AgentMonad m => AgentClient -> NtfServer -> EntityI withNtfClient c srv = withLogClient c (0, srv, Nothing) withXFTPClient :: - (AgentMonad m, ProtocolServerClient err msg) => + (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> (UserId, ProtoServer msg, EntityId) -> ByteString -> @@ -1001,7 +1005,7 @@ mkSMPTSession q = mkTSession (qUserId q) (qServer q) (qConnId q) getSessionMode :: AgentMonad' m => AgentClient -> m TransportSessionMode getSessionMode = fmap sessionMode . readTVarIO . useNetworkConfig -newRcvQueue :: AgentMonad m => AgentClient -> UserId -> ConnId -> SMPServerWithAuth -> VersionRange -> SubscriptionMode -> m (NewRcvQueue, SMPQueueUri) +newRcvQueue :: AgentMonad m => AgentClient -> UserId -> ConnId -> SMPServerWithAuth -> VersionRangeSMPC -> SubscriptionMode -> m (NewRcvQueue, SMPQueueUri) newRcvQueue c userId connId (ProtoServerWithAuth srv auth) vRange subMode = do C.AuthAlg a <- asks (rcvAuthAlg . config) g <- asks random @@ -1151,7 +1155,7 @@ sendConfirmation c sq@SndQueue {sndId, sndPublicKey = Just sndPublicKey, e2ePubK liftClient SMP (clientServer smp) $ sendSMPMessage smp Nothing sndId (SMP.MsgFlags {notification = True}) msg sendConfirmation _ _ _ = throwError $ INTERNAL "sendConfirmation called without snd_queue public key(s) in the database" -sendInvitation :: forall m. AgentMonad m => AgentClient -> UserId -> Compatible SMPQueueInfo -> Compatible Version -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m () +sendInvitation :: forall m. AgentMonad m => AgentClient -> UserId -> Compatible SMPQueueInfo -> Compatible VersionSMPA -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m () sendInvitation c userId (Compatible (SMPQueueInfo v SMPQueueAddress {smpServer, senderId, dhPublicKey})) (Compatible agentVersion) connReq connInfo = do tSess <- mkTransportSession c userId smpServer senderId withLogClient_ c tSess senderId "SEND " $ \smp -> do @@ -1334,7 +1338,7 @@ agentCbEncrypt SndQueue {e2eDhSecret, smpClientVersion} e2ePubKey msg = do pure $ smpEncode SMP.ClientMsgEnvelope {cmHeader, cmNonce, cmEncBody} -- add encoding as AgentInvitation'? -agentCbEncryptOnce :: AgentMonad m => Version -> C.PublicKeyX25519 -> ByteString -> m ByteString +agentCbEncryptOnce :: AgentMonad m => VersionSMPC -> C.PublicKeyX25519 -> ByteString -> m ByteString agentCbEncryptOnce clientVersion dhRcvPubKey msg = do g <- asks random (dhSndPubKey, dhSndPrivKey) <- atomically $ C.generateKeyPair g @@ -1518,7 +1522,7 @@ incStat AgentClient {agentStats} n k = do Just v -> modifyTVar' v (+ n) _ -> newTVar n >>= \v -> TM.insert k v agentStats -incClientStat :: ProtocolServerClient err msg => AgentClient -> UserId -> Client msg -> ByteString -> ByteString -> IO () +incClientStat :: ProtocolServerClient v err msg => AgentClient -> UserId -> Client msg -> ByteString -> ByteString -> IO () incClientStat c userId pc = incClientStatN c userId pc 1 incServerStat :: AgentClient -> UserId -> ProtocolServer p -> ByteString -> ByteString -> IO () @@ -1528,7 +1532,7 @@ incServerStat c userId ProtocolServer {host} cmd res = do where statsKey = AgentStatsKey {userId, host = strEncode $ L.head host, clientTs = "", cmd, res} -incClientStatN :: ProtocolServerClient err msg => AgentClient -> UserId -> Client msg -> Int -> ByteString -> ByteString -> IO () +incClientStatN :: ProtocolServerClient v err msg => AgentClient -> UserId -> Client msg -> Int -> ByteString -> ByteString -> IO () incClientStatN c userId pc n cmd res = do atomically $ incStat c n statsKey where diff --git a/src/Simplex/Messaging/Agent/Env/SQLite.hs b/src/Simplex/Messaging/Agent/Env/SQLite.hs index e603e50b8..7a879bb22 100644 --- a/src/Simplex/Messaging/Agent/Env/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Env/SQLite.hs @@ -56,16 +56,17 @@ import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations import Simplex.Messaging.Client import Simplex.Messaging.Client.Agent () import qualified Simplex.Messaging.Crypto as C -import Simplex.Messaging.Crypto.Ratchet (supportedE2EEncryptVRange) +import Simplex.Messaging.Crypto.Ratchet (VersionRangeE2E, supportedE2EEncryptVRange) import Simplex.Messaging.Notifications.Client (defaultNTFClientConfig) +import Simplex.Messaging.Notifications.Transport (NTFVersion) import Simplex.Messaging.Notifications.Types -import Simplex.Messaging.Protocol (NtfServer, XFTPServer, XFTPServerWithAuth, supportedSMPClientVRange) +import Simplex.Messaging.Protocol (NtfServer, VersionRangeSMPC, XFTPServer, XFTPServerWithAuth, supportedSMPClientVRange) +import Simplex.Messaging.Transport (SMPVersion) import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (TLS, Transport (..)) import Simplex.Messaging.Transport.Client (defaultSMPPort) import Simplex.Messaging.Util (allFinally, catchAllErrors, tryAllErrors) -import Simplex.Messaging.Version import System.Random (StdGen, newStdGen) import UnliftIO (Async, SomeException) import UnliftIO.STM @@ -87,8 +88,8 @@ data AgentConfig = AgentConfig sndAuthAlg :: C.AuthAlg, connIdBytes :: Int, tbqSize :: Natural, - smpCfg :: ProtocolClientConfig, - ntfCfg :: ProtocolClientConfig, + smpCfg :: ProtocolClientConfig SMPVersion, + ntfCfg :: ProtocolClientConfig NTFVersion, xftpCfg :: XFTPClientConfig, reconnectInterval :: RetryInterval, messageRetryInterval :: RetryInterval2, @@ -116,9 +117,9 @@ data AgentConfig = AgentConfig caCertificateFile :: FilePath, privateKeyFile :: FilePath, certificateFile :: FilePath, - e2eEncryptVRange :: VersionRange, - smpAgentVRange :: VersionRange, - smpClientVRange :: VersionRange + e2eEncryptVRange :: VersionRangeE2E, + smpAgentVRange :: VersionRangeSMPA, + smpClientVRange :: VersionRangeSMPC } defaultReconnectInterval :: RetryInterval diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index 6129b8503..b3dc929d8 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -4,6 +4,7 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} @@ -33,6 +34,9 @@ -- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/agent-protocol.md module Simplex.Messaging.Agent.Protocol ( -- * Protocol parameters + VersionSMPA, + VersionRangeSMPA, + pattern VersionSMPA, ratchetSyncSMPAgentVersion, deliveryRcptsSMPAgentVersion, supportedSMPAgentVRange, @@ -175,11 +179,12 @@ import Data.Time.Clock.System (SystemTime) import Data.Time.ISO8601 import Data.Type.Equality import Data.Typeable () -import Data.Word (Word32) +import Data.Word (Word16, Word32) import Database.SQLite.Simple.FromField import Database.SQLite.Simple.ToField import Simplex.FileTransfer.Description -import Simplex.FileTransfer.Protocol (FileParty (..), XFTPErrorType) +import Simplex.FileTransfer.Protocol (FileParty (..)) +import Simplex.FileTransfer.Transport (XFTPErrorType) import Simplex.Messaging.Agent.QueryString import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.Ratchet (E2ERatchetParams, E2ERatchetParamsUri) @@ -200,6 +205,10 @@ import Simplex.Messaging.Protocol SMPServerWithAuth, SndPublicAuthKey, SubscriptionMode, + SMPClientVersion, + VersionSMPC, + VersionRangeSMPC, + initialSMPClientVersion, legacyEncodeServer, legacyServerP, legacyStrEncodeServer, @@ -215,6 +224,7 @@ import Simplex.Messaging.Transport (Transport (..), TransportError, serializeTra import Simplex.Messaging.Transport.Client (TransportHost, TransportHosts_ (..)) import Simplex.Messaging.Util import Simplex.Messaging.Version +import Simplex.Messaging.Version.Internal import Simplex.RemoteControl.Types import Text.Read import UnliftIO.Exception (Exception) @@ -225,19 +235,30 @@ import UnliftIO.Exception (Exception) -- 3 - support ratchet renegotiation (6/30/2023) -- 4 - delivery receipts (7/13/2023) -duplexHandshakeSMPAgentVersion :: Version -duplexHandshakeSMPAgentVersion = 2 +data SMPAgentVersion -ratchetSyncSMPAgentVersion :: Version -ratchetSyncSMPAgentVersion = 3 +instance VersionScope SMPAgentVersion -deliveryRcptsSMPAgentVersion :: Version -deliveryRcptsSMPAgentVersion = 4 +type VersionSMPA = Version SMPAgentVersion -currentSMPAgentVersion :: Version -currentSMPAgentVersion = 4 +type VersionRangeSMPA = VersionRange SMPAgentVersion -supportedSMPAgentVRange :: VersionRange +pattern VersionSMPA :: Word16 -> VersionSMPA +pattern VersionSMPA v = Version v + +duplexHandshakeSMPAgentVersion :: VersionSMPA +duplexHandshakeSMPAgentVersion = VersionSMPA 2 + +ratchetSyncSMPAgentVersion :: VersionSMPA +ratchetSyncSMPAgentVersion = VersionSMPA 3 + +deliveryRcptsSMPAgentVersion :: VersionSMPA +deliveryRcptsSMPAgentVersion = VersionSMPA 4 + +currentSMPAgentVersion :: VersionSMPA +currentSMPAgentVersion = VersionSMPA 4 + +supportedSMPAgentVRange :: VersionRangeSMPA supportedSMPAgentVRange = mkVersionRange duplexHandshakeSMPAgentVersion currentSMPAgentVersion -- it is shorter to allow all handshake headers, @@ -665,7 +686,7 @@ instance StrEncoding SndQueueInfo where pure SndQueueInfo {sndServer, sndSwitchStatus} data ConnectionStats = ConnectionStats - { connAgentVersion :: Version, + { connAgentVersion :: VersionSMPA, rcvQueuesInfo :: [RcvQueueInfo], sndQueuesInfo :: [SndQueueInfo], ratchetSyncState :: RatchetSyncState, @@ -802,27 +823,27 @@ data SMPConfirmation = SMPConfirmation -- | optional reply queues included in confirmation (added in agent protocol v2) smpReplyQueues :: [SMPQueueInfo], -- | SMP client version - smpClientVersion :: Version + smpClientVersion :: VersionSMPC } deriving (Show) data AgentMsgEnvelope = AgentConfirmation - { agentVersion :: Version, + { agentVersion :: VersionSMPA, e2eEncryption_ :: Maybe (E2ERatchetParams 'C.X448), encConnInfo :: ByteString } | AgentMsgEnvelope - { agentVersion :: Version, + { agentVersion :: VersionSMPA, encAgentMessage :: ByteString } | AgentInvitation -- the connInfo in contactInvite is only encrypted with per-queue E2E, not with double ratchet, - { agentVersion :: Version, + { agentVersion :: VersionSMPA, connReq :: ConnectionRequestUri 'CMInvitation, connInfo :: ByteString -- this message is only encrypted with per-queue E2E, not with double ratchet, } | AgentRatchetKey - { agentVersion :: Version, + { agentVersion :: VersionSMPA, e2eEncryption :: E2ERatchetParams 'C.X448, info :: ByteString } @@ -1228,16 +1249,16 @@ sameQueue :: SMPQueue q => (SMPServer, SMP.QueueId) -> q -> Bool sameQueue addr q = sameQAddress addr (qAddress q) {-# INLINE sameQueue #-} -data SMPQueueInfo = SMPQueueInfo {clientVersion :: Version, queueAddress :: SMPQueueAddress} +data SMPQueueInfo = SMPQueueInfo {clientVersion :: VersionSMPC, queueAddress :: SMPQueueAddress} deriving (Eq, Show) instance Encoding SMPQueueInfo where smpEncode (SMPQueueInfo clientVersion SMPQueueAddress {smpServer, senderId, dhPublicKey}) - | clientVersion > 1 = smpEncode (clientVersion, smpServer, senderId, dhPublicKey) + | clientVersion > initialSMPClientVersion = smpEncode (clientVersion, smpServer, senderId, dhPublicKey) | otherwise = smpEncode clientVersion <> legacyEncodeServer smpServer <> smpEncode (senderId, dhPublicKey) smpP = do clientVersion <- smpP - smpServer <- if clientVersion > 1 then smpP else updateSMPServerHosts <$> legacyServerP + smpServer <- if clientVersion > initialSMPClientVersion then smpP else updateSMPServerHosts <$> legacyServerP (senderId, dhPublicKey) <- smpP pure $ SMPQueueInfo clientVersion SMPQueueAddress {smpServer, senderId, dhPublicKey} @@ -1245,20 +1266,20 @@ instance Encoding SMPQueueInfo where -- But this is created to allow backward and forward compatibility where SMPQueueUri -- could have more fields to convert to different versions of SMPQueueInfo in a different way, -- and this instance would become non-trivial. -instance VersionI SMPQueueInfo where - type VersionRangeT SMPQueueInfo = SMPQueueUri +instance VersionI SMPClientVersion SMPQueueInfo where + type VersionRangeT SMPClientVersion SMPQueueInfo = SMPQueueUri version = clientVersion toVersionRangeT (SMPQueueInfo _v addr) vr = SMPQueueUri vr addr -instance VersionRangeI SMPQueueUri where - type VersionT SMPQueueUri = SMPQueueInfo +instance VersionRangeI SMPClientVersion SMPQueueUri where + type VersionT SMPClientVersion SMPQueueUri = SMPQueueInfo versionRange = clientVRange toVersionT (SMPQueueUri _vr addr) v = SMPQueueInfo v addr -- | SMP queue information sent out-of-band. -- -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#out-of-band-messages -data SMPQueueUri = SMPQueueUri {clientVRange :: VersionRange, queueAddress :: SMPQueueAddress} +data SMPQueueUri = SMPQueueUri {clientVRange :: VersionRangeSMPC, queueAddress :: SMPQueueAddress} deriving (Eq, Show) data SMPQueueAddress = SMPQueueAddress @@ -1307,7 +1328,7 @@ instance StrEncoding SMPQueueUri where smpServer = if maxVersion vr < srvHostnamesSMPClientVersion then updateSMPServerHosts srv' else srv' pure $ SMPQueueUri vr SMPQueueAddress {smpServer, senderId, dhPublicKey} where - unversioned = (versionToRange 1,[],) <$> strP <* A.endOfInput + unversioned = (versionToRange initialSMPClientVersion,[],) <$> strP <* A.endOfInput versioned = do dhKey_ <- optional strP query <- optional (A.char '/') *> A.char '?' *> strP @@ -1344,7 +1365,7 @@ deriving instance Show AConnectionRequestUri data ConnReqUriData = ConnReqUriData { crScheme :: ServiceScheme, - crAgentVRange :: VersionRange, + crAgentVRange :: VersionRangeSMPA, crSmpQueues :: NonEmpty SMPQueueUri, crClientData :: Maybe CRClientData } diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index 8f67c74c2..5a91d0df4 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -44,10 +44,10 @@ import Simplex.Messaging.Protocol RcvPrivateAuthKey, SndPrivateAuthKey, SndPublicAuthKey, + VersionSMPC, ) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Util ((<$?>)) -import Simplex.Messaging.Version -- * Queue types @@ -96,7 +96,7 @@ data StoredRcvQueue (q :: QueueStored) = RcvQueue dbReplaceQueueId :: Maybe Int64, rcvSwchStatus :: Maybe RcvSwitchStatus, -- | SMP client version - smpClientVersion :: Version, + smpClientVersion :: VersionSMPC, -- | credentials used in context of notifications clientNtfCreds :: Maybe ClientNtfCreds, deleteErrors :: Int @@ -159,7 +159,7 @@ data StoredSndQueue (q :: QueueStored) = SndQueue dbReplaceQueueId :: Maybe Int64, sndSwchStatus :: Maybe SndSwitchStatus, -- | SMP client version - smpClientVersion :: Version + smpClientVersion :: VersionSMPC } deriving (Eq, Show) @@ -315,7 +315,7 @@ deriving instance Show SomeConn data ConnData = ConnData { connId :: ConnId, userId :: UserId, - connAgentVersion :: Version, + connAgentVersion :: VersionSMPA, enableNtfs :: Bool, lastExternalSndId :: PrevExternalSndId, deleted :: Bool, diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 647e8bd03..0974e99a4 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -278,7 +278,7 @@ import Simplex.Messaging.Protocol import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Transport.Client (TransportHost) import Simplex.Messaging.Util (bshow, catchAllErrors, eitherToMaybe, ifM, safeDecodeUtf8, ($>>=), (<$$>)) -import Simplex.Messaging.Version +import Simplex.Messaging.Version.Internal import System.Directory (copyFile, createDirectoryIfMissing, doesFileExist) import System.Exit (exitFailure) import System.FilePath (takeDirectory) @@ -701,7 +701,7 @@ setRcvQueueDeleted db RcvQueue {rcvId, server = ProtocolServer {host, port}} = d |] (host, port, rcvId) -setRcvQueueConfirmedE2E :: DB.Connection -> RcvQueue -> C.DhSecretX25519 -> Version -> IO () +setRcvQueueConfirmedE2E :: DB.Connection -> RcvQueue -> C.DhSecretX25519 -> VersionSMPC -> IO () setRcvQueueConfirmedE2E db RcvQueue {rcvId, server = ProtocolServer {host, port}} e2eDhSecret smpClientVersion = DB.executeNamed db @@ -802,7 +802,7 @@ setRcvQueueNtfCreds db connId clientNtfCreds = Just ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret} -> (Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret) Nothing -> (Nothing, Nothing, Nothing, Nothing) -type SMPConfirmationRow = (SndPublicAuthKey, C.PublicKeyX25519, ConnInfo, Maybe [SMPQueueInfo], Maybe Version) +type SMPConfirmationRow = (SndPublicAuthKey, C.PublicKeyX25519, ConnInfo, Maybe [SMPQueueInfo], Maybe VersionSMPC) smpConfirmation :: SMPConfirmationRow -> SMPConfirmation smpConfirmation (senderKey, e2ePubKey, connInfo, smpReplyQueues_, smpClientVersion_) = @@ -811,7 +811,7 @@ smpConfirmation (senderKey, e2ePubKey, connInfo, smpReplyQueues_, smpClientVersi e2ePubKey, connInfo, smpReplyQueues = fromMaybe [] smpReplyQueues_, - smpClientVersion = fromMaybe 1 smpClientVersion_ + smpClientVersion = fromMaybe initialSMPClientVersion smpClientVersion_ } createConfirmation :: DB.Connection -> TVar ChaChaDRG -> NewConfirmation -> IO (Either StoreError ConfirmationId) @@ -888,7 +888,7 @@ removeConfirmations db connId = |] [":conn_id" := connId] -setConnectionVersion :: DB.Connection -> ConnId -> Version -> IO () +setConnectionVersion :: DB.Connection -> ConnId -> VersionSMPA -> IO () setConnectionVersion db connId aVersion = DB.execute db "UPDATE connections SET smp_agent_version = ? WHERE conn_id = ?" (aVersion, connId) @@ -1772,6 +1772,10 @@ instance ToField MsgReceiptStatus where toField = toField . decodeLatin1 . strEn instance FromField MsgReceiptStatus where fromField = fromTextField_ $ eitherToMaybe . strDecode . encodeUtf8 +instance ToField (Version v) where toField (Version v) = toField v + +instance FromField (Version v) where fromField f = Version <$> fromField f + listToEither :: e -> [a] -> Either e a listToEither _ (x : _) = Right x listToEither e _ = Left e @@ -1940,7 +1944,7 @@ setConnDeleted db waitDelivery connId | otherwise = DB.execute db "UPDATE connections SET deleted = ? WHERE conn_id = ?" (True, connId) -setConnAgentVersion :: DB.Connection -> ConnId -> Version -> IO () +setConnAgentVersion :: DB.Connection -> ConnId -> VersionSMPA -> IO () setConnAgentVersion db connId aVersion = DB.execute db "UPDATE connections SET smp_agent_version = ? WHERE conn_id = ?" (aVersion, connId) @@ -1999,12 +2003,12 @@ rcvQueueQuery = toRcvQueue :: (UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SMP.RecipientId, SMP.RcvPrivateAuthKey, SMP.RcvDhSecret, C.PrivateKeyX25519, Maybe C.DhSecretX25519, SMP.SenderId, QueueStatus) - :. (DBQueueId 'QSStored, Bool, Maybe Int64, Maybe RcvSwitchStatus, Maybe Version, Int) + :. (DBQueueId 'QSStored, Bool, Maybe Int64, Maybe RcvSwitchStatus, Maybe VersionSMPC, Int) :. (Maybe SMP.NtfPublicAuthKey, Maybe SMP.NtfPrivateAuthKey, Maybe SMP.NotifierId, Maybe RcvNtfDhSecret) -> RcvQueue toRcvQueue ((userId, keyHash, connId, host, port, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status) :. (dbQueueId, primary, dbReplaceQueueId, rcvSwchStatus, smpClientVersion_, deleteErrors) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_)) = let server = SMPServer host port keyHash - smpClientVersion = fromMaybe 1 smpClientVersion_ + smpClientVersion = fromMaybe initialSMPClientVersion smpClientVersion_ clientNtfCreds = case (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) of (Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret) -> Just $ ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret} _ -> Nothing @@ -2040,7 +2044,7 @@ sndQueueQuery = toSndQueue :: (UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SenderId) :. (Maybe SndPublicAuthKey, SndPrivateAuthKey, Maybe C.PublicKeyX25519, C.DhSecretX25519, QueueStatus) - :. (DBQueueId 'QSStored, Bool, Maybe Int64, Maybe SndSwitchStatus, Version) -> + :. (DBQueueId 'QSStored, Bool, Maybe Int64, Maybe SndSwitchStatus, VersionSMPC) -> SndQueue toSndQueue ( (userId, keyHash, connId, host, port, sndId) diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index 2cbdced35..d8b202761 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -76,7 +76,7 @@ module Simplex.Messaging.Client PCTransmission, mkTransmission, authTransmission, - clientStub, + smpClientStub, ) where @@ -117,14 +117,14 @@ import System.Timeout (timeout) -- | 'SMPClient' is a handle used to send commands to a specific SMP server. -- -- Use 'getSMPClient' to connect to an SMP server and create a client handle. -data ProtocolClient err msg = ProtocolClient +data ProtocolClient v err msg = ProtocolClient { action :: Maybe (Async ()), - thParams :: THandleParams, + thParams :: THandleParams v, sessionTs :: UTCTime, - client_ :: PClient err msg + client_ :: PClient v err msg } -data PClient err msg = PClient +data PClient v err msg = PClient { connected :: TVar Bool, transportSession :: TransportSession msg, transportHost :: TransportHost, @@ -135,11 +135,11 @@ data PClient err msg = PClient sentCommands :: TMap CorrId (Request err msg), sndQ :: TBQueue ByteString, rcvQ :: TBQueue (NonEmpty (SignedTransmission err msg)), - msgQ :: Maybe (TBQueue (ServerTransmission msg)) + msgQ :: Maybe (TBQueue (ServerTransmission v msg)) } -clientStub :: TVar ChaChaDRG -> ByteString -> Version -> Maybe THandleAuth -> STM (ProtocolClient err msg) -clientStub g sessionId thVersion thAuth = do +smpClientStub :: TVar ChaChaDRG -> ByteString -> VersionSMP -> Maybe THandleAuth -> STM (ProtocolClient SMPVersion err msg) +smpClientStub g sessionId thVersion thAuth = do connected <- newTVar False clientCorrId <- C.newRandomDRG g sentCommands <- TM.empty @@ -174,13 +174,13 @@ clientStub g sessionId thVersion thAuth = do } } -type SMPClient = ProtocolClient ErrorType BrokerMsg +type SMPClient = ProtocolClient SMPVersion ErrorType BrokerMsg -- | Type for client command data type ClientCommand msg = (Maybe C.APrivateAuthKey, EntityId, ProtoCommand msg) -- | Type synonym for transmission from some SPM server queue. -type ServerTransmission msg = (TransportSession msg, Version, SessionId, EntityId, msg) +type ServerTransmission v msg = (TransportSession msg, Version v, SessionId, EntityId, msg) data HostMode = -- | prefer (or require) onion hosts when connecting via SOCKS proxy @@ -241,7 +241,7 @@ transportClientConfig NetworkConfig {socksProxy, tcpKeepAlive, logTLSErrors} = TransportClientConfig {socksProxy, tcpKeepAlive, logTLSErrors, clientCredentials = Nothing} -- | protocol client configuration. -data ProtocolClientConfig = ProtocolClientConfig +data ProtocolClientConfig v = ProtocolClientConfig { -- | size of TBQueue to use for server commands and responses qSize :: Natural, -- | default server port if port is not specified in ProtocolServer @@ -249,13 +249,13 @@ data ProtocolClientConfig = ProtocolClientConfig -- | network configuration networkConfig :: NetworkConfig, -- | client-server protocol version range - serverVRange :: VersionRange, + serverVRange :: VersionRange v, -- | delay between sending batches of commands (microseconds) batchDelay :: Maybe Int } -- | Default protocol client configuration. -defaultClientConfig :: VersionRange -> ProtocolClientConfig +defaultClientConfig :: VersionRange v -> ProtocolClientConfig v defaultClientConfig serverVRange = ProtocolClientConfig { qSize = 64, @@ -265,7 +265,7 @@ defaultClientConfig serverVRange = batchDelay = Nothing } -defaultSMPClientConfig :: ProtocolClientConfig +defaultSMPClientConfig :: ProtocolClientConfig SMPVersion defaultSMPClientConfig = defaultClientConfig supportedClientSMPRelayVRange data Request err msg = Request @@ -292,15 +292,15 @@ chooseTransportHost NetworkConfig {socksProxy, hostMode, requiredHostMode} hosts onionHost = find isOnionHost hosts publicHost = find (not . isOnionHost) hosts -protocolClientServer :: ProtocolTypeI (ProtoType msg) => ProtocolClient err msg -> String +protocolClientServer :: ProtocolTypeI (ProtoType msg) => ProtocolClient v err msg -> String protocolClientServer = B.unpack . strEncode . snd3 . transportSession . client_ where snd3 (_, s, _) = s -transportHost' :: ProtocolClient err msg -> TransportHost +transportHost' :: ProtocolClient v err msg -> TransportHost transportHost' = transportHost . client_ -transportSession' :: ProtocolClient err msg -> TransportSession msg +transportSession' :: ProtocolClient v err msg -> TransportSession msg transportSession' = transportSession . client_ type UserId = Int64 @@ -313,7 +313,7 @@ type TransportSession msg = (UserId, ProtoServer msg, Maybe EntityId) -- -- A single queue can be used for multiple 'SMPClient' instances, -- as 'SMPServerTransmission' includes server information. -getProtocolClient :: forall err msg. Protocol err msg => TVar ChaChaDRG -> TransportSession msg -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> (ProtocolClient err msg -> IO ()) -> IO (Either (ProtocolClientError err) (ProtocolClient err msg)) +getProtocolClient :: forall v err msg. Protocol v err msg => TVar ChaChaDRG -> TransportSession msg -> ProtocolClientConfig v -> Maybe (TBQueue (ServerTransmission v msg)) -> (ProtocolClient v err msg -> IO ()) -> IO (Either (ProtocolClientError err) (ProtocolClient v err msg)) getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize, networkConfig, serverVRange, batchDelay} msgQ disconnected = do case chooseTransportHost networkConfig (host srv) of Right useHost -> @@ -322,7 +322,7 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize Left e -> pure $ Left e where NetworkConfig {tcpConnectTimeout, tcpTimeout, smpPingInterval} = networkConfig - mkProtocolClient :: TransportHost -> STM (PClient err msg) + mkProtocolClient :: TransportHost -> STM (PClient v err msg) mkProtocolClient transportHost = do connected <- newTVar False pingErrorCount <- newTVar 0 @@ -345,7 +345,7 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize msgQ } - runClient :: (ServiceName, ATransport) -> TransportHost -> PClient err msg -> IO (Either (ProtocolClientError err) (ProtocolClient err msg)) + runClient :: (ServiceName, ATransport) -> TransportHost -> PClient v err msg -> IO (Either (ProtocolClientError err) (ProtocolClient v err msg)) runClient (port', ATransport t) useHost c = do cVar <- newEmptyTMVarIO let tcConfig = transportClientConfig networkConfig @@ -366,10 +366,10 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize "80" -> ("80", transport @WS) p -> (p, transport @TLS) - client :: forall c. Transport c => TProxy c -> PClient err msg -> TMVar (Either (ProtocolClientError err) (ProtocolClient err msg)) -> c -> IO () + client :: forall c. Transport c => TProxy c -> PClient v err msg -> TMVar (Either (ProtocolClientError err) (ProtocolClient v err msg)) -> c -> IO () client _ c cVar h = do ks <- atomically $ C.generateKeyPair g - runExceptT (protocolClientHandshake @err @msg h ks (keyHash srv) serverVRange) >>= \case + runExceptT (protocolClientHandshake @v @err @msg h ks (keyHash srv) serverVRange) >>= \case Left e -> atomically . putTMVar cVar . Left $ PCETransportError e Right th@THandle {params} -> do sessionTs <- getCurrentTime @@ -380,16 +380,16 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize raceAny_ ([send c' th, process c', receive c' th] <> [ping c' | smpPingInterval > 0]) `finally` disconnected c' - send :: Transport c => ProtocolClient err msg -> THandle c -> IO () + send :: Transport c => ProtocolClient v err msg -> THandle v c -> IO () send ProtocolClient {client_ = PClient {sndQ}} h = forever $ atomically (readTBQueue sndQ) >>= tPutLog h - receive :: Transport c => ProtocolClient err msg -> THandle c -> IO () + receive :: Transport c => ProtocolClient v err msg -> THandle v c -> IO () receive ProtocolClient {client_ = PClient {rcvQ}} h = forever $ tGet h >>= atomically . writeTBQueue rcvQ - ping :: ProtocolClient err msg -> IO () + ping :: ProtocolClient v err msg -> IO () ping c@ProtocolClient {client_ = PClient {pingErrorCount}} = do threadDelay' smpPingInterval - runExceptT (sendProtocolCommand c Nothing "" $ protocolPing @err @msg) >>= \case + runExceptT (sendProtocolCommand c Nothing "" $ protocolPing @v @err @msg) >>= \case Left PCEResponseTimeout -> do cnt <- atomically $ stateTVar pingErrorCount $ \cnt -> (cnt + 1, cnt + 1) when (maxCnt == 0 || cnt < maxCnt) $ ping c @@ -397,10 +397,10 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize where maxCnt = smpPingCount networkConfig - process :: ProtocolClient err msg -> IO () + process :: ProtocolClient v err msg -> IO () process c = forever $ atomically (readTBQueue $ rcvQ $ client_ c) >>= mapM_ (processMsg c) - processMsg :: ProtocolClient err msg -> SignedTransmission err msg -> IO () + processMsg :: ProtocolClient v err msg -> SignedTransmission err msg -> IO () processMsg c@ProtocolClient {client_ = PClient {sentCommands}} (_, _, (corrId, entId, respOrErr)) = if B.null $ bs corrId then sendMsg respOrErr @@ -428,7 +428,7 @@ proxyUsername :: TransportSession msg -> ByteString proxyUsername (userId, _, entityId_) = C.sha256Hash $ bshow userId <> maybe "" (":" <>) entityId_ -- | Disconnects client from the server and terminates client threads. -closeProtocolClient :: ProtocolClient err msg -> IO () +closeProtocolClient :: ProtocolClient v err msg -> IO () closeProtocolClient = mapM_ uninterruptibleCancel . action -- | SMP client error type. @@ -517,7 +517,7 @@ processSUBResponse c (Response rId r) = case r of writeSMPMessage :: SMPClient -> RecipientId -> BrokerMsg -> IO () writeSMPMessage c rId msg = atomically $ mapM_ (`writeTBQueue` serverTransmission c rId msg) (msgQ $ client_ c) -serverTransmission :: ProtocolClient err msg -> RecipientId -> msg -> ServerTransmission msg +serverTransmission :: ProtocolClient v err msg -> RecipientId -> msg -> ServerTransmission v msg serverTransmission ProtocolClient {thParams = THandleParams {thVersion, sessionId}, client_ = PClient {transportSession}} entityId message = (transportSession, thVersion, sessionId, entityId, message) @@ -635,7 +635,7 @@ sendSMPCommand c pKey qId cmd = sendProtocolCommand c pKey qId (Cmd sParty cmd) type PCTransmission err msg = (Either TransportError SentRawTransmission, Request err msg) -- | Send multiple commands with batching and collect responses -sendProtocolCommands :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> NonEmpty (ClientCommand msg) -> IO (NonEmpty (Response err msg)) +sendProtocolCommands :: forall v err msg. ProtocolEncoding v err (ProtoCommand msg) => ProtocolClient v err msg -> NonEmpty (ClientCommand msg) -> IO (NonEmpty (Response err msg)) sendProtocolCommands c@ProtocolClient {thParams = THandleParams {batch, blockSize}} cs = do bs <- batchTransmissions' batch blockSize <$> mapM (mkTransmission c) cs validate . concat =<< mapM (sendBatch c) bs @@ -652,12 +652,12 @@ sendProtocolCommands c@ProtocolClient {thParams = THandleParams {batch, blockSiz where diff = L.length cs - length rs -streamProtocolCommands :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> NonEmpty (ClientCommand msg) -> ([Response err msg] -> IO ()) -> IO () +streamProtocolCommands :: forall v err msg. ProtocolEncoding v err (ProtoCommand msg) => ProtocolClient v err msg -> NonEmpty (ClientCommand msg) -> ([Response err msg] -> IO ()) -> IO () streamProtocolCommands c@ProtocolClient {thParams = THandleParams {batch, blockSize}} cs cb = do bs <- batchTransmissions' batch blockSize <$> mapM (mkTransmission c) cs mapM_ (cb <=< sendBatch c) bs -sendBatch :: ProtocolClient err msg -> TransportBatch (Request err msg) -> IO [Response err msg] +sendBatch :: ProtocolClient v err msg -> TransportBatch (Request err msg) -> IO [Response err msg] sendBatch c@ProtocolClient {client_ = PClient {sndQ}} b = do case b of TBError e Request {entityId} -> do @@ -673,7 +673,7 @@ sendBatch c@ProtocolClient {client_ = PClient {sndQ}} b = do (: []) <$> getResponse c r -- | Send Protocol command -sendProtocolCommand :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> Maybe C.APrivateAuthKey -> EntityId -> ProtoCommand msg -> ExceptT (ProtocolClientError err) IO msg +sendProtocolCommand :: forall v err msg. ProtocolEncoding v err (ProtoCommand msg) => ProtocolClient v err msg -> Maybe C.APrivateAuthKey -> EntityId -> ProtoCommand msg -> ExceptT (ProtocolClientError err) IO msg sendProtocolCommand c@ProtocolClient {client_ = PClient {sndQ}, thParams = THandleParams {batch, blockSize}} pKey entId cmd = ExceptT $ uncurry sendRecv =<< mkTransmission c (pKey, entId, cmd) where @@ -690,7 +690,7 @@ sendProtocolCommand c@ProtocolClient {client_ = PClient {sndQ}, thParams = THand | otherwise = tEncode t -- TODO switch to timeout or TimeManager that supports Int64 -getResponse :: ProtocolClient err msg -> Request err msg -> IO (Response err msg) +getResponse :: ProtocolClient v err msg -> Request err msg -> IO (Response err msg) getResponse ProtocolClient {client_ = PClient {tcpTimeout, pingErrorCount}} Request {entityId, responseVar} = do response <- timeout tcpTimeout (atomically (takeTMVar responseVar)) >>= \case @@ -698,7 +698,7 @@ getResponse ProtocolClient {client_ = PClient {tcpTimeout, pingErrorCount}} Requ Nothing -> pure $ Left PCEResponseTimeout pure Response {entityId, response} -mkTransmission :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> ClientCommand msg -> IO (PCTransmission err msg) +mkTransmission :: forall v err msg. ProtocolEncoding v err (ProtoCommand msg) => ProtocolClient v err msg -> ClientCommand msg -> IO (PCTransmission err msg) mkTransmission ProtocolClient {thParams, client_ = PClient {clientCorrId, sentCommands}} (pKey_, entId, cmd) = do corrId <- atomically getNextCorrId let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (corrId, entId, cmd) diff --git a/src/Simplex/Messaging/Client/Agent.hs b/src/Simplex/Messaging/Client/Agent.hs index 068a52782..73f47648b 100644 --- a/src/Simplex/Messaging/Client/Agent.hs +++ b/src/Simplex/Messaging/Client/Agent.hs @@ -65,7 +65,7 @@ type SMPSub = (SMPSubParty, QueueId) -- type SMPServerSub = (SMPServer, SMPSub) data SMPClientAgentConfig = SMPClientAgentConfig - { smpCfg :: ProtocolClientConfig, + { smpCfg :: ProtocolClientConfig SMPVersion, reconnectInterval :: RetryInterval, msgQSize :: Natural, agentQSize :: Natural, @@ -91,7 +91,7 @@ defaultSMPClientAgentConfig = data SMPClientAgent = SMPClientAgent { agentCfg :: SMPClientAgentConfig, - msgQ :: TBQueue (ServerTransmission BrokerMsg), + msgQ :: TBQueue (ServerTransmission SMPVersion BrokerMsg), agentQ :: TBQueue SMPClientAgentEvent, randomDrg :: TVar ChaChaDRG, smpClients :: TMap SMPServer SMPClientVar, diff --git a/src/Simplex/Messaging/Crypto/Ratchet.hs b/src/Simplex/Messaging/Crypto/Ratchet.hs index 0afa06db3..11cd40571 100644 --- a/src/Simplex/Messaging/Crypto/Ratchet.hs +++ b/src/Simplex/Messaging/Crypto/Ratchet.hs @@ -3,8 +3,10 @@ {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TupleSections #-} @@ -12,15 +14,42 @@ {-# LANGUAGE TypeFamilies #-} {-# OPTIONS_GHC -fno-warn-redundant-constraints #-} -module Simplex.Messaging.Crypto.Ratchet where +module Simplex.Messaging.Crypto.Ratchet + ( Ratchet (..), + RatchetX448, + SkippedMsgDiff (..), + SkippedMsgKeys, + E2ERatchetParamsUri (..), + E2ERatchetParams (..), + VersionE2E, + VersionRangeE2E, + pattern VersionE2E, + currentE2EEncryptVersion, + supportedE2EEncryptVRange, + generateE2EParams, + x3dhSnd, + x3dhRcv, + initSndRatchet, + initRcvRatchet, + rcEncrypt, + rcDecrypt, + -- used in tests + MsgHeader (..), + RatchetVersions, + ratchetVersions, + fullHeaderLen, + applySMDiff, + ) +where +import Control.Applicative ((<|>)) import Control.Monad.Except import Control.Monad.Trans.Except import Crypto.Cipher.AES (AES256) import Crypto.Hash (SHA512) import qualified Crypto.KDF.HKDF as H import Crypto.Random (ChaChaDRG) -import Data.Aeson (FromJSON (..), ToJSON (..)) +import Data.Aeson (FromJSON (..), ToJSON (..), (.:)) import qualified Data.Aeson as J import qualified Data.Aeson.TH as JQ import Data.ByteString.Char8 (ByteString) @@ -31,7 +60,7 @@ import Data.Map.Strict (Map) import qualified Data.Map.Strict as M import Data.Maybe (fromMaybe) import Data.Typeable (Typeable) -import Data.Word (Word32) +import Data.Word (Word16, Word32) import Database.SQLite.Simple.FromField (FromField (..)) import Database.SQLite.Simple.ToField (ToField (..)) import Simplex.Messaging.Agent.QueryString @@ -40,41 +69,53 @@ import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (blobFieldDecoder, defaultJSON, parseE, parseE') import Simplex.Messaging.Version +import Simplex.Messaging.Version.Internal import UnliftIO.STM -- e2e encryption headers version history: -- 1 - binary protocol encoding (1/1/2022) -- 2 - use KDF in x3dh (10/20/2022) -kdfX3DHE2EEncryptVersion :: Version -kdfX3DHE2EEncryptVersion = 2 +data E2EVersion -currentE2EEncryptVersion :: Version -currentE2EEncryptVersion = 2 +instance VersionScope E2EVersion -supportedE2EEncryptVRange :: VersionRange +type VersionE2E = Version E2EVersion + +type VersionRangeE2E = VersionRange E2EVersion + +pattern VersionE2E :: Word16 -> VersionE2E +pattern VersionE2E v = Version v + +kdfX3DHE2EEncryptVersion :: VersionE2E +kdfX3DHE2EEncryptVersion = VersionE2E 2 + +currentE2EEncryptVersion :: VersionE2E +currentE2EEncryptVersion = VersionE2E 2 + +supportedE2EEncryptVRange :: VersionRangeE2E supportedE2EEncryptVRange = mkVersionRange kdfX3DHE2EEncryptVersion currentE2EEncryptVersion data E2ERatchetParams (a :: Algorithm) - = E2ERatchetParams Version (PublicKey a) (PublicKey a) + = E2ERatchetParams VersionE2E (PublicKey a) (PublicKey a) deriving (Eq, Show) instance AlgorithmI a => Encoding (E2ERatchetParams a) where smpEncode (E2ERatchetParams v k1 k2) = smpEncode (v, k1, k2) smpP = E2ERatchetParams <$> smpP <*> smpP <*> smpP -instance VersionI (E2ERatchetParams a) where - type VersionRangeT (E2ERatchetParams a) = E2ERatchetParamsUri a +instance VersionI E2EVersion (E2ERatchetParams a) where + type VersionRangeT E2EVersion (E2ERatchetParams a) = E2ERatchetParamsUri a version (E2ERatchetParams v _ _) = v toVersionRangeT (E2ERatchetParams _ k1 k2) vr = E2ERatchetParamsUri vr k1 k2 -instance VersionRangeI (E2ERatchetParamsUri a) where - type VersionT (E2ERatchetParamsUri a) = (E2ERatchetParams a) +instance VersionRangeI E2EVersion (E2ERatchetParamsUri a) where + type VersionT E2EVersion (E2ERatchetParamsUri a) = (E2ERatchetParams a) versionRange (E2ERatchetParamsUri vr _ _) = vr toVersionT (E2ERatchetParamsUri _ k1 k2) v = E2ERatchetParams v k1 k2 data E2ERatchetParamsUri (a :: Algorithm) - = E2ERatchetParamsUri VersionRange (PublicKey a) (PublicKey a) + = E2ERatchetParamsUri VersionRangeE2E (PublicKey a) (PublicKey a) deriving (Eq, Show) instance AlgorithmI a => StrEncoding (E2ERatchetParamsUri a) where @@ -89,7 +130,7 @@ instance AlgorithmI a => StrEncoding (E2ERatchetParamsUri a) where [key1, key2] -> pure $ E2ERatchetParamsUri vs key1 key2 _ -> fail "bad e2e params" -generateE2EParams :: (AlgorithmI a, DhAlgorithm a) => TVar ChaChaDRG -> Version -> STM (PrivateKey a, PrivateKey a, E2ERatchetParams a) +generateE2EParams :: (AlgorithmI a, DhAlgorithm a) => TVar ChaChaDRG -> VersionE2E -> STM (PrivateKey a, PrivateKey a, E2ERatchetParams a) generateE2EParams g v = do (k1, pk1) <- generateKeyPair g (k2, pk2) <- generateKeyPair g @@ -125,7 +166,7 @@ type RatchetX448 = Ratchet 'X448 data Ratchet a = Ratchet { -- ratchet version range sent in messages (current .. max supported ratchet version) - rcVersion :: VersionRange, + rcVersion :: RatchetVersions, -- associated data - must be the same in both parties ratchets rcAD :: Str, rcDHRs :: PrivateKey a, @@ -140,6 +181,29 @@ data Ratchet a = Ratchet } deriving (Eq, Show) +data RatchetVersions = RVersions + { current :: VersionE2E, + maxSupported :: VersionE2E + } + deriving (Eq, Show) + +instance ToJSON RatchetVersions where + -- TODO v5.7 or v5.8 change to the default record encoding + toJSON (RVersions v1 v2) = toJSON (v1, v2) + toEncoding (RVersions v1 v2) = toEncoding (v1, v2) + +instance FromJSON RatchetVersions where + -- TODO v6.0 replace with the default record parser + -- this parser supports JSON record encoding for forward compatibility + parseJSON v = (tupleP <|> recordP v) >>= toRV + where + tupleP = parseJSON v + recordP = J.withObject "RatchetVersions" $ \o -> ((,) <$> o .: "current" <*> o .: "maxSupported") + toRV (v1, v2) = maybe (fail "bad version range") (pure . ratchetVersions) $ safeVersionRange v1 v2 + +ratchetVersions :: VersionRangeE2E -> RatchetVersions +ratchetVersions (VersionRange v1 v2) = RVersions {current = v1, maxSupported = v2} + data SndRatchet a = SndRatchet { rcDHRr :: PublicKey a, rcCKs :: RatchetKey, @@ -207,12 +271,12 @@ instance FromField MessageKey where fromField = blobFieldDecoder smpDecode -- Please note that sPKey is not stored, and its public part together with random salt -- is sent to the recipient. initSndRatchet :: - forall a. (AlgorithmI a, DhAlgorithm a) => VersionRange -> PublicKey a -> PrivateKey a -> RatchetInitParams -> Ratchet a -initSndRatchet rcVersion rcDHRr rcDHRs RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK} = do + forall a. (AlgorithmI a, DhAlgorithm a) => VersionRangeE2E -> PublicKey a -> PrivateKey a -> RatchetInitParams -> Ratchet a +initSndRatchet v rcDHRr rcDHRs RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK} = do -- state.RK, state.CKs, state.NHKs = KDF_RK_HE(SK, DH(state.DHRs, state.DHRr)) let (rcRK, rcCKs, rcNHKs) = rootKdf ratchetKey rcDHRr rcDHRs in Ratchet - { rcVersion, + { rcVersion = ratchetVersions v, rcAD = assocData, rcDHRs, rcRK, @@ -230,10 +294,10 @@ initSndRatchet rcVersion rcDHRr rcDHRs RatchetInitParams {assocData, ratchetKey, -- Please note that the public part of rcDHRs was sent to the sender -- as part of the connection request and random salt was received from the sender. initRcvRatchet :: - forall a. (AlgorithmI a, DhAlgorithm a) => VersionRange -> PrivateKey a -> RatchetInitParams -> Ratchet a -initRcvRatchet rcVersion rcDHRs RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK} = + forall a. (AlgorithmI a, DhAlgorithm a) => VersionRangeE2E -> PrivateKey a -> RatchetInitParams -> Ratchet a +initRcvRatchet v rcDHRs RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK} = Ratchet - { rcVersion, + { rcVersion = ratchetVersions v, rcAD = assocData, rcDHRs, rcRK = ratchetKey, @@ -248,18 +312,13 @@ initRcvRatchet rcVersion rcDHRs RatchetInitParams {assocData, ratchetKey, sndHK, data MsgHeader a = MsgHeader { -- | max supported ratchet version - msgMaxVersion :: Version, + msgMaxVersion :: VersionE2E, msgDHRs :: PublicKey a, msgPN :: Word32, msgNs :: Word32 } deriving (Eq, Show) -data AMsgHeader - = forall a. - (AlgorithmI a, DhAlgorithm a) => - AMsgHeader (SAlgorithm a) (MsgHeader a) - -- to allow extension without increasing the size, the actual header length is: -- 69 = 2 (original size) + 2 + 1+56 (Curve448) + 4 + 4 paddedHeaderLen :: Int @@ -281,7 +340,7 @@ instance AlgorithmI a => Encoding (MsgHeader a) where pure MsgHeader {msgMaxVersion, msgDHRs, msgPN, msgNs} data EncMessageHeader = EncMessageHeader - { ehVersion :: Version, + { ehVersion :: VersionE2E, ehIV :: IV, ehAuthTag :: AuthTag, ehBody :: ByteString @@ -315,7 +374,7 @@ rcEncrypt rc@Ratchet {rcSnd = Just sr@SndRatchet {rcCKs, rcHKs}, rcDHRs, rcNs, r -- enc_header = HENCRYPT(state.HKs, header) (ehAuthTag, ehBody) <- encryptAEAD rcHKs ehIV paddedHeaderLen rcAD msgHeader -- return enc_header, ENCRYPT(mk, plaintext, CONCAT(AD, enc_header)) - let emHeader = smpEncode EncMessageHeader {ehVersion = minVersion rcVersion, ehBody, ehAuthTag, ehIV} + let emHeader = smpEncode EncMessageHeader {ehVersion = current rcVersion, ehBody, ehAuthTag, ehIV} (emAuthTag, emBody) <- encryptAEAD mk iv paddedMsgLen (rcAD <> emHeader) msg let msg' = smpEncode EncRatchetMessage {emHeader, emBody, emAuthTag} -- state.Ns += 1 @@ -326,7 +385,7 @@ rcEncrypt rc@Ratchet {rcSnd = Just sr@SndRatchet {rcCKs, rcHKs}, rcDHRs, rcNs, r msgHeader = smpEncode MsgHeader - { msgMaxVersion = maxVersion rcVersion, + { msgMaxVersion = maxSupported rcVersion, msgDHRs = publicKey rcDHRs, msgPN = rcPN, msgNs = rcNs diff --git a/src/Simplex/Messaging/Notifications/Client.hs b/src/Simplex/Messaging/Notifications/Client.hs index d69114b68..72a92c278 100644 --- a/src/Simplex/Messaging/Notifications/Client.hs +++ b/src/Simplex/Messaging/Notifications/Client.hs @@ -10,15 +10,15 @@ import Data.Word (Word16) import Simplex.Messaging.Client import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Notifications.Protocol -import Simplex.Messaging.Notifications.Transport (supportedClientNTFVRange) +import Simplex.Messaging.Notifications.Transport (NTFVersion, supportedClientNTFVRange) import Simplex.Messaging.Protocol (ErrorType) import Simplex.Messaging.Util (bshow) -type NtfClient = ProtocolClient ErrorType NtfResponse +type NtfClient = ProtocolClient NTFVersion ErrorType NtfResponse type NtfClientError = ProtocolClientError ErrorType -defaultNTFClientConfig :: ProtocolClientConfig +defaultNTFClientConfig :: ProtocolClientConfig NTFVersion defaultNTFClientConfig = defaultClientConfig supportedClientNTFVRange ntfRegisterToken :: NtfClient -> C.APrivateAuthKey -> NewNtfEntity 'Token -> ExceptT NtfClientError IO (NtfTokenId, C.PublicKeyX25519) diff --git a/src/Simplex/Messaging/Notifications/Protocol.hs b/src/Simplex/Messaging/Notifications/Protocol.hs index 73c2dada6..943c30c5a 100644 --- a/src/Simplex/Messaging/Notifications/Protocol.hs +++ b/src/Simplex/Messaging/Notifications/Protocol.hs @@ -28,7 +28,7 @@ import Simplex.Messaging.Agent.Protocol (updateSMPServerHosts) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Notifications.Transport (ntfClientHandshake) +import Simplex.Messaging.Notifications.Transport (NTFVersion, ntfClientHandshake) import Simplex.Messaging.Parsers (fromTextField_) import Simplex.Messaging.Protocol hiding (Command (..), CommandTag (..)) import Simplex.Messaging.Util (eitherToMaybe, (<$?>)) @@ -147,7 +147,7 @@ instance Encoding ANewNtfEntity where 'S' -> ANE SSubscription <$> (NewNtfSub <$> smpP <*> smpP <*> smpP) _ -> fail "bad ANewNtfEntity" -instance Protocol ErrorType NtfResponse where +instance Protocol NTFVersion ErrorType NtfResponse where type ProtoCommand NtfResponse = NtfCmd type ProtoType NtfResponse = 'PNTF protocolClientHandshake = ntfClientHandshake @@ -184,7 +184,7 @@ data NtfCmd = forall e. NtfEntityI e => NtfCmd (SNtfEntity e) (NtfCommand e) deriving instance Show NtfCmd -instance NtfEntityI e => ProtocolEncoding ErrorType (NtfCommand e) where +instance NtfEntityI e => ProtocolEncoding NTFVersion ErrorType (NtfCommand e) where type Tag (NtfCommand e) = NtfCommandTag e encodeProtocol _v = \case TNEW newTkn -> e (TNEW_, ' ', newTkn) @@ -203,7 +203,7 @@ instance NtfEntityI e => ProtocolEncoding ErrorType (NtfCommand e) where protocolP _v tag = (\(NtfCmd _ c) -> checkEntity c) <$?> protocolP _v (NCT (sNtfEntity @e) tag) - fromProtocolError = fromProtocolError @ErrorType @NtfResponse + fromProtocolError = fromProtocolError @NTFVersion @ErrorType @NtfResponse {-# INLINE fromProtocolError #-} checkCredentials (auth, _, entityId, _) cmd = case cmd of @@ -223,7 +223,7 @@ instance NtfEntityI e => ProtocolEncoding ErrorType (NtfCommand e) where | not (B.null entityId) = Left $ CMD HAS_AUTH | otherwise = Right cmd -instance ProtocolEncoding ErrorType NtfCmd where +instance ProtocolEncoding NTFVersion ErrorType NtfCmd where type Tag NtfCmd = NtfCmdTag encodeProtocol _v (NtfCmd _ c) = encodeProtocol _v c @@ -243,7 +243,7 @@ instance ProtocolEncoding ErrorType NtfCmd where SDEL_ -> pure SDEL PING_ -> pure PING - fromProtocolError = fromProtocolError @ErrorType @NtfResponse + fromProtocolError = fromProtocolError @NTFVersion @ErrorType @NtfResponse {-# INLINE fromProtocolError #-} checkCredentials t (NtfCmd e c) = NtfCmd e <$> checkCredentials t c @@ -290,7 +290,7 @@ data NtfResponse | NRPong deriving (Show) -instance ProtocolEncoding ErrorType NtfResponse where +instance ProtocolEncoding NTFVersion ErrorType NtfResponse where type Tag NtfResponse = NtfResponseTag encodeProtocol _v = \case NRTknId entId dhKey -> e (NRTknId_, ' ', entId, dhKey) diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index 754aa6d62..7ae657fd1 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -338,7 +338,7 @@ updateTknStatus NtfTknData {ntfTknId, tknStatus} status = do old <- atomically $ stateTVar tknStatus (,status) when (old /= status) $ withNtfLog $ \sl -> logTokenStatus sl ntfTknId status -runNtfClientTransport :: Transport c => THandle c -> M () +runNtfClientTransport :: Transport c => THandleNTF c -> M () runNtfClientTransport th@THandle {params} = do qSize <- asks $ clientQSize . config ts <- liftIO getSystemTime @@ -355,7 +355,7 @@ runNtfClientTransport th@THandle {params} = do clientDisconnected :: NtfServerClient -> IO () clientDisconnected NtfServerClient {connected} = atomically $ writeTVar connected False -receive :: Transport c => THandle c -> NtfServerClient -> M () +receive :: Transport c => THandleNTF c -> NtfServerClient -> M () receive th@THandle {params = THandleParams {thAuth}} NtfServerClient {rcvQ, sndQ, rcvActiveAt} = forever $ do ts <- liftIO $ tGet th forM_ ts $ \t@(_, _, (corrId, entId, cmdOrError)) -> do @@ -370,7 +370,7 @@ receive th@THandle {params = THandleParams {thAuth}} NtfServerClient {rcvQ, sndQ where write q t = atomically $ writeTBQueue q t -send :: Transport c => THandle c -> NtfServerClient -> IO () +send :: Transport c => THandleNTF c -> NtfServerClient -> IO () send h@THandle {params} NtfServerClient {sndQ, sndActiveAt} = forever $ do t <- atomically $ readTBQueue sndQ void . liftIO $ tPut h [Right (Nothing, encodeTransmission params t)] diff --git a/src/Simplex/Messaging/Notifications/Server/Env.hs b/src/Simplex/Messaging/Notifications/Server/Env.hs index 9e3013a8d..0d722dcc3 100644 --- a/src/Simplex/Messaging/Notifications/Server/Env.hs +++ b/src/Simplex/Messaging/Notifications/Server/Env.hs @@ -24,6 +24,7 @@ import Numeric.Natural import Simplex.Messaging.Client.Agent import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Notifications.Protocol +import Simplex.Messaging.Notifications.Transport (NTFVersion, VersionRangeNTF) import Simplex.Messaging.Notifications.Server.Push.APNS import Simplex.Messaging.Notifications.Server.Stats import Simplex.Messaging.Notifications.Server.Store @@ -34,7 +35,6 @@ import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (ATransport, THandleParams) import Simplex.Messaging.Transport.Server (TransportServerConfig, loadFingerprint, loadTLSServerParams) -import Simplex.Messaging.Version (VersionRange) import System.IO (IOMode (..)) import System.Mem.Weak (Weak) import UnliftIO.STM @@ -60,7 +60,7 @@ data NtfServerConfig = NtfServerConfig logStatsStartTime :: Int64, serverStatsLogFile :: FilePath, serverStatsBackupFile :: Maybe FilePath, - ntfServerVRange :: VersionRange, + ntfServerVRange :: VersionRangeNTF, transportConfig :: TransportServerConfig } @@ -161,13 +161,13 @@ data NtfRequest data NtfServerClient = NtfServerClient { rcvQ :: TBQueue NtfRequest, sndQ :: TBQueue (Transmission NtfResponse), - ntfThParams :: THandleParams, + ntfThParams :: THandleParams NTFVersion, connected :: TVar Bool, rcvActiveAt :: TVar SystemTime, sndActiveAt :: TVar SystemTime } -newNtfServerClient :: Natural -> THandleParams -> SystemTime -> STM NtfServerClient +newNtfServerClient :: Natural -> THandleParams NTFVersion -> SystemTime -> STM NtfServerClient newNtfServerClient qSize ntfThParams ts = do rcvQ <- newTBQueue qSize sndQ <- newTBQueue qSize diff --git a/src/Simplex/Messaging/Notifications/Transport.hs b/src/Simplex/Messaging/Notifications/Transport.hs index 00fd811a2..bc68fab03 100644 --- a/src/Simplex/Messaging/Notifications/Transport.hs +++ b/src/Simplex/Messaging/Notifications/Transport.hs @@ -3,6 +3,7 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} module Simplex.Messaging.Notifications.Transport where @@ -12,33 +13,51 @@ import Control.Monad.Except import Data.Attoparsec.ByteString.Char8 (Parser) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B +import Data.Word (Word16) import qualified Data.X509 as X import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Transport import Simplex.Messaging.Version +import Simplex.Messaging.Version.Internal import Simplex.Messaging.Util (liftEitherWith) ntfBlockSize :: Int ntfBlockSize = 512 -authBatchCmdsNTFVersion :: Version -authBatchCmdsNTFVersion = 2 +data NTFVersion -currentClientNTFVersion :: Version -currentClientNTFVersion = 1 +instance VersionScope NTFVersion -currentServerNTFVersion :: Version -currentServerNTFVersion = 1 +type VersionNTF = Version NTFVersion -supportedClientNTFVRange :: VersionRange -supportedClientNTFVRange = mkVersionRange 1 currentClientNTFVersion +type VersionRangeNTF = VersionRange NTFVersion -supportedServerNTFVRange :: VersionRange -supportedServerNTFVRange = mkVersionRange 1 currentServerNTFVersion +pattern VersionNTF :: Word16 -> VersionNTF +pattern VersionNTF v = Version v + +initialNTFVersion :: VersionNTF +initialNTFVersion = VersionNTF 1 + +authBatchCmdsNTFVersion :: VersionNTF +authBatchCmdsNTFVersion = VersionNTF 2 + +currentClientNTFVersion :: VersionNTF +currentClientNTFVersion = VersionNTF 1 + +currentServerNTFVersion :: VersionNTF +currentServerNTFVersion = VersionNTF 1 + +supportedClientNTFVRange :: VersionRangeNTF +supportedClientNTFVRange = mkVersionRange initialNTFVersion currentClientNTFVersion + +supportedServerNTFVRange :: VersionRangeNTF +supportedServerNTFVRange = mkVersionRange initialNTFVersion currentServerNTFVersion + +type THandleNTF c = THandle NTFVersion c data NtfServerHandshake = NtfServerHandshake - { ntfVersionRange :: VersionRange, + { ntfVersionRange :: VersionRangeNTF, sessionId :: SessionId, -- pub key to agree shared secrets for command authorization and entity ID encryption. authPubKey :: Maybe (X.SignedExact X.PubKey) @@ -46,7 +65,7 @@ data NtfServerHandshake = NtfServerHandshake data NtfClientHandshake = NtfClientHandshake { -- | agreed SMP notifications server protocol version - ntfVersion :: Version, + ntfVersion :: VersionNTF, -- | server identity - CA certificate fingerprint keyHash :: C.KeyHash, -- pub key to agree shared secret for entity ID encryption, shared secret for command authorization is agreed using per-queue keys. @@ -66,12 +85,12 @@ instance Encoding NtfServerHandshake where authPubKey <- authEncryptCmdsP (maxVersion ntfVersionRange) $ C.getSignedExact <$> smpP pure NtfServerHandshake {ntfVersionRange, sessionId, authPubKey} -encodeAuthEncryptCmds :: Encoding a => Version -> Maybe a -> ByteString +encodeAuthEncryptCmds :: Encoding a => VersionNTF -> Maybe a -> ByteString encodeAuthEncryptCmds v k | v >= authBatchCmdsNTFVersion = maybe "" smpEncode k | otherwise = "" -authEncryptCmdsP :: Version -> Parser a -> Parser (Maybe a) +authEncryptCmdsP :: VersionNTF -> Parser a -> Parser (Maybe a) authEncryptCmdsP v p = if v >= authBatchCmdsNTFVersion then Just <$> p else pure Nothing instance Encoding NtfClientHandshake where @@ -83,16 +102,16 @@ instance Encoding NtfClientHandshake where authPubKey <- ntfAuthPubKeyP ntfVersion pure NtfClientHandshake {ntfVersion, keyHash, authPubKey} -ntfAuthPubKeyP :: Version -> Parser (Maybe C.PublicKeyX25519) +ntfAuthPubKeyP :: VersionNTF -> Parser (Maybe C.PublicKeyX25519) ntfAuthPubKeyP v = if v >= authBatchCmdsNTFVersion then Just <$> smpP else pure Nothing -encodeNtfAuthPubKey :: Version -> Maybe C.PublicKeyX25519 -> ByteString +encodeNtfAuthPubKey :: VersionNTF -> Maybe C.PublicKeyX25519 -> ByteString encodeNtfAuthPubKey v k | v >= authBatchCmdsNTFVersion = maybe "" smpEncode k | otherwise = "" -- | Notifcations server transport handshake. -ntfServerHandshake :: forall c. Transport c => C.APrivateSignKey -> c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c) +ntfServerHandshake :: forall c. Transport c => C.APrivateSignKey -> c -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeNTF -> ExceptT TransportError IO (THandleNTF c) ntfServerHandshake serverSignKey c (k, pk) kh ntfVRange = do let th@THandle {params = THandleParams {sessionId}} = ntfTHandle c let sk = C.signX509 serverSignKey $ C.publicToX509 k @@ -106,7 +125,7 @@ ntfServerHandshake serverSignKey c (k, pk) kh ntfVRange = do | otherwise -> throwError $ TEHandshake VERSION -- | Notifcations server client transport handshake. -ntfClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c) +ntfClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeNTF -> ExceptT TransportError IO (THandleNTF c) ntfClientHandshake c (k, pk) keyHash ntfVRange = do let th@THandle {params = THandleParams {sessionId}} = ntfTHandle c NtfServerHandshake {sessionId = sessId, ntfVersionRange, authPubKey = sk'} <- getHandshake th @@ -122,15 +141,15 @@ ntfClientHandshake c (k, pk) keyHash ntfVRange = do pure $ ntfThHandle th v pk sk_ Nothing -> throwError $ TEHandshake VERSION -ntfThHandle :: forall c. THandle c -> Version -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> THandle c +ntfThHandle :: forall c. THandleNTF c -> VersionNTF -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> THandleNTF c ntfThHandle th@THandle {params} v privKey k_ = -- TODO drop SMP v6: make thAuth non-optional let thAuth = (\k -> THandleAuth {peerPubKey = k, privKey}) <$> k_ v3 = v >= authBatchCmdsNTFVersion params' = params {thVersion = v, thAuth, implySessId = v3, batch = v3} - in (th :: THandle c) {params = params'} + in (th :: THandleNTF c) {params = params'} -ntfTHandle :: Transport c => c -> THandle c +ntfTHandle :: Transport c => c -> THandleNTF c ntfTHandle c = THandle {connection = c, params} where - params = THandleParams {sessionId = tlsUnique c, blockSize = ntfBlockSize, thVersion = 0, thAuth = Nothing, implySessId = False, batch = False} + params = THandleParams {sessionId = tlsUnique c, blockSize = ntfBlockSize, thVersion = VersionNTF 0, thAuth = Nothing, implySessId = False, batch = False} diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 315a4e5a3..af7704acb 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -46,6 +46,10 @@ module Simplex.Messaging.Protocol e2eEncMessageLength, -- * SMP protocol types + SMPClientVersion, + VersionSMPC, + VersionRangeSMPC, + pattern VersionSMPC, ProtocolEncoding (..), Command (..), SubscriptionMode (..), @@ -117,6 +121,7 @@ module Simplex.Messaging.Protocol SMPMsgMeta (..), NMsgMeta (..), MsgFlags (..), + initialSMPClientVersion, userProtocol, rcvMessageMeta, noMsgFlags, @@ -179,6 +184,7 @@ import Data.Maybe (isJust, isNothing) import Data.String import Data.Time.Clock.System (SystemTime (..)) import Data.Type.Equality +import Data.Word (Word16) import GHC.TypeLits (ErrorMessage (..), TypeError, type (+)) import Network.Socket (ServiceName) import qualified Simplex.Messaging.Crypto as C @@ -190,19 +196,34 @@ import Simplex.Messaging.Transport import Simplex.Messaging.Transport.Client (TransportHost, TransportHosts (..)) import Simplex.Messaging.Util (bshow, eitherToMaybe, (<$?>)) import Simplex.Messaging.Version +import Simplex.Messaging.Version.Internal -- SMP client protocol version history: -- 1 - binary protocol encoding (1/1/2022) -- 2 - multiple server hostnames and versioned queue addresses (8/12/2022) -srvHostnamesSMPClientVersion :: Version -srvHostnamesSMPClientVersion = 2 +data SMPClientVersion -currentSMPClientVersion :: Version -currentSMPClientVersion = 2 +instance VersionScope SMPClientVersion -supportedSMPClientVRange :: VersionRange -supportedSMPClientVRange = mkVersionRange 1 currentSMPClientVersion +type VersionSMPC = Version SMPClientVersion + +type VersionRangeSMPC = VersionRange SMPClientVersion + +pattern VersionSMPC :: Word16 -> VersionSMPC +pattern VersionSMPC v = Version v + +initialSMPClientVersion :: VersionSMPC +initialSMPClientVersion = VersionSMPC 1 + +srvHostnamesSMPClientVersion :: VersionSMPC +srvHostnamesSMPClientVersion = VersionSMPC 2 + +currentSMPClientVersion :: VersionSMPC +currentSMPClientVersion = VersionSMPC 2 + +supportedSMPClientVRange :: VersionRangeSMPC +supportedSMPClientVRange = mkVersionRange initialSMPClientVersion currentSMPClientVersion maxMessageLength :: Int maxMessageLength = 16088 @@ -644,7 +665,7 @@ data ClientMsgEnvelope = ClientMsgEnvelope deriving (Show) data PubHeader = PubHeader - { phVersion :: Version, + { phVersion :: VersionSMPC, phE2ePubDhKey :: Maybe C.PublicKeyX25519 } deriving (Show) @@ -1053,7 +1074,7 @@ data CommandError deriving (Eq, Read, Show) -- | SMP transmission parser. -transmissionP :: THandleParams -> Parser RawTransmission +transmissionP :: THandleParams v -> Parser RawTransmission transmissionP THandleParams {sessionId, implySessId} = do authenticator <- smpP authorized <- A.takeByteString @@ -1067,16 +1088,16 @@ transmissionP THandleParams {sessionId, implySessId} = do command <- A.takeByteString pure RawTransmission {authenticator, authorized = authorized', sessId, corrId, entityId, command} -class (ProtocolEncoding err msg, ProtocolEncoding err (ProtoCommand msg), Show err, Show msg) => Protocol err msg | msg -> err where +class (ProtocolEncoding v err msg, ProtocolEncoding v err (ProtoCommand msg), Show err, Show msg) => Protocol v err msg | msg -> v, msg -> err where type ProtoCommand msg = cmd | cmd -> msg type ProtoType msg = (sch :: ProtocolType) | sch -> msg - protocolClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c) + protocolClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange v -> ExceptT TransportError IO (THandle v c) protocolPing :: ProtoCommand msg protocolError :: msg -> Maybe err type ProtoServer msg = ProtocolServer (ProtoType msg) -instance Protocol ErrorType BrokerMsg where +instance Protocol SMPVersion ErrorType BrokerMsg where type ProtoCommand BrokerMsg = Cmd type ProtoType BrokerMsg = 'PSMP protocolClientHandshake = smpClientHandshake @@ -1085,14 +1106,14 @@ instance Protocol ErrorType BrokerMsg where ERR e -> Just e _ -> Nothing -class ProtocolMsgTag (Tag msg) => ProtocolEncoding err msg | msg -> err where +class ProtocolMsgTag (Tag msg) => ProtocolEncoding v err msg | msg -> err, msg -> v where type Tag msg - encodeProtocol :: Version -> msg -> ByteString - protocolP :: Version -> Tag msg -> Parser msg + encodeProtocol :: Version v -> msg -> ByteString + protocolP :: Version v -> Tag msg -> Parser msg fromProtocolError :: ProtocolErrorType -> err checkCredentials :: SignedRawTransmission -> msg -> Either err msg -instance PartyI p => ProtocolEncoding ErrorType (Command p) where +instance PartyI p => ProtocolEncoding SMPVersion ErrorType (Command p) where type Tag (Command p) = CommandTag p encodeProtocol v = \case NEW rKey dhKey auth_ subMode @@ -1119,7 +1140,7 @@ instance PartyI p => ProtocolEncoding ErrorType (Command p) where protocolP v tag = (\(Cmd _ c) -> checkParty c) <$?> protocolP v (CT (sParty @p) tag) - fromProtocolError = fromProtocolError @ErrorType @BrokerMsg + fromProtocolError = fromProtocolError @SMPVersion @ErrorType @BrokerMsg {-# INLINE fromProtocolError #-} checkCredentials (auth, _, queueId, _) cmd = case cmd of @@ -1141,7 +1162,7 @@ instance PartyI p => ProtocolEncoding ErrorType (Command p) where | isNothing auth || B.null queueId -> Left $ CMD NO_AUTH | otherwise -> Right cmd -instance ProtocolEncoding ErrorType Cmd where +instance ProtocolEncoding SMPVersion ErrorType Cmd where type Tag Cmd = CmdTag encodeProtocol v (Cmd _ c) = encodeProtocol v c @@ -1169,12 +1190,12 @@ instance ProtocolEncoding ErrorType Cmd where PING_ -> pure PING CT SNotifier NSUB_ -> pure $ Cmd SNotifier NSUB - fromProtocolError = fromProtocolError @ErrorType @BrokerMsg + fromProtocolError = fromProtocolError @SMPVersion @ErrorType @BrokerMsg {-# INLINE fromProtocolError #-} checkCredentials t (Cmd p c) = Cmd p <$> checkCredentials t c -instance ProtocolEncoding ErrorType BrokerMsg where +instance ProtocolEncoding SMPVersion ErrorType BrokerMsg where type Tag BrokerMsg = BrokerMsgTag encodeProtocol _v = \case IDS (QIK rcvId sndId srvDh) -> e (IDS_, ' ', rcvId, sndId, srvDh) @@ -1226,12 +1247,12 @@ instance ProtocolEncoding ErrorType BrokerMsg where | otherwise -> Right cmd -- | Parse SMP protocol commands and broker messages -parseProtocol :: forall err msg. ProtocolEncoding err msg => Version -> ByteString -> Either err msg +parseProtocol :: forall v err msg. ProtocolEncoding v err msg => Version v -> ByteString -> Either err msg parseProtocol v s = let (tag, params) = B.break (== ' ') s in case decodeTag tag of - Just cmd -> parse (protocolP v cmd) (fromProtocolError @err @msg $ PECmdSyntax) params - Nothing -> Left $ fromProtocolError @err @msg $ PECmdUnknown + Just cmd -> parse (protocolP v cmd) (fromProtocolError @v @err @msg $ PECmdSyntax) params + Nothing -> Left $ fromProtocolError @v @err @msg $ PECmdUnknown checkParty :: forall t p p'. (PartyI p, PartyI p') => t p' -> Either String (t p) checkParty c = case testEquality (sParty @p) (sParty @p') of @@ -1286,7 +1307,7 @@ instance Encoding CommandError where _ -> fail "bad command error type" -- | Send signed SMP transmission to TCP transport. -tPut :: Transport c => THandle c -> NonEmpty (Either TransportError SentRawTransmission) -> IO [Either TransportError ()] +tPut :: Transport c => THandle v c -> NonEmpty (Either TransportError SentRawTransmission) -> IO [Either TransportError ()] tPut th@THandle {params} = fmap concat . mapM tPutBatch . batchTransmissions (batch params) (blockSize params) where tPutBatch :: TransportBatch () -> IO [Either TransportError ()] @@ -1295,7 +1316,7 @@ tPut th@THandle {params} = fmap concat . mapM tPutBatch . batchTransmissions (ba TBTransmissions s n _ -> replicate n <$> tPutLog th s TBTransmission s _ -> (: []) <$> tPutLog th s -tPutLog :: Transport c => THandle c -> ByteString -> IO (Either TransportError ()) +tPutLog :: Transport c => THandle v c -> ByteString -> IO (Either TransportError ()) tPutLog th s = do r <- tPutBlock th s case r of @@ -1357,7 +1378,7 @@ tEncodeBatch1 t = lenEncode 1 `B.cons` tEncodeForBatch t -- tForAuth is lazy to avoid computing it when there is no key to sign data TransmissionForAuth = TransmissionForAuth {tForAuth :: ~ByteString, tToSend :: ByteString} -encodeTransmissionForAuth :: ProtocolEncoding e c => THandleParams -> Transmission c -> TransmissionForAuth +encodeTransmissionForAuth :: ProtocolEncoding v e c => THandleParams v -> Transmission c -> TransmissionForAuth encodeTransmissionForAuth THandleParams {thVersion = v, sessionId, implySessId} t = TransmissionForAuth {tForAuth, tToSend = if implySessId then t' else tForAuth} where @@ -1365,24 +1386,24 @@ encodeTransmissionForAuth THandleParams {thVersion = v, sessionId, implySessId} t' = encodeTransmission_ v t {-# INLINE encodeTransmissionForAuth #-} -encodeTransmission :: ProtocolEncoding e c => THandleParams -> Transmission c -> ByteString +encodeTransmission :: ProtocolEncoding v e c => THandleParams v -> Transmission c -> ByteString encodeTransmission THandleParams {thVersion = v, sessionId, implySessId} t = if implySessId then t' else smpEncode sessionId <> t' where t' = encodeTransmission_ v t {-# INLINE encodeTransmission #-} -encodeTransmission_ :: ProtocolEncoding e c => Version -> Transmission c -> ByteString +encodeTransmission_ :: ProtocolEncoding v e c => Version v -> Transmission c -> ByteString encodeTransmission_ v (CorrId corrId, queueId, command) = smpEncode (corrId, queueId) <> encodeProtocol v command {-# INLINE encodeTransmission_ #-} -- | Receive and parse transmission from the TCP transport (ignoring any trailing padding). -tGetParse :: Transport c => THandle c -> IO (NonEmpty (Either TransportError RawTransmission)) +tGetParse :: Transport c => THandle v c -> IO (NonEmpty (Either TransportError RawTransmission)) tGetParse th@THandle {params} = eitherList (tParse params) <$> tGetBlock th {-# INLINE tGetParse #-} -tParse :: THandleParams -> ByteString -> NonEmpty (Either TransportError RawTransmission) +tParse :: THandleParams v -> ByteString -> NonEmpty (Either TransportError RawTransmission) tParse thParams@THandleParams {batch} s | batch = eitherList (L.map (\(Large t) -> tParse1 t)) ts | otherwise = [tParse1 s] @@ -1394,24 +1415,24 @@ eitherList :: (a -> NonEmpty (Either e b)) -> Either e a -> NonEmpty (Either e b eitherList = either (\e -> [Left e]) -- | Receive client and server transmissions (determined by `cmd` type). -tGet :: forall err cmd c. (ProtocolEncoding err cmd, Transport c) => THandle c -> IO (NonEmpty (SignedTransmission err cmd)) +tGet :: forall v err cmd c. (ProtocolEncoding v err cmd, Transport c) => THandle v c -> IO (NonEmpty (SignedTransmission err cmd)) tGet th@THandle {params} = L.map (tDecodeParseValidate params) <$> tGetParse th -tDecodeParseValidate :: forall err cmd. ProtocolEncoding err cmd => THandleParams -> Either TransportError RawTransmission -> SignedTransmission err cmd +tDecodeParseValidate :: forall v err cmd. ProtocolEncoding v err cmd => THandleParams v -> Either TransportError RawTransmission -> SignedTransmission err cmd tDecodeParseValidate THandleParams {sessionId, thVersion = v, implySessId} = \case Right RawTransmission {authenticator, authorized, sessId, corrId, entityId, command} | implySessId || sessId == sessionId -> let decodedTransmission = (,corrId,entityId,command) <$> decodeTAuthBytes authenticator in either (const $ tError corrId) (tParseValidate authorized) decodedTransmission - | otherwise -> (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @err @cmd PESession)) + | otherwise -> (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @v @err @cmd PESession)) Left _ -> tError "" where tError :: ByteString -> SignedTransmission err cmd - tError corrId = (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @err @cmd PEBlock)) + tError corrId = (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @v @err @cmd PEBlock)) tParseValidate :: ByteString -> SignedRawTransmission -> SignedTransmission err cmd tParseValidate signed t@(sig, corrId, entityId, command) = - let cmd = parseProtocol @err @cmd v command >>= checkCredentials t + let cmd = parseProtocol @v @err @cmd v command >>= checkCredentials t in (sig, signed, (CorrId corrId, entityId, cmd)) $(J.deriveJSON defaultJSON ''MsgFlags) diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index aaa42d91b..0dcde0350 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -380,7 +380,7 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do CPQuit -> pure () CPSkip -> pure () -runClientTransport :: Transport c => THandle c -> M () +runClientTransport :: Transport c => THandleSMP c -> M () runClientTransport th@THandle {params = THandleParams {thVersion, sessionId}} = do q <- asks $ tbqSize . config ts <- liftIO getSystemTime @@ -428,7 +428,7 @@ cancelSub sub = Sub {subThread = SubThread t} -> liftIO $ deRefWeak t >>= mapM_ killThread _ -> return () -receive :: Transport c => THandle c -> Client -> M () +receive :: Transport c => THandleSMP c -> Client -> M () receive th@THandle {params = THandleParams {thAuth}} Client {rcvQ, sndQ, rcvActiveAt, sessionId} = do labelMyThread . B.unpack $ "client $" <> encode sessionId <> " receive" forever $ do @@ -449,7 +449,7 @@ receive th@THandle {params = THandleParams {thAuth}} Client {rcvQ, sndQ, rcvActi VRFailed -> Left (corrId, queueId, ERR AUTH) write q = mapM_ (atomically . writeTBQueue q) . L.nonEmpty -send :: Transport c => THandle c -> Client -> IO () +send :: Transport c => THandleSMP c -> Client -> IO () send h@THandle {params} Client {sndQ, sessionId, sndActiveAt} = do labelMyThread . B.unpack $ "client $" <> encode sessionId <> " send" forever $ do @@ -464,7 +464,7 @@ send h@THandle {params} Client {sndQ, sessionId, sndActiveAt} = do NMSG {} -> 0 _ -> 1 -disconnectTransport :: Transport c => THandle c -> TVar SystemTime -> TVar SystemTime -> ExpirationConfig -> IO Bool -> IO () +disconnectTransport :: Transport c => THandle v c -> TVar SystemTime -> TVar SystemTime -> ExpirationConfig -> IO Bool -> IO () disconnectTransport THandle {connection, params = THandleParams {sessionId}} rcvActiveAt sndActiveAt expCfg noSubscriptions = do labelMyThread . B.unpack $ "client $" <> encode sessionId <> " disconnectTransport" loop diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index 82666a0fc..7b9fef0b3 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -33,9 +33,8 @@ import Simplex.Messaging.Server.Stats import Simplex.Messaging.Server.StoreLog import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Transport (ATransport) +import Simplex.Messaging.Transport (ATransport, VersionSMP, VersionRangeSMP) import Simplex.Messaging.Transport.Server (SocketState, TransportServerConfig, loadFingerprint, loadTLSServerParams, newSocketState) -import Simplex.Messaging.Version import System.IO (IOMode (..)) import System.Mem.Weak (Weak) import UnliftIO.STM @@ -73,7 +72,7 @@ data ServerConfig = ServerConfig privateKeyFile :: FilePath, certificateFile :: FilePath, -- | SMP client-server protocol version range - smpServerVRange :: VersionRange, + smpServerVRange :: VersionRangeSMP, -- | TCP transport config transportConfig :: TransportServerConfig, -- | run listener on control port @@ -128,7 +127,7 @@ data Client = Client sndQ :: TBQueue (NonEmpty (Transmission BrokerMsg)), endThreads :: TVar (IntMap (Weak ThreadId)), endThreadSeq :: TVar Int, - thVersion :: Version, + thVersion :: VersionSMP, sessionId :: ByteString, connected :: TVar Bool, createdAt :: SystemTime, @@ -152,7 +151,7 @@ newServer = do savingLock <- createLock return Server {subscribedQ, subscribers, ntfSubscribedQ, notifiers, savingLock} -newClient :: TVar Int -> Natural -> Version -> ByteString -> SystemTime -> STM Client +newClient :: TVar Int -> Natural -> VersionSMP -> ByteString -> SystemTime -> STM Client newClient nextClientId qSize thVersion sessionId createdAt = do clientId <- stateTVar nextClientId $ \next -> (next, next + 1) subscriptions <- TM.empty diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index 0d9552f9b..775400260 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -9,6 +9,7 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} @@ -27,10 +28,15 @@ -- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a module Simplex.Messaging.Transport ( -- * SMP transport parameters + SMPVersion, + VersionSMP, + VersionRangeSMP, + THandleSMP, supportedClientSMPRelayVRange, supportedServerSMPRelayVRange, currentClientSMPRelayVersion, currentServerSMPRelayVersion, + batchCmdsSMPVersion, basicAuthSMPVersion, subModeSMPVersion, authCmdsSMPVersion, @@ -85,6 +91,7 @@ import qualified Data.ByteString.Lazy.Char8 as LB import Data.Default (def) import Data.Functor (($>)) import Data.Version (showVersion) +import Data.Word (Word16) import qualified Data.X509 as X import qualified Data.X509.Validation as XV import GHC.IO.Handle.Internals (ioe_EOF) @@ -98,6 +105,7 @@ import Simplex.Messaging.Parsers (dropPrefix, parseRead1, sumTypeJSON) import Simplex.Messaging.Transport.Buffer import Simplex.Messaging.Util (bshow, catchAll, catchAll_, liftEitherWith) import Simplex.Messaging.Version +import Simplex.Messaging.Version.Internal import UnliftIO.Exception (Exception) import qualified UnliftIO.Exception as E import UnliftIO.STM @@ -116,30 +124,41 @@ smpBlockSize = 16384 -- 6 - allow creating queues without subscribing (9/10/2023) -- 7 - support authenticated encryption to verify senders' commands, imply but do NOT send session ID in signed part (2/3/2024) -batchCmdsSMPVersion :: Version -batchCmdsSMPVersion = 4 +data SMPVersion -basicAuthSMPVersion :: Version -basicAuthSMPVersion = 5 +instance VersionScope SMPVersion -subModeSMPVersion :: Version -subModeSMPVersion = 6 +type VersionSMP = Version SMPVersion -authCmdsSMPVersion :: Version -authCmdsSMPVersion = 7 +type VersionRangeSMP = VersionRange SMPVersion -currentClientSMPRelayVersion :: Version -currentClientSMPRelayVersion = 6 +pattern VersionSMP :: Word16 -> VersionSMP +pattern VersionSMP v = Version v -currentServerSMPRelayVersion :: Version -currentServerSMPRelayVersion = 6 +batchCmdsSMPVersion :: VersionSMP +batchCmdsSMPVersion = VersionSMP 4 + +basicAuthSMPVersion :: VersionSMP +basicAuthSMPVersion = VersionSMP 5 + +subModeSMPVersion :: VersionSMP +subModeSMPVersion = VersionSMP 6 + +authCmdsSMPVersion :: VersionSMP +authCmdsSMPVersion = VersionSMP 7 + +currentClientSMPRelayVersion :: VersionSMP +currentClientSMPRelayVersion = VersionSMP 6 + +currentServerSMPRelayVersion :: VersionSMP +currentServerSMPRelayVersion = VersionSMP 6 -- minimal supported protocol version is 4 -- TODO remove code that supports sending commands without batching -supportedClientSMPRelayVRange :: VersionRange +supportedClientSMPRelayVRange :: VersionRangeSMP supportedClientSMPRelayVRange = mkVersionRange batchCmdsSMPVersion currentClientSMPRelayVersion -supportedServerSMPRelayVRange :: VersionRange +supportedServerSMPRelayVRange :: VersionRangeSMP supportedServerSMPRelayVRange = mkVersionRange batchCmdsSMPVersion currentServerSMPRelayVersion simplexMQVersion :: String @@ -287,16 +306,18 @@ instance Transport TLS where -- * SMP transport -- | The handle for SMP encrypted transport connection over Transport. -data THandle c = THandle +data THandle v c = THandle { connection :: c, - params :: THandleParams + params :: THandleParams v } -data THandleParams = THandleParams +type THandleSMP c = THandle SMPVersion c + +data THandleParams v = THandleParams { sessionId :: SessionId, blockSize :: Int, -- | agreed server protocol version - thVersion :: Version, + thVersion :: Version v, -- | peer public key for command authorization and shared secrets for entity ID encryption thAuth :: Maybe THandleAuth, -- | do NOT send session ID in transmission, but include it into signed message @@ -316,7 +337,7 @@ data THandleAuth = THandleAuth type SessionId = ByteString data ServerHandshake = ServerHandshake - { smpVersionRange :: VersionRange, + { smpVersionRange :: VersionRangeSMP, sessionId :: SessionId, -- pub key to agree shared secrets for command authorization and entity ID encryption. authPubKey :: Maybe (X.CertificateChain, X.SignedExact X.PubKey) @@ -324,7 +345,7 @@ data ServerHandshake = ServerHandshake data ClientHandshake = ClientHandshake { -- | agreed SMP server protocol version - smpVersion :: Version, + smpVersion :: VersionSMP, -- | server identity - CA certificate fingerprint keyHash :: C.KeyHash, -- pub key to agree shared secret for entity ID encryption, shared secret for command authorization is agreed using per-queue keys. @@ -358,12 +379,12 @@ instance Encoding ServerHandshake where C.SignedObject key <- smpP pure (cert, key) -encodeAuthEncryptCmds :: Encoding a => Version -> Maybe a -> ByteString +encodeAuthEncryptCmds :: Encoding a => VersionSMP -> Maybe a -> ByteString encodeAuthEncryptCmds v k | v >= authCmdsSMPVersion = maybe "" smpEncode k | otherwise = "" -authEncryptCmdsP :: Version -> Parser a -> Parser (Maybe a) +authEncryptCmdsP :: VersionSMP -> Parser a -> Parser (Maybe a) authEncryptCmdsP v p = if v >= authCmdsSMPVersion then Just <$> p else pure Nothing -- | Error of SMP encrypted transport over TCP. @@ -412,13 +433,13 @@ serializeTransportError = \case TEHandshake e -> "HANDSHAKE " <> bshow e -- | Pad and send block to SMP transport. -tPutBlock :: Transport c => THandle c -> ByteString -> IO (Either TransportError ()) +tPutBlock :: Transport c => THandle v c -> ByteString -> IO (Either TransportError ()) tPutBlock THandle {connection = c, params = THandleParams {blockSize}} block = bimapM (const $ pure TELargeMsg) (cPut c) $ C.pad block blockSize -- | Receive block from SMP transport. -tGetBlock :: Transport c => THandle c -> IO (Either TransportError ByteString) +tGetBlock :: Transport c => THandle v c -> IO (Either TransportError ByteString) tGetBlock THandle {connection = c, params = THandleParams {blockSize}} = do msg <- cGet c blockSize if B.length msg == blockSize @@ -428,7 +449,7 @@ tGetBlock THandle {connection = c, params = THandleParams {blockSize}} = do -- | Server SMP transport handshake. -- -- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a -smpServerHandshake :: forall c. Transport c => C.APrivateSignKey -> c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c) +smpServerHandshake :: forall c. Transport c => C.APrivateSignKey -> c -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeSMP -> ExceptT TransportError IO (THandleSMP c) smpServerHandshake serverSignKey c (k, pk) kh smpVRange = do let th@THandle {params = THandleParams {sessionId}} = smpTHandle c sk = C.signX509 serverSignKey $ C.publicToX509 k @@ -445,7 +466,7 @@ smpServerHandshake serverSignKey c (k, pk) kh smpVRange = do -- | Client SMP transport handshake. -- -- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a -smpClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c) +smpClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeSMP -> ExceptT TransportError IO (THandleSMP c) smpClientHandshake c (k, pk) keyHash@(C.KeyHash kh) smpVRange = do let th@THandle {params = THandleParams {sessionId}} = smpTHandle c ServerHandshake {sessionId = sessId, smpVersionRange, authPubKey} <- getHandshake th @@ -465,24 +486,24 @@ smpClientHandshake c (k, pk) keyHash@(C.KeyHash kh) smpVRange = do pure $ smpThHandle th v pk sk_ Nothing -> throwE $ TEHandshake VERSION -smpThHandle :: forall c. THandle c -> Version -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> THandle c +smpThHandle :: forall c. THandleSMP c -> VersionSMP -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> THandleSMP c smpThHandle th@THandle {params} v privKey k_ = -- TODO drop SMP v6: make thAuth non-optional let thAuth = (\k -> THandleAuth {peerPubKey = k, privKey}) <$> k_ params' = params {thVersion = v, thAuth, implySessId = v >= authCmdsSMPVersion} - in (th :: THandle c) {params = params'} + in (th :: THandleSMP c) {params = params'} -sendHandshake :: (Transport c, Encoding smp) => THandle c -> smp -> ExceptT TransportError IO () +sendHandshake :: (Transport c, Encoding smp) => THandle v c -> smp -> ExceptT TransportError IO () sendHandshake th = ExceptT . tPutBlock th . smpEncode -- ignores tail bytes to allow future extensions -getHandshake :: (Transport c, Encoding smp) => THandle c -> ExceptT TransportError IO smp +getHandshake :: (Transport c, Encoding smp) => THandle v c -> ExceptT TransportError IO smp getHandshake th = ExceptT $ (first (\_ -> TEHandshake PARSE) . A.parseOnly smpP =<<) <$> tGetBlock th -smpTHandle :: Transport c => c -> THandle c +smpTHandle :: Transport c => c -> THandleSMP c smpTHandle c = THandle {connection = c, params} where - params = THandleParams {sessionId = tlsUnique c, blockSize = smpBlockSize, thVersion = 0, thAuth = Nothing, implySessId = False, batch = True} + params = THandleParams {sessionId = tlsUnique c, blockSize = smpBlockSize, thVersion = VersionSMP 0, thAuth = Nothing, implySessId = False, batch = True} $(J.deriveJSON (sumTypeJSON id) ''HandshakeError) diff --git a/src/Simplex/Messaging/Version.hs b/src/Simplex/Messaging/Version.hs index dc8cfff68..78d290687 100644 --- a/src/Simplex/Messaging/Version.hs +++ b/src/Simplex/Messaging/Version.hs @@ -1,13 +1,15 @@ {-# LANGUAGE ConstrainedClassMethods #-} {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeSynonymInstances #-} module Simplex.Messaging.Version ( Version, VersionRange (minVersion, maxVersion), + VersionScope, pattern VersionRange, VersionI (..), VersionRangeI (..), @@ -24,47 +26,45 @@ module Simplex.Messaging.Version where import Control.Applicative (optional) -import Data.Aeson (FromJSON (..), ToJSON (..)) import qualified Data.Attoparsec.ByteString.Char8 as A -import Data.Word (Word16) import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Util ((<$?>)) +import Simplex.Messaging.Version.Internal (Version (..)) -pattern VersionRange :: Word16 -> Word16 -> VersionRange +pattern VersionRange :: Version v -> Version v -> VersionRange v pattern VersionRange v1 v2 <- VRange v1 v2 {-# COMPLETE VersionRange #-} -type Version = Word16 - -data VersionRange = VRange - { minVersion :: Version, - maxVersion :: Version +data VersionRange v = VRange + { minVersion :: Version v, + maxVersion :: Version v } deriving (Eq, Show) +class VersionScope v + -- | construct valid version range, to be used in constants -mkVersionRange :: Version -> Version -> VersionRange +mkVersionRange :: Version v -> Version v -> VersionRange v mkVersionRange v1 v2 | v1 <= v2 = VRange v1 v2 | otherwise = error "invalid version range" -safeVersionRange :: Version -> Version -> Maybe VersionRange +safeVersionRange :: Version v -> Version v -> Maybe (VersionRange v) safeVersionRange v1 v2 | v1 <= v2 = Just $ VRange v1 v2 | otherwise = Nothing -versionToRange :: Version -> VersionRange +versionToRange :: Version v -> VersionRange v versionToRange v = VRange v v -instance Encoding VersionRange where +instance VersionScope v => Encoding (VersionRange v) where smpEncode (VRange v1 v2) = smpEncode (v1, v2) smpP = maybe (fail "invalid version range") pure =<< safeVersionRange <$> smpP <*> smpP -instance StrEncoding VersionRange where +instance VersionScope v => StrEncoding (VersionRange v) where strEncode (VRange v1 v2) | v1 == v2 = strEncode v1 | otherwise = strEncode v1 <> "-" <> strEncode v2 @@ -73,32 +73,23 @@ instance StrEncoding VersionRange where v2 <- maybe (pure v1) (const strP) =<< optional (A.char '-') maybe (fail "invalid version range") pure $ safeVersionRange v1 v2 -instance ToJSON VersionRange where - toJSON (VRange v1 v2) = toJSON (v1, v2) - toEncoding (VRange v1 v2) = toEncoding (v1, v2) +class VersionScope v => VersionI v a | a -> v where + type VersionRangeT v a + version :: a -> Version v + toVersionRangeT :: a -> VersionRange v -> VersionRangeT v a -instance FromJSON VersionRange where - parseJSON v = - (\(v1, v2) -> maybe (Left "bad VersionRange") Right $ safeVersionRange v1 v2) - <$?> parseJSON v +class VersionScope v => VersionRangeI v a | a -> v where + type VersionT v a + versionRange :: a -> VersionRange v + toVersionT :: a -> Version v -> VersionT v a -class VersionI a where - type VersionRangeT a - version :: a -> Version - toVersionRangeT :: a -> VersionRange -> VersionRangeT a - -class VersionRangeI a where - type VersionT a - versionRange :: a -> VersionRange - toVersionT :: a -> Version -> VersionT a - -instance VersionI Version where - type VersionRangeT Version = VersionRange +instance VersionScope v => VersionI v (Version v) where + type VersionRangeT v (Version v) = VersionRange v version = id toVersionRangeT _ vr = vr -instance VersionRangeI VersionRange where - type VersionT VersionRange = Version +instance VersionScope v => VersionRangeI v (VersionRange v) where + type VersionT v (VersionRange v) = Version v versionRange = id toVersionT _ v = v @@ -109,18 +100,18 @@ pattern Compatible a <- Compatible_ a {-# COMPLETE Compatible #-} -isCompatible :: VersionI a => a -> VersionRange -> Bool +isCompatible :: VersionI v a => a -> VersionRange v -> Bool isCompatible x (VRange v1 v2) = let v = version x in v1 <= v && v <= v2 -isCompatibleRange :: VersionRangeI a => a -> VersionRange -> Bool +isCompatibleRange :: VersionRangeI v a => a -> VersionRange v -> Bool isCompatibleRange x (VRange min2 max2) = min1 <= max2 && min2 <= max1 where VRange min1 max1 = versionRange x -proveCompatible :: VersionI a => a -> VersionRange -> Maybe (Compatible a) +proveCompatible :: VersionI v a => a -> VersionRange v -> Maybe (Compatible a) proveCompatible x vr = x `mkCompatibleIf` (x `isCompatible` vr) -compatibleVersion :: VersionRangeI a => a -> VersionRange -> Maybe (Compatible (VersionT a)) +compatibleVersion :: VersionRangeI v a => a -> VersionRange v -> Maybe (Compatible (VersionT v a)) compatibleVersion x vr = toVersionT x (min max1 max2) `mkCompatibleIf` isCompatibleRange x vr where diff --git a/src/Simplex/Messaging/Version/Internal.hs b/src/Simplex/Messaging/Version/Internal.hs new file mode 100644 index 000000000..23cab1d1f --- /dev/null +++ b/src/Simplex/Messaging/Version/Internal.hs @@ -0,0 +1,25 @@ +module Simplex.Messaging.Version.Internal where + +import Data.Aeson (FromJSON (..), ToJSON (..)) +import Data.Word (Word16) +import Simplex.Messaging.Encoding +import Simplex.Messaging.Encoding.String + +-- Do not use constructor of this type directry +newtype Version v = Version Word16 + deriving (Eq, Ord, Show) + +instance Encoding (Version v) where + smpEncode (Version v) = smpEncode v + smpP = Version <$> smpP + +instance StrEncoding (Version v) where + strEncode (Version v) = strEncode v + strP = Version <$> strP + +instance ToJSON (Version v) where + toEncoding (Version v) = toEncoding v + toJSON (Version v) = toJSON v + +instance FromJSON (Version v) where + parseJSON v = Version <$> parseJSON v diff --git a/src/Simplex/RemoteControl/Client.hs b/src/Simplex/RemoteControl/Client.hs index c73679439..3cf1050fa 100644 --- a/src/Simplex/RemoteControl/Client.hs +++ b/src/Simplex/RemoteControl/Client.hs @@ -68,12 +68,6 @@ import Simplex.RemoteControl.Types import UnliftIO import UnliftIO.Concurrent -currentRCVersion :: Version -currentRCVersion = 1 - -supportedRCVRange :: VersionRange -supportedRCVRange = mkVersionRange 1 currentRCVersion - xrcpBlockSize :: Int xrcpBlockSize = 16384 @@ -181,7 +175,7 @@ connectRCHost drg pairing@RCHostPairing {caKey, caCert, idPrivKey, knownHost} ct { ca = certFingerprint caCert, host, port = fromIntegral portNum, - v = supportedRCVRange, + v = supportedRCPVRange, app = ctrlAppInfo, ts, skey = fst sessKeys, @@ -220,7 +214,7 @@ prepareHostSession unless (ca == tlsHostFingerprint) $ throwError RCEIdentity (kemCiphertext, kemSharedKey) <- liftIO $ sntrup761Enc drg kemPubKey let hybridKey = kemHybridSecret dhPubKey dhPrivKey kemSharedKey - unless (isCompatible v supportedRCVRange) $ throwError RCEVersion + unless (isCompatible v supportedRCPVRange) $ throwError RCEVersion let keys = HostSessKeys {hybridKey, idPrivKey, sessPrivKey} knownHost' <- updateKnownHost ca dhPubKey let ctrlHello = RCCtrlHello {} @@ -334,7 +328,7 @@ prepareHostHello RCInvitation {v, dh = dhPubKey} hostAppInfo = do logDebug "Preparing session" - case compatibleVersion v supportedRCVRange of + case compatibleVersion v supportedRCPVRange of Nothing -> throwError RCEVersion Just (Compatible v') -> do nonce <- liftIO . atomically $ C.randomCbNonce drg diff --git a/src/Simplex/RemoteControl/Invitation.hs b/src/Simplex/RemoteControl/Invitation.hs index f5deac9a8..712c41a9d 100644 --- a/src/Simplex/RemoteControl/Invitation.hs +++ b/src/Simplex/RemoteControl/Invitation.hs @@ -27,7 +27,7 @@ import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (parseAll) import Simplex.Messaging.Transport.Client (TransportHost) -import Simplex.Messaging.Version (VersionRange) +import Simplex.RemoteControl.Types (VersionRangeRCP) data RCInvitation = RCInvitation { -- | CA TLS certificate fingerprint of the controller. @@ -37,7 +37,7 @@ data RCInvitation = RCInvitation host :: TransportHost, port :: Word16, -- | Supported version range for remote control protocol - v :: VersionRange, + v :: VersionRangeRCP, -- | Application information app :: J.Value, -- | Session start time in seconds since epoch diff --git a/src/Simplex/RemoteControl/Types.hs b/src/Simplex/RemoteControl/Types.hs index e1598f25c..b8a7c1141 100644 --- a/src/Simplex/RemoteControl/Types.hs +++ b/src/Simplex/RemoteControl/Types.hs @@ -5,6 +5,7 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TemplateHaskell #-} @@ -17,6 +18,7 @@ import Data.ByteString (ByteString) import Data.Text (Text) import qualified Data.Text as T import Data.Text.Encoding (encodeUtf8) +import Data.Word (Word16) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.SNTRUP761 import Simplex.Messaging.Crypto.SNTRUP761.Bindings @@ -26,7 +28,8 @@ import Simplex.Messaging.Parsers (defaultJSON, dropPrefix, sumTypeJSON) import Simplex.Messaging.Transport (TLS) import Simplex.Messaging.Transport.Client (TransportHost) import Simplex.Messaging.Util (safeDecodeUtf8) -import Simplex.Messaging.Version (Version, VersionRange, mkVersionRange) +import Simplex.Messaging.Version (VersionRange, VersionScope, mkVersionRange) +import Simplex.Messaging.Version.Internal import UnliftIO data RCErrorType @@ -92,24 +95,37 @@ instance StrEncoding RCErrorType where -- * Discovery -ipProbeVersionRange :: VersionRange -ipProbeVersionRange = mkVersionRange 1 1 +data RCPVersion + +instance VersionScope RCPVersion + +type VersionRCP = Version RCPVersion + +type VersionRangeRCP = VersionRange RCPVersion + +pattern VersionRCP :: Word16 -> VersionRCP +pattern VersionRCP v = Version v + +currentRCPVersion :: VersionRCP +currentRCPVersion = VersionRCP 1 + +supportedRCPVRange :: VersionRangeRCP +supportedRCPVRange = mkVersionRange (VersionRCP 1) currentRCPVersion data IpProbe = IpProbe - { versionRange :: VersionRange, + { versionRange :: VersionRangeRCP, randomNonce :: ByteString } deriving (Show) instance Encoding IpProbe where smpEncode IpProbe {versionRange, randomNonce} = smpEncode (versionRange, 'I', randomNonce) - smpP = IpProbe <$> (smpP <* "I") *> smpP -- * Session data RCHostHello = RCHostHello - { v :: Version, + { v :: VersionRCP, ca :: C.KeyHash, app :: J.Value, kem :: KEMPublicKey diff --git a/tests/AgentTests/ConnectionRequestTests.hs b/tests/AgentTests/ConnectionRequestTests.hs index 83548182a..a153f1d2a 100644 --- a/tests/AgentTests/ConnectionRequestTests.hs +++ b/tests/AgentTests/ConnectionRequestTests.hs @@ -2,6 +2,7 @@ {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} {-# OPTIONS_GHC -fno-warn-ambiguous-fields #-} module AgentTests.ConnectionRequestTests where @@ -12,7 +13,7 @@ import Simplex.Messaging.Agent.Protocol import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.Ratchet import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Protocol (ProtocolServer (..), supportedSMPClientVRange) +import Simplex.Messaging.Protocol (ProtocolServer (..), pattern VersionSMPC, supportedSMPClientVRange) import Simplex.Messaging.ServiceScheme (ServiceScheme (..)) import Simplex.Messaging.Version import Test.Hspec @@ -38,7 +39,7 @@ queue :: SMPQueueUri queue = SMPQueueUri supportedSMPClientVRange queueAddr queueV1 :: SMPQueueUri -queueV1 = SMPQueueUri (mkVersionRange 1 1) queueAddr +queueV1 = SMPQueueUri (mkVersionRange (VersionSMPC 1) (VersionSMPC 1)) queueAddr testDhKey :: C.PublicKeyX25519 testDhKey = "MCowBQYDK2VuAyEAjiswwI3O/NlS8Fk3HJUW870EY2bAwmttMBsvRB9eV3o=" @@ -53,7 +54,7 @@ connReqData :: ConnReqUriData connReqData = ConnReqUriData { crScheme = SSSimplex, - crAgentVRange = mkVersionRange 2 2, + crAgentVRange = mkVersionRange (VersionSMPA 2) (VersionSMPA 2), crSmpQueues = [queueV1], crClientData = Nothing } @@ -62,7 +63,7 @@ testDhPubKey :: C.PublicKeyX448 testDhPubKey = "MEIwBQYDK2VvAzkAmKuSYeQ/m0SixPDS8Wq8VBaTS1cW+Lp0n0h4Diu+kUpR+qXx4SDJ32YGEFoGFGSbGPry5Ychr6U=" testE2ERatchetParams :: E2ERatchetParamsUri 'C.X448 -testE2ERatchetParams = E2ERatchetParamsUri (mkVersionRange 1 1) testDhPubKey testDhPubKey +testE2ERatchetParams = E2ERatchetParamsUri (mkVersionRange (VersionE2E 1) (VersionE2E 1)) testDhPubKey testDhPubKey testE2ERatchetParams12 :: E2ERatchetParamsUri 'C.X448 testE2ERatchetParams12 = E2ERatchetParamsUri supportedE2EEncryptVRange testDhPubKey testDhPubKey @@ -98,7 +99,7 @@ connectionRequestTests = it "should serialize SMP queue URIs" $ do strEncode (queue :: SMPQueueUri) {queueAddress = queueAddrNoPort} `shouldBe` "smp://1234-w==@smp.simplex.im/3456-w==#/?v=1-2&dh=" <> testDhKeyStrUri - strEncode queue {clientVRange = mkVersionRange 1 2} + strEncode queue {clientVRange = mkVersionRange (VersionSMPC 1) (VersionSMPC 2)} `shouldBe` "smp://1234-w==@smp.simplex.im:5223/3456-w==#/?v=1-2&dh=" <> testDhKeyStrUri it "should parse SMP queue URIs" $ do strDecode ("smp://1234-w==@smp.simplex.im/3456-w==#/?v=1-2&dh=" <> testDhKeyStr) diff --git a/tests/AgentTests/DoubleRatchetTests.hs b/tests/AgentTests/DoubleRatchetTests.hs index 95e23b333..dbcff7ff0 100644 --- a/tests/AgentTests/DoubleRatchetTests.hs +++ b/tests/AgentTests/DoubleRatchetTests.hs @@ -12,7 +12,7 @@ module AgentTests.DoubleRatchetTests where import Control.Concurrent.STM import Control.Monad.Except import Crypto.Random (ChaChaDRG) -import Data.Aeson (FromJSON, ToJSON) +import Data.Aeson (FromJSON, ToJSON, (.=)) import qualified Data.Aeson as J import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B @@ -23,6 +23,7 @@ import Simplex.Messaging.Crypto.Ratchet import Simplex.Messaging.Encoding import Simplex.Messaging.Parsers (parseAll) import Simplex.Messaging.Util ((<$$>)) +import Simplex.Messaging.Version import Test.Hspec doubleRatchetTests :: Spec @@ -44,6 +45,7 @@ doubleRatchetTests = do testKeyJSON C.SX448 testRatchetJSON C.SX25519 testRatchetJSON C.SX448 + testVersionJSON it "should agree the same ratchet parameters" $ do testX3dh C.SX25519 testX3dh C.SX448 @@ -164,6 +166,20 @@ testRatchetJSON _ = do testEncodeDecode alice testEncodeDecode bob +testVersionJSON :: IO () +testVersionJSON = do + testEncodeDecode $ rv 1 1 + testEncodeDecode $ rv 1 2 + -- let bad = RVersions 2 1 + -- Left err <- pure $ J.eitherDecode' @RatchetVersions (J.encode bad) + -- err `shouldContain` "bad version range" + testDecodeRV $ (1 :: Int, 2 :: Int) + testDecodeRV $ J.object ["current" .= (1 :: Int), "maxSupported" .= (2 :: Int)] + where + rv v1 v2 = ratchetVersions $ mkVersionRange (VersionE2E v1) (VersionE2E v2) + testDecodeRV :: ToJSON a => a -> Expectation + testDecodeRV a = J.eitherDecode' (J.encode a) `shouldBe` Right (rv 1 2) + testEncodeDecode :: (Eq a, Show a, ToJSON a, FromJSON a) => a -> Expectation testEncodeDecode x = do let j = J.encode x @@ -182,8 +198,8 @@ testX3dh _ = do testX3dhV1 :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO () testX3dhV1 _ = do g <- C.newRandom - (pkBob1, pkBob2, e2eBob) <- atomically $ generateE2EParams @a g 1 - (pkAlice1, pkAlice2, e2eAlice) <- atomically $ generateE2EParams @a g 1 + (pkBob1, pkBob2, e2eBob) <- atomically $ generateE2EParams @a g (VersionE2E 1) + (pkAlice1, pkAlice2, e2eAlice) <- atomically $ generateE2EParams @a g (VersionE2E 1) let paramsBob = x3dhSnd pkBob1 pkBob2 e2eAlice paramsAlice = x3dhRcv pkAlice1 pkAlice2 e2eBob paramsAlice `shouldBe` paramsBob diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 5870266a7..6211df932 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -50,6 +50,7 @@ import qualified Data.Set as S import Data.Time.Clock (diffUTCTime, getCurrentTime) import Data.Time.Clock.System (SystemTime (..), getSystemTime) import Data.Type.Equality +import Data.Word (Word16) import qualified Database.SQLite.Simple as SQL import SMPAgentClient import SMPClient (cfg, testPort, testPort2, testStoreLogFile2, withSmpServer, withSmpServerConfigOn, withSmpServerOn, withSmpServerStoreLogOn, withSmpServerStoreMsgLogOn, withSmpServerV7) @@ -62,13 +63,15 @@ import Simplex.Messaging.Agent.Store.SQLite.Common (withTransaction') import Simplex.Messaging.Client (NetworkConfig (..), ProtocolClientConfig (..), TransportSessionMode (TSMEntity, TSMUser), defaultSMPClientConfig) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Notifications.Transport (authBatchCmdsNTFVersion) +import Simplex.Messaging.Notifications.Transport (NTFVersion, pattern VersionNTF, authBatchCmdsNTFVersion) import Simplex.Messaging.Protocol (BasicAuth, ErrorType (..), MsgBody, ProtocolServer (..), SubscriptionMode (..), supportedSMPClientVRange) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Server.Env.STM (ServerConfig (..)) import Simplex.Messaging.Server.Expiration -import Simplex.Messaging.Transport (ATransport (..), authCmdsSMPVersion, basicAuthSMPVersion, currentServerSMPRelayVersion) -import Simplex.Messaging.Version +import Simplex.Messaging.Transport (ATransport (..), SMPVersion, VersionSMP, authCmdsSMPVersion, batchCmdsSMPVersion, basicAuthSMPVersion, currentServerSMPRelayVersion) +import Simplex.Messaging.Version (VersionRange (..)) +import qualified Simplex.Messaging.Version as V +import Simplex.Messaging.Version.Internal (Version (..)) import System.Directory (copyFile, renameFile) import Test.Hspec import UnliftIO @@ -127,14 +130,14 @@ pattern MsgErr msgId err msgBody <- MSG MsgMeta {recipient = (msgId, _), integri pattern Rcvd :: AgentMsgId -> ACommand 'Agent e pattern Rcvd agentMsgId <- RCVD MsgMeta {integrity = MsgOk} [MsgReceipt {agentMsgId, msgRcptStatus = MROk}] -smpCfgVPrev :: ProtocolClientConfig +smpCfgVPrev :: ProtocolClientConfig SMPVersion smpCfgVPrev = (smpCfg agentCfg) {serverVRange = prevRange $ serverVRange $ smpCfg agentCfg} -smpCfgV7 :: ProtocolClientConfig -smpCfgV7 = (smpCfg agentCfg) {serverVRange = mkVersionRange 4 authCmdsSMPVersion} +smpCfgV7 :: ProtocolClientConfig SMPVersion +smpCfgV7 = (smpCfg agentCfg) {serverVRange = V.mkVersionRange batchCmdsSMPVersion authCmdsSMPVersion} -ntfCfgV2 :: ProtocolClientConfig -ntfCfgV2 = (smpCfg agentCfg) {serverVRange = mkVersionRange 1 authBatchCmdsNTFVersion} +ntfCfgV2 :: ProtocolClientConfig NTFVersion +ntfCfgV2 = (smpCfg agentCfg) {serverVRange = V.mkVersionRange (VersionNTF 1) authBatchCmdsNTFVersion} agentCfgVPrev :: AgentConfig agentCfgVPrev = @@ -157,8 +160,14 @@ agentCfgV7 = agentCfgRatchetVPrev :: AgentConfig agentCfgRatchetVPrev = agentCfg {e2eEncryptVRange = prevRange $ e2eEncryptVRange agentCfg} -prevRange :: VersionRange -> VersionRange -prevRange vr = vr {maxVersion = max (minVersion vr) (maxVersion vr - 1)} +prevRange :: VersionRange v -> VersionRange v +prevRange vr = vr {maxVersion = max (minVersion vr) (prevVersion $ maxVersion vr)} + +prevVersion :: Version v -> Version v +prevVersion (Version v) = Version (v - 1) + +mkVersionRange :: Word16 -> Word16 -> VersionRange v +mkVersionRange v1 v2 = V.mkVersionRange (Version v1) (Version v2) runRight_ :: (Eq e, Show e, HasCallStack) => ExceptT e IO () -> Expectation runRight_ action = runExceptT action `shouldReturn` Right () @@ -311,8 +320,8 @@ functionalAPITests t = do describe "should switch two connections simultaneously, abort one" $ testServerMatrix2 t testSwitch2ConnectionsAbort1 describe "SMP basic auth" $ do - let v4 = basicAuthSMPVersion - 1 - forM_ (nub [authCmdsSMPVersion - 1, authCmdsSMPVersion, currentServerSMPRelayVersion]) $ \v -> do + let v4 = prevVersion basicAuthSMPVersion + forM_ (nub [prevVersion authCmdsSMPVersion, authCmdsSMPVersion, currentServerSMPRelayVersion]) $ \v -> do describe ("v" <> show v <> ": with server auth") $ do -- allow NEW | server auth, v | clnt1 auth, v | clnt2 auth, v | 2 - success, 1 - JOIN fail, 0 - NEW fail it "success " $ testBasicAuth t True (Just "abcd", v) (Just "abcd", v) (Just "abcd", v) `shouldReturn` 2 @@ -356,9 +365,9 @@ functionalAPITests t = do it "should send delivery receipt only in connection v3+" $ testDeliveryReceiptsVersion t it "send delivery receipts concurrently with messages" $ testDeliveryReceiptsConcurrent t -testBasicAuth :: ATransport -> Bool -> (Maybe BasicAuth, Version) -> (Maybe BasicAuth, Version) -> (Maybe BasicAuth, Version) -> IO Int +testBasicAuth :: ATransport -> Bool -> (Maybe BasicAuth, VersionSMP) -> (Maybe BasicAuth, VersionSMP) -> (Maybe BasicAuth, VersionSMP) -> IO Int testBasicAuth t allowNewQueues srv@(srvAuth, srvVersion) clnt1 clnt2 = do - let testCfg = cfg {allowNewQueues, newQueueBasicAuth = srvAuth, smpServerVRange = mkVersionRange 4 srvVersion} + let testCfg = cfg {allowNewQueues, newQueueBasicAuth = srvAuth, smpServerVRange = V.mkVersionRange batchCmdsSMPVersion srvVersion} canCreate1 = canCreateQueue allowNewQueues srv clnt1 canCreate2 = canCreateQueue allowNewQueues srv clnt2 expected @@ -369,7 +378,7 @@ testBasicAuth t allowNewQueues srv@(srvAuth, srvVersion) clnt1 clnt2 = do created `shouldBe` expected pure created -canCreateQueue :: Bool -> (Maybe BasicAuth, Version) -> (Maybe BasicAuth, Version) -> Bool +canCreateQueue :: Bool -> (Maybe BasicAuth, VersionSMP) -> (Maybe BasicAuth, VersionSMP) -> Bool canCreateQueue allowNew (srvAuth, srvVersion) (clntAuth, clntVersion) = let v = basicAuthSMPVersion in allowNew && (isNothing srvAuth || (srvVersion >= v && clntVersion >= v && srvAuth == clntAuth)) @@ -690,10 +699,10 @@ testIncreaseConnAgentVersion t = do disconnectAgentClient alice3 disconnectAgentClient bob3 -checkVersion :: AgentClient -> ConnId -> Version -> ExceptT AgentErrorType IO () +checkVersion :: AgentClient -> ConnId -> Word16 -> ExceptT AgentErrorType IO () checkVersion c connId v = do ConnectionStats {connAgentVersion} <- getConnectionServers c connId - liftIO $ connAgentVersion `shouldBe` v + liftIO $ connAgentVersion `shouldBe` VersionSMPA v testIncreaseConnAgentVersionMaxCompatible :: HasCallStack => ATransport -> IO () testIncreaseConnAgentVersionMaxCompatible t = do @@ -2225,7 +2234,7 @@ testSwitch2ConnectionsAbort1 servers = do withB :: (AgentClient -> IO a) -> IO a withB = withAgent 2 agentCfg servers testDB2 -testCreateQueueAuth :: HasCallStack => Version -> (Maybe BasicAuth, Version) -> (Maybe BasicAuth, Version) -> IO Int +testCreateQueueAuth :: HasCallStack => VersionSMP -> (Maybe BasicAuth, VersionSMP) -> (Maybe BasicAuth, VersionSMP) -> IO Int testCreateQueueAuth srvVersion clnt1 clnt2 = do a <- getClient 1 clnt1 testDB b <- getClient 2 clnt2 testDB2 @@ -2251,7 +2260,7 @@ testCreateQueueAuth srvVersion clnt1 clnt2 = do where getClient clientId (clntAuth, clntVersion) db = let servers = initAgentServers {smp = userServers [ProtoServerWithAuth testSMPServer clntAuth]} - smpCfg = (defaultSMPClientConfig :: ProtocolClientConfig) {serverVRange = mkVersionRange (basicAuthSMPVersion - 1) clntVersion} + smpCfg = (defaultSMPClientConfig :: ProtocolClientConfig SMPVersion) {serverVRange = V.mkVersionRange (prevVersion basicAuthSMPVersion) clntVersion} sndAuthAlg = if srvVersion >= authCmdsSMPVersion && clntVersion >= authCmdsSMPVersion then C.AuthAlg C.SX25519 else C.AuthAlg C.SEd25519 in getSMPAgentClient' clientId agentCfg {smpCfg, sndAuthAlg} servers db diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index 714b7e15e..308b2593c 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -5,6 +5,7 @@ {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# OPTIONS_GHC -fno-warn-ambiguous-fields #-} @@ -42,7 +43,7 @@ import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.File (CryptoFile (..)) import Simplex.Messaging.Encoding.String (StrEncoding (..)) -import Simplex.Messaging.Protocol (SubscriptionMode (..)) +import Simplex.Messaging.Protocol (SubscriptionMode (..), pattern VersionSMPC) import qualified Simplex.Messaging.Protocol as SMP import System.Random import Test.Hspec @@ -174,7 +175,7 @@ testForeignKeysEnabled = `shouldThrow` (\e -> SQL.sqlError e == SQL.ErrorConstraint) cData1 :: ConnData -cData1 = ConnData {userId = 1, connId = "conn1", connAgentVersion = 1, enableNtfs = True, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk} +cData1 = ConnData {userId = 1, connId = "conn1", connAgentVersion = VersionSMPA 1, enableNtfs = True, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk} testPrivateAuthKey :: C.APrivateAuthKey testPrivateAuthKey = C.APrivateAuthKey C.SEd25519 "MC4CAQAwBQYDK2VwBCIEIDfEfevydXXfKajz3sRkcQ7RPvfWUPoq6pu1TYHV1DEe" @@ -205,7 +206,7 @@ rcvQueue1 = primary = True, dbReplaceQueueId = Nothing, rcvSwchStatus = Nothing, - smpClientVersion = 1, + smpClientVersion = VersionSMPC 1, clientNtfCreds = Nothing, deleteErrors = 0 } @@ -226,7 +227,7 @@ sndQueue1 = primary = True, dbReplaceQueueId = Nothing, sndSwchStatus = Nothing, - smpClientVersion = 1 + smpClientVersion = VersionSMPC 1 } createRcvConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> NewRcvQueue -> SConnectionMode c -> IO (Either StoreError (ConnId, RcvQueue)) @@ -370,7 +371,7 @@ testUpgradeRcvConnToDuplex = sndSwchStatus = Nothing, primary = True, dbReplaceQueueId = Nothing, - smpClientVersion = 1 + smpClientVersion = VersionSMPC 1 } upgradeRcvConnToDuplex db "conn1" anotherSndQueue `shouldReturn` Left (SEBadConnType CSnd) @@ -399,7 +400,7 @@ testUpgradeSndConnToDuplex = rcvSwchStatus = Nothing, primary = True, dbReplaceQueueId = Nothing, - smpClientVersion = 1, + smpClientVersion = VersionSMPC 1, clientNtfCreds = Nothing, deleteErrors = 0 } diff --git a/tests/CoreTests/BatchingTests.hs b/tests/CoreTests/BatchingTests.hs index eb9b62d3d..996d2fed1 100644 --- a/tests/CoreTests/BatchingTests.hs +++ b/tests/CoreTests/BatchingTests.hs @@ -15,7 +15,6 @@ import Simplex.Messaging.Client import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol import Simplex.Messaging.Transport -import Simplex.Messaging.Version (Version) import Test.Hspec batchingTests :: Spec @@ -253,27 +252,27 @@ testClientBatchWithLargeMessageV7 = do (length rs1', length rs2') `shouldBe` (74, 136) all lenOk [s1', s2'] `shouldBe` True -testClientStub :: IO (ProtocolClient ErrorType BrokerMsg) +testClientStub :: IO (ProtocolClient SMPVersion ErrorType BrokerMsg) testClientStub = do g <- C.newRandom sessId <- atomically $ C.randomBytes 32 g - atomically $ clientStub g sessId (authCmdsSMPVersion - 1) Nothing + atomically $ smpClientStub g sessId subModeSMPVersion Nothing -clientStubV7 :: IO (ProtocolClient ErrorType BrokerMsg) +clientStubV7 :: IO (ProtocolClient SMPVersion ErrorType BrokerMsg) clientStubV7 = do g <- C.newRandom sessId <- atomically $ C.randomBytes 32 g (rKey, _) <- atomically $ C.generateAuthKeyPair C.SX25519 g thAuth_ <- testTHandleAuth authCmdsSMPVersion g rKey - atomically $ clientStub g sessId authCmdsSMPVersion thAuth_ + atomically $ smpClientStub g sessId authCmdsSMPVersion thAuth_ randomSUB :: ByteString -> IO (Either TransportError (Maybe TransmissionAuth, ByteString)) -randomSUB = randomSUB_ C.SEd25519 (authCmdsSMPVersion - 1) +randomSUB = randomSUB_ C.SEd25519 subModeSMPVersion randomSUBv7 :: ByteString -> IO (Either TransportError (Maybe TransmissionAuth, ByteString)) randomSUBv7 = randomSUB_ C.SEd25519 authCmdsSMPVersion -randomSUB_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> Version -> ByteString -> IO (Either TransportError (Maybe TransmissionAuth, ByteString)) +randomSUB_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> VersionSMP -> ByteString -> IO (Either TransportError (Maybe TransmissionAuth, ByteString)) randomSUB_ a v sessId = do g <- C.newRandom rId <- atomically $ C.randomBytes 24 g @@ -284,13 +283,13 @@ randomSUB_ a v sessId = do TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (corrId, rId, Cmd SRecipient SUB) pure $ (,tToSend) <$> authTransmission thAuth_ (Just rpKey) corrId tForAuth -randomSUBCmd :: ProtocolClient ErrorType BrokerMsg -> IO (PCTransmission ErrorType BrokerMsg) +randomSUBCmd :: ProtocolClient SMPVersion ErrorType BrokerMsg -> IO (PCTransmission ErrorType BrokerMsg) randomSUBCmd = randomSUBCmd_ C.SEd25519 -randomSUBCmdV7 :: ProtocolClient ErrorType BrokerMsg -> IO (PCTransmission ErrorType BrokerMsg) +randomSUBCmdV7 :: ProtocolClient SMPVersion ErrorType BrokerMsg -> IO (PCTransmission ErrorType BrokerMsg) randomSUBCmdV7 = randomSUBCmd_ C.SEd25519 -- same as v6 -randomSUBCmd_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> ProtocolClient ErrorType BrokerMsg -> IO (PCTransmission ErrorType BrokerMsg) +randomSUBCmd_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> ProtocolClient SMPVersion ErrorType BrokerMsg -> IO (PCTransmission ErrorType BrokerMsg) randomSUBCmd_ a c = do g <- C.newRandom rId <- atomically $ C.randomBytes 24 g @@ -298,12 +297,12 @@ randomSUBCmd_ a c = do mkTransmission c (Just rpKey, rId, Cmd SRecipient SUB) randomSEND :: ByteString -> Int -> IO (Either TransportError (Maybe TransmissionAuth, ByteString)) -randomSEND = randomSEND_ C.SEd25519 (authCmdsSMPVersion - 1) +randomSEND = randomSEND_ C.SEd25519 subModeSMPVersion randomSENDv7 :: ByteString -> Int -> IO (Either TransportError (Maybe TransmissionAuth, ByteString)) randomSENDv7 = randomSEND_ C.SX25519 authCmdsSMPVersion -randomSEND_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> Version -> ByteString -> Int -> IO (Either TransportError (Maybe TransmissionAuth, ByteString)) +randomSEND_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> VersionSMP -> ByteString -> Int -> IO (Either TransportError (Maybe TransmissionAuth, ByteString)) randomSEND_ a v sessId len = do g <- C.newRandom sId <- atomically $ C.randomBytes 24 g @@ -315,7 +314,7 @@ randomSEND_ a v sessId len = do TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (corrId, sId, Cmd SSender $ SEND noMsgFlags msg) pure $ (,tToSend) <$> authTransmission thAuth_ (Just spKey) corrId tForAuth -testTHandleParams :: Version -> ByteString -> THandleParams +testTHandleParams :: VersionSMP -> ByteString -> THandleParams SMPVersion testTHandleParams v sessionId = THandleParams { sessionId, @@ -326,20 +325,20 @@ testTHandleParams v sessionId = batch = True } -testTHandleAuth :: Version -> TVar ChaChaDRG -> C.APublicAuthKey -> IO (Maybe THandleAuth) +testTHandleAuth :: VersionSMP -> TVar ChaChaDRG -> C.APublicAuthKey -> IO (Maybe THandleAuth) testTHandleAuth v g (C.APublicAuthKey a k) = case a of C.SX25519 | v >= authCmdsSMPVersion -> do (_, privKey) <- atomically $ C.generateKeyPair g pure $ Just THandleAuth {peerPubKey = k, privKey} _ -> pure Nothing -randomSENDCmd :: ProtocolClient ErrorType BrokerMsg -> Int -> IO (PCTransmission ErrorType BrokerMsg) +randomSENDCmd :: ProtocolClient SMPVersion ErrorType BrokerMsg -> Int -> IO (PCTransmission ErrorType BrokerMsg) randomSENDCmd = randomSENDCmd_ C.SEd25519 -randomSENDCmdV7 :: ProtocolClient ErrorType BrokerMsg -> Int -> IO (PCTransmission ErrorType BrokerMsg) +randomSENDCmdV7 :: ProtocolClient SMPVersion ErrorType BrokerMsg -> Int -> IO (PCTransmission ErrorType BrokerMsg) randomSENDCmdV7 = randomSENDCmd_ C.SX25519 -randomSENDCmd_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> ProtocolClient ErrorType BrokerMsg -> Int -> IO (PCTransmission ErrorType BrokerMsg) +randomSENDCmd_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> ProtocolClient SMPVersion ErrorType BrokerMsg -> Int -> IO (PCTransmission ErrorType BrokerMsg) randomSENDCmd_ a c len = do g <- C.newRandom sId <- atomically $ C.randomBytes 24 g diff --git a/tests/CoreTests/ProtocolErrorTests.hs b/tests/CoreTests/ProtocolErrorTests.hs index 6dc6f2c02..7b1a7b813 100644 --- a/tests/CoreTests/ProtocolErrorTests.hs +++ b/tests/CoreTests/ProtocolErrorTests.hs @@ -12,7 +12,7 @@ import qualified Data.Text as T import Data.Text.Encoding (encodeUtf8) import GHC.Generics (Generic) import Generic.Random (genericArbitraryU) -import Simplex.FileTransfer.Protocol (XFTPErrorType (..)) +import Simplex.FileTransfer.Transport (XFTPErrorType (..)) import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String diff --git a/tests/CoreTests/TRcvQueuesTests.hs b/tests/CoreTests/TRcvQueuesTests.hs index 70f2d93ab..ab28a145e 100644 --- a/tests/CoreTests/TRcvQueuesTests.hs +++ b/tests/CoreTests/TRcvQueuesTests.hs @@ -1,5 +1,6 @@ {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE TypeApplications #-} module CoreTests.TRcvQueuesTests where @@ -11,7 +12,7 @@ import Simplex.Messaging.Agent.Protocol (ConnId, QueueStatus (..), UserId) import Simplex.Messaging.Agent.Store (DBQueueId (..), RcvQueue, StoredRcvQueue (..)) import qualified Simplex.Messaging.Agent.TRcvQueues as RQ import qualified Simplex.Messaging.Crypto as C -import Simplex.Messaging.Protocol (SMPServer) +import Simplex.Messaging.Protocol (SMPServer, pattern VersionSMPC) import Test.Hspec import UnliftIO @@ -136,7 +137,7 @@ dummyRQ userId server connId = primary = True, dbReplaceQueueId = Nothing, rcvSwchStatus = Nothing, - smpClientVersion = 123, + smpClientVersion = VersionSMPC 123, clientNtfCreds = Nothing, deleteErrors = 0 } diff --git a/tests/CoreTests/VersionRangeTests.hs b/tests/CoreTests/VersionRangeTests.hs index be02e38b7..cef556376 100644 --- a/tests/CoreTests/VersionRangeTests.hs +++ b/tests/CoreTests/VersionRangeTests.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} module CoreTests.VersionRangeTests where @@ -8,6 +9,7 @@ module CoreTests.VersionRangeTests where import GHC.Generics (Generic) import Generic.Random (genericArbitraryU) import Simplex.Messaging.Version +import Simplex.Messaging.Version.Internal import Test.Hspec import Test.Hspec.QuickCheck (modifyMaxSuccess) import Test.QuickCheck @@ -16,6 +18,10 @@ data V = V1 | V2 | V3 | V4 | V5 deriving (Eq, Enum, Ord, Generic, Show) instance Arbitrary V where arbitrary = genericArbitraryU +data T + +instance VersionScope T + versionRangeTests :: Spec versionRangeTests = modifyMaxSuccess (const 1000) $ do describe "VersionRange construction" $ do @@ -25,31 +31,31 @@ versionRangeTests = modifyMaxSuccess (const 1000) $ do (pure $! vr 2 1) `shouldThrow` anyErrorCall describe "compatible version" $ do it "should choose mutually compatible max version" $ do - (vr 1 1, vr 1 1) `compatible` Just 1 - (vr 1 1, vr 1 2) `compatible` Just 1 - (vr 1 2, vr 1 2) `compatible` Just 2 - (vr 1 2, vr 2 3) `compatible` Just 2 - (vr 1 3, vr 2 3) `compatible` Just 3 - (vr 1 3, vr 2 4) `compatible` Just 3 + (vr 1 1, vr 1 1) `compatible` Just (Version 1) + (vr 1 1, vr 1 2) `compatible` Just (Version 1) + (vr 1 2, vr 1 2) `compatible` Just (Version 2) + (vr 1 2, vr 2 3) `compatible` Just (Version 2) + (vr 1 3, vr 2 3) `compatible` Just (Version 3) + (vr 1 3, vr 2 4) `compatible` Just (Version 3) (vr 1 2, vr 3 4) `compatible` Nothing it "should check if version is compatible" $ do - isCompatible (1 :: Version) (vr 1 2) `shouldBe` True - isCompatible (2 :: Version) (vr 1 2) `shouldBe` True - isCompatible (2 :: Version) (vr 1 1) `shouldBe` False - isCompatible (1 :: Version) (vr 2 2) `shouldBe` False + isCompatible @T (Version 1) (vr 1 2) `shouldBe` True + isCompatible @T (Version 2) (vr 1 2) `shouldBe` True + isCompatible @T (Version 2) (vr 1 1) `shouldBe` False + isCompatible @T (Version 1) (vr 2 2) `shouldBe` False it "compatibleVersion should pass isCompatible check" . property $ \((min1, max1) :: (V, V)) ((min2, max2) :: (V, V)) -> min1 > max1 || min2 > max2 -- one of ranges is invalid, skip testing it - || let w = fromIntegral . fromEnum - vr1 = mkVersionRange (w min1) (w max1) :: VersionRange - vr2 = mkVersionRange (w min2) (w max2) :: VersionRange + || let w = Version . fromIntegral . fromEnum + vr1 = mkVersionRange (w min1) (w max1) :: VersionRange T + vr2 = mkVersionRange (w min2) (w max2) :: VersionRange T in case compatibleVersion vr1 vr2 of Just (Compatible v) -> v `isCompatible` vr1 && v `isCompatible` vr2 _ -> True where - vr = mkVersionRange - compatible :: (VersionRange, VersionRange) -> Maybe Version -> Expectation + vr v1 v2 = mkVersionRange (Version v1) (Version v2) + compatible :: (VersionRange T, VersionRange T) -> Maybe (Version T) -> Expectation (vr1, vr2) `compatible` v = do (vr1, vr2) `checkCompatible` v (vr2, vr1) `checkCompatible` v diff --git a/tests/NtfClient.hs b/tests/NtfClient.hs index 43558a86c..5a2dbb8de 100644 --- a/tests/NtfClient.hs +++ b/tests/NtfClient.hs @@ -34,6 +34,7 @@ import Simplex.Messaging.Client (ProtocolClientConfig (..), chooseTransportHost, import Simplex.Messaging.Client.Agent (SMPClientAgentConfig (..), defaultSMPClientAgentConfig) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding +import Simplex.Messaging.Notifications.Protocol (NtfResponse) import Simplex.Messaging.Notifications.Server (runNtfServerBlocking) import Simplex.Messaging.Notifications.Server.Env import Simplex.Messaging.Notifications.Server.Push.APNS @@ -70,7 +71,7 @@ testKeyHash = "LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=" ntfTestStoreLogFile :: FilePath ntfTestStoreLogFile = "tests/tmp/ntf-server-store.log" -testNtfClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandle c -> m a) -> m a +testNtfClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandleNTF c -> m a) -> m a testNtfClient client = do Right host <- pure $ chooseTransportHost defaultNetworkConfig testHost runTransportClient defaultTransportClientConfig Nothing host ntfTestPort (Just testKeyHash) $ \h -> do @@ -114,8 +115,8 @@ ntfServerCfg = ntfServerCfgV2 :: NtfServerConfig ntfServerCfgV2 = ntfServerCfg - { ntfServerVRange = mkVersionRange 1 authBatchCmdsNTFVersion, - smpAgentCfg = defaultSMPClientAgentConfig {smpCfg = (smpCfg defaultSMPClientAgentConfig) {serverVRange = mkVersionRange 4 authCmdsSMPVersion}} + { ntfServerVRange = mkVersionRange initialNTFVersion authBatchCmdsNTFVersion, + smpAgentCfg = defaultSMPClientAgentConfig {smpCfg = (smpCfg defaultSMPClientAgentConfig) {serverVRange = mkVersionRange batchCmdsSMPVersion authCmdsSMPVersion}} } withNtfServerStoreLog :: ATransport -> (ThreadId -> IO a) -> IO a @@ -139,7 +140,7 @@ withNtfServerOn t port' = withNtfServerThreadOn t port' . const withNtfServer :: ATransport -> IO a -> IO a withNtfServer t = withNtfServerOn t ntfTestPort -runNtfTest :: forall c a. Transport c => (THandle c -> IO a) -> IO a +runNtfTest :: forall c a. Transport c => (THandleNTF c -> IO a) -> IO a runNtfTest test = withNtfServer (transport @c) $ testNtfClient test ntfServerTest :: @@ -147,10 +148,10 @@ ntfServerTest :: (Transport c, Encoding smp) => TProxy c -> (Maybe TransmissionAuth, ByteString, ByteString, smp) -> - IO (Maybe TransmissionAuth, ByteString, ByteString, BrokerMsg) + IO (Maybe TransmissionAuth, ByteString, ByteString, NtfResponse) ntfServerTest _ t = runNtfTest $ \h -> tPut' h t >> tGet' h where - tPut' :: THandle c -> (Maybe TransmissionAuth, ByteString, ByteString, smp) -> IO () + tPut' :: THandleNTF c -> (Maybe TransmissionAuth, ByteString, ByteString, smp) -> IO () tPut' h@THandle {params = THandleParams {sessionId, implySessId}} (sig, corrId, queueId, smp) = do let t' = if implySessId then smpEncode (corrId, queueId, smp) else smpEncode (sessionId, corrId, queueId, smp) [Right ()] <- tPut h [Right (sig, t')] @@ -159,7 +160,7 @@ ntfServerTest _ t = runNtfTest $ \h -> tPut' h t >> tGet' h [(Nothing, _, (CorrId corrId, qId, Right cmd))] <- tGet h pure (Nothing, corrId, qId, cmd) -ntfTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> Expectation +ntfTest :: Transport c => TProxy c -> (THandleNTF c -> IO ()) -> Expectation ntfTest _ test' = runNtfTest test' `shouldReturn` () data APNSMockRequest = APNSMockRequest diff --git a/tests/NtfServerTests.hs b/tests/NtfServerTests.hs index e29a292ee..e7e2018c2 100644 --- a/tests/NtfServerTests.hs +++ b/tests/NtfServerTests.hs @@ -5,7 +5,9 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} {-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-} +{-# OPTIONS_GHC -Wno-orphans #-} module NtfServerTests where @@ -37,6 +39,7 @@ import Simplex.Messaging.Encoding.String import Simplex.Messaging.Notifications.Protocol import Simplex.Messaging.Notifications.Server.Push.APNS import qualified Simplex.Messaging.Notifications.Server.Push.APNS as APNS +import Simplex.Messaging.Notifications.Transport (THandleNTF) import Simplex.Messaging.Parsers (parse, parseAll) import Simplex.Messaging.Protocol hiding (notification) import Simplex.Messaging.Transport @@ -50,30 +53,32 @@ ntfServerTests t = do ntfSyntaxTests :: ATransport -> Spec ntfSyntaxTests (ATransport t) = do - it "unknown command" $ ("", "abcd", "1234", ('H', 'E', 'L', 'L', 'O')) >#> ("", "abcd", "1234", ERR $ CMD UNKNOWN) + it "unknown command" $ ("", "abcd", "1234", ('H', 'E', 'L', 'L', 'O')) >#> ("", "abcd", "1234", NRErr $ CMD UNKNOWN) describe "NEW" $ do - it "no parameters" $ (sampleSig, "bcda", "", TNEW_) >#> ("", "bcda", "", ERR $ CMD SYNTAX) - it "many parameters" $ (sampleSig, "cdab", "", (TNEW_, (' ', '\x01', 'A'), ('T', 'A', 'T', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "cdab", "", ERR $ CMD SYNTAX) - it "no signature" $ ("", "dabc", "", (TNEW_, ' ', ('T', 'A', 'T', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "dabc", "", ERR $ CMD NO_AUTH) - it "token ID" $ (sampleSig, "abcd", "12345678", (TNEW_, ' ', ('T', 'A', 'T', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "abcd", "12345678", ERR $ CMD HAS_AUTH) + it "no parameters" $ (sampleSig, "bcda", "", TNEW_) >#> ("", "bcda", "", NRErr $ CMD SYNTAX) + it "many parameters" $ (sampleSig, "cdab", "", (TNEW_, (' ', '\x01', 'A'), ('T', 'A', 'T', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "cdab", "", NRErr $ CMD SYNTAX) + it "no signature" $ ("", "dabc", "", (TNEW_, ' ', ('T', 'A', 'T', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "dabc", "", NRErr $ CMD NO_AUTH) + it "token ID" $ (sampleSig, "abcd", "12345678", (TNEW_, ' ', ('T', 'A', 'T', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "abcd", "12345678", NRErr $ CMD HAS_AUTH) where (>#>) :: Encoding smp => (Maybe TransmissionAuth, ByteString, ByteString, smp) -> - (Maybe TransmissionAuth, ByteString, ByteString, BrokerMsg) -> + (Maybe TransmissionAuth, ByteString, ByteString, NtfResponse) -> Expectation command >#> response = withAPNSMockServer $ \_ -> ntfServerTest t command `shouldReturn` response pattern RespNtf :: CorrId -> QueueId -> NtfResponse -> SignedTransmission ErrorType NtfResponse pattern RespNtf corrId queueId command <- (_, _, (corrId, queueId, Right command)) -sendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandle c -> (Maybe TransmissionAuth, ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission ErrorType NtfResponse) +deriving instance Eq NtfResponse + +sendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandleNTF c -> (Maybe TransmissionAuth, ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission ErrorType NtfResponse) sendRecvNtf h@THandle {params} (sgn, corrId, qId, cmd) = do let TransmissionForAuth {tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd) Right () <- tPut1 h (sgn, tToSend) tGet1 h -signSendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandle c -> C.APrivateAuthKey -> (ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission ErrorType NtfResponse) +signSendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandleNTF c -> C.APrivateAuthKey -> (ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission ErrorType NtfResponse) signSendRecvNtf h@THandle {params} (C.APrivateAuthKey a pk) (corrId, qId, cmd) = do let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd) Right () <- tPut1 h (authorize tForAuth, tToSend) diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index f1ed84d68..e63678c7d 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -26,7 +26,7 @@ import Simplex.Messaging.Server.Env.STM import Simplex.Messaging.Transport import Simplex.Messaging.Transport.Client import Simplex.Messaging.Transport.Server -import Simplex.Messaging.Version (VersionRange, mkVersionRange) +import Simplex.Messaging.Version (mkVersionRange) import System.Environment (lookupEnv) import System.Info (os) import Test.Hspec @@ -67,10 +67,10 @@ xit'' d t = do ci <- runIO $ lookupEnv "CI" (if ci == Just "true" then xit else it) d t -testSMPClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandle c -> m a) -> m a +testSMPClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandleSMP c -> m a) -> m a testSMPClient = testSMPClientVR supportedClientSMPRelayVRange -testSMPClientVR :: (Transport c, MonadUnliftIO m, MonadFail m) => VersionRange -> (THandle c -> m a) -> m a +testSMPClientVR :: (Transport c, MonadUnliftIO m, MonadFail m) => VersionRangeSMP -> (THandleSMP c -> m a) -> m a testSMPClientVR vr client = do Right useHost <- pure $ chooseTransportHost defaultNetworkConfig testHost runTransportClient defaultTransportClientConfig Nothing useHost testPort (Just testKeyHash) $ \h -> do @@ -109,7 +109,7 @@ cfg = } cfgV7 :: ServerConfig -cfgV7 = cfg {smpServerVRange = mkVersionRange 4 authCmdsSMPVersion} +cfgV7 = cfg {smpServerVRange = mkVersionRange batchCmdsSMPVersion authCmdsSMPVersion} withSmpServerStoreMsgLogOn :: HasCallStack => ATransport -> ServiceName -> (HasCallStack => ThreadId -> IO a) -> IO a withSmpServerStoreMsgLogOn t = withSmpServerConfigOn t cfg {storeLogFile = Just testStoreLogFile, storeMsgsFile = Just testStoreMsgsFile, serverStatsBackupFile = Just testServerStatsBackupFile} @@ -148,16 +148,16 @@ withSmpServer t = withSmpServerOn t testPort withSmpServerV7 :: HasCallStack => ATransport -> IO a -> IO a withSmpServerV7 t = withSmpServerConfigOn t cfgV7 testPort . const -runSmpTest :: forall c a. (HasCallStack, Transport c) => (HasCallStack => THandle c -> IO a) -> IO a +runSmpTest :: forall c a. (HasCallStack, Transport c) => (HasCallStack => THandleSMP c -> IO a) -> IO a runSmpTest test = withSmpServer (transport @c) $ testSMPClient test -runSmpTestN :: forall c a. (HasCallStack, Transport c) => Int -> (HasCallStack => [THandle c] -> IO a) -> IO a +runSmpTestN :: forall c a. (HasCallStack, Transport c) => Int -> (HasCallStack => [THandleSMP c] -> IO a) -> IO a runSmpTestN = runSmpTestNCfg cfg supportedClientSMPRelayVRange -runSmpTestNCfg :: forall c a. (HasCallStack, Transport c) => ServerConfig -> VersionRange -> Int -> (HasCallStack => [THandle c] -> IO a) -> IO a +runSmpTestNCfg :: forall c a. (HasCallStack, Transport c) => ServerConfig -> VersionRangeSMP -> Int -> (HasCallStack => [THandleSMP c] -> IO a) -> IO a runSmpTestNCfg srvCfg clntVR nClients test = withSmpServerConfigOn (transport @c) srvCfg testPort $ \_ -> run nClients [] where - run :: Int -> [THandle c] -> IO a + run :: Int -> [THandleSMP c] -> IO a run 0 hs = test hs run n hs = testSMPClientVR clntVR $ \h -> run (n - 1) (h : hs) @@ -169,7 +169,7 @@ smpServerTest :: IO (Maybe TransmissionAuth, ByteString, ByteString, BrokerMsg) smpServerTest _ t = runSmpTest $ \h -> tPut' h t >> tGet' h where - tPut' :: THandle c -> (Maybe TransmissionAuth, ByteString, ByteString, smp) -> IO () + tPut' :: THandleSMP c -> (Maybe TransmissionAuth, ByteString, ByteString, smp) -> IO () tPut' h@THandle {params = THandleParams {sessionId, implySessId}} (sig, corrId, queueId, smp) = do let t' = if implySessId then smpEncode (corrId, queueId, smp) else smpEncode (sessionId, corrId, queueId, smp) [Right ()] <- tPut h [Right (sig, t')] @@ -178,33 +178,33 @@ smpServerTest _ t = runSmpTest $ \h -> tPut' h t >> tGet' h [(Nothing, _, (CorrId corrId, qId, Right cmd))] <- tGet h pure (Nothing, corrId, qId, cmd) -smpTest :: (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandle c -> IO ()) -> Expectation +smpTest :: (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandleSMP c -> IO ()) -> Expectation smpTest _ test' = runSmpTest test' `shouldReturn` () -smpTestN :: (HasCallStack, Transport c) => Int -> (HasCallStack => [THandle c] -> IO ()) -> Expectation +smpTestN :: (HasCallStack, Transport c) => Int -> (HasCallStack => [THandleSMP c] -> IO ()) -> Expectation smpTestN n test' = runSmpTestN n test' `shouldReturn` () -smpTest2 :: forall c. (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandle c -> THandle c -> IO ()) -> Expectation +smpTest2 :: forall c. (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandleSMP c -> THandleSMP c -> IO ()) -> Expectation smpTest2 = smpTest2Cfg cfg supportedClientSMPRelayVRange -smpTest2Cfg :: forall c. (HasCallStack, Transport c) => ServerConfig -> VersionRange -> TProxy c -> (HasCallStack => THandle c -> THandle c -> IO ()) -> Expectation +smpTest2Cfg :: forall c. (HasCallStack, Transport c) => ServerConfig -> VersionRangeSMP -> TProxy c -> (HasCallStack => THandleSMP c -> THandleSMP c -> IO ()) -> Expectation smpTest2Cfg srvCfg clntVR _ test' = runSmpTestNCfg srvCfg clntVR 2 _test `shouldReturn` () where - _test :: HasCallStack => [THandle c] -> IO () + _test :: HasCallStack => [THandleSMP c] -> IO () _test [h1, h2] = test' h1 h2 _test _ = error "expected 2 handles" -smpTest3 :: forall c. (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandle c -> THandle c -> THandle c -> IO ()) -> Expectation +smpTest3 :: forall c. (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandleSMP c -> THandleSMP c -> THandleSMP c -> IO ()) -> Expectation smpTest3 _ test' = smpTestN 3 _test where - _test :: HasCallStack => [THandle c] -> IO () + _test :: HasCallStack => [THandleSMP c] -> IO () _test [h1, h2, h3] = test' h1 h2 h3 _test _ = error "expected 3 handles" -smpTest4 :: forall c. (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandle c -> THandle c -> THandle c -> THandle c -> IO ()) -> Expectation +smpTest4 :: forall c. (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandleSMP c -> THandleSMP c -> THandleSMP c -> THandleSMP c -> IO ()) -> Expectation smpTest4 _ test' = smpTestN 4 _test where - _test :: HasCallStack => [THandle c] -> IO () + _test :: HasCallStack => [THandleSMP c] -> IO () _test [h1, h2, h3, h4] = test' h1 h2 h3 h4 _test _ = error "expected 4 handles" diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index d6938fa0f..5770d7922 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -1,6 +1,7 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} @@ -74,13 +75,13 @@ pattern Ids rId sId srvDh <- IDS (QIK rId sId srvDh) pattern Msg :: MsgId -> MsgBody -> BrokerMsg pattern Msg msgId body <- MSG RcvMessage {msgId, msgBody = EncRcvMsgBody body} -sendRecv :: forall c p. (Transport c, PartyI p) => THandle c -> (Maybe TransmissionAuth, ByteString, ByteString, Command p) -> IO (SignedTransmission ErrorType BrokerMsg) +sendRecv :: forall c p. (Transport c, PartyI p) => THandleSMP c -> (Maybe TransmissionAuth, ByteString, ByteString, Command p) -> IO (SignedTransmission ErrorType BrokerMsg) sendRecv h@THandle {params} (sgn, corrId, qId, cmd) = do let TransmissionForAuth {tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd) Right () <- tPut1 h (sgn, tToSend) tGet1 h -signSendRecv :: forall c p. (Transport c, PartyI p) => THandle c -> C.APrivateAuthKey -> (ByteString, ByteString, Command p) -> IO (SignedTransmission ErrorType BrokerMsg) +signSendRecv :: forall c p. (Transport c, PartyI p) => THandleSMP c -> C.APrivateAuthKey -> (ByteString, ByteString, Command p) -> IO (SignedTransmission ErrorType BrokerMsg) signSendRecv h@THandle {params} (C.APrivateAuthKey a pk) (corrId, qId, cmd) = do let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd) Right () <- tPut1 h (authorize tForAuth, tToSend) @@ -94,12 +95,12 @@ signSendRecv h@THandle {params} (C.APrivateAuthKey a pk) (corrId, qId, cmd) = do _sx448 -> undefined -- ghc8107 fails to the branch excluded by types #endif -tPut1 :: Transport c => THandle c -> SentRawTransmission -> IO (Either TransportError ()) +tPut1 :: Transport c => THandle v c -> SentRawTransmission -> IO (Either TransportError ()) tPut1 h t = do [r] <- tPut h [Right t] pure r -tGet1 :: (ProtocolEncoding err cmd, Transport c, MonadIO m, MonadFail m) => THandle c -> m (SignedTransmission err cmd) +tGet1 :: (ProtocolEncoding v err cmd, Transport c, MonadIO m, MonadFail m) => THandle v c -> m (SignedTransmission err cmd) tGet1 h = do [r] <- liftIO $ tGet h pure r @@ -380,7 +381,7 @@ testSwitchSub (ATransport t) = Resp "bcda" _ ok3 <- signSendRecv rh2 rKey ("bcda", rId, ACK mId3) (ok3, OK) #== "accepts ACK from the 2nd TCP connection" - 1000 `timeout` tGet @ErrorType @BrokerMsg rh1 >>= \case + 1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg rh1 >>= \case Nothing -> return () Just _ -> error "nothing else is delivered to the 1st TCP connection" @@ -551,12 +552,12 @@ testWithStoreLog at@(ATransport t) = logSize testStoreLogFile `shouldReturn` 1 removeFile testStoreLogFile where - runTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> ThreadId -> Expectation + runTest :: Transport c => TProxy c -> (THandleSMP c -> IO ()) -> ThreadId -> Expectation runTest _ test' server = do testSMPClient test' `shouldReturn` () killThread server - runClient :: Transport c => TProxy c -> (THandle c -> IO ()) -> Expectation + runClient :: Transport c => TProxy c -> (THandleSMP c -> IO ()) -> Expectation runClient _ test' = testSMPClient test' `shouldReturn` () logSize :: FilePath -> IO Int @@ -649,12 +650,12 @@ testRestoreMessages at@(ATransport t) = removeFile testStoreMsgsFile removeFile testServerStatsBackupFile where - runTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> ThreadId -> Expectation + runTest :: Transport c => TProxy c -> (THandleSMP c -> IO ()) -> ThreadId -> Expectation runTest _ test' server = do testSMPClient test' `shouldReturn` () killThread server - runClient :: Transport c => TProxy c -> (THandle c -> IO ()) -> Expectation + runClient :: Transport c => TProxy c -> (THandleSMP c -> IO ()) -> Expectation runClient _ test' = testSMPClient test' `shouldReturn` () checkStats :: ServerStatsData -> [RecipientId] -> Int -> Int -> Expectation @@ -723,15 +724,15 @@ testRestoreExpireMessages at@(ATransport t) = Right ServerStatsData {_msgExpired} <- strDecode <$> B.readFile testServerStatsBackupFile _msgExpired `shouldBe` 2 where - runTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> ThreadId -> Expectation + runTest :: Transport c => TProxy c -> (THandleSMP c -> IO ()) -> ThreadId -> Expectation runTest _ test' server = do testSMPClient test' `shouldReturn` () killThread server - runClient :: Transport c => TProxy c -> (THandle c -> IO ()) -> Expectation + runClient :: Transport c => TProxy c -> (THandleSMP c -> IO ()) -> Expectation runClient _ test' = testSMPClient test' `shouldReturn` () -createAndSecureQueue :: Transport c => THandle c -> SndPublicAuthKey -> IO (SenderId, RecipientId, RcvPrivateAuthKey, RcvDhSecret) +createAndSecureQueue :: Transport c => THandleSMP c -> SndPublicAuthKey -> IO (SenderId, RecipientId, RcvPrivateAuthKey, RcvDhSecret) createAndSecureQueue h sPub = do g <- C.newRandom (rPub, rKey) <- atomically $ C.generateAuthKeyPair C.SEd448 g @@ -747,7 +748,7 @@ testTiming (ATransport t) = describe "should have similar time for auth error, whether queue exists or not, for all key types" $ forM_ timingTests $ \tst -> it (testName tst) $ - smpTest2Cfg cfgV7 (mkVersionRange 4 authCmdsSMPVersion) t $ \rh sh -> + smpTest2Cfg cfgV7 (mkVersionRange batchCmdsSMPVersion authCmdsSMPVersion) t $ \rh sh -> testSameTiming rh sh tst where testName :: (C.AuthAlg, C.AuthAlg, Int) -> String @@ -766,7 +767,7 @@ testTiming (ATransport t) = ] timeRepeat n = fmap fst . timeItT . forM_ (replicate n ()) . const similarTime t1 t2 = abs (t2 / t1 - 1) < 0.15 -- normally the difference between "no queue" and "wrong key" is less than 5% - testSameTiming :: forall c. Transport c => THandle c -> THandle c -> (C.AuthAlg, C.AuthAlg, Int) -> Expectation + testSameTiming :: forall c. Transport c => THandleSMP c -> THandleSMP c -> (C.AuthAlg, C.AuthAlg, Int) -> Expectation testSameTiming rh sh (C.AuthAlg goodKeyAlg, C.AuthAlg badKeyAlg, n) = do g <- C.newRandom (rPub, rKey) <- atomically $ C.generateAuthKeyPair goodKeyAlg g @@ -787,7 +788,7 @@ testTiming (ATransport t) = runTimingTest sh badKey sId $ _SEND "hello" where - runTimingTest :: PartyI p => THandle c -> C.APrivateAuthKey -> ByteString -> Command p -> IO () + runTimingTest :: PartyI p => THandleSMP c -> C.APrivateAuthKey -> ByteString -> Command p -> IO () runTimingTest h badKey qId cmd = do threadDelay 100000 _ <- timeRepeat n $ do -- "warm up" the server @@ -837,14 +838,14 @@ testMessageNotifications (ATransport t) = Resp "5a" _ OK <- signSendRecv rh rKey ("5a", rId, ACK mId2) (dec mId2 msg2, Right "hello again") #== "delivered from queue again" Resp "" _ (NMSG _ _) <- tGet1 nh2 - 1000 `timeout` tGet @ErrorType @BrokerMsg nh1 >>= \case + 1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg nh1 >>= \case Nothing -> pure () Just _ -> error "nothing else should be delivered to the 1st notifier's TCP connection" Resp "6" _ OK <- signSendRecv rh rKey ("6", rId, NDEL) Resp "7" _ OK <- signSendRecv sh sKey ("7", sId, _SEND' "hello there") Resp "" _ (Msg mId3 msg3) <- tGet1 rh (dec mId3 msg3, Right "hello there") #== "delivered from queue again" - 1000 `timeout` tGet @ErrorType @BrokerMsg nh2 >>= \case + 1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg nh2 >>= \case Nothing -> pure () Just _ -> error "nothing else should be delivered to the 2nd notifier's TCP connection" @@ -864,7 +865,7 @@ testMsgExpireOnSend t = testSMPClient @c $ \rh -> do Resp "3" _ (Msg mId msg) <- signSendRecv rh rKey ("3", rId, SUB) (dec mId msg, Right "hello (should NOT expire)") #== "delivered" - 1000 `timeout` tGet @ErrorType @BrokerMsg rh >>= \case + 1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg rh >>= \case Nothing -> return () Just _ -> error "nothing else should be delivered" @@ -884,7 +885,7 @@ testMsgExpireOnInterval t = signSendRecv rh rKey ("2", rId, SUB) >>= \case Resp "2" _ OK -> pure () r -> unexpected r - 1000 `timeout` tGet @ErrorType @BrokerMsg rh >>= \case + 1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg rh >>= \case Nothing -> return () Just _ -> error "nothing should be delivered" @@ -903,7 +904,7 @@ testMsgNOTExpireOnInterval t = testSMPClient @c $ \rh -> do Resp "2" _ (Msg mId msg) <- signSendRecv rh rKey ("2", rId, SUB) (dec mId msg, Right "hello (should NOT expire)") #== "delivered" - 1000 `timeout` tGet @ErrorType @BrokerMsg rh >>= \case + 1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg rh >>= \case Nothing -> return () Just _ -> error "nothing else should be delivered" diff --git a/tests/XFTPAgent.hs b/tests/XFTPAgent.hs index 46d3d4dd8..999746858 100644 --- a/tests/XFTPAgent.hs +++ b/tests/XFTPAgent.hs @@ -21,7 +21,8 @@ import Data.List (find, isSuffixOf) import Data.Maybe (fromJust) import SMPAgentClient (agentCfg, initAgentServers, testDB, testDB2, testDB3) import Simplex.FileTransfer.Description (FileDescription (..), FileDescriptionURI (..), ValidFileDescription, fileDescriptionURI, mb, qrSizeLimit, pattern ValidFileDescription) -import Simplex.FileTransfer.Protocol (FileParty (..), XFTPErrorType (AUTH)) +import Simplex.FileTransfer.Protocol (FileParty (..)) +import Simplex.FileTransfer.Transport (XFTPErrorType (AUTH)) import Simplex.FileTransfer.Server.Env (XFTPServerConfig (..)) import Simplex.Messaging.Agent (AgentClient, disconnectAgentClient, testProtocolServer, xftpDeleteRcvFile, xftpDeleteSndFileInternal, xftpDeleteSndFileRemote, xftpReceiveFile, xftpSendDescription, xftpSendFile, xftpStartWorkers) import Simplex.Messaging.Agent.Client (ProtocolTestFailure (..), ProtocolTestStep (..)) diff --git a/tests/XFTPServerTests.hs b/tests/XFTPServerTests.hs index a11ba515a..451406275 100644 --- a/tests/XFTPServerTests.hs +++ b/tests/XFTPServerTests.hs @@ -20,9 +20,9 @@ import Data.List (isInfixOf) import ServerTests (logSize) import Simplex.FileTransfer.Client import Simplex.FileTransfer.Description (kb) -import Simplex.FileTransfer.Protocol (FileInfo (..), XFTPErrorType (..)) +import Simplex.FileTransfer.Protocol (FileInfo (..)) import Simplex.FileTransfer.Server.Env (XFTPServerConfig (..)) -import Simplex.FileTransfer.Transport (XFTPRcvChunkSpec (..)) +import Simplex.FileTransfer.Transport (XFTPRcvChunkSpec (..), XFTPErrorType (..)) import Simplex.Messaging.Client (ProtocolClientError (..)) import qualified Simplex.Messaging.Crypto as C import qualified Simplex.Messaging.Crypto.Lazy as LC From 52a67daea67689f1a1dc530cc8f33529632e9780 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin Date: Tue, 5 Mar 2024 11:09:07 +0000 Subject: [PATCH 07/30] agent: pass PQ encryption flag separately for each message in batch APIs (#1027) --- src/Simplex/Messaging/Agent.hs | 54 +++++++++++++++++----------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index e8af60b96..74d4e1c40 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -295,14 +295,14 @@ resubscribeConnections c = withAgentEnv c . resubscribeConnections' c sendMessage :: AgentErrorMonad m => AgentClient -> ConnId -> CR.PQEncryption -> MsgFlags -> MsgBody -> m (AgentMsgId, CR.PQEncryption) sendMessage c = withAgentEnv c .:: sendMessage' c -type MsgReq = (ConnId, MsgFlags, MsgBody) +type MsgReq = (ConnId, CR.PQEncryption, MsgFlags, MsgBody) -- | Send multiple messages to different connections (SEND command) -sendMessages :: MonadUnliftIO m => AgentClient -> CR.PQEncryption -> [MsgReq] -> m [Either AgentErrorType (AgentMsgId, CR.PQEncryption)] -sendMessages c = withAgentEnv c .: sendMessages' c +sendMessages :: MonadUnliftIO m => AgentClient -> [MsgReq] -> m [Either AgentErrorType (AgentMsgId, CR.PQEncryption)] +sendMessages c = withAgentEnv c . sendMessages' c -sendMessagesB :: (MonadUnliftIO m, Traversable t) => AgentClient -> CR.PQEncryption -> t (Either AgentErrorType MsgReq) -> m (t (Either AgentErrorType (AgentMsgId, CR.PQEncryption))) -sendMessagesB c = withAgentEnv c .: sendMessagesB' c +sendMessagesB :: (MonadUnliftIO m, Traversable t) => AgentClient -> t (Either AgentErrorType MsgReq) -> m (t (Either AgentErrorType (AgentMsgId, CR.PQEncryption))) +sendMessagesB c = withAgentEnv c . sendMessagesB' c ackMessage :: AgentErrorMonad m => AgentClient -> ConnId -> AgentMsgId -> Maybe MsgReceiptInfo -> m () ackMessage c = withAgentEnv c .:. ackMessage' c @@ -911,29 +911,29 @@ getNotificationMessage' c nonce encNtfInfo = do -- | Send message to the connection (SEND command) in Reader monad sendMessage' :: forall m. AgentMonad m => AgentClient -> ConnId -> CR.PQEncryption -> MsgFlags -> MsgBody -> m (AgentMsgId, CR.PQEncryption) -sendMessage' c connId pqEnc msgFlags msg = liftEither . runIdentity =<< sendMessagesB' c pqEnc (Identity (Right (connId, msgFlags, msg))) +sendMessage' c connId pqEnc msgFlags msg = liftEither . runIdentity =<< sendMessagesB' c (Identity (Right (connId, pqEnc, msgFlags, msg))) -- | Send multiple messages to different connections (SEND command) in Reader monad -sendMessages' :: forall m. AgentMonad' m => AgentClient -> CR.PQEncryption -> [MsgReq] -> m [Either AgentErrorType (AgentMsgId, CR.PQEncryption)] -sendMessages' c pqEnc = sendMessagesB' c pqEnc . map Right +sendMessages' :: forall m. AgentMonad' m => AgentClient -> [MsgReq] -> m [Either AgentErrorType (AgentMsgId, CR.PQEncryption)] +sendMessages' c = sendMessagesB' c . map Right -sendMessagesB' :: forall m t. (AgentMonad' m, Traversable t) => AgentClient -> CR.PQEncryption -> t (Either AgentErrorType MsgReq) -> m (t (Either AgentErrorType (AgentMsgId, CR.PQEncryption))) -sendMessagesB' c pqEnc reqs = withConnLocks c connIds "sendMessages" $ do - reqs' <- withStoreBatch c (\db -> fmap (bindRight $ \req@(connId, _, _) -> bimap storeError (req,) <$> getConn db connId) reqs) +sendMessagesB' :: forall m t. (AgentMonad' m, Traversable t) => AgentClient -> t (Either AgentErrorType MsgReq) -> m (t (Either AgentErrorType (AgentMsgId, CR.PQEncryption))) +sendMessagesB' c reqs = withConnLocks c connIds "sendMessages" $ do + reqs' <- withStoreBatch c (\db -> fmap (bindRight $ \req@(connId, _, _, _) -> bimap storeError (req,) <$> getConn db connId) reqs) let reqs'' = fmap (>>= prepareConn) reqs' - enqueueMessagesB c (Just pqEnc) reqs'' + enqueueMessagesB c reqs'' where - prepareConn :: (MsgReq, SomeConn) -> Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage) - prepareConn ((_, msgFlags, msg), SomeConn _ conn) = case conn of + prepareConn :: (MsgReq, SomeConn) -> Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe CR.PQEncryption, MsgFlags, AMessage) + prepareConn ((_, pqEnc, msgFlags, msg), SomeConn _ conn) = case conn of DuplexConnection cData _ sqs -> prepareMsg cData sqs SndConnection cData sq -> prepareMsg cData [sq] _ -> Left $ CONN SIMPLEX where - prepareMsg :: ConnData -> NonEmpty SndQueue -> Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage) + prepareMsg :: ConnData -> NonEmpty SndQueue -> Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe CR.PQEncryption, MsgFlags, AMessage) prepareMsg cData sqs | ratchetSyncSendProhibited cData = Left $ CMD PROHIBITED - | otherwise = Right (cData, sqs, msgFlags, A_MSG msg) - connIds = map (\(connId, _, _) -> connId) $ rights $ toList reqs + | otherwise = Right (cData, sqs, Just pqEnc, msgFlags, A_MSG msg) + connIds = map (\(connId, _, _, _) -> connId) $ rights $ toList reqs -- / async command processing v v v @@ -1089,11 +1089,11 @@ enqueueMessages c cData sqs pqEnc_ msgFlags aMessage = do enqueueMessages' :: AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> Maybe CR.PQEncryption -> MsgFlags -> AMessage -> m (AgentMsgId, CR.PQEncryption) enqueueMessages' c cData sqs pqEnc_ msgFlags aMessage = - liftEither . runIdentity =<< enqueueMessagesB c pqEnc_ (Identity (Right (cData, sqs, msgFlags, aMessage))) + liftEither . runIdentity =<< enqueueMessagesB c (Identity (Right (cData, sqs, pqEnc_, msgFlags, aMessage))) -enqueueMessagesB :: (AgentMonad' m, Traversable t) => AgentClient -> Maybe CR.PQEncryption -> t (Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage)) -> m (t (Either AgentErrorType (AgentMsgId, CR.PQEncryption))) -enqueueMessagesB c pqEnc_ reqs = do - reqs' <- enqueueMessageB c pqEnc_ reqs +enqueueMessagesB :: (AgentMonad' m, Traversable t) => AgentClient -> t (Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe CR.PQEncryption, MsgFlags, AMessage)) -> m (t (Either AgentErrorType (AgentMsgId, CR.PQEncryption))) +enqueueMessagesB c reqs = do + reqs' <- enqueueMessageB c reqs enqueueSavedMessageB c $ mapMaybe snd $ rights $ toList reqs' pure $ fst <$$> reqs' @@ -1102,20 +1102,20 @@ isActiveSndQ SndQueue {status} = status == Secured || status == Active enqueueMessage :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> Maybe CR.PQEncryption -> MsgFlags -> AMessage -> m (AgentMsgId, CR.PQEncryption) enqueueMessage c cData sq pqEnc_ msgFlags aMessage = - liftEither . fmap fst . runIdentity =<< enqueueMessageB c pqEnc_ (Identity (Right (cData, [sq], msgFlags, aMessage))) + liftEither . fmap fst . runIdentity =<< enqueueMessageB c (Identity (Right (cData, [sq], pqEnc_, msgFlags, aMessage))) -- this function is used only for sending messages in batch, it returns the list of successes to enqueue additional deliveries -enqueueMessageB :: forall m t. (AgentMonad' m, Traversable t) => AgentClient -> Maybe CR.PQEncryption -> t (Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage)) -> m (t (Either AgentErrorType ((AgentMsgId, CR.PQEncryption), Maybe (ConnData, [SndQueue], AgentMsgId)))) -enqueueMessageB c pqEnc_ reqs = do +enqueueMessageB :: forall m t. (AgentMonad' m, Traversable t) => AgentClient -> t (Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe CR.PQEncryption, MsgFlags, AMessage)) -> m (t (Either AgentErrorType ((AgentMsgId, CR.PQEncryption), Maybe (ConnData, [SndQueue], AgentMsgId)))) +enqueueMessageB c reqs = do aVRange <- asks $ maxVersion . smpAgentVRange . config reqMids <- withStoreBatch c $ \db -> fmap (bindRight $ storeSentMsg db aVRange) reqs - forME reqMids $ \((cData, sq :| sqs, _, _), InternalId msgId, pqSecr) -> do + forME reqMids $ \((cData, sq :| sqs, _, _, _), InternalId msgId, pqSecr) -> do submitPendingMsg c cData sq let sqs' = filter isActiveSndQ sqs pure $ Right ((msgId, pqSecr), if null sqs' then Nothing else Just (cData, sqs', msgId)) where - storeSentMsg :: DB.Connection -> VersionSMPA -> (ConnData, NonEmpty SndQueue, MsgFlags, AMessage) -> IO (Either AgentErrorType ((ConnData, NonEmpty SndQueue, MsgFlags, AMessage), InternalId, CR.PQEncryption)) - storeSentMsg db agentVersion req@(ConnData {connId}, sq :| _, msgFlags, aMessage) = fmap (first storeError) $ runExceptT $ do + storeSentMsg :: DB.Connection -> VersionSMPA -> (ConnData, NonEmpty SndQueue, Maybe CR.PQEncryption, MsgFlags, AMessage) -> IO (Either AgentErrorType ((ConnData, NonEmpty SndQueue, Maybe CR.PQEncryption, MsgFlags, AMessage), InternalId, CR.PQEncryption)) + storeSentMsg db agentVersion req@(ConnData {connId}, sq :| _, pqEnc_, msgFlags, aMessage) = fmap (first storeError) $ runExceptT $ do internalTs <- liftIO getCurrentTime (internalId, internalSndId, prevMsgHash) <- liftIO $ updateSndIds db connId let privHeader = APrivHeader (unSndId internalSndId) prevMsgHash From b050cf502794d011e52f8e37972f134294f3a5fc Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin Date: Tue, 5 Mar 2024 17:07:15 +0000 Subject: [PATCH 08/30] double ratchet versioning for post-quantum encryption (#1025) * correctly parse new Ratchet fields when omitted * rfc: migrating connection versions to pqdr * update rfc * WIP (dont commit) * rename versions * update ratchet version based on PQ encryption feature flag * remove duplicate function * synchronize ratchet, fix tests, refactor * comments * test * pattern --- rfcs/2024-03-03-pqdr-version.md | 92 ++++++++ src/Simplex/Messaging/Agent.hs | 108 +++++---- src/Simplex/Messaging/Agent/Env/SQLite.hs | 9 +- src/Simplex/Messaging/Agent/Protocol.hs | 42 +++- src/Simplex/Messaging/Agent/Store/SQLite.hs | 11 +- src/Simplex/Messaging/Crypto/Ratchet.hs | 238 ++++++++++++-------- tests/AgentTests.hs | 38 ++-- tests/AgentTests/ConnectionRequestTests.hs | 8 +- tests/AgentTests/DoubleRatchetTests.hs | 76 ++++--- tests/AgentTests/FunctionalAPITests.hs | 228 ++++++++++--------- 10 files changed, 514 insertions(+), 336 deletions(-) create mode 100644 rfcs/2024-03-03-pqdr-version.md diff --git a/rfcs/2024-03-03-pqdr-version.md b/rfcs/2024-03-03-pqdr-version.md new file mode 100644 index 000000000..5db9f23a5 --- /dev/null +++ b/rfcs/2024-03-03-pqdr-version.md @@ -0,0 +1,92 @@ +# Migrating existing connections to post-quantum double ratchet algorithm + +## Problem + +Post-quantum variant of double ratchet algorithm represents an almost full-stack change affecting all parts of the protocol stack except client-server protocol (SMP): +- double-ratchet end-to-end encryption: different encoding (additional large keys require byte-strings larger than 255 bytes with 2-byte length prefixes) and larger message headers (increased by ~2200 bytes). +- agent-agent protocol: a smaller maximum message size to accomodate larger headers and to fit in 16kb blocks, reduced by ~2200 bytes for the messages and by almost ~4000 bytes for connection information. +- chat protocol: also a smaller message size compensated by zstd comression of JSON messages. + +We want the versioning that achieves these objectives: +- all changes in all protocol layers happen at the same time, when both clients support it. +- ability to downgrade the clients to the previous version without losing connection. +- ability to opt-in into this functionality via "experimental" feature toggle, that enables post-quantum encryption in connections when both contacts enable this toggle. + +To have ability to downgrade the clients we have two options: +- roll-out this functionality in two stages: 1) roll-out clients support but do not enable the new version, and then 2) upgrade client version. The problem here is that the clients won't be able to opt-in into this experiment. +- make offered range dependent on experimental feature being enabled. Currently we have an option to enable PQ encryption in agent API, and this option can be used as a proxy to maxium supported protocol version - if the option is passed, it can be seen as an indication that higher version range (or version) should offered (or accepted). + +## Solution + +Currently ratchet state stores version range. It's unclear what was the intended semantics of that version range - it simply stores the offered/supported version range at the time ratchet was initialised, but only a high bound is used to send in message headers, and it is never upgraded. In JSON this range is encoded as tuple (an array of two elements in JSON). + +We could continue using this range with the meaning of the lower bound to be "currently used ratchet version" and the meaning of higher boundary to be "maximum supported ratchet version". We could also use the version communicated in message headers to upgrade ratchet version, with the condition that upgrade should only happen if both sides want it. Currently it's defined by pqEnableKEM property in ratchet state. We could also make it more explicit by defining maximum version to which ratchet should upgrade. Given that irreversible upgrades are not very common, it is probably ok to keep it implicit. + +We can define a better type than VersionRange to reflect semantics of the range in ratchet (current/max supported range), but for backward compatibility it needs to be encoded in the same way as now. + +To summarize, the proposed solution for ratchet versioning is: +- define ratchet versions as new type to include current and maximum allowed versions, where maximum allowed will be either the same or lower than maximum supported based on PQ option (in 5.6), and in 5.7 it will be changed to maximum supported, so version starts upgrading independently from PQ being enabled. +- make encodings in ratchet depend on current version (in curent code it depends on max version). +- include max allowed in message header. +- upgrade current if in range on each new message if less than max and higher than current (same as we do for connections). +- increase max allowed once PQ is enabled (only in 5.6). Make max allowed the same as max supported (global constant). + +```haskell +data RatchetVR = RatchetVR + { currentVersion :: Version, + maxAllowedVersion :: Version + } + +instance ToJSON RatchetVR where + toEncoding (RatchetVR v1 v2) = toEncoding (v1, v2) + toJSON (RatchetVR v1 v2) = toJSON (v1, v2) + +instance FromJSON RatchetVR where + parseJSON v = do + -- this also verifies that v2 > v1 (although we could remove JSON instances for VersionRange) + VersionRange v1 v2 <- parseJSON v + pure $ RatchetVR v1 v2 +``` + +For connections, we could also make version used for the purposes of encoding dependent on the PQ being enabled, and version for decoding taken from message header, but then we'd have to not only upgrade ratchets but the connection as well every time PQ mode changes. + +Another suggestion to ensure that correct version range is used in correct contexts could be: +- using different newtypes for different version ranges. +- define generic type class for version aware encoding that would also accept only specific type class for the version to use the correct range. This may be justified as there will be several version-aware encodings, and not just the protocol as now. + +```haskell +class Ord v => EncodingV v a where + {-# MINIMAL smpEncodeV, (smpDecodeV | smpVP) #-} + smpEncodeV :: v -> a -> ByteString + -- default decode uses parser + smpDecodeV :: v -> ByteString -> Either String a + smpDecodeV = parseAll . smpVP + -- default parser decodes from length-specified bytestring + smpVP :: v -> Parser a + smpVP v = smpDecodeV v <$?> smpP +``` + +The version will be passed from currently agreed version, it may only change when message is received, not when message is sent. The version will not be extracted from the encoding itself as it happens now in ratchet encodings. + +## Various options how the problem can be simplified + +1. Do not support connection downgrade once both devices upgraded. If applied to all existing connections then it is a bad option, as it would disrupt some important conversations. + +2. Do not provide ability to opt-in into PQ encryption until v5.7 where it will be rolled out automatically. That is also suboptimal, as it won't allow announcing technology design and have testing outside of the team devices. + +3. The logic explained above where connection upgrade and downgrade is possible and applied to all existing connections if both parties consent to it. There are these important downsides: + - complexity of this logic + - regression risks when this logic is removed. + - some non-coordinated upgrades of existing, potentially important conversations, simply because two users opt-in into the experiment without any expectation that another side also opts-in. + +4. Apply upgrade/downgrade logic and enable PQ encryption as opt-in, based on the toggle in the UX, only for the new connections. This seems the least risky, and also simpler than option 3, as it would only apply to the new connections, and both users will have to enable experimental toggle prior to connecting. + +Option 4 seems the best trade-off, and has these sub-options regarding where it is controlled: +a) in chat based on connection flag. Chat will pass PQ options only to connections that were created when experimental option was enabled. +b) in agent - there will be additional logic to ignore PQ option for existing connections. +c) both in chat and in agent. + +Option 4a seems better, as it would: +- simplify agent code +- minimise required changes when releasing v5.7 (as we do want that all direct and small groups connections migrate to PQ encryption at the time, without any toggles) +- allow tests for connection upgrade in the currect code. diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 74d4e1c40..6f3099ad6 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -558,14 +558,14 @@ newConnAsync c userId corrId enableNtfs cMode pqInitKeys subMode = do newConnNoQueues :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> CR.PQEncryption -> m ConnId newConnNoQueues c userId connId enableNtfs cMode pqEncryption = do g <- asks random - connAgentVersion <- asks $ maxVersion . smpAgentVRange . config + connAgentVersion <- asks $ maxVersion . ($ pqEncryption) . smpAgentVRange . config let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk, pqEncryption} withStore c $ \db -> createNewConn db g cData cMode joinConnAsync :: AgentMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> CR.PQEncryption -> SubscriptionMode -> m ConnId joinConnAsync c userId corrId enableNtfs cReqUri@(CRInvitationUri ConnReqUriData {crAgentVRange} _) cInfo pqEncryption subMode = do withInvLock c (strEncode cReqUri) "joinConnAsync" $ do - aVRange <- asks $ smpAgentVRange . config + aVRange <- asks $ ($ pqEncryption) . smpAgentVRange . config case crAgentVRange `compatibleVersion` aVRange of Just (Compatible connAgentVersion) -> do g <- asks random @@ -667,14 +667,16 @@ newRcvConnSrv c userId connId enableNtfs cMode clientData pqInitKeys subMode srv when enableNtfs $ do ns <- asks ntfSupervisor atomically $ sendNtfSubCommand ns (connId, NSCCreate) - let crData = ConnReqUriData SSSimplex smpAgentVRange [qUri] clientData + let pqEnc = CR.connPQEncryption pqInitKeys + crData = ConnReqUriData SSSimplex (smpAgentVRange pqEnc) [qUri] clientData + e2eVRange = e2eEncryptVRange pqEnc case cMode of SCMContact -> pure (connId, CRContactUri crData) SCMInvitation -> do g <- asks random - (pk1, pk2, pKem, e2eRcvParams) <- liftIO $ CR.generateRcvE2EParams g (maxVersion e2eEncryptVRange) (CR.initialPQEncryption pqInitKeys) + (pk1, pk2, pKem, e2eRcvParams) <- liftIO $ CR.generateRcvE2EParams g (maxVersion e2eVRange) (CR.initialPQEncryption pqInitKeys) withStore' c $ \db -> createRatchetX3dhKeys db connId pk1 pk2 pKem - pure (connId, CRInvitationUri crData $ toVersionRangeT e2eRcvParams e2eEncryptVRange) + pure (connId, CRInvitationUri crData $ toVersionRangeT e2eRcvParams e2eVRange) joinConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> CR.PQEncryption -> SubscriptionMode -> m ConnId joinConn c userId connId enableNtfs cReq cInfo pqEnc subMode = do @@ -687,17 +689,18 @@ joinConn c userId connId enableNtfs cReq cInfo pqEnc subMode = do startJoinInvitation :: AgentMonad m => UserId -> ConnId -> Bool -> ConnectionRequestUri 'CMInvitation -> CR.PQEncryption -> m (Compatible VersionSMPA, ConnData, NewSndQueue, CR.Ratchet 'C.X448, CR.SndE2ERatchetParams 'C.X448) startJoinInvitation userId connId enableNtfs (CRInvitationUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)} e2eRcvParamsUri) pqEncryption = do AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config + let e2eVRange = e2eEncryptVRange pqEncryption case ( qUri `compatibleVersion` smpClientVRange, - e2eRcvParamsUri `compatibleVersion` e2eEncryptVRange, - crAgentVRange `compatibleVersion` smpAgentVRange + e2eRcvParamsUri `compatibleVersion` e2eVRange, + crAgentVRange `compatibleVersion` smpAgentVRange pqEncryption ) of (Just qInfo, Just (Compatible e2eRcvParams@(CR.E2ERatchetParams v _ rcDHRr kem_)), Just aVersion@(Compatible connAgentVersion)) -> do g <- asks random (pk1, pk2, pKem, e2eSndParams) <- liftIO $ CR.generateSndE2EParams g v (CR.replyKEM_ pqEncryption kem_) (_, rcDHRs) <- atomically $ C.generateKeyPair g - -- TODO PQ generate KEM keypair if needed - is it done? rcParams <- liftEitherWith cryptoError $ CR.pqX3dhSnd pk1 pk2 pKem e2eRcvParams - let rc = CR.initSndRatchet e2eEncryptVRange rcDHRr rcDHRs rcParams + let rcVs = CR.RVersions {current = v, maxSupported = maxVersion e2eVRange} + rc = CR.initSndRatchet rcVs rcDHRr rcDHRs rcParams q <- newSndQueue userId "" qInfo let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk, pqEncryption} pure (aVersion, cData, q, rc, e2eSndParams) @@ -720,7 +723,7 @@ joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqEnc subMod void $ withStore' c $ \db -> deleteConn db Nothing connId' throwError e joinConnSrv c userId connId enableNtfs (CRContactUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)}) cInfo pqEnc subMode srv = do - aVRange <- asks $ smpAgentVRange . config + aVRange <- asks $ ($ pqEnc) . smpAgentVRange . config clientVRange <- asks $ smpClientVRange . config case ( qUri `compatibleVersion` clientVRange, crAgentVRange `compatibleVersion` aVRange @@ -1107,23 +1110,24 @@ enqueueMessage c cData sq pqEnc_ msgFlags aMessage = -- this function is used only for sending messages in batch, it returns the list of successes to enqueue additional deliveries enqueueMessageB :: forall m t. (AgentMonad' m, Traversable t) => AgentClient -> t (Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe CR.PQEncryption, MsgFlags, AMessage)) -> m (t (Either AgentErrorType ((AgentMsgId, CR.PQEncryption), Maybe (ConnData, [SndQueue], AgentMsgId)))) enqueueMessageB c reqs = do - aVRange <- asks $ maxVersion . smpAgentVRange . config - reqMids <- withStoreBatch c $ \db -> fmap (bindRight $ storeSentMsg db aVRange) reqs + getAVRange <- asks $ smpAgentVRange . config + reqMids <- withStoreBatch c $ \db -> fmap (bindRight $ storeSentMsg db getAVRange) reqs forME reqMids $ \((cData, sq :| sqs, _, _, _), InternalId msgId, pqSecr) -> do submitPendingMsg c cData sq let sqs' = filter isActiveSndQ sqs pure $ Right ((msgId, pqSecr), if null sqs' then Nothing else Just (cData, sqs', msgId)) where - storeSentMsg :: DB.Connection -> VersionSMPA -> (ConnData, NonEmpty SndQueue, Maybe CR.PQEncryption, MsgFlags, AMessage) -> IO (Either AgentErrorType ((ConnData, NonEmpty SndQueue, Maybe CR.PQEncryption, MsgFlags, AMessage), InternalId, CR.PQEncryption)) - storeSentMsg db agentVersion req@(ConnData {connId}, sq :| _, pqEnc_, msgFlags, aMessage) = fmap (first storeError) $ runExceptT $ do + storeSentMsg :: DB.Connection -> (CR.PQEncryption -> VersionRangeSMPA) -> (ConnData, NonEmpty SndQueue, Maybe CR.PQEncryption, MsgFlags, AMessage) -> IO (Either AgentErrorType ((ConnData, NonEmpty SndQueue, Maybe CR.PQEncryption, MsgFlags, AMessage), InternalId, CR.PQEncryption)) + storeSentMsg db getAVRange req@(ConnData {connId, connAgentVersion = v}, sq :| _, pqEnc_, msgFlags, aMessage) = fmap (first storeError) $ runExceptT $ do internalTs <- liftIO getCurrentTime (internalId, internalSndId, prevMsgHash) <- liftIO $ updateSndIds db connId let privHeader = APrivHeader (unSndId internalSndId) prevMsgHash agentMsg = AgentMessage privHeader aMessage agentMsgStr = smpEncode agentMsg internalHash = C.sha256Hash agentMsgStr - (encAgentMessage, pqEncryption) <- agentRatchetEncrypt db connId agentMsgStr e2eEncUserMsgLength pqEnc_ - let msgBody = smpEncode $ AgentMsgEnvelope {agentVersion, encAgentMessage} + (encAgentMessage, pqEncryption) <- agentRatchetEncrypt db connId agentMsgStr (e2eEncUserMsgLength v) pqEnc_ + let agentVersion = maxVersion . getAVRange $ fromMaybe CR.PQEncOff pqEnc_ + msgBody = smpEncode $ AgentMsgEnvelope {agentVersion, encAgentMessage} msgType = agentMessageType agentMsg msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgFlags, msgBody, pqEncryption, internalHash, prevMsgHash} liftIO $ createSndMsg db connId msgData @@ -1402,18 +1406,19 @@ abortConnectionSwitch' c connId = synchronizeRatchet' :: AgentMonad m => AgentClient -> ConnId -> CR.PQEncryption -> Bool -> m ConnectionStats synchronizeRatchet' c connId pqEnc force = withConnLock c connId "synchronizeRatchet" $ do withStore c (`getConn` connId) >>= \case - SomeConn _ (DuplexConnection cData rqs sqs) + SomeConn _ (DuplexConnection cData@ConnData {pqEncryption} rqs sqs) | ratchetSyncAllowed cData || force -> do -- check queues are not switching? + cData' <- if pqEncryption == pqEnc then pure cData else withStore' c $ \db -> setConnPQEncryption db cData pqEnc AgentConfig {e2eEncryptVRange} <- asks config g <- asks random - (pk1, pk2, pKem, e2eParams) <- liftIO $ CR.generateRcvE2EParams g (maxVersion e2eEncryptVRange) pqEnc - enqueueRatchetKeyMsgs c cData sqs e2eParams + (pk1, pk2, pKem, e2eParams) <- liftIO $ CR.generateRcvE2EParams g (maxVersion $ e2eEncryptVRange pqEnc) pqEnc + enqueueRatchetKeyMsgs c cData' sqs e2eParams withStore' c $ \db -> do setConnRatchetSync db connId RSStarted setRatchetX3dhKeys db connId pk1 pk2 pKem - let cData' = cData {ratchetSyncState = RSStarted} :: ConnData - conn' = DuplexConnection cData' rqs sqs + let cData'' = cData' {ratchetSyncState = RSStarted} :: ConnData + conn' = DuplexConnection cData'' rqs sqs pure $ connectionStats conn' | otherwise -> throwError $ CMD PROHIBITED _ -> throwError $ CMD PROHIBITED @@ -2064,8 +2069,8 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, _ -> prohibited >> ack _ -> prohibited >> ack updateConnVersion :: Connection c -> ConnData -> VersionSMPA -> m (Connection c) - updateConnVersion conn' cData' msgAgentVersion = do - aVRange <- asks $ smpAgentVRange . config + updateConnVersion conn' cData'@ConnData {pqEncryption} msgAgentVersion = do + aVRange <- asks $ ($ pqEncryption) . smpAgentVRange . config let msgAVRange = fromMaybe (versionToRange msgAgentVersion) $ safeVersionRange (minVersion aVRange) msgAgentVersion case msgAVRange `compatibleVersion` aVRange of Just (Compatible av) @@ -2126,21 +2131,27 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, parseMessage :: Encoding a => ByteString -> m a parseMessage = liftEither . parse smpP (AGENT A_MESSAGE) + -- TODO PQ make sure pqEncryption in conn' is set correctly smpConfirmation :: SMP.MsgId -> Connection c -> C.APublicAuthKey -> C.PublicKeyX25519 -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> ByteString -> VersionSMPC -> VersionSMPA -> m () smpConfirmation srvMsgId conn' senderKey e2ePubKey e2eEncryption encConnInfo smpClientVersion agentVersion = do logServer "<--" c srv rId $ "MSG :" <> logSecret srvMsgId AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config + let ConnData {pqEncryption} = toConnData conn' + aVRange = smpAgentVRange pqEncryption + e2eVRange = e2eEncryptVRange pqEncryption unless - (agentVersion `isCompatible` smpAgentVRange && smpClientVersion `isCompatible` smpClientVRange) + (agentVersion `isCompatible` aVRange && smpClientVersion `isCompatible` smpClientVRange) (throwError $ AGENT A_VERSION) case status of New -> case (conn', e2eEncryption) of -- party initiating connection - (RcvConnection ConnData {pqEncryption} _, Just (CR.AE2ERatchetParams _ e2eSndParams@(CR.E2ERatchetParams e2eVersion _ _ _))) -> do - unless (e2eVersion `isCompatible` e2eEncryptVRange) (throwError $ AGENT A_VERSION) + (RcvConnection _ _, Just (CR.AE2ERatchetParams _ e2eSndParams@(CR.E2ERatchetParams e2eVersion _ _ _))) -> do + unless (e2eVersion `isCompatible` e2eVRange) (throwError $ AGENT A_VERSION) (pk1, rcDHRs, pKem) <- withStore c (`getRatchetX3dhKeys` connId) rcParams <- liftError cryptoError $ CR.pqX3dhRcv pk1 rcDHRs pKem e2eSndParams - let rc = CR.initRcvRatchet e2eEncryptVRange rcDHRs rcParams pqEncryption + -- TODO PQ combine isCompatible check and construction in one call + let rcVs = CR.RVersions {current = e2eVersion, maxSupported = maxVersion e2eVRange} + rc = CR.initRcvRatchet rcVs rcDHRs rcParams pqEncryption g <- asks random (agentMsgBody_, rc', skipped) <- liftError cryptoError $ CR.rcDecrypt g rc M.empty encConnInfo case (agentMsgBody_, skipped) of @@ -2153,7 +2164,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, processConf connInfo senderConf = do let newConfirmation = NewConfirmation {connId, senderConf, ratchetState = rc'} confId <- withStore c $ \db -> do - setConnectionVersion db connId agentVersion + setConnAgentVersion db connId agentVersion createConfirmation db g newConfirmation let srvs = map qServer $ smpReplyQueues senderConf notify $ CONF confId srvs connInfo @@ -2182,8 +2193,6 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, -- `sndStatus == Active` when HELLO was previously sent, and this is the reply HELLO -- this branch is executed by the accepting party in duplexHandshake mode (v2) -- (was executed by initiating party in v1 that is no longer supported) - -- - -- TODO PQ encryption mode | sndStatus == Active -> notify $ CON pqEncryption | otherwise -> enqueueDuplexHello sq _ -> pure () @@ -2328,13 +2337,17 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, DuplexConnection {} -> action conn' _ -> qError $ name <> ": message must be sent to duplex connection" + -- TODO PQ make sure pqEncryption is set correctly here newRatchetKey :: CR.RcvE2ERatchetParams 'C.X448 -> Connection 'CDuplex -> m () - newRatchetKey e2eOtherPartyParams@(CR.E2ERatchetParams e2eVersion k1Rcv k2Rcv kem_) conn'@(DuplexConnection cData'@ConnData {lastExternalSndId} _ sqs) = + newRatchetKey e2eOtherPartyParams@(CR.E2ERatchetParams e2eVersion k1Rcv k2Rcv _) conn'@(DuplexConnection cData'@ConnData {lastExternalSndId, pqEncryption} _ sqs) = unlessM ratchetExists $ do AgentConfig {e2eEncryptVRange} <- asks config - unless (e2eVersion `isCompatible` e2eEncryptVRange) (throwError $ AGENT A_VERSION) + let connE2EVRange = e2eEncryptVRange pqEncryption + unless (e2eVersion `isCompatible` connE2EVRange) (throwError $ AGENT A_VERSION) keys <- getSendRatchetKeys - initRatchet e2eEncryptVRange keys + -- TODO PQ combine with `isCompatible` check above + let rcVs = CR.RVersions {current = e2eVersion, maxSupported = maxVersion connE2EVRange} + initRatchet rcVs keys notifyAgreed where rkHashRcv = rkHash k1Rcv k2Rcv @@ -2360,8 +2373,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, where sendReplyKey = do g <- asks random - -- TODO PQ the decision to use KEM should depend on connection - (pk1, pk2, pKem, e2eParams) <- liftIO $ CR.generateRcvE2EParams g e2eVersion CR.PQEncOn + (pk1, pk2, pKem, e2eParams) <- liftIO $ CR.generateRcvE2EParams g e2eVersion pqEncryption enqueueRatchetKeyMsgs c cData' sqs e2eParams pure (pk1, pk2, pKem) notifyRatchetSyncError = do @@ -2380,16 +2392,15 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, createRatchet db connId rc -- compare public keys `k1` in AgentRatchetKey messages sent by self and other party -- to determine ratchet initilization ordering - initRatchet :: CR.VersionRangeE2E -> (C.PrivateKeyX448, C.PrivateKeyX448, Maybe CR.RcvPrivRKEMParams) -> m () - initRatchet e2eEncryptVRange (pk1, pk2, pKem) + initRatchet :: CR.RatchetVersions -> (C.PrivateKeyX448, C.PrivateKeyX448, Maybe CR.RcvPrivRKEMParams) -> m () + initRatchet rcVs (pk1, pk2, pKem) | rkHash (C.publicKey pk1) (C.publicKey pk2) <= rkHashRcv = do rcParams <- liftError cryptoError $ CR.pqX3dhRcv pk1 pk2 pKem e2eOtherPartyParams - -- TODO PQ the decision to use KEM should either depend on the global setting or on whether it was enabled in connection before - recreateRatchet $ CR.initRcvRatchet e2eEncryptVRange pk2 rcParams $ CR.PQEncryption (isJust kem_) + recreateRatchet $ CR.initRcvRatchet rcVs pk2 rcParams pqEncryption | otherwise = do (_, rcDHRs) <- atomically . C.generateKeyPair =<< asks random rcParams <- liftEitherWith cryptoError $ CR.pqX3dhSnd pk1 pk2 (CR.APRKP CR.SRKSProposed <$> pKem) e2eOtherPartyParams - recreateRatchet $ CR.initSndRatchet e2eEncryptVRange k2Rcv rcDHRs rcParams + recreateRatchet $ CR.initSndRatchet rcVs k2Rcv rcDHRs rcParams void . enqueueMessages' c cData' sqs Nothing SMP.MsgFlags {notification = True} $ EREADY lastExternalSndId checkMsgIntegrity :: PrevExternalSndId -> ExternalSndId -> PrevRcvMsgHash -> ByteString -> MsgIntegrity @@ -2432,7 +2443,7 @@ confirmQueueAsync c cData sq srv connInfo e2eEncryption_ pqEnc subMode = do submitPendingMsg c cData sq confirmQueue :: forall m. AgentMonad m => Compatible VersionSMPA -> AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> Maybe CR.PQEncryption -> SubscriptionMode -> m () -confirmQueue (Compatible agentVersion) c cData@ConnData {connId} sq srv connInfo e2eEncryption_ pqEnc_ subMode = do +confirmQueue (Compatible agentVersion) c cData@ConnData {connId, connAgentVersion = v} sq srv connInfo e2eEncryption_ pqEnc_ subMode = do msg <- mkConfirmation =<< mkAgentConfirmation c cData sq srv connInfo subMode sendConfirmation c sq msg withStore' c $ \db -> setSndQueueStatus db sq Confirmed @@ -2440,7 +2451,7 @@ confirmQueue (Compatible agentVersion) c cData@ConnData {connId} sq srv connInfo mkConfirmation :: AgentMessage -> m MsgBody mkConfirmation aMessage = withStore c $ \db -> runExceptT $ do void . liftIO $ updateSndIds db connId - (encConnInfo, _) <- agentRatchetEncrypt db connId (smpEncode aMessage) e2eEncConnInfoLength pqEnc_ + (encConnInfo, _) <- agentRatchetEncrypt db connId (smpEncode aMessage) (e2eEncConnInfoLength v) pqEnc_ pure . smpEncode $ AgentConfirmation {agentVersion, e2eEncryption_, encConnInfo} mkAgentConfirmation :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> SubscriptionMode -> m AgentMessage @@ -2454,13 +2465,13 @@ enqueueConfirmation c cData sq connInfo e2eEncryption_ pqEnc_ = do submitPendingMsg c cData sq storeConfirmation :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> Maybe CR.PQEncryption -> AgentMessage -> m () -storeConfirmation c ConnData {connId, connAgentVersion} sq e2eEncryption_ pqEnc_ agentMsg = withStore c $ \db -> runExceptT $ do +storeConfirmation c ConnData {connId, connAgentVersion = v} sq e2eEncryption_ pqEnc_ agentMsg = withStore c $ \db -> runExceptT $ do internalTs <- liftIO getCurrentTime (internalId, internalSndId, prevMsgHash) <- liftIO $ updateSndIds db connId let agentMsgStr = smpEncode agentMsg internalHash = C.sha256Hash agentMsgStr - (encConnInfo, pqEncryption) <- agentRatchetEncrypt db connId agentMsgStr e2eEncConnInfoLength pqEnc_ - let msgBody = smpEncode $ AgentConfirmation {agentVersion = connAgentVersion, e2eEncryption_, encConnInfo} + (encConnInfo, pqEncryption) <- agentRatchetEncrypt db connId agentMsgStr (e2eEncConnInfoLength v) pqEnc_ + let msgBody = smpEncode $ AgentConfirmation {agentVersion = v, e2eEncryption_, encConnInfo} msgType = agentMessageType agentMsg msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgBody, pqEncryption, msgFlags = SMP.MsgFlags {notification = True}, internalHash, prevMsgHash} liftIO $ createSndMsg db connId msgData @@ -2472,8 +2483,8 @@ enqueueRatchetKeyMsgs c cData (sq :| sqs) e2eEncryption = do mapM_ (enqueueSavedMessage c cData msgId) $ filter isActiveSndQ sqs enqueueRatchetKey :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> CR.RcvE2ERatchetParams 'C.X448 -> m AgentMsgId -enqueueRatchetKey c cData@ConnData {connId} sq e2eEncryption = do - aVRange <- asks $ smpAgentVRange . config +enqueueRatchetKey c cData@ConnData {connId, pqEncryption} sq e2eEncryption = do + aVRange <- asks $ ($ pqEncryption) . smpAgentVRange . config msgId <- storeRatchetKey $ maxVersion aVRange submitPendingMsg c cData sq pure $ unId msgId @@ -2487,8 +2498,7 @@ enqueueRatchetKey c cData@ConnData {connId} sq e2eEncryption = do internalHash = C.sha256Hash agentMsgStr let msgBody = smpEncode $ AgentRatchetKey {agentVersion, e2eEncryption, info = agentMsgStr} msgType = agentMessageType agentMsg - -- TODO PQ set pqEncryption based on connection mode - msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgBody, pqEncryption = CR.PQEncOff, msgFlags = SMP.MsgFlags {notification = True}, internalHash, prevMsgHash} + msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgBody, pqEncryption, msgFlags = SMP.MsgFlags {notification = True}, internalHash, prevMsgHash} liftIO $ createSndMsg db connId msgData liftIO $ createSndMsgDelivery db connId sq internalId pure internalId diff --git a/src/Simplex/Messaging/Agent/Env/SQLite.hs b/src/Simplex/Messaging/Agent/Env/SQLite.hs index 7a879bb22..a292e5db6 100644 --- a/src/Simplex/Messaging/Agent/Env/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Env/SQLite.hs @@ -56,15 +56,14 @@ import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations import Simplex.Messaging.Client import Simplex.Messaging.Client.Agent () import qualified Simplex.Messaging.Crypto as C -import Simplex.Messaging.Crypto.Ratchet (VersionRangeE2E, supportedE2EEncryptVRange) +import Simplex.Messaging.Crypto.Ratchet (PQEncryption, VersionRangeE2E, supportedE2EEncryptVRange) import Simplex.Messaging.Notifications.Client (defaultNTFClientConfig) import Simplex.Messaging.Notifications.Transport (NTFVersion) import Simplex.Messaging.Notifications.Types import Simplex.Messaging.Protocol (NtfServer, VersionRangeSMPC, XFTPServer, XFTPServerWithAuth, supportedSMPClientVRange) -import Simplex.Messaging.Transport (SMPVersion) import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Transport (TLS, Transport (..)) +import Simplex.Messaging.Transport (SMPVersion, TLS, Transport (..)) import Simplex.Messaging.Transport.Client (defaultSMPPort) import Simplex.Messaging.Util (allFinally, catchAllErrors, tryAllErrors) import System.Random (StdGen, newStdGen) @@ -117,8 +116,8 @@ data AgentConfig = AgentConfig caCertificateFile :: FilePath, privateKeyFile :: FilePath, certificateFile :: FilePath, - e2eEncryptVRange :: VersionRangeE2E, - smpAgentVRange :: VersionRangeSMPA, + e2eEncryptVRange :: PQEncryption -> VersionRangeE2E, + smpAgentVRange :: PQEncryption -> VersionRangeSMPA, smpClientVRange :: VersionRangeSMPC } diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index b4383cf9b..af271531b 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -39,6 +39,7 @@ module Simplex.Messaging.Agent.Protocol pattern VersionSMPA, ratchetSyncSMPAgentVersion, deliveryRcptsSMPAgentVersion, + pqdrSMPAgentVersion, supportedSMPAgentVRange, e2eEncConnInfoLength, e2eEncUserMsgLength, @@ -187,7 +188,15 @@ import Simplex.FileTransfer.Protocol (FileParty (..)) import Simplex.FileTransfer.Transport (XFTPErrorType) import Simplex.Messaging.Agent.QueryString import qualified Simplex.Messaging.Crypto as C -import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), PQEncryption (..), pattern PQEncOff, RcvE2ERatchetParams, RcvE2ERatchetParamsUri, SndE2ERatchetParams) +import Simplex.Messaging.Crypto.Ratchet + ( InitialKeys (..), + PQEncryption (..), + pattern PQEncOff, + pattern PQEncOn, + RcvE2ERatchetParams, + RcvE2ERatchetParamsUri, + SndE2ERatchetParams + ) import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers @@ -234,6 +243,7 @@ import UnliftIO.Exception (Exception) -- 2 - "duplex" (more efficient) connection handshake (6/9/2022) -- 3 - support ratchet renegotiation (6/30/2023) -- 4 - delivery receipts (7/13/2023) +-- 5 - post-quantum double ratchet (3/14/2024) data SMPAgentVersion @@ -255,24 +265,34 @@ ratchetSyncSMPAgentVersion = VersionSMPA 3 deliveryRcptsSMPAgentVersion :: VersionSMPA deliveryRcptsSMPAgentVersion = VersionSMPA 4 +pqdrSMPAgentVersion :: VersionSMPA +pqdrSMPAgentVersion = VersionSMPA 5 + +-- TODO v5.7 increase to 5 currentSMPAgentVersion :: VersionSMPA currentSMPAgentVersion = VersionSMPA 4 -supportedSMPAgentVRange :: VersionRangeSMPA -supportedSMPAgentVRange = mkVersionRange duplexHandshakeSMPAgentVersion currentSMPAgentVersion +-- TODO v5.7 remove dependency of version range on whether PQ encryption is used +supportedSMPAgentVRange :: PQEncryption -> VersionRangeSMPA +supportedSMPAgentVRange pq = + mkVersionRange duplexHandshakeSMPAgentVersion $ case pq of + PQEncOn -> pqdrSMPAgentVersion + PQEncOff -> currentSMPAgentVersion -- it is shorter to allow all handshake headers, -- including E2E (double-ratchet) parameters and -- signing key of the sender for the server --- TODO PQ this should be version-dependent --- previously it was 14848, reduced by 3700 (roughly the increase of message ratchet header size + key and ciphertext in reply link) -e2eEncConnInfoLength :: Int -e2eEncConnInfoLength = 11148 +e2eEncConnInfoLength :: VersionSMPA -> Int +e2eEncConnInfoLength v + -- reduced by 3700 (roughly the increase of message ratchet header size + key and ciphertext in reply link) + | v >= pqdrSMPAgentVersion = 11148 + | otherwise = 14848 --- TODO PQ this should be version-dependent --- previously it was 15856, reduced by 2200 (roughly the increase of message ratchet header size) -e2eEncUserMsgLength :: Int -e2eEncUserMsgLength = 13656 +e2eEncUserMsgLength :: VersionSMPA -> Int +e2eEncUserMsgLength v + -- reduced by 2200 (roughly the increase of message ratchet header size) + | v >= pqdrSMPAgentVersion = 13656 + | otherwise = 15856 -- | Raw (unparsed) SMP agent protocol transmission. type ARawTransmission = (ByteString, ByteString, ByteString) diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index a2d01c201..78660ce6c 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -58,6 +58,7 @@ module Simplex.Messaging.Agent.Store.SQLite getConnData, setConnDeleted, setConnAgentVersion, + setConnPQEncryption, getDeletedConnIds, getDeletedWaitingDeliveryConnIds, setConnRatchetSync, @@ -93,7 +94,6 @@ module Simplex.Messaging.Agent.Store.SQLite getAcceptedConfirmation, removeConfirmations, -- Invitations - sent via Contact connections - setConnectionVersion, createInvitation, getInvitation, acceptInvitation, @@ -889,10 +889,6 @@ removeConfirmations db connId = |] [":conn_id" := connId] -setConnectionVersion :: DB.Connection -> ConnId -> VersionSMPA -> IO () -setConnectionVersion db connId aVersion = - DB.execute db "UPDATE connections SET smp_agent_version = ? WHERE conn_id = ?" (aVersion, connId) - createInvitation :: DB.Connection -> TVar ChaChaDRG -> NewInvitation -> IO (Either StoreError InvitationId) createInvitation db gVar NewInvitation {contactConnId, connReq, recipientConnInfo} = createWithRandomId gVar $ \invitationId -> @@ -1956,6 +1952,11 @@ setConnAgentVersion :: DB.Connection -> ConnId -> VersionSMPA -> IO () setConnAgentVersion db connId aVersion = DB.execute db "UPDATE connections SET smp_agent_version = ? WHERE conn_id = ?" (aVersion, connId) +setConnPQEncryption :: DB.Connection -> ConnData -> CR.PQEncryption -> IO ConnData +setConnPQEncryption db cData@ConnData {connId} pqEnc = do + DB.execute db "UPDATE connections SET pq_encryption = ? WHERE conn_id = ?" (pqEnc, connId) + pure (cData :: ConnData) {pqEncryption = pqEnc} + getDeletedConnIds :: DB.Connection -> IO [ConnId] getDeletedConnIds db = map fromOnly <$> DB.query db "SELECT conn_id FROM connections WHERE deleted = ?" (Only True) diff --git a/src/Simplex/Messaging/Crypto/Ratchet.hs b/src/Simplex/Messaging/Crypto/Ratchet.hs index a6faf49c7..9d4b919ca 100644 --- a/src/Simplex/Messaging/Crypto/Ratchet.hs +++ b/src/Simplex/Messaging/Crypto/Ratchet.hs @@ -23,6 +23,8 @@ module Simplex.Messaging.Crypto.Ratchet SkippedMsgDiff (..), SkippedMsgKeys, InitialKeys (..), + pattern IKPQOn, + pattern IKPQOff, PQEncryption (..), pattern PQEncOn, pattern PQEncOff, @@ -40,8 +42,9 @@ module Simplex.Messaging.Crypto.Ratchet VersionE2E, VersionRangeE2E, pattern VersionE2E, + RatchetVersions (..), kdfX3DHE2EEncryptVersion, - pqRatchetVersion, + pqRatchetE2EEncryptVersion, currentE2EEncryptVersion, supportedE2EEncryptVRange, generateRcvE2EParams, @@ -58,7 +61,6 @@ module Simplex.Messaging.Crypto.Ratchet rcDecrypt, -- used in tests MsgHeader (..), - RatchetVersions (..), RatchetInitParams (..), UseKEM (..), RKEMParams (..), @@ -71,6 +73,8 @@ module Simplex.Messaging.Crypto.Ratchet ratchetVersions, fullHeaderLen, applySMDiff, + encodeMsgHeader, + msgHeaderP, ) where @@ -85,7 +89,7 @@ import Crypto.Random (ChaChaDRG) import Data.Aeson (FromJSON (..), ToJSON (..)) import qualified Data.Aeson as J import qualified Data.Aeson.TH as JQ -import Data.Attoparsec.ByteString (Parser) +import Data.Attoparsec.ByteString (Parser, peekWord8') import qualified Data.Attoparsec.ByteString.Char8 as A import qualified Data.ByteArray as BA import Data.ByteString.Char8 (ByteString) @@ -131,14 +135,19 @@ pattern VersionE2E v = Version v kdfX3DHE2EEncryptVersion :: VersionE2E kdfX3DHE2EEncryptVersion = VersionE2E 2 -pqRatchetVersion :: VersionE2E -pqRatchetVersion = VersionE2E 3 +pqRatchetE2EEncryptVersion :: VersionE2E +pqRatchetE2EEncryptVersion = VersionE2E 3 +-- TODO v5.7 increase to 3 currentE2EEncryptVersion :: VersionE2E -currentE2EEncryptVersion = VersionE2E 3 +currentE2EEncryptVersion = VersionE2E 2 -supportedE2EEncryptVRange :: VersionRangeE2E -supportedE2EEncryptVRange = mkVersionRange kdfX3DHE2EEncryptVersion currentE2EEncryptVersion +-- TODO v5.7 remove dependency of version range on whether PQ encryption is used +supportedE2EEncryptVRange :: PQEncryption -> VersionRangeE2E +supportedE2EEncryptVRange pq = + mkVersionRange kdfX3DHE2EEncryptVersion $ case pq of + PQEncOn -> pqRatchetE2EEncryptVersion + PQEncOff -> currentE2EEncryptVersion data RatchetKEMState = RKSProposed -- only KEM encapsulation key @@ -215,7 +224,7 @@ deriving instance Show AnyE2ERatchetParams instance (RatchetKEMStateI s, AlgorithmI a) => Encoding (E2ERatchetParams s a) where smpEncode (E2ERatchetParams v k1 k2 kem_) - | v >= pqRatchetVersion = smpEncode (v, k1, k2, kem_) + | v >= pqRatchetE2EEncryptVersion = smpEncode (v, k1, k2, kem_) | otherwise = smpEncode (v, k1, k2) smpP = toParams <$?> smpP where @@ -243,7 +252,7 @@ instance Encoding AnyE2ERatchetParams where where kemP :: VersionE2E -> Parser (Maybe ARKEMParams) kemP v - | v >= pqRatchetVersion = smpP + | v >= pqRatchetE2EEncryptVersion = smpP | otherwise = pure Nothing instance VersionI E2EVersion (E2ERatchetParams s a) where @@ -283,7 +292,7 @@ instance (RatchetKEMStateI s, AlgorithmI a) => StrEncoding (E2ERatchetParamsUri <> maybe [] encodeKem kem_ where encodeKem kem - | maxVersion vs < pqRatchetVersion = [] + | maxVersion vs < pqRatchetE2EEncryptVersion = [] | otherwise = case kem of RKParamsProposed k -> [("kem_key", strEncode k)] RKParamsAccepted ct k -> [("kem_ct", strEncode ct), ("kem_key", strEncode k)] @@ -313,7 +322,7 @@ instance StrEncoding AnyE2ERatchetParamsUri where _ -> fail "bad e2e params" where kemP vr query - | maxVersion vr >= pqRatchetVersion = + | maxVersion vr >= pqRatchetE2EEncryptVersion = queryParam_ "kem_key" query $>>= \k -> Just . kemParams k <$> queryParam_ "kem_ct" query | otherwise = pure Nothing @@ -366,7 +375,7 @@ generateE2EParams g v useKEM_ = do where kemParams :: IO (Maybe (RKEMParams s, PrivRKEMParams s)) kemParams = case useKEM_ of - Just useKem | v >= pqRatchetVersion -> Just <$> do + Just useKem | v >= pqRatchetE2EEncryptVersion -> Just <$> do ks@(k, _) <- sntrup761Keypair g case useKem of ProposeKEM -> pure (RKParamsProposed k, PrivateRKParamsProposed ks) @@ -414,7 +423,7 @@ pqX3dhSnd spk1 spk2 spKem_ (E2ERatchetParams v rk1 rk2 rKem_) = do where sndPq :: Either CryptoError (Maybe KEMKeyPair, Maybe RatchetKEMAccepted) sndPq = case spKem_ of - Just (APRKP _ ps) | v >= pqRatchetVersion -> case (ps, rKem_) of + Just (APRKP _ ps) | v >= pqRatchetE2EEncryptVersion -> case (ps, rKem_) of (PrivateRKParamsAccepted ct shared ks, Just (RKParamsProposed k)) -> Right (Just ks, Just $ RatchetKEMAccepted k shared ct) (PrivateRKParamsProposed ks, _) -> Right (Just ks, Nothing) -- both parties can send "proposal" in case of ratchet renegotiation _ -> Left CERatchetKEMState @@ -430,7 +439,7 @@ pqX3dhRcv rpk1 rpk2 rpKem_ (E2ERatchetParams v sk1 sk2 sKem_) = do where rcvPq :: ExceptT CryptoError IO (Maybe (KEMKeyPair, RatchetKEMAccepted)) rcvPq = case sKem_ of - Just (RKParamsAccepted ct k') | v >= pqRatchetVersion -> case rpKem_ of + Just (RKParamsAccepted ct k') | v >= pqRatchetE2EEncryptVersion -> case rpKem_ of Just (PrivateRKParamsProposed ks@(_, pk)) -> do shared <- liftIO $ sntrup761Dec ct pk pure $ Just (ks, RatchetKEMAccepted k' shared ct) @@ -457,7 +466,6 @@ data Ratchet a = Ratchet rcAD :: Str, rcDHRs :: PrivateKey a, rcKEM :: Maybe RatchetKEM, - -- TODO PQ make them optional via JSON parser for PQEncryption rcEnableKEM :: PQEncryption, -- will enable KEM on the next ratchet step rcSndKEM :: PQEncryption, -- used KEM hybrid secret for sending ratchet rcRcvKEM :: PQEncryption, -- used KEM hybrid secret for receiving ratchet @@ -584,12 +592,12 @@ instance FromField MessageKey where fromField = blobFieldDecoder smpDecode -- // above added for KEM -- @ initSndRatchet :: - forall a. (AlgorithmI a, DhAlgorithm a) => VersionRangeE2E -> PublicKey a -> PrivateKey a -> (RatchetInitParams, Maybe KEMKeyPair) -> Ratchet a -initSndRatchet v rcDHRr rcDHRs (RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted}, rcPQRs_) = do + forall a. (AlgorithmI a, DhAlgorithm a) => RatchetVersions -> PublicKey a -> PrivateKey a -> (RatchetInitParams, Maybe KEMKeyPair) -> Ratchet a +initSndRatchet rcVersion rcDHRr rcDHRs (RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted}, rcPQRs_) = do -- state.RK, state.CKs, state.NHKs = KDF_RK_HE(SK, DH(state.DHRs, state.DHRr) || state.PQRss) let (rcRK, rcCKs, rcNHKs) = rootKdf ratchetKey rcDHRr rcDHRs (rcPQRss <$> kemAccepted) in Ratchet - { rcVersion = ratchetVersions v, + { rcVersion, rcAD = assocData, rcDHRs, rcKEM = (`RatchetKEM` kemAccepted) <$> rcPQRs_, @@ -613,10 +621,10 @@ initSndRatchet v rcDHRr rcDHRs (RatchetInitParams {assocData, ratchetKey, sndHK, -- Please note that the public part of rcDHRs was sent to the sender -- as part of the connection request and random salt was received from the sender. initRcvRatchet :: - forall a. (AlgorithmI a, DhAlgorithm a) => VersionRangeE2E -> PrivateKey a -> (RatchetInitParams, Maybe KEMKeyPair) -> PQEncryption -> Ratchet a -initRcvRatchet v rcDHRs (RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted}, rcPQRs_) rcEnableKEM = + forall a. (AlgorithmI a, DhAlgorithm a) => RatchetVersions -> PrivateKey a -> (RatchetInitParams, Maybe KEMKeyPair) -> PQEncryption -> Ratchet a +initRcvRatchet rcVersion rcDHRs (RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted}, rcPQRs_) rcEnableKEM = Ratchet - { rcVersion = ratchetVersions v, + { rcVersion, rcAD = assocData, rcDHRs, -- rcKEM: @@ -654,56 +662,69 @@ data MsgHeader a = MsgHeader -- 69 = 2 (original size) + 2 + 1+56 (Curve448) + 4 + 4 -- TODO PQ this must be version-dependent -- TODO this is the exact size, some reserve should be added -paddedHeaderLen :: Int -paddedHeaderLen = 2284 +paddedHeaderLen :: VersionE2E -> Int +paddedHeaderLen v + | v >= pqRatchetE2EEncryptVersion = 2284 + | otherwise = 88 -- only used in tests to validate correct padding -- (2 bytes - version size, 1 byte - header size, not to have it fixed or version-dependent) -fullHeaderLen :: Int -fullHeaderLen = 2 + 1 + paddedHeaderLen + authTagSize + ivSize @AES256 +fullHeaderLen :: VersionE2E -> Int +fullHeaderLen v = 2 + 1 + paddedHeaderLen v + authTagSize + ivSize @AES256 -instance AlgorithmI a => Encoding (MsgHeader a) where - smpEncode MsgHeader {msgMaxVersion, msgDHRs, msgKEM, msgPN, msgNs} - | msgMaxVersion >= pqRatchetVersion = smpEncode (msgMaxVersion, msgDHRs, msgKEM, msgPN, msgNs) - | otherwise = smpEncode (msgMaxVersion, msgDHRs, msgPN, msgNs) - smpP = do - msgMaxVersion <- smpP - msgDHRs <- smpP - msgKEM <- if msgMaxVersion >= pqRatchetVersion then smpP else pure Nothing - msgPN <- smpP - msgNs <- smpP - pure MsgHeader {msgMaxVersion, msgDHRs, msgKEM, msgPN, msgNs} +-- pass the current version, as MsgHeader only includes the max supported version that can be different from the current +encodeMsgHeader :: AlgorithmI a => VersionE2E -> MsgHeader a -> ByteString +encodeMsgHeader v MsgHeader {msgMaxVersion, msgDHRs, msgKEM, msgPN, msgNs} + | v >= pqRatchetE2EEncryptVersion = smpEncode (msgMaxVersion, msgDHRs, msgKEM, msgPN, msgNs) + | otherwise = smpEncode (msgMaxVersion, msgDHRs, msgPN, msgNs) + +-- pass the current version, as MsgHeader only includes the max supported version that can be different from the current +msgHeaderP :: AlgorithmI a => VersionE2E -> Parser (MsgHeader a) +msgHeaderP v = do + msgMaxVersion <- smpP + msgDHRs <- smpP + msgKEM <- if v >= pqRatchetE2EEncryptVersion then smpP else pure Nothing + msgPN <- smpP + msgNs <- smpP + pure MsgHeader {msgMaxVersion, msgDHRs, msgKEM, msgPN, msgNs} data EncMessageHeader = EncMessageHeader - { ehVersion :: VersionE2E, + { ehVersion :: VersionE2E, -- this is current ratchet version ehIV :: IV, ehAuthTag :: AuthTag, ehBody :: ByteString } +-- this encoding depends on version in EncMessageHeader because it is "current" ratchet version instance Encoding EncMessageHeader where smpEncode EncMessageHeader {ehVersion, ehIV, ehAuthTag, ehBody} - | ehVersion >= pqRatchetVersion = smpEncode (ehVersion, ehIV, ehAuthTag, Large ehBody) + | ehVersion >= pqRatchetE2EEncryptVersion = smpEncode (ehVersion, ehIV, ehAuthTag, Large ehBody) | otherwise = smpEncode (ehVersion, ehIV, ehAuthTag, ehBody) smpP = do (ehVersion, ehIV, ehAuthTag) <- smpP - ehBody <- if ehVersion >= pqRatchetVersion then unLarge <$> smpP else smpP + ehBody <- if ehVersion >= pqRatchetE2EEncryptVersion then unLarge <$> smpP else smpP pure EncMessageHeader {ehVersion, ehIV, ehAuthTag, ehBody} +-- the header is length-prefixed to parse it as string and use as part of associated data for authenticated encryption data EncRatchetMessage = EncRatchetMessage { emHeader :: ByteString, emAuthTag :: AuthTag, emBody :: ByteString } +-- the encoder always uses 2-byte lengths for the new version, even for short headers without PQ keys. encodeEncRatchetMessage :: VersionE2E -> EncRatchetMessage -> ByteString encodeEncRatchetMessage v EncRatchetMessage {emHeader, emBody, emAuthTag} - | v >= pqRatchetVersion = smpEncode (Large emHeader, emAuthTag, Tail emBody) + | v >= pqRatchetE2EEncryptVersion = smpEncode (Large emHeader, emAuthTag, Tail emBody) | otherwise = smpEncode (emHeader, emAuthTag, Tail emBody) -encRatchetMessageP :: VersionE2E -> Parser EncRatchetMessage -encRatchetMessageP v = do - emHeader <- if v >= pqRatchetVersion then unLarge <$> smpP else smpP +-- This parser relies on the fact that header cannot be shorter than 32 bytes (it is ~69 bytes without PQ KEM), +-- therefore if the first byte is less or equal to 31 (x1F), then we have 2 byte-length limited to 8191. +-- This allows upgrading the current version in one message. +encRatchetMessageP :: Parser EncRatchetMessage +encRatchetMessageP = do + len1 <- peekWord8' + emHeader <- if len1 < 32 then unLarge <$> smpP else smpP (emAuthTag, Tail emBody) <- smpP pure EncRatchetMessage {emHeader, emBody, emAuthTag} @@ -724,6 +745,7 @@ instance ToJSON PQEncryption where instance FromJSON PQEncryption where parseJSON v = PQEncryption <$> parseJSON v + omittedField = Just PQEncOff replyKEM_ :: PQEncryption -> Maybe (RKEMParams 'RKSProposed) -> Maybe AUseKEM replyKEM_ pqEnc kem_ = case pqEnc of @@ -747,6 +769,12 @@ instance StrEncoding PQEncryption where data InitialKeys = IKUsePQ | IKNoPQ PQEncryption deriving (Eq, Show) +pattern IKPQOn :: InitialKeys +pattern IKPQOn = IKNoPQ PQEncOn + +pattern IKPQOff :: InitialKeys +pattern IKPQOff = IKNoPQ PQEncOff + instance StrEncoding InitialKeys where strEncode = \case IKUsePQ -> "pq=invitation" @@ -773,25 +801,30 @@ joinContactInitialKeys = \case rcEncrypt :: AlgorithmI a => Ratchet a -> Int -> ByteString -> Maybe PQEncryption -> ExceptT CryptoError IO (ByteString, Ratchet a) rcEncrypt Ratchet {rcSnd = Nothing} _ _ _ = throwE CERatchetState -rcEncrypt rc@Ratchet {rcSnd = Just sr@SndRatchet {rcCKs, rcHKs}, rcDHRs, rcKEM, rcNs, rcPN, rcAD = Str rcAD, rcVersion} paddedMsgLen msg pqMode_ = do +rcEncrypt rc@Ratchet {rcSnd = Just sr@SndRatchet {rcCKs, rcHKs}, rcDHRs, rcKEM, rcNs, rcPN, rcAD = Str rcAD, rcVersion} paddedMsgLen msg pqEnc_ = do -- state.CKs, mk = KDF_CK(state.CKs) let (ck', mk, iv, ehIV) = chainKdf rcCKs -- enc_header = HENCRYPT(state.HKs, header) - (ehAuthTag, ehBody) <- encryptAEAD rcHKs ehIV paddedHeaderLen rcAD msgHeader + let v = current rcVersion + (ehAuthTag, ehBody) <- encryptAEAD rcHKs ehIV (paddedHeaderLen v) rcAD (msgHeader v) -- return enc_header, ENCRYPT(mk, plaintext, CONCAT(AD, enc_header)) - -- TODO PQ versioning in Ratchet should change: we should use "current" version here - let emHeader = smpEncode EncMessageHeader {ehVersion = maxSupported rcVersion, ehBody, ehAuthTag, ehIV} + let emHeader = smpEncode EncMessageHeader {ehVersion = v, ehBody, ehAuthTag, ehIV} (emAuthTag, emBody) <- encryptAEAD mk iv paddedMsgLen (rcAD <> emHeader) msg - let msg' = encodeEncRatchetMessage (maxSupported rcVersion) EncRatchetMessage {emHeader, emBody, emAuthTag} + let msg' = encodeEncRatchetMessage v EncRatchetMessage {emHeader, emBody, emAuthTag} -- state.Ns += 1 - rc' = rc {rcSnd = Just sr {rcCKs = ck'}, rcNs = rcNs + 1} - rc'' = case pqMode_ of + -- TODO v5.8 remove comments below + -- Note that maxSupported will not downgrade here below current. + -- TODO v5.7 remove comments below + -- It will downgrade when decrypting the message when the current version downgrades to remove support for PQ encryption. + -- TODO v5.8 replace `max v currentE2EEncryptVersion` with `v` (to allow downgrade when app downgraded) + rc' = rc {rcSnd = Just sr {rcCKs = ck'}, rcNs = rcNs + 1, rcVersion = rcVersion {maxSupported = max v currentE2EEncryptVersion}} + rc'' = case pqEnc_ of Nothing -> rc' - Just rcEnableKEM - | enablePQ rcEnableKEM -> rc' {rcEnableKEM} - | otherwise -> - let rcKEM' = (\rck -> rck {rcKEMs = Nothing}) <$> rcKEM - in rc' {rcEnableKEM, rcKEM = rcKEM'} + -- This sets max version to support PQ encryption. + -- Current version upgrade happens when peer decrypts the message. + -- TODO v5.7 remove version upgrade here, as it's already upgraded above + Just PQEncOn -> rc' {rcEnableKEM = PQEncOn, rcVersion = rcVersion {maxSupported = pqRatchetE2EEncryptVersion}} + Just PQEncOff -> rc' {rcEnableKEM = PQEncOff, rcKEM = (\rck -> rck {rcKEMs = Nothing}) <$> rcKEM} pure (msg', rc'') where -- header = HEADER_PQ2( @@ -801,8 +834,9 @@ rcEncrypt rc@Ratchet {rcSnd = Just sr@SndRatchet {rcCKs, rcHKs}, rcDHRs, rcKEM, -- pn = state.PN, -- n = state.Ns -- ) - msgHeader = - smpEncode + msgHeader v = + encodeMsgHeader + v MsgHeader { msgMaxVersion = maxSupported rcVersion, msgDHRs = publicKey rcDHRs, @@ -837,7 +871,7 @@ rcDecrypt :: ExceptT CryptoError IO (DecryptResult a) rcDecrypt g rc@Ratchet {rcRcv, rcAD = Str rcAD, rcVersion} rcMKSkipped msg' = do -- TODO PQ versioning should change - encMsg@EncRatchetMessage {emHeader} <- parseE CryptoHeaderError (encRatchetMessageP $ maxSupported rcVersion) msg' + encMsg@EncRatchetMessage {emHeader} <- parseE CryptoHeaderError encRatchetMessageP msg' encHdr <- parseE CryptoHeaderError smpP emHeader -- plaintext = TrySkippedMessageKeysHE(state, enc_header, cipher-text, AD) decryptSkipped encHdr encMsg >>= \case @@ -851,9 +885,16 @@ rcDecrypt g rc@Ratchet {rcRcv, rcAD = Str rcAD, rcVersion} rcMKSkipped msg' = do SMMessage r -> pure r where decryptRcMessage :: RatchetStep -> MsgHeader a -> EncRatchetMessage -> ExceptT CryptoError IO (DecryptResult a) - decryptRcMessage rcStep MsgHeader {msgDHRs, msgKEM, msgPN, msgNs} encMsg = do + decryptRcMessage rcStep hdr@MsgHeader {msgMaxVersion, msgPN, msgNs} encMsg = do -- if dh_ratchet: - (rc', smks1) <- ratchetStep rcStep + (rc', smks1) <- case rcStep of + SameRatchet -> pure (upgradedRatchet, M.empty) + AdvanceRatchet -> do + -- SkipMessageKeysHE(state, header.pn) + (rc', hmks) <- liftEither $ skipMessageKeys msgPN upgradedRatchet + -- DHRatchetPQ2HE(state, header) + (,hmks) <$> ratchetStep rc' hdr + -- SkipMessageKeysHE(state, header.n) case skipMessageKeys msgNs rc' of Left e -> pure (Left e, rc', smkDiff smks1) Right (rc''@Ratchet {rcRcv = Just rr@RcvRatchet {rcCKr}, rcNr}, smks2) -> do @@ -863,47 +904,44 @@ rcDecrypt g rc@Ratchet {rcRcv, rcAD = Str rcAD, rcVersion} rcMKSkipped msg' = do msg <- decryptMessage (MessageKey mk iv) encMsg -- state . Nr += 1 pure (msg, rc'' {rcRcv = Just rr {rcCKr = rcCKr'}, rcNr = rcNr + 1}, smkDiff $ smks1 <> smks2) - Right (rc'', smks2) -> do + Right (rc'', smks2) -> pure (Left CERatchetState, rc'', smkDiff $ smks1 <> smks2) where + upgradedRatchet :: Ratchet a + upgradedRatchet + | msgMaxVersion > current rcVersion = rc {rcVersion = rcVersion {current = min msgMaxVersion $ maxSupported rcVersion}} + | otherwise = rc smkDiff :: SkippedMsgKeys -> SkippedMsgDiff smkDiff smks = if M.null smks then SMDNoChange else SMDAdd smks - ratchetStep :: RatchetStep -> ExceptT CryptoError IO (Ratchet a, SkippedMsgKeys) - ratchetStep SameRatchet = pure (rc, M.empty) - ratchetStep AdvanceRatchet = - -- SkipMessageKeysHE(state, header.pn) - case skipMessageKeys msgPN rc of - Left e -> throwE e - Right (rc'@Ratchet {rcDHRs, rcRK, rcNHKs, rcNHKr}, hmks) -> do - -- DHRatchetPQ2HE(state, header) - (kemSS, kemSS', rcKEM') <- pqRatchetStep rc' msgKEM - -- state.DHRs = GENERATE_DH() - (_, rcDHRs') <- atomically $ generateKeyPair @a g - -- state.RK, state.CKr, state.NHKr = KDF_RK_HE(state.RK, DH(state.DHRs, state.DHRr) || ss) - let (rcRK', rcCKr', rcNHKr') = rootKdf rcRK msgDHRs rcDHRs kemSS - -- state.RK, state.CKs, state.NHKs = KDF_RK_HE(state.RK, DH(state.DHRs, state.DHRr) || state.PQRss) - (rcRK'', rcCKs', rcNHKs') = rootKdf rcRK' msgDHRs rcDHRs' kemSS' - sndKEM = isJust kemSS' - rcvKEM = isJust kemSS - rc'' = - rc' - { rcDHRs = rcDHRs', - rcKEM = rcKEM', - rcEnableKEM = PQEncryption $ sndKEM || rcvKEM, - rcSndKEM = PQEncryption sndKEM, - rcRcvKEM = PQEncryption rcvKEM, - rcRK = rcRK'', - rcSnd = Just SndRatchet {rcDHRr = msgDHRs, rcCKs = rcCKs', rcHKs = rcNHKs}, - rcRcv = Just RcvRatchet {rcCKr = rcCKr', rcHKr = rcNHKr}, - rcPN = rcNs rc, - rcNs = 0, - rcNr = 0, - rcNHKs = rcNHKs', - rcNHKr = rcNHKr' - } - pure (rc'', hmks) + ratchetStep :: Ratchet a -> MsgHeader a -> ExceptT CryptoError IO (Ratchet a) + ratchetStep rc'@Ratchet {rcDHRs, rcRK, rcNHKs, rcNHKr} MsgHeader {msgDHRs, msgKEM} = do + (kemSS, kemSS', rcKEM') <- pqRatchetStep rc' msgKEM + -- state.DHRs = GENERATE_DH() + (_, rcDHRs') <- atomically $ generateKeyPair @a g + -- state.RK, state.CKr, state.NHKr = KDF_RK_HE(state.RK, DH(state.DHRs, state.DHRr) || ss) + let (rcRK', rcCKr', rcNHKr') = rootKdf rcRK msgDHRs rcDHRs kemSS + -- state.RK, state.CKs, state.NHKs = KDF_RK_HE(state.RK, DH(state.DHRs, state.DHRr) || state.PQRss) + (rcRK'', rcCKs', rcNHKs') = rootKdf rcRK' msgDHRs rcDHRs' kemSS' + sndKEM = isJust kemSS' + rcvKEM = isJust kemSS + pure + rc' + { rcDHRs = rcDHRs', + rcKEM = rcKEM', + rcEnableKEM = PQEncryption $ sndKEM || rcvKEM, + rcSndKEM = PQEncryption sndKEM, + rcRcvKEM = PQEncryption rcvKEM, + rcRK = rcRK'', + rcSnd = Just SndRatchet {rcDHRr = msgDHRs, rcCKs = rcCKs', rcHKs = rcNHKs}, + rcRcv = Just RcvRatchet {rcCKr = rcCKr', rcHKr = rcNHKr}, + rcPN = rcNs rc, + rcNs = 0, + rcNr = 0, + rcNHKs = rcNHKs', + rcNHKr = rcNHKr' + } pqRatchetStep :: Ratchet a -> Maybe ARKEMParams -> ExceptT CryptoError IO (Maybe KEMSharedKey, Maybe KEMSharedKey, Maybe RatchetKEM) - pqRatchetStep Ratchet {rcKEM, rcEnableKEM = PQEncryption pqEnc} = \case + pqRatchetStep Ratchet {rcKEM, rcEnableKEM = PQEncryption pqEnc, rcVersion = rv} = \case -- received message does not have KEM in header, -- but the user enabled KEM when sending previous message Nothing -> case rcKEM of @@ -913,7 +951,7 @@ rcDecrypt g rc@Ratchet {rcRcv, rcAD = Str rcAD, rcVersion} rcMKSkipped msg' = do _ -> pure (Nothing, Nothing, Nothing) -- received message has KEM in header. Just (ARKP _ ps) - | pqEnc -> do + | pqEnc && current rv >= pqRatchetE2EEncryptVersion -> do -- state.PQRr = header.kem (ss, rcPQRr) <- sharedSecret -- state.PQRct = PQKEM-ENC(state.PQRr, state.PQRss) // encapsulated additional shared secret KEM #1 @@ -981,9 +1019,9 @@ rcDecrypt g rc@Ratchet {rcRcv, rcAD = Str rcAD, rcVersion} rcMKSkipped msg' = do e -> throwE e -- header = HDECRYPT(state.NHKr, enc_header) decryptNextHeader hdr = (AdvanceRatchet,) <$> decryptHeader (rcNHKr rc) hdr - decryptHeader k EncMessageHeader {ehBody, ehAuthTag, ehIV} = do + decryptHeader k EncMessageHeader {ehVersion, ehBody, ehAuthTag, ehIV} = do header <- decryptAEAD k ehIV rcAD ehBody ehAuthTag `catchE` \_ -> throwE CERatchetHeader - parseE' CryptoHeaderError smpP header + parseE' CryptoHeaderError (msgHeaderP ehVersion) header decryptMessage :: MessageKey -> EncRatchetMessage -> ExceptT CryptoError IO (Either CryptoError ByteString) decryptMessage (MessageKey mk iv) EncRatchetMessage {emHeader, emBody, emAuthTag} = -- DECRYPT(mk, cipher-text, CONCAT(AD, enc_header)) diff --git a/tests/AgentTests.hs b/tests/AgentTests.hs index 09e9e5002..6ff9db523 100644 --- a/tests/AgentTests.hs +++ b/tests/AgentTests.hs @@ -29,7 +29,7 @@ import SMPAgentClient import SMPClient (testKeyHash, testPort, testPort2, testStoreLogFile, withSmpServer, withSmpServerStoreLogOn) import Simplex.Messaging.Agent.Protocol hiding (MID) import qualified Simplex.Messaging.Agent.Protocol as A -import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), PQEncryption (..), pattern PQEncOn, pattern PQEncOff) +import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), PQEncryption (..), pattern IKPQOn, pattern IKPQOff, pattern PQEncOn, pattern PQEncOff) import qualified Simplex.Messaging.Crypto.Ratchet as CR import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol (ErrorType (..)) @@ -184,10 +184,10 @@ pqMatrix2NoInv = pqMatrix2_ False pqMatrix2_ :: Bool -> PQMatrix2 c pqMatrix2_ pqInv _ smpTest test = do - it "dh/dh handshake" $ smpTest $ \a b -> test (a, ikPQOff) (b, PQEncOff) - it "dh/pq handshake" $ smpTest $ \a b -> test (a, ikPQOff) (b, PQEncOn) - it "pq/dh handshake" $ smpTest $ \a b -> test (a, ikPQOn) (b, PQEncOff) - it "pq/pq handshake" $ smpTest $ \a b -> test (a, ikPQOn) (b, PQEncOn) + it "dh/dh handshake" $ smpTest $ \a b -> test (a, IKPQOff) (b, PQEncOff) + it "dh/pq handshake" $ smpTest $ \a b -> test (a, IKPQOff) (b, PQEncOn) + it "pq/dh handshake" $ smpTest $ \a b -> test (a, IKPQOn) (b, PQEncOff) + it "pq/pq handshake" $ smpTest $ \a b -> test (a, IKPQOn) (b, PQEncOn) when pqInv $ do it "pq-inv/dh handshake" $ smpTest $ \a b -> test (a, IKUsePQ) (b, PQEncOff) it "pq-inv/pq handshake" $ smpTest $ \a b -> test (a, IKUsePQ) (b, PQEncOn) @@ -199,17 +199,17 @@ pqMatrix3 :: (HasCallStack => (c, InitialKeys) -> (c, PQEncryption) -> (c, PQEncryption) -> IO ()) -> Spec pqMatrix3 _ smpTest test = do - it "dh" $ smpTest $ \a b c -> test (a, ikPQOff) (b, PQEncOff) (c, PQEncOff) - it "dh/dh/pq" $ smpTest $ \a b c -> test (a, ikPQOff) (b, PQEncOff) (c, PQEncOn) - it "dh/pq/dh" $ smpTest $ \a b c -> test (a, ikPQOff) (b, PQEncOn) (c, PQEncOff) - it "dh/pq/pq" $ smpTest $ \a b c -> test (a, ikPQOff) (b, PQEncOn) (c, PQEncOn) - it "pq/dh/dh" $ smpTest $ \a b c -> test (a, ikPQOn) (b, PQEncOff) (c, PQEncOff) - it "pq/dh/pq" $ smpTest $ \a b c -> test (a, ikPQOn) (b, PQEncOff) (c, PQEncOn) - it "pq/pq/dh" $ smpTest $ \a b c -> test (a, ikPQOn) (b, PQEncOn) (c, PQEncOff) - it "pq" $ smpTest $ \a b c -> test (a, ikPQOn) (b, PQEncOn) (c, PQEncOn) + it "dh" $ smpTest $ \a b c -> test (a, IKPQOff) (b, PQEncOff) (c, PQEncOff) + it "dh/dh/pq" $ smpTest $ \a b c -> test (a, IKPQOff) (b, PQEncOff) (c, PQEncOn) + it "dh/pq/dh" $ smpTest $ \a b c -> test (a, IKPQOff) (b, PQEncOn) (c, PQEncOff) + it "dh/pq/pq" $ smpTest $ \a b c -> test (a, IKPQOff) (b, PQEncOn) (c, PQEncOn) + it "pq/dh/dh" $ smpTest $ \a b c -> test (a, IKPQOn) (b, PQEncOff) (c, PQEncOff) + it "pq/dh/pq" $ smpTest $ \a b c -> test (a, IKPQOn) (b, PQEncOff) (c, PQEncOn) + it "pq/pq/dh" $ smpTest $ \a b c -> test (a, IKPQOn) (b, PQEncOn) (c, PQEncOff) + it "pq" $ smpTest $ \a b c -> test (a, IKPQOn) (b, PQEncOn) (c, PQEncOn) testDuplexConnection :: (HasCallStack, Transport c) => TProxy c -> c -> c -> IO () -testDuplexConnection _ alice bob = testDuplexConnection' (alice, ikPQOn) (bob, PQEncOn) +testDuplexConnection _ alice bob = testDuplexConnection' (alice, IKPQOn) (bob, PQEncOn) testDuplexConnection' :: (HasCallStack, Transport c) => (c, InitialKeys) -> (c, PQEncryption) -> IO () testDuplexConnection' (alice, aPQ) (bob, bPQ) = do @@ -246,7 +246,7 @@ testDuplexConnection' (alice, aPQ) (bob, bPQ) = do alice #:# "nothing else should be delivered to alice" testDuplexConnRandomIds :: (HasCallStack, Transport c) => TProxy c -> c -> c -> IO () -testDuplexConnRandomIds _ alice bob = testDuplexConnRandomIds' (alice, ikPQOn) (bob, PQEncOn) +testDuplexConnRandomIds _ alice bob = testDuplexConnRandomIds' (alice, IKPQOn) (bob, PQEncOn) testDuplexConnRandomIds' :: (HasCallStack, Transport c) => (c, InitialKeys) -> (c, PQEncryption) -> IO () testDuplexConnRandomIds' (alice, aPQ) (bob, bPQ) = do @@ -546,14 +546,8 @@ testResumeDeliveryQuotaExceeded _ alice bob = do -- message 8 is skipped because of alice agent sending "QCONT" message bob #: ("5", "alice", "ACK 9") #> ("5", "alice", OK) -ikPQOn :: InitialKeys -ikPQOn = IKNoPQ PQEncOn - -ikPQOff :: InitialKeys -ikPQOff = IKNoPQ PQEncOff - connect :: Transport c => (c, ByteString) -> (c, ByteString) -> IO () -connect (h1, name1) (h2, name2) = connect' (h1, name1, ikPQOn) (h2, name2, PQEncOn) +connect (h1, name1) (h2, name2) = connect' (h1, name1, IKPQOn) (h2, name2, PQEncOn) connect' :: forall c. Transport c => (c, ByteString, InitialKeys) -> (c, ByteString, PQEncryption) -> IO () connect' (h1, name1, pqMode1) (h2, name2, pqMode2) = do diff --git a/tests/AgentTests/ConnectionRequestTests.hs b/tests/AgentTests/ConnectionRequestTests.hs index 758c2d66d..8228ae4cd 100644 --- a/tests/AgentTests/ConnectionRequestTests.hs +++ b/tests/AgentTests/ConnectionRequestTests.hs @@ -81,7 +81,7 @@ testE2ERatchetParams :: RcvE2ERatchetParamsUri 'C.X448 testE2ERatchetParams = E2ERatchetParamsUri (mkVersionRange (VersionE2E 1) (VersionE2E 1)) testDhPubKey testDhPubKey Nothing testE2ERatchetParams12 :: RcvE2ERatchetParamsUri 'C.X448 -testE2ERatchetParams12 = E2ERatchetParamsUri supportedE2EEncryptVRange testDhPubKey testDhPubKey Nothing +testE2ERatchetParams12 = E2ERatchetParamsUri (supportedE2EEncryptVRange PQEncOn) testDhPubKey testDhPubKey Nothing connectionRequest :: AConnectionRequestUri connectionRequest = @@ -95,7 +95,7 @@ connectionRequestCurrentRange :: AConnectionRequestUri connectionRequestCurrentRange = ACR SCMInvitation $ CRInvitationUri - connReqData {crAgentVRange = supportedSMPAgentVRange, crSmpQueues = [queueV1, queueV1]} + connReqData {crAgentVRange = supportedSMPAgentVRange PQEncOn, crSmpQueues = [queueV1, queueV1]} testE2ERatchetParams12 connectionRequestClientDataEmpty :: AConnectionRequestUri @@ -135,7 +135,7 @@ connectionRequestTests = <> urlEncode True testDhKeyStrUri <> "&e2e=v%3D1%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D" strEncode connectionRequestCurrentRange - `shouldBe` "simplex:/invitation#/?v=2-4&smp=smp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D" + `shouldBe` "simplex:/invitation#/?v=2-5&smp=smp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D" <> urlEncode True testDhKeyStrUri <> "%2Csmp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D" <> urlEncode True testDhKeyStrUri @@ -185,7 +185,7 @@ connectionRequestTests = <> testDhKeyStrUri <> "&e2e=extra_key%3Dnew%26v%3D2-3%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D" <> "&some_new_param=abc" - <> "&v=2-4" + <> "&v=2-5" ) `shouldBe` Right connectionRequestCurrentRange strDecode diff --git a/tests/AgentTests/DoubleRatchetTests.hs b/tests/AgentTests/DoubleRatchetTests.hs index 87c9801ee..0e4b9079e 100644 --- a/tests/AgentTests/DoubleRatchetTests.hs +++ b/tests/AgentTests/DoubleRatchetTests.hs @@ -39,12 +39,13 @@ doubleRatchetTests = do describe "double-ratchet encryption/decryption" $ do it "should serialize and parse message header" $ do testAlgs $ testMessageHeader kdfX3DHE2EEncryptVersion - testAlgs $ testMessageHeader $ max pqRatchetVersion currentE2EEncryptVersion + testAlgs $ testMessageHeader $ max pqRatchetE2EEncryptVersion currentE2EEncryptVersion describe "message tests" $ runMessageTests initRatchets False it "should encode/decode ratchet as JSON" $ do testAlgs testKeyJSON testAlgs testRatchetJSON testVersionJSON + it "should decode v2 Ratchet with default field values" $ testDecodeV2RatchetJSON it "should agree the same ratchet parameters" $ testAlgs testX3dh it "should agree the same ratchet parameters with version 1" $ testAlgs testX3dhV1 describe "post-quantum hybrid KEM double-ratchet algorithm" $ do @@ -89,15 +90,15 @@ paddedMsgLen :: Int paddedMsgLen = 100 fullMsgLen :: VersionE2E -> Int -fullMsgLen v = headerLenLength + fullHeaderLen + C.authTagSize + paddedMsgLen +fullMsgLen v = headerLenLength + fullHeaderLen v + C.authTagSize + paddedMsgLen where - headerLenLength = if v < pqRatchetVersion then 1 else 3 -- two bytes are added because of two Large used in new encoding + headerLenLength = if v < pqRatchetE2EEncryptVersion then 1 else 3 -- two bytes are added because of two Large used in new encoding testMessageHeader :: forall a. AlgorithmI a => VersionE2E -> C.SAlgorithm a -> Expectation testMessageHeader v _ = do (k, _) <- atomically . C.generateKeyPair @a =<< C.newRandom let hdr = MsgHeader {msgMaxVersion = v, msgDHRs = k, msgKEM = Nothing, msgPN = 0, msgNs = 0} - parseAll (smpP @(MsgHeader a)) (smpEncode hdr) `shouldBe` Right hdr + parseAll (msgHeaderP v) (encodeMsgHeader v hdr) `shouldBe` Right hdr testKEMParams :: Expectation testKEMParams = do @@ -115,15 +116,15 @@ testMessageHeaderKEM _ = do g <- C.newRandom (k, _) <- atomically $ C.generateKeyPair @a g (kem, _) <- sntrup761Keypair g - let msgMaxVersion = max pqRatchetVersion currentE2EEncryptVersion + let msgMaxVersion = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion msgKEM = Just . ARKP SRKSProposed $ RKParamsProposed kem hdr = MsgHeader {msgMaxVersion, msgDHRs = k, msgKEM, msgPN = 0, msgNs = 0} - parseAll (smpP @(MsgHeader a)) (smpEncode hdr) `shouldBe` Right hdr + parseAll (msgHeaderP msgMaxVersion) (encodeMsgHeader msgMaxVersion hdr) `shouldBe` Right hdr (kem', _) <- sntrup761Keypair g (ct, _) <- sntrup761Enc g kem let msgKEM' = Just . ARKP SRKSAccepted $ RKParamsAccepted ct kem' hdr' = MsgHeader {msgMaxVersion, msgDHRs = k, msgKEM = msgKEM', msgPN = 0, msgNs = 0} - parseAll (smpP @(MsgHeader a)) (smpEncode hdr') `shouldBe` Right hdr' + parseAll (msgHeaderP msgMaxVersion) (encodeMsgHeader msgMaxVersion hdr') `shouldBe` Right hdr' pattern Decrypted :: ByteString -> Either CryptoError (Either CryptoError ByteString) pattern Decrypted msg <- Right (Right msg) @@ -350,6 +351,14 @@ testVersionJSON = do testDecodeRV :: ToJSON a => a -> Expectation testDecodeRV a = J.eitherDecode' (J.encode a) `shouldBe` Right (rv 1 2) +testDecodeV2RatchetJSON :: IO () +testDecodeV2RatchetJSON = do + let v2RatchetJSON = "{\"rcVersion\":[2,2],\"rcAD\":\"2GEJrq48TmQse6NR16I-hrI0tSySZQ57E_g46nDceAPRAiF6j0drq26RTE7be6X7uiB4RaGJGf4QRXzcYuVtWw==\",\"rcDHRs\":\"TUM0Q0FRQXdCUVlESzJWdUJDSUVJRkNYbUxtSHQ3SUNfeHpGTi1Qb3ZqTVQ3S2p6XzZlZlBjOG9fRFY2RWxKOQ==\",\"rcRK\":\"BOX2X7YW5qDSp2XknY_lqacSrtDqQNPvS6iJlZIs3G0=\",\"rcNs\":0,\"rcNr\":0,\"rcPN\":0,\"rcNHKs\":\"IMouSkXUvzT_mo0WM-pqEUK09-HTLk9WOTCFQglyQxU=\",\"rcNHKr\":\"g-tus1clYPV0rGlzkf5a959tUqDYQVZ1FpcPeXdKwxI=\"}" + Right (r :: Ratchet X25519) <- pure $ J.eitherDecodeStrict' v2RatchetJSON + rcEnableKEM r `shouldBe` PQEncOff + rcSndKEM r `shouldBe` PQEncOff + rcRcvKEM r `shouldBe` PQEncOff + testEncodeDecode :: (Eq a, Show a, ToJSON a, FromJSON a) => a -> Expectation testEncodeDecode x = do let j = J.encode x @@ -359,7 +368,7 @@ testEncodeDecode x = do testX3dh :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO () testX3dh _ = do g <- C.newRandom - let v = max pqRatchetVersion currentE2EEncryptVersion + let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion (pkBob1, pkBob2, Nothing, AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g v Nothing (pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQEncOff let paramsBob = pqX3dhSnd pkBob1 pkBob2 Nothing e2eAlice @@ -378,7 +387,7 @@ testX3dhV1 _ = do testPqX3dhProposeInReply :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO () testPqX3dhProposeInReply _ = do g <- C.newRandom - let v = max pqRatchetVersion currentE2EEncryptVersion + let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion -- initiate (no KEM) (pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQEncOff -- propose KEM in reply @@ -390,7 +399,7 @@ testPqX3dhProposeInReply _ = do testPqX3dhProposeAccept :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO () testPqX3dhProposeAccept _ = do g <- C.newRandom - let v = max pqRatchetVersion currentE2EEncryptVersion + let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion -- initiate (propose KEM) (pkAlice1, pkAlice2, pKemAlice_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQEncOn E2ERatchetParams _ _ _ (Just (RKParamsProposed aliceKem)) <- pure e2eAlice @@ -403,7 +412,7 @@ testPqX3dhProposeAccept _ = do testPqX3dhProposeReject :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO () testPqX3dhProposeReject _ = do g <- C.newRandom - let v = max pqRatchetVersion currentE2EEncryptVersion + let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion -- initiate (propose KEM) (pkAlice1, pkAlice2, pKemAlice_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQEncOn E2ERatchetParams _ _ _ (Just (RKParamsProposed _)) <- pure e2eAlice @@ -416,7 +425,7 @@ testPqX3dhProposeReject _ = do testPqX3dhAcceptWithoutProposalError :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO () testPqX3dhAcceptWithoutProposalError _ = do g <- C.newRandom - let v = max pqRatchetVersion currentE2EEncryptVersion + let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion -- initiate (no KEM) (pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQEncOff E2ERatchetParams _ _ _ Nothing <- pure e2eAlice @@ -430,7 +439,7 @@ testPqX3dhAcceptWithoutProposalError _ = do testPqX3dhProposeAgain :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO () testPqX3dhProposeAgain _ = do g <- C.newRandom - let v = max pqRatchetVersion currentE2EEncryptVersion + let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion -- initiate (propose KEM) (pkAlice1, pkAlice2, pKemAlice_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQEncOn E2ERatchetParams _ _ _ (Just (RKParamsProposed _)) <- pure e2eAlice @@ -452,9 +461,9 @@ compatibleRatchets _ -> expectationFailure "RatchetInitParams params are not compatible" encryptDecrypt :: (AlgorithmI a, DhAlgorithm a) => Maybe PQEncryption -> (Ratchet a -> ()) -> (Ratchet a -> ()) -> EncryptDecryptSpec a -encryptDecrypt pqEnc invalidSnd invalidRcv (alice, msg) bob = do - Right msg' <- withTVar (encrypt_ pqEnc) invalidSnd alice msg - Decrypted msg'' <- decrypt' invalidRcv bob msg' +encryptDecrypt pqEnc validSnd validRcv (alice, msg) bob = do + Right msg' <- withTVar (encrypt_ pqEnc) validSnd alice msg + Decrypted msg'' <- decrypt' validRcv bob msg' msg'' `shouldBe` msg -- enable KEM (currently disabled) @@ -493,20 +502,21 @@ withRatchets_ initRatchets_ test = do initRatchets :: (AlgorithmI a, DhAlgorithm a) => IO (Ratchet a, Ratchet a, Encrypt a, Decrypt a, EncryptDecryptSpec a) initRatchets = do g <- C.newRandom - let v = max pqRatchetVersion currentE2EEncryptVersion + let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion (pkBob1, pkBob2, _pKemParams@Nothing, AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams g v Nothing (pkAlice1, pkAlice2, _pKem@Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams g v PQEncOff Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 Nothing e2eAlice Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 Nothing e2eBob (_, pkBob3) <- atomically $ C.generateKeyPair g - let bob = initSndRatchet supportedE2EEncryptVRange (C.publicKey pkAlice2) pkBob3 paramsBob - alice = initRcvRatchet supportedE2EEncryptVRange pkAlice2 paramsAlice PQEncOff + let vs = testRatchetVersions PQEncOff + bob = initSndRatchet vs (C.publicKey pkAlice2) pkBob3 paramsBob + alice = initRcvRatchet vs pkAlice2 paramsAlice PQEncOff pure (alice, bob, encrypt' noSndKEM, decrypt' noRcvKEM, (\#>)) initRatchetsKEMProposed :: forall a. (AlgorithmI a, DhAlgorithm a) => IO (Ratchet a, Ratchet a, Encrypt a, Decrypt a, EncryptDecryptSpec a) initRatchetsKEMProposed = do g <- C.newRandom - let v = max pqRatchetVersion currentE2EEncryptVersion + let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion -- initiate (no KEM) (pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams g v PQEncOff -- propose KEM in reply @@ -515,14 +525,15 @@ initRatchetsKEMProposed = do Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 pKemParams_ e2eAlice Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 Nothing e2eBob (_, pkBob3) <- atomically $ C.generateKeyPair g - let bob = initSndRatchet supportedE2EEncryptVRange (C.publicKey pkAlice2) pkBob3 paramsBob - alice = initRcvRatchet supportedE2EEncryptVRange pkAlice2 paramsAlice PQEncOn + let vs = testRatchetVersions PQEncOn + bob = initSndRatchet vs (C.publicKey pkAlice2) pkBob3 paramsBob + alice = initRcvRatchet vs pkAlice2 paramsAlice PQEncOn pure (alice, bob, encrypt' hasSndKEM, decrypt' hasRcvKEM, (!#>)) initRatchetsKEMAccepted :: forall a. (AlgorithmI a, DhAlgorithm a) => IO (Ratchet a, Ratchet a, Encrypt a, Decrypt a, EncryptDecryptSpec a) initRatchetsKEMAccepted = do g <- C.newRandom - let v = max pqRatchetVersion currentE2EEncryptVersion + let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion -- initiate (propose) (pkAlice1, pkAlice2, pKem_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams g v PQEncOn E2ERatchetParams _ _ _ (Just (RKParamsProposed aliceKem)) <- pure e2eAlice @@ -532,14 +543,15 @@ initRatchetsKEMAccepted = do Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 pKemParams_ e2eAlice Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 pKem_ e2eBob (_, pkBob3) <- atomically $ C.generateKeyPair g - let bob = initSndRatchet supportedE2EEncryptVRange (C.publicKey pkAlice2) pkBob3 paramsBob - alice = initRcvRatchet supportedE2EEncryptVRange pkAlice2 paramsAlice PQEncOn + let vs = testRatchetVersions PQEncOn + bob = initSndRatchet vs (C.publicKey pkAlice2) pkBob3 paramsBob + alice = initRcvRatchet vs pkAlice2 paramsAlice PQEncOn pure (alice, bob, encrypt' hasSndKEM, decrypt' hasRcvKEM, (!#>)) initRatchetsKEMProposedAgain :: forall a. (AlgorithmI a, DhAlgorithm a) => IO (Ratchet a, Ratchet a, Encrypt a, Decrypt a, EncryptDecryptSpec a) initRatchetsKEMProposedAgain = do g <- C.newRandom - let v = max pqRatchetVersion currentE2EEncryptVersion + let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion -- initiate (propose KEM) (pkAlice1, pkAlice2, pKem_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams g v PQEncOn -- propose KEM again in reply @@ -548,10 +560,16 @@ initRatchetsKEMProposedAgain = do Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 pKemParams_ e2eAlice Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 pKem_ e2eBob (_, pkBob3) <- atomically $ C.generateKeyPair g - let bob = initSndRatchet supportedE2EEncryptVRange (C.publicKey pkAlice2) pkBob3 paramsBob - alice = initRcvRatchet supportedE2EEncryptVRange pkAlice2 paramsAlice PQEncOn + let vs = testRatchetVersions PQEncOn + bob = initSndRatchet vs (C.publicKey pkAlice2) pkBob3 paramsBob + alice = initRcvRatchet vs pkAlice2 paramsAlice PQEncOn pure (alice, bob, encrypt' hasSndKEM, decrypt' hasRcvKEM, (!#>)) +testRatchetVersions :: PQEncryption -> RatchetVersions +testRatchetVersions pq = + let v = maxVersion $ supportedE2EEncryptVRange pq + in RVersions v v + encrypt_ :: AlgorithmI a => Maybe PQEncryption -> (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError (ByteString, Ratchet a, SkippedMsgDiff)) encrypt_ enableKem (_, rc, _) msg = -- print msg >> @@ -559,7 +577,7 @@ encrypt_ enableKem (_, rc, _) msg = >>= either (pure . Left) checkLength where checkLength (msg', rc') = do - B.length msg' `shouldBe` fullMsgLen (maxSupported $ rcVersion rc) + B.length msg' `shouldBe` fullMsgLen (current $ rcVersion rc) pure $ Right (msg', rc', SMDNoChange) decrypt_ :: (AlgorithmI a, DhAlgorithm a) => (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError (Either CryptoError ByteString, Ratchet a, SkippedMsgDiff)) diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 11adc2d0a..bc328bbad 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -63,11 +63,11 @@ import qualified Database.SQLite.Simple as SQL import SMPAgentClient import SMPClient (cfg, testPort, testPort2, testStoreLogFile2, withSmpServer, withSmpServerConfigOn, withSmpServerOn, withSmpServerStoreLogOn, withSmpServerStoreMsgLogOn, withSmpServerV7) import Simplex.Messaging.Agent hiding (createConnection, joinConnection, sendMessage) -import qualified Simplex.Messaging.Agent as Agent +import qualified Simplex.Messaging.Agent as A import Simplex.Messaging.Agent.Client (ProtocolTestFailure (..), ProtocolTestStep (..)) import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..), InitialAgentServers (..), createAgentStore) import Simplex.Messaging.Agent.Protocol hiding (CON) -import qualified Simplex.Messaging.Agent.Protocol as Agent +import qualified Simplex.Messaging.Agent.Protocol as A import Simplex.Messaging.Agent.Store.SQLite (MigrationConfirmation (..), SQLiteStore (dbNew)) import Simplex.Messaging.Agent.Store.SQLite.Common (withTransaction') import Simplex.Messaging.Client (NetworkConfig (..), ProtocolClientConfig (..), TransportSessionMode (TSMEntity, TSMUser), defaultSMPClientConfig) @@ -145,7 +145,7 @@ pGet c = do _ -> pure t pattern CON :: ACommand 'Agent 'AEConn -pattern CON = Agent.CON PQEncOn +pattern CON = A.CON PQEncOn pattern Msg :: MsgBody -> ACommand 'Agent e pattern Msg msgBody <- MSG MsgMeta {integrity = MsgOk, pqEncryption = PQEncOn} _ msgBody @@ -168,13 +168,14 @@ smpCfgV7 = (smpCfg agentCfg) {serverVRange = V.mkVersionRange batchCmdsSMPVersio ntfCfgV2 :: ProtocolClientConfig NTFVersion ntfCfgV2 = (smpCfg agentCfg) {serverVRange = V.mkVersionRange (VersionNTF 1) authBatchCmdsNTFVersion} +-- TODO PQ test next version with PQ agentCfgVPrev :: AgentConfig agentCfgVPrev = agentCfg { sndAuthAlg = C.AuthAlg C.SEd25519, - smpAgentVRange = prevRange $ smpAgentVRange agentCfg, + smpAgentVRange = \_ -> prevRange $ smpAgentVRange agentCfg PQEncOff, smpClientVRange = prevRange $ smpClientVRange agentCfg, - e2eEncryptVRange = prevRange $ e2eEncryptVRange agentCfg, + e2eEncryptVRange = \_ -> prevRange $ e2eEncryptVRange agentCfg PQEncOff, smpCfg = smpCfgVPrev } @@ -187,7 +188,7 @@ agentCfgV7 = } agentCfgRatchetVPrev :: AgentConfig -agentCfgRatchetVPrev = agentCfg {e2eEncryptVRange = prevRange $ e2eEncryptVRange agentCfg} +agentCfgRatchetVPrev = agentCfg {e2eEncryptVRange = \_ -> prevRange $ e2eEncryptVRange agentCfg PQEncOff} prevRange :: VersionRange v -> VersionRange v prevRange vr = vr {maxVersion = max (minVersion vr) (prevVersion $ maxVersion vr)} @@ -223,14 +224,14 @@ inAnyOrder g rs = do expected r rp = rp r createConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SubscriptionMode -> m (ConnId, ConnectionRequestUri c) -createConnection c userId enableNtfs cMode clientData = Agent.createConnection c userId enableNtfs cMode clientData (IKNoPQ PQEncOn) +createConnection c userId enableNtfs cMode clientData = A.createConnection c userId enableNtfs cMode clientData (IKNoPQ PQEncOn) joinConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> m ConnId -joinConnection c userId enableNtfs cReq connInfo = Agent.joinConnection c userId enableNtfs cReq connInfo PQEncOn +joinConnection c userId enableNtfs cReq connInfo = A.joinConnection c userId enableNtfs cReq connInfo PQEncOn sendMessage :: AgentErrorMonad m => AgentClient -> ConnId -> SMP.MsgFlags -> MsgBody -> m AgentMsgId sendMessage c connId msgFlags msgBody = do - (msgId, pqEnc) <- Agent.sendMessage c connId PQEncOn msgFlags msgBody + (msgId, pqEnc) <- A.sendMessage c connId PQEncOn msgFlags msgBody liftIO $ pqEnc `shouldBe` PQEncOn pure msgId @@ -267,6 +268,7 @@ functionalAPITests t = do testIncreaseConnAgentVersionMaxCompatible t it "should increase when connection was negotiated on different versions" $ testIncreaseConnAgentVersionStartDifferentVersion t + -- TODO PQ tests for upgrading connection to PQ encryption it "should deliver message after client restart" $ testDeliverClientRestart t it "should deliver messages to the user once, even if repeat delivery is made by the server (no ACK)" $ @@ -424,29 +426,25 @@ canCreateQueue allowNew (srvAuth, srvVersion) (clntAuth, clntVersion) = let v = basicAuthSMPVersion in allowNew && (isNothing srvAuth || (srvVersion >= v && clntVersion >= v && srvAuth == clntAuth)) -testMatrix2 :: ATransport -> (AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec +-- TODO PQ test next version with PQ +testMatrix2 :: ATransport -> (PQEncryption -> AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec testMatrix2 t runTest = do - it "v7" $ withSmpServerV7 t $ runTestCfg2 agentCfgV7 agentCfgV7 3 runTest - it "v7 to current" $ withSmpServerV7 t $ runTestCfg2 agentCfgV7 agentCfg 3 runTest - it "current to v7" $ withSmpServerV7 t $ runTestCfg2 agentCfg agentCfgV7 3 runTest - it "current with v7 server" $ withSmpServerV7 t $ runTestCfg2 agentCfg agentCfg 3 runTest - it "current" $ withSmpServer t $ runTestCfg2 agentCfg agentCfg 3 runTest - skip "TODO PQ versioning" $ describe "TODO fails with previous version" $ do - it "prev" $ withSmpServer t $ runTestCfg2 agentCfgVPrev agentCfgVPrev 3 runTest - it "prev to current" $ withSmpServer t $ runTestCfg2 agentCfgVPrev agentCfg 3 runTest - it "current to prev" $ withSmpServer t $ runTestCfg2 agentCfg agentCfgVPrev 3 runTest + it "v7" $ withSmpServerV7 t $ runTestCfg2 agentCfgV7 agentCfgV7 3 $ runTest PQEncOn + it "v7 to current" $ withSmpServerV7 t $ runTestCfg2 agentCfgV7 agentCfg 3 $ runTest PQEncOn + it "current to v7" $ withSmpServerV7 t $ runTestCfg2 agentCfg agentCfgV7 3 $ runTest PQEncOn + it "current with v7 server" $ withSmpServerV7 t $ runTestCfg2 agentCfg agentCfg 3 $ runTest PQEncOn + it "current" $ withSmpServer t $ runTestCfg2 agentCfg agentCfg 3 $ runTest PQEncOn + it "prev" $ withSmpServer t $ runTestCfg2 agentCfgVPrev agentCfgVPrev 3 $ runTest PQEncOff + it "prev to current" $ withSmpServer t $ runTestCfg2 agentCfgVPrev agentCfg 3 $ runTest PQEncOff + it "current to prev" $ withSmpServer t $ runTestCfg2 agentCfg agentCfgVPrev 3 $ runTest PQEncOff -testRatchetMatrix2 :: ATransport -> (AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec +-- TODO PQ test next version with PQ +testRatchetMatrix2 :: ATransport -> (PQEncryption -> AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec testRatchetMatrix2 t runTest = do - it "ratchet current" $ withSmpServer t $ runTestCfg2 agentCfg agentCfg 3 runTest - skip "TODO PQ versioning" $ describe "TODO fails with previous version" $ do - pendingV "ratchet prev" $ withSmpServer t $ runTestCfg2 agentCfgRatchetVPrev agentCfgRatchetVPrev 3 runTest - pendingV "ratchets prev to current" $ withSmpServer t $ runTestCfg2 agentCfgRatchetVPrev agentCfg 3 runTest - pendingV "ratchets current to prev" $ withSmpServer t $ runTestCfg2 agentCfg agentCfgRatchetVPrev 3 runTest - where - pendingV d = - let vr = e2eEncryptVRange agentCfg - in if minVersion vr == maxVersion vr then skip "previous version is not supported" . it d else it d + it "ratchet current" $ withSmpServer t $ runTestCfg2 agentCfg agentCfg 3 $ runTest PQEncOn + it "ratchet prev" $ withSmpServer t $ runTestCfg2 agentCfgRatchetVPrev agentCfgRatchetVPrev 3 $ runTest PQEncOff + it "ratchets prev to current" $ withSmpServer t $ runTestCfg2 agentCfgRatchetVPrev agentCfg 3 $ runTest PQEncOff + it "ratchets current to prev" $ withSmpServer t $ runTestCfg2 agentCfg agentCfgRatchetVPrev 3 $ runTest PQEncOff testServerMatrix2 :: ATransport -> (InitialAgentServers -> IO ()) -> Spec testServerMatrix2 t runTest = do @@ -468,40 +466,40 @@ withAgentClientsCfg2 aCfg bCfg runTest = do withAgentClients2 :: (AgentClient -> AgentClient -> IO ()) -> IO () withAgentClients2 = withAgentClientsCfg2 agentCfg agentCfg -runAgentClientTest :: HasCallStack => AgentClient -> AgentClient -> AgentMsgId -> IO () -runAgentClientTest alice@AgentClient {} bob baseId = +runAgentClientTest :: HasCallStack => PQEncryption -> AgentClient -> AgentClient -> AgentMsgId -> IO () +runAgentClientTest pqEnc alice@AgentClient {} bob baseId = runRight_ $ do - (bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe - aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe + (bobId, qInfo) <- A.createConnection alice 1 True SCMInvitation Nothing (IKNoPQ pqEnc) SMSubscribe + aliceId <- A.joinConnection bob 1 True qInfo "bob's connInfo" pqEnc SMSubscribe ("", _, CONF confId _ "bob's connInfo") <- get alice allowConnection alice bobId confId "alice's connInfo" - get alice ##> ("", bobId, CON) + get alice ##> ("", bobId, A.CON pqEnc) get bob ##> ("", aliceId, INFO "alice's connInfo") - get bob ##> ("", aliceId, CON) + get bob ##> ("", aliceId, A.CON pqEnc) -- message IDs 1 to 3 (or 1 to 4 in v1) get assigned to control messages, so first MSG is assigned ID 4 - 1 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "hello" + 1 <- msgId <$> A.sendMessage alice bobId pqEnc SMP.noMsgFlags "hello" get alice ##> ("", bobId, SENT $ baseId + 1) - 2 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "how are you?" + 2 <- msgId <$> A.sendMessage alice bobId pqEnc SMP.noMsgFlags "how are you?" get alice ##> ("", bobId, SENT $ baseId + 2) - get bob =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False + get bob =##> \case ("", c, Msg' _ pq "hello") -> c == aliceId && pq == pqEnc; _ -> False ackMessage bob aliceId (baseId + 1) Nothing - get bob =##> \case ("", c, Msg "how are you?") -> c == aliceId; _ -> False + get bob =##> \case ("", c, Msg' _ pq "how are you?") -> c == aliceId && pq == pqEnc; _ -> False ackMessage bob aliceId (baseId + 2) Nothing - 3 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "hello too" + 3 <- msgId <$> A.sendMessage bob aliceId pqEnc SMP.noMsgFlags "hello too" get bob ##> ("", aliceId, SENT $ baseId + 3) - 4 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "message 1" + 4 <- msgId <$> A.sendMessage bob aliceId pqEnc SMP.noMsgFlags "message 1" get bob ##> ("", aliceId, SENT $ baseId + 4) - get alice =##> \case ("", c, Msg "hello too") -> c == bobId; _ -> False + get alice =##> \case ("", c, Msg' _ pq "hello too") -> c == bobId && pq == pqEnc; _ -> False ackMessage alice bobId (baseId + 3) Nothing - get alice =##> \case ("", c, Msg "message 1") -> c == bobId; _ -> False + get alice =##> \case ("", c, Msg' _ pq "message 1") -> c == bobId && pq == pqEnc; _ -> False ackMessage alice bobId (baseId + 4) Nothing suspendConnection alice bobId - 5 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "message 2" + 5 <- msgId <$> A.sendMessage bob aliceId pqEnc SMP.noMsgFlags "message 2" get bob ##> ("", aliceId, MERR (baseId + 5) (SMP AUTH)) deleteConnection alice bobId liftIO $ noMessages alice "nothing else should be delivered to alice" where - msgId = subtract baseId + msgId = subtract baseId . fst testAgentClient3 :: HasCallStack => IO () testAgentClient3 = do @@ -529,42 +527,42 @@ testAgentClient3 = do get c =##> \case ("", connId, Msg "c5") -> connId == aIdForC; _ -> False ackMessage c aIdForC 5 Nothing -runAgentClientContactTest :: HasCallStack => AgentClient -> AgentClient -> AgentMsgId -> IO () -runAgentClientContactTest alice bob baseId = +runAgentClientContactTest :: HasCallStack => PQEncryption -> AgentClient -> AgentClient -> AgentMsgId -> IO () +runAgentClientContactTest pqEnc alice bob baseId = runRight_ $ do - (_, qInfo) <- createConnection alice 1 True SCMContact Nothing SMSubscribe - aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe + (_, qInfo) <- A.createConnection alice 1 True SCMContact Nothing (IKNoPQ pqEnc) SMSubscribe + aliceId <- A.joinConnection bob 1 True qInfo "bob's connInfo" pqEnc SMSubscribe ("", _, REQ invId _ "bob's connInfo") <- get alice bobId <- acceptContact alice True invId "alice's connInfo" PQEncOn SMSubscribe ("", _, CONF confId _ "alice's connInfo") <- get bob allowConnection bob aliceId confId "bob's connInfo" get alice ##> ("", bobId, INFO "bob's connInfo") - get alice ##> ("", bobId, CON) - get bob ##> ("", aliceId, CON) + get alice ##> ("", bobId, A.CON pqEnc) + get bob ##> ("", aliceId, A.CON pqEnc) -- message IDs 1 to 3 (or 1 to 4 in v1) get assigned to control messages, so first MSG is assigned ID 4 - 1 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "hello" + 1 <- msgId <$> A.sendMessage alice bobId pqEnc SMP.noMsgFlags "hello" get alice ##> ("", bobId, SENT $ baseId + 1) - 2 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "how are you?" + 2 <- msgId <$> A.sendMessage alice bobId pqEnc SMP.noMsgFlags "how are you?" get alice ##> ("", bobId, SENT $ baseId + 2) - get bob =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False + get bob =##> \case ("", c, Msg' _ pq "hello") -> c == aliceId && pq == pqEnc; _ -> False ackMessage bob aliceId (baseId + 1) Nothing - get bob =##> \case ("", c, Msg "how are you?") -> c == aliceId; _ -> False + get bob =##> \case ("", c, Msg' _ pq "how are you?") -> c == aliceId && pq == pqEnc; _ -> False ackMessage bob aliceId (baseId + 2) Nothing - 3 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "hello too" + 3 <- msgId <$> A.sendMessage bob aliceId pqEnc SMP.noMsgFlags "hello too" get bob ##> ("", aliceId, SENT $ baseId + 3) - 4 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "message 1" + 4 <- msgId <$> A.sendMessage bob aliceId pqEnc SMP.noMsgFlags "message 1" get bob ##> ("", aliceId, SENT $ baseId + 4) - get alice =##> \case ("", c, Msg "hello too") -> c == bobId; _ -> False + get alice =##> \case ("", c, Msg' _ pq "hello too") -> c == bobId && pq == pqEnc; _ -> False ackMessage alice bobId (baseId + 3) Nothing - get alice =##> \case ("", c, Msg "message 1") -> c == bobId; _ -> False + get alice =##> \case ("", c, Msg' _ pq "message 1") -> c == bobId && pq == pqEnc; _ -> False ackMessage alice bobId (baseId + 4) Nothing suspendConnection alice bobId - 5 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "message 2" + 5 <- msgId <$> A.sendMessage bob aliceId pqEnc SMP.noMsgFlags "message 2" get bob ##> ("", aliceId, MERR (baseId + 5) (SMP AUTH)) deleteConnection alice bobId liftIO $ noMessages alice "nothing else should be delivered to alice" where - msgId = subtract baseId + msgId = subtract baseId . fst noMessages :: HasCallStack => AgentClient -> String -> Expectation noMessages c err = tryGet `shouldReturn` () @@ -688,12 +686,12 @@ testAllowConnectionClientRestart t = do testIncreaseConnAgentVersion :: HasCallStack => ATransport -> IO () testIncreaseConnAgentVersion t = do - alice <- getSMPAgentClient' 1 agentCfg {smpAgentVRange = mkVersionRange 1 2} initAgentServers testDB - bob <- getSMPAgentClient' 2 agentCfg {smpAgentVRange = mkVersionRange 1 2} initAgentServers testDB2 + alice <- getSMPAgentClient' 1 agentCfg {smpAgentVRange = \_ -> mkVersionRange 1 2} initAgentServers testDB + bob <- getSMPAgentClient' 2 agentCfg {smpAgentVRange = \_ -> mkVersionRange 1 2} initAgentServers testDB2 withSmpServerStoreMsgLogOn t testPort $ \_ -> do (aliceId, bobId) <- runRight $ do - (aliceId, bobId) <- makeConnection alice bob - exchangeGreetingsMsgId 4 alice bobId bob aliceId + (aliceId, bobId) <- makeConnection_ PQEncOff alice bob + exchangeGreetingsMsgId_ PQEncOff 4 alice bobId bob aliceId checkVersion alice bobId 2 checkVersion bob aliceId 2 pure (aliceId, bobId) @@ -701,42 +699,42 @@ testIncreaseConnAgentVersion t = do -- version doesn't increase if incompatible disconnectAgentClient alice - alice2 <- getSMPAgentClient' 3 agentCfg {smpAgentVRange = mkVersionRange 1 3} initAgentServers testDB + alice2 <- getSMPAgentClient' 3 agentCfg {smpAgentVRange = \_ -> mkVersionRange 1 3} initAgentServers testDB runRight_ $ do subscribeConnection alice2 bobId - exchangeGreetingsMsgId 6 alice2 bobId bob aliceId + exchangeGreetingsMsgId_ PQEncOff 6 alice2 bobId bob aliceId checkVersion alice2 bobId 2 checkVersion bob aliceId 2 -- version increases if compatible disconnectAgentClient bob - bob2 <- getSMPAgentClient' 4 agentCfg {smpAgentVRange = mkVersionRange 1 3} initAgentServers testDB2 + bob2 <- getSMPAgentClient' 4 agentCfg {smpAgentVRange = \_ -> mkVersionRange 1 3} initAgentServers testDB2 runRight_ $ do subscribeConnection bob2 aliceId - exchangeGreetingsMsgId 8 alice2 bobId bob2 aliceId + exchangeGreetingsMsgId_ PQEncOff 8 alice2 bobId bob2 aliceId checkVersion alice2 bobId 3 checkVersion bob2 aliceId 3 -- version doesn't decrease, even if incompatible disconnectAgentClient alice2 - alice3 <- getSMPAgentClient' 5 agentCfg {smpAgentVRange = mkVersionRange 2 2} initAgentServers testDB + alice3 <- getSMPAgentClient' 5 agentCfg {smpAgentVRange = \_ -> mkVersionRange 2 2} initAgentServers testDB runRight_ $ do subscribeConnection alice3 bobId - exchangeGreetingsMsgId 10 alice3 bobId bob2 aliceId + exchangeGreetingsMsgId_ PQEncOff 10 alice3 bobId bob2 aliceId checkVersion alice3 bobId 3 checkVersion bob2 aliceId 3 disconnectAgentClient bob2 - bob3 <- getSMPAgentClient' 6 agentCfg {smpAgentVRange = mkVersionRange 1 1} initAgentServers testDB2 + bob3 <- getSMPAgentClient' 6 agentCfg {smpAgentVRange = \_ -> mkVersionRange 1 1} initAgentServers testDB2 runRight_ $ do subscribeConnection bob3 aliceId - exchangeGreetingsMsgId 12 alice3 bobId bob3 aliceId + exchangeGreetingsMsgId_ PQEncOff 12 alice3 bobId bob3 aliceId checkVersion alice3 bobId 3 checkVersion bob3 aliceId 3 disconnectAgentClient alice3 @@ -749,12 +747,12 @@ checkVersion c connId v = do testIncreaseConnAgentVersionMaxCompatible :: HasCallStack => ATransport -> IO () testIncreaseConnAgentVersionMaxCompatible t = do - alice <- getSMPAgentClient' 1 agentCfg {smpAgentVRange = mkVersionRange 1 2} initAgentServers testDB - bob <- getSMPAgentClient' 2 agentCfg {smpAgentVRange = mkVersionRange 1 2} initAgentServers testDB2 + alice <- getSMPAgentClient' 1 agentCfg {smpAgentVRange = \_ -> mkVersionRange 1 2} initAgentServers testDB + bob <- getSMPAgentClient' 2 agentCfg {smpAgentVRange = \_ -> mkVersionRange 1 2} initAgentServers testDB2 withSmpServerStoreMsgLogOn t testPort $ \_ -> do (aliceId, bobId) <- runRight $ do - (aliceId, bobId) <- makeConnection alice bob - exchangeGreetingsMsgId 4 alice bobId bob aliceId + (aliceId, bobId) <- makeConnection_ PQEncOff alice bob + exchangeGreetingsMsgId_ PQEncOff 4 alice bobId bob aliceId checkVersion alice bobId 2 checkVersion bob aliceId 2 pure (aliceId, bobId) @@ -762,14 +760,14 @@ testIncreaseConnAgentVersionMaxCompatible t = do -- version increases to max compatible disconnectAgentClient alice - alice2 <- getSMPAgentClient' 3 agentCfg {smpAgentVRange = mkVersionRange 1 3} initAgentServers testDB + alice2 <- getSMPAgentClient' 3 agentCfg {smpAgentVRange = \_ -> mkVersionRange 1 3} initAgentServers testDB disconnectAgentClient bob - bob2 <- getSMPAgentClient' 4 agentCfg {smpAgentVRange = mkVersionRange 1 4} initAgentServers testDB2 + bob2 <- getSMPAgentClient' 4 agentCfg {smpAgentVRange = supportedSMPAgentVRange} initAgentServers testDB2 runRight_ $ do subscribeConnection alice2 bobId subscribeConnection bob2 aliceId - exchangeGreetingsMsgId 6 alice2 bobId bob2 aliceId + exchangeGreetingsMsgId_ PQEncOff 6 alice2 bobId bob2 aliceId checkVersion alice2 bobId 3 checkVersion bob2 aliceId 3 disconnectAgentClient alice2 @@ -777,12 +775,12 @@ testIncreaseConnAgentVersionMaxCompatible t = do testIncreaseConnAgentVersionStartDifferentVersion :: HasCallStack => ATransport -> IO () testIncreaseConnAgentVersionStartDifferentVersion t = do - alice <- getSMPAgentClient' 1 agentCfg {smpAgentVRange = mkVersionRange 1 2} initAgentServers testDB - bob <- getSMPAgentClient' 2 agentCfg {smpAgentVRange = mkVersionRange 1 3} initAgentServers testDB2 + alice <- getSMPAgentClient' 1 agentCfg {smpAgentVRange = \_ -> mkVersionRange 1 2} initAgentServers testDB + bob <- getSMPAgentClient' 2 agentCfg {smpAgentVRange = \_ -> mkVersionRange 1 3} initAgentServers testDB2 withSmpServerStoreMsgLogOn t testPort $ \_ -> do (aliceId, bobId) <- runRight $ do - (aliceId, bobId) <- makeConnection alice bob - exchangeGreetingsMsgId 4 alice bobId bob aliceId + (aliceId, bobId) <- makeConnection_ PQEncOff alice bob + exchangeGreetingsMsgId_ PQEncOff 4 alice bobId bob aliceId checkVersion alice bobId 2 checkVersion bob aliceId 2 pure (aliceId, bobId) @@ -790,11 +788,11 @@ testIncreaseConnAgentVersionStartDifferentVersion t = do -- version increases to max compatible disconnectAgentClient alice - alice2 <- getSMPAgentClient' 3 agentCfg {smpAgentVRange = mkVersionRange 1 3} initAgentServers testDB + alice2 <- getSMPAgentClient' 3 agentCfg {smpAgentVRange = \_ -> mkVersionRange 1 3} initAgentServers testDB runRight_ $ do subscribeConnection alice2 bobId - exchangeGreetingsMsgId 6 alice2 bobId bob aliceId + exchangeGreetingsMsgId_ PQEncOff 6 alice2 bobId bob aliceId checkVersion alice2 bobId 3 checkVersion bob aliceId 3 disconnectAgentClient alice2 @@ -1075,13 +1073,13 @@ setupDesynchronizedRatchet alice bob = do runRight_ $ do subscribeConnection bob2 aliceId - Left Agent.CMD {cmdErr = PROHIBITED} <- runExceptT $ synchronizeRatchet bob2 aliceId PQEncOn False + Left A.CMD {cmdErr = PROHIBITED} <- runExceptT $ synchronizeRatchet bob2 aliceId PQEncOn False 8 <- sendMessage alice bobId SMP.noMsgFlags "hello 5" get alice ##> ("", bobId, SENT 8) get bob2 =##> ratchetSyncP aliceId RSRequired - Left Agent.CMD {cmdErr = PROHIBITED} <- runExceptT $ sendMessage bob2 aliceId SMP.noMsgFlags "hello 6" + Left A.CMD {cmdErr = PROHIBITED} <- runExceptT $ sendMessage bob2 aliceId SMP.noMsgFlags "hello 6" pure () pure (aliceId, bobId, bob2) @@ -1265,13 +1263,13 @@ makeConnectionForUsers = makeConnectionForUsers_ PQEncOn makeConnectionForUsers_ :: PQEncryption -> AgentClient -> UserId -> AgentClient -> UserId -> ExceptT AgentErrorType IO (ConnId, ConnId) makeConnectionForUsers_ pqEnc alice aliceUserId bob bobUserId = do - (bobId, qInfo) <- Agent.createConnection alice aliceUserId True SCMInvitation Nothing (CR.IKNoPQ pqEnc) SMSubscribe - aliceId <- Agent.joinConnection bob bobUserId True qInfo "bob's connInfo" pqEnc SMSubscribe + (bobId, qInfo) <- A.createConnection alice aliceUserId True SCMInvitation Nothing (CR.IKNoPQ pqEnc) SMSubscribe + aliceId <- A.joinConnection bob bobUserId True qInfo "bob's connInfo" pqEnc SMSubscribe ("", _, CONF confId _ "bob's connInfo") <- get alice allowConnection alice bobId confId "alice's connInfo" - get alice ##> ("", bobId, Agent.CON pqEnc) + get alice ##> ("", bobId, A.CON pqEnc) get bob ##> ("", aliceId, INFO "alice's connInfo") - get bob ##> ("", aliceId, Agent.CON pqEnc) + get bob ##> ("", aliceId, A.CON pqEnc) pure (aliceId, bobId) testInactiveNoSubs :: ATransport -> IO () @@ -2032,7 +2030,7 @@ testAbortSwitchStarted servers = do liftIO $ rcvSwchStatuses' stats `shouldMatchList` [Just RSSwitchStarted] phaseRcv a bId SPStarted [Just RSSendingQADD, Nothing] -- repeat switch is prohibited - Left Agent.CMD {cmdErr = PROHIBITED} <- runExceptT $ switchConnectionAsync a "" bId + Left A.CMD {cmdErr = PROHIBITED} <- runExceptT $ switchConnectionAsync a "" bId -- abort current switch stats' <- abortConnectionSwitch a bId liftIO $ rcvSwchStatuses' stats' `shouldMatchList` [Nothing] @@ -2154,7 +2152,7 @@ testCannotAbortSwitchSecured servers = do withA' $ \a -> do phaseRcv a bId SPConfirmed [Just RSSendingQADD, Nothing] phaseRcv a bId SPSecured [Just RSSendingQUSE, Nothing] - Left Agent.CMD {cmdErr = PROHIBITED} <- runExceptT $ abortConnectionSwitch a bId + Left A.CMD {cmdErr = PROHIBITED} <- runExceptT $ abortConnectionSwitch a bId pure () withA $ \a -> withB $ \b -> runRight_ $ do subscribeConnection a bId @@ -2346,53 +2344,61 @@ testDeliveryReceipts = get a =##> \case ("", c, Msg "hello too") -> c == bId; _ -> False ackMessage a bId 6 $ Just "" get b =##> \case ("", c, Rcvd 6) -> c == aId; _ -> False - ackMessage b aId 7 (Just "") `catchError` \e -> liftIO $ e `shouldBe` Agent.CMD PROHIBITED + ackMessage b aId 7 (Just "") `catchError` \e -> liftIO $ e `shouldBe` A.CMD PROHIBITED ackMessage b aId 7 Nothing testDeliveryReceiptsVersion :: HasCallStack => ATransport -> IO () testDeliveryReceiptsVersion t = do - a <- getSMPAgentClient' 1 agentCfg {smpAgentVRange = mkVersionRange 1 3} initAgentServers testDB - b <- getSMPAgentClient' 2 agentCfg {smpAgentVRange = mkVersionRange 1 3} initAgentServers testDB2 + a <- getSMPAgentClient' 1 agentCfg {smpAgentVRange = \_ -> mkVersionRange 1 3} initAgentServers testDB + b <- getSMPAgentClient' 2 agentCfg {smpAgentVRange = \_ -> mkVersionRange 1 3} initAgentServers testDB2 withSmpServerStoreMsgLogOn t testPort $ \_ -> do (aId, bId) <- runRight $ do - (aId, bId) <- makeConnection a b + (aId, bId) <- makeConnection_ PQEncOff a b checkVersion a bId 3 checkVersion b aId 3 - 4 <- sendMessage a bId SMP.noMsgFlags "hello" + (4, _) <- A.sendMessage a bId PQEncOff SMP.noMsgFlags "hello" get a ##> ("", bId, SENT 4) - get b =##> \case ("", c, Msg "hello") -> c == aId; _ -> False + get b =##> \case ("", c, Msg' 4 PQEncOff "hello") -> c == aId; _ -> False ackMessage b aId 4 $ Just "" liftIO $ noMessages a "no delivery receipt (unsupported version)" - 5 <- sendMessage b aId SMP.noMsgFlags "hello too" + (5, _) <- A.sendMessage b aId PQEncOff SMP.noMsgFlags "hello too" get b ##> ("", aId, SENT 5) - get a =##> \case ("", c, Msg "hello too") -> c == bId; _ -> False + get a =##> \case ("", c, Msg' 5 PQEncOff "hello too") -> c == bId; _ -> False ackMessage a bId 5 $ Just "" liftIO $ noMessages b "no delivery receipt (unsupported version)" pure (aId, bId) disconnectAgentClient a disconnectAgentClient b - a' <- getSMPAgentClient' 3 agentCfg {smpAgentVRange = mkVersionRange 1 4} initAgentServers testDB - b' <- getSMPAgentClient' 4 agentCfg {smpAgentVRange = mkVersionRange 1 4} initAgentServers testDB2 + a' <- getSMPAgentClient' 3 agentCfg {smpAgentVRange = supportedSMPAgentVRange} initAgentServers testDB + b' <- getSMPAgentClient' 4 agentCfg {smpAgentVRange = supportedSMPAgentVRange} initAgentServers testDB2 runRight_ $ do subscribeConnection a' bId subscribeConnection b' aId - exchangeGreetingsMsgId 6 a' bId b' aId + exchangeGreetingsMsgId_ PQEncOff 6 a' bId b' aId checkVersion a' bId 4 checkVersion b' aId 4 - 8 <- sendMessage a' bId SMP.noMsgFlags "hello" + (8, PQEncOff) <- A.sendMessage a' bId PQEncOn SMP.noMsgFlags "hello" get a' ##> ("", bId, SENT 8) - get b' =##> \case ("", c, Msg "hello") -> c == aId; _ -> False + get b' =##> \case ("", c, Msg' 8 PQEncOff "hello") -> c == aId; _ -> False ackMessage b' aId 8 $ Just "" get a' =##> \case ("", c, Rcvd 8) -> c == bId; _ -> False ackMessage a' bId 9 Nothing - 10 <- sendMessage b' aId SMP.noMsgFlags "hello too" + (10, PQEncOff) <- A.sendMessage b' aId PQEncOn SMP.noMsgFlags "hello too" get b' ##> ("", aId, SENT 10) - get a' =##> \case ("", c, Msg "hello too") -> c == bId; _ -> False + get a' =##> \case ("", c, Msg' 10 PQEncOff "hello too") -> c == bId; _ -> False ackMessage a' bId 10 $ Just "" get b' =##> \case ("", c, Rcvd 10) -> c == aId; _ -> False ackMessage b' aId 11 Nothing + -- TODO PQ this part hangs when waiting for Rcvd, because connection tries to upgrade to PQ encryption. + -- replacing 2 PQEncOn with PQEncOff above prevents hanging. + -- (12, _) <- A.sendMessage a' bId PQEncOn SMP.noMsgFlags "hello 2" + -- get a' ##> ("", bId, SENT 12) + -- get b' =##> \case ("", c, Msg' 12 PQEncOff "hello 2") -> c == aId; _ -> False + -- ackMessage b' aId 12 $ Just "" + -- get a' =##> \case ("", c, Rcvd 12) -> c == bId; _ -> False + -- ackMessage a' bId 13 Nothing disconnectAgentClient a' disconnectAgentClient b' @@ -2562,12 +2568,12 @@ exchangeGreetingsMsgId = exchangeGreetingsMsgId_ PQEncOn exchangeGreetingsMsgId_ :: HasCallStack => PQEncryption -> Int64 -> AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO () exchangeGreetingsMsgId_ pqEnc msgId alice bobId bob aliceId = do - msgId1 <- Agent.sendMessage alice bobId pqEnc SMP.noMsgFlags "hello" + msgId1 <- A.sendMessage alice bobId pqEnc SMP.noMsgFlags "hello" liftIO $ msgId1 `shouldBe` (msgId, pqEnc) get alice ##> ("", bobId, SENT msgId) get bob =##> \case ("", c, Msg' mId pq "hello") -> c == aliceId && mId == msgId && pq == pqEnc; _ -> False ackMessage bob aliceId msgId Nothing - msgId2 <- Agent.sendMessage bob aliceId pqEnc SMP.noMsgFlags "hello too" + msgId2 <- A.sendMessage bob aliceId pqEnc SMP.noMsgFlags "hello too" let msgId' = msgId + 1 liftIO $ msgId2 `shouldBe` (msgId', pqEnc) get bob ##> ("", aliceId, SENT msgId') From e04705d9c5e6b3d3652f909a5176c375acf29411 Mon Sep 17 00:00:00 2001 From: Alexander Bondarenko <486682+dpwiz@users.noreply.github.com> Date: Wed, 6 Mar 2024 11:10:49 +0200 Subject: [PATCH 09/30] utils: add generic batching and compression (#1018) * extract batchTransmissions_ * add Simplex.Messaging.Compression * add combined compression/batching * force NonEmpty for batching * hide FFI and allocation related IO * split packing * remove batch compression, tweak API * OCD over API * remove Empty, extract passthrough const --------- Co-authored-by: Evgeny Poberezkin --- package.yaml | 1 + simplexmq.cabal | 8 +++ src/Simplex/Messaging/Compression.hs | 74 ++++++++++++++++++++++++++++ src/Simplex/Messaging/Protocol.hs | 23 ++++++--- 4 files changed, 98 insertions(+), 8 deletions(-) create mode 100644 src/Simplex/Messaging/Compression.hs diff --git a/package.yaml b/package.yaml index 4dbc971a1..76de72ac6 100644 --- a/package.yaml +++ b/package.yaml @@ -73,6 +73,7 @@ dependencies: - unliftio-core == 0.2.* - websockets == 0.12.* - yaml == 0.11.* + - zstd == 0.1.3.* flags: swift: diff --git a/simplexmq.cabal b/simplexmq.cabal index 9242f6168..35f916cc8 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -108,6 +108,7 @@ library Simplex.Messaging.Agent.TRcvQueues Simplex.Messaging.Client Simplex.Messaging.Client.Agent + Simplex.Messaging.Compression Simplex.Messaging.Crypto Simplex.Messaging.Crypto.File Simplex.Messaging.Crypto.Lazy @@ -228,6 +229,7 @@ library , unliftio-core ==0.2.* , websockets ==0.12.* , yaml ==0.11.* + , zstd ==0.1.3.* default-language: Haskell2010 if flag(swift) cpp-options: -DswiftJSON @@ -300,6 +302,7 @@ executable ntf-server , unliftio-core ==0.2.* , websockets ==0.12.* , yaml ==0.11.* + , zstd ==0.1.3.* default-language: Haskell2010 if flag(swift) cpp-options: -DswiftJSON @@ -372,6 +375,7 @@ executable smp-agent , unliftio-core ==0.2.* , websockets ==0.12.* , yaml ==0.11.* + , zstd ==0.1.3.* default-language: Haskell2010 if flag(swift) cpp-options: -DswiftJSON @@ -444,6 +448,7 @@ executable smp-server , unliftio-core ==0.2.* , websockets ==0.12.* , yaml ==0.11.* + , zstd ==0.1.3.* default-language: Haskell2010 if flag(swift) cpp-options: -DswiftJSON @@ -516,6 +521,7 @@ executable xftp , unliftio-core ==0.2.* , websockets ==0.12.* , yaml ==0.11.* + , zstd ==0.1.3.* default-language: Haskell2010 if flag(swift) cpp-options: -DswiftJSON @@ -588,6 +594,7 @@ executable xftp-server , unliftio-core ==0.2.* , websockets ==0.12.* , yaml ==0.11.* + , zstd ==0.1.3.* default-language: Haskell2010 if flag(swift) cpp-options: -DswiftJSON @@ -701,6 +708,7 @@ test-suite simplexmq-test , unliftio-core ==0.2.* , websockets ==0.12.* , yaml ==0.11.* + , zstd ==0.1.3.* default-language: Haskell2010 if flag(swift) cpp-options: -DswiftJSON diff --git a/src/Simplex/Messaging/Compression.hs b/src/Simplex/Messaging/Compression.hs new file mode 100644 index 000000000..c6664a179 --- /dev/null +++ b/src/Simplex/Messaging/Compression.hs @@ -0,0 +1,74 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} + +module Simplex.Messaging.Compression where + +import qualified Codec.Compression.Zstd.FFI as Z +import Control.Monad (forM) +import Data.ByteString (ByteString) +import qualified Data.ByteString as B +import qualified Data.ByteString.Unsafe as B +import Data.List.NonEmpty (NonEmpty) +import Foreign +import Foreign.C.Types +import GHC.IO (unsafePerformIO) +import Simplex.Messaging.Encoding +import UnliftIO.Exception (bracket) + +data Compressed + = -- | Short messages are left intact to skip copying and FFI festivities. + Passthrough ByteString + | -- | Generic compression using no extra context. + Compressed Large + +-- | Messages below this length are not encoded to avoid compression overhead. +maxLengthPassthrough :: Int +maxLengthPassthrough = 181 -- Sampled from real client data. Messages with length >=181 rapidly gain compression ratio. + +instance Encoding Compressed where + smpEncode = \case + Passthrough bytes -> "0" <> smpEncode bytes + Compressed bytes -> "1" <> smpEncode bytes + smpP = + smpP >>= \case + '0' -> Passthrough <$> smpP + '1' -> Compressed <$> smpP + x -> fail $ "unknown Compressed tag: " <> show x + +type CompressCtx = (Ptr Z.CCtx, Ptr CChar, Int) + +withCompressCtx :: Int -> (CompressCtx -> IO a) -> IO a +withCompressCtx scratchSize action = + bracket Z.createCCtx Z.freeCCtx $ \cctx -> + allocaBytes scratchSize $ \scratchPtr -> + action (cctx, scratchPtr, scratchSize) + +compress :: CompressCtx -> ByteString -> IO (Either String Compressed) +compress (cctx, scratchPtr, scratchSize) bs + | B.length bs < maxLengthPassthrough = pure . Right $ Passthrough bs + | otherwise = + B.unsafeUseAsCStringLen bs $ \(sourcePtr, sourceSize) -> do + res <- Z.checkError $ Z.compressCCtx cctx scratchPtr (fromIntegral scratchSize) sourcePtr (fromIntegral sourceSize) 3 + case res of + Left e -> pure $ Left e -- should not happen, unless input buffer is too short + Right dstSize -> Right . Compressed . Large <$> B.packCStringLen (scratchPtr, fromIntegral dstSize) + +type DecompressCtx = (Ptr Z.DCtx, Ptr CChar, CSize) + +withDecompressCtx :: Int -> (DecompressCtx -> IO a) -> IO a +withDecompressCtx maxUnpackedSize action = + bracket Z.createDCtx Z.freeDCtx $ \dctx -> + allocaBytes maxUnpackedSize $ \scratchPtr -> + action (dctx, scratchPtr, fromIntegral maxUnpackedSize) + +decompress :: DecompressCtx -> Compressed -> IO (Either String ByteString) +decompress (dctx, scratchPtr, scratchSize) = \case + Passthrough bs -> pure $ Right bs + Compressed (Large bs) -> + B.unsafeUseAsCStringLen bs $ \(sourcePtr, sourceSize) -> do + res <- Z.checkError $ Z.decompressDCtx dctx scratchPtr scratchSize sourcePtr (fromIntegral sourceSize) + forM res $ \dstSize -> B.packCStringLen (scratchPtr, fromIntegral dstSize) + +decompressBatch :: Int -> NonEmpty Compressed -> NonEmpty (Either String ByteString) +decompressBatch maxUnpackedSize items = unsafePerformIO $ withDecompressCtx maxUnpackedSize $ forM items . decompress +{-# NOINLINE decompressBatch #-} -- prevent double-evaluation under unsafePerformIO diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 8f36d8e7a..3aef08622 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -1,4 +1,5 @@ {-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DuplicateRecordFields #-} @@ -157,6 +158,7 @@ module Simplex.Messaging.Protocol tEncodeBatch1, batchTransmissions, batchTransmissions', + batchTransmissions_, -- * exports for tests CommandTag (..), @@ -171,6 +173,7 @@ import Data.Aeson (FromJSON (..), ToJSON (..)) import qualified Data.Aeson.TH as J import Data.Attoparsec.ByteString.Char8 (Parser, ()) import qualified Data.Attoparsec.ByteString.Char8 as A +import Data.Bifunctor (first) import qualified Data.ByteString.Base64 as B64 import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B @@ -1325,11 +1328,11 @@ data TransportBatch r = TBTransmissions ByteString Int [r] | TBTransmission Byte batchTransmissions :: Bool -> Int -> NonEmpty (Either TransportError SentRawTransmission) -> [TransportBatch ()] batchTransmissions batch bSize = batchTransmissions' batch bSize . L.map (,()) --- | encodes and batches transmissions into blocks, +-- | encodes and batches transmissions into blocks batchTransmissions' :: forall r. Bool -> Int -> NonEmpty (Either TransportError SentRawTransmission, r) -> [TransportBatch r] -batchTransmissions' batch bSize - | batch = addBatch . foldr addTransmission ([], 0, 0, [], []) - | otherwise = map mkBatch1 . L.toList +batchTransmissions' batch bSize ts + | batch = batchTransmissions_ bSize $ L.map (first $ fmap tEncodeForBatch) ts + | otherwise = map mkBatch1 $ L.toList ts where mkBatch1 :: (Either TransportError SentRawTransmission, r) -> TransportBatch r mkBatch1 (t_, r) = case t_ of @@ -1340,17 +1343,21 @@ batchTransmissions' batch bSize | otherwise -> TBError TELargeMsg r where s = tEncode t + +-- | Pack encoded transmissions into batches +batchTransmissions_ :: Int -> NonEmpty (Either TransportError ByteString, r) -> [TransportBatch r] +batchTransmissions_ bSize = addBatch . foldr addTransmission ([], 0, 0, [], []) + where -- 3 = 2 bytes reserved for pad size + 1 for transmission count bSize' = bSize - 3 - addTransmission :: (Either TransportError SentRawTransmission, r) -> ([TransportBatch r], Int, Int, [ByteString], [r]) -> ([TransportBatch r], Int, Int, [ByteString], [r]) - addTransmission (t_, r) acc@(bs, len, n, ss, rs) = case t_ of + addTransmission :: (Either TransportError ByteString, r) -> ([TransportBatch r], Int, Int, [ByteString], [r]) -> ([TransportBatch r], Int, Int, [ByteString], [r]) + addTransmission (t_, r) acc@(bs, !len, !n, ss, rs) = case t_ of Left e -> (TBError e r : addBatch acc, 0, 0, [], []) - Right t + Right s | len' <= bSize' && n < 255 -> (bs, len', 1 + n, s : ss, r : rs) | sLen <= bSize' -> (addBatch acc, sLen, 1, [s], [r]) | otherwise -> (TBError TELargeMsg r : addBatch acc, 0, 0, [], []) where - s = tEncodeForBatch t sLen = B.length s len' = len + sLen addBatch :: ([TransportBatch r], Int, Int, [ByteString], [r]) -> [TransportBatch r] From b435a4dacbdbda7830fe4118e1e205a104801ed9 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin Date: Wed, 6 Mar 2024 16:38:30 +0000 Subject: [PATCH 10/30] envelope sizes dependent on PQ encryption (#1028) * envelope sizes dependent on PQ encryption (WIP) * add "supported" flag to ratchets, update this flag on ratchet resync * change connection PQ status on sendMessage * comment, fix * refactor --- src/Simplex/Messaging/Agent.hs | 168 +++++++++++--------- src/Simplex/Messaging/Agent/Protocol.hs | 16 +- src/Simplex/Messaging/Agent/Store/SQLite.hs | 9 +- src/Simplex/Messaging/Crypto/Ratchet.hs | 95 +++++++---- tests/AgentTests/DoubleRatchetTests.hs | 36 +++-- 5 files changed, 191 insertions(+), 133 deletions(-) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 6f3099ad6..d0a232131 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -9,6 +9,7 @@ {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} @@ -140,6 +141,7 @@ import Data.Text (Text) import qualified Data.Text as T import Data.Time.Clock import Data.Time.Clock.System (systemToUTCTime) +import Data.Traversable (mapAccumL) import Data.Word (Word16) import Simplex.FileTransfer.Agent (closeXFTPAgent, deleteSndFileInternal, deleteSndFileRemote, deleteSndFilesInternal, deleteSndFilesRemote, startXFTPWorkers, toFSFilePath, xftpDeleteRcvFile', xftpDeleteRcvFiles', xftpReceiveFile', xftpSendDescription', xftpSendFile') import Simplex.FileTransfer.Description (ValidFileDescription) @@ -158,6 +160,7 @@ import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations import Simplex.Messaging.Client (ProtocolClient (..), ServerTransmission) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.File (CryptoFile, CryptoFileArgs) +import Simplex.Messaging.Crypto.Ratchet (PQEncryption, pattern PQEncOn, pattern PQEncOff) import qualified Simplex.Messaging.Crypto.Ratchet as CR import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String @@ -222,7 +225,7 @@ createConnectionAsync :: forall m c. (AgentErrorMonad m, ConnectionModeI c) => A createConnectionAsync c userId aCorrId enableNtfs = withAgentEnv c .:. newConnAsync c userId aCorrId enableNtfs -- | Join SMP agent connection (JOIN command) asynchronously, synchronous response is new connection id -joinConnectionAsync :: AgentErrorMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> CR.PQEncryption -> SubscriptionMode -> m ConnId +joinConnectionAsync :: AgentErrorMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQEncryption -> SubscriptionMode -> m ConnId joinConnectionAsync c userId aCorrId enableNtfs = withAgentEnv c .:: joinConnAsync c userId aCorrId enableNtfs -- | Allow connection to continue after CONF notification (LET command), no synchronous response @@ -230,7 +233,7 @@ allowConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> allowConnectionAsync c = withAgentEnv c .:: allowConnectionAsync' c -- | Accept contact after REQ notification (ACPT command) asynchronously, synchronous response is new connection id -acceptContactAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> Bool -> ConfirmationId -> ConnInfo -> CR.PQEncryption -> SubscriptionMode -> m ConnId +acceptContactAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> Bool -> ConfirmationId -> ConnInfo -> PQEncryption -> SubscriptionMode -> m ConnId acceptContactAsync c aCorrId enableNtfs = withAgentEnv c .:: acceptContactAsync' c aCorrId enableNtfs -- | Acknowledge message (ACK command) asynchronously, no synchronous response @@ -254,7 +257,7 @@ createConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> SConne createConnection c userId enableNtfs = withAgentEnv c .:: newConn c userId "" enableNtfs -- | Join SMP agent connection (JOIN command) -joinConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> CR.PQEncryption -> SubscriptionMode -> m ConnId +joinConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQEncryption -> SubscriptionMode -> m ConnId joinConnection c userId enableNtfs = withAgentEnv c .:: joinConn c userId "" enableNtfs -- | Allow connection to continue after CONF notification (LET command) @@ -262,7 +265,7 @@ allowConnection :: AgentErrorMonad m => AgentClient -> ConnId -> ConfirmationId allowConnection c = withAgentEnv c .:. allowConnection' c -- | Accept contact after REQ notification (ACPT command) -acceptContact :: AgentErrorMonad m => AgentClient -> Bool -> ConfirmationId -> ConnInfo -> CR.PQEncryption -> SubscriptionMode -> m ConnId +acceptContact :: AgentErrorMonad m => AgentClient -> Bool -> ConfirmationId -> ConnInfo -> PQEncryption -> SubscriptionMode -> m ConnId acceptContact c enableNtfs = withAgentEnv c .:: acceptContact' c "" enableNtfs -- | Reject contact (RJCT command) @@ -292,16 +295,16 @@ resubscribeConnections :: AgentErrorMonad m => AgentClient -> [ConnId] -> m (Map resubscribeConnections c = withAgentEnv c . resubscribeConnections' c -- | Send message to the connection (SEND command) -sendMessage :: AgentErrorMonad m => AgentClient -> ConnId -> CR.PQEncryption -> MsgFlags -> MsgBody -> m (AgentMsgId, CR.PQEncryption) +sendMessage :: AgentErrorMonad m => AgentClient -> ConnId -> PQEncryption -> MsgFlags -> MsgBody -> m (AgentMsgId, PQEncryption) sendMessage c = withAgentEnv c .:: sendMessage' c -type MsgReq = (ConnId, CR.PQEncryption, MsgFlags, MsgBody) +type MsgReq = (ConnId, PQEncryption, MsgFlags, MsgBody) -- | Send multiple messages to different connections (SEND command) -sendMessages :: MonadUnliftIO m => AgentClient -> [MsgReq] -> m [Either AgentErrorType (AgentMsgId, CR.PQEncryption)] +sendMessages :: MonadUnliftIO m => AgentClient -> [MsgReq] -> m [Either AgentErrorType (AgentMsgId, PQEncryption)] sendMessages c = withAgentEnv c . sendMessages' c -sendMessagesB :: (MonadUnliftIO m, Traversable t) => AgentClient -> t (Either AgentErrorType MsgReq) -> m (t (Either AgentErrorType (AgentMsgId, CR.PQEncryption))) +sendMessagesB :: (MonadUnliftIO m, Traversable t) => AgentClient -> t (Either AgentErrorType MsgReq) -> m (t (Either AgentErrorType (AgentMsgId, PQEncryption))) sendMessagesB c = withAgentEnv c . sendMessagesB' c ackMessage :: AgentErrorMonad m => AgentClient -> ConnId -> AgentMsgId -> Maybe MsgReceiptInfo -> m () @@ -316,7 +319,7 @@ abortConnectionSwitch :: AgentErrorMonad m => AgentClient -> ConnId -> m Connect abortConnectionSwitch c = withAgentEnv c . abortConnectionSwitch' c -- | Re-synchronize connection ratchet keys -synchronizeRatchet :: AgentErrorMonad m => AgentClient -> ConnId -> CR.PQEncryption -> Bool -> m ConnectionStats +synchronizeRatchet :: AgentErrorMonad m => AgentClient -> ConnId -> PQEncryption -> Bool -> m ConnectionStats synchronizeRatchet c = withAgentEnv c .:. synchronizeRatchet' c -- | Suspend SMP agent connection (OFF command) @@ -555,14 +558,14 @@ newConnAsync c userId corrId enableNtfs cMode pqInitKeys subMode = do enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn $ NEW enableNtfs (ACM cMode) pqInitKeys subMode pure connId -newConnNoQueues :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> CR.PQEncryption -> m ConnId +newConnNoQueues :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> PQEncryption -> m ConnId newConnNoQueues c userId connId enableNtfs cMode pqEncryption = do g <- asks random connAgentVersion <- asks $ maxVersion . ($ pqEncryption) . smpAgentVRange . config let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk, pqEncryption} withStore c $ \db -> createNewConn db g cData cMode -joinConnAsync :: AgentMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> CR.PQEncryption -> SubscriptionMode -> m ConnId +joinConnAsync :: AgentMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQEncryption -> SubscriptionMode -> m ConnId joinConnAsync c userId corrId enableNtfs cReqUri@(CRInvitationUri ConnReqUriData {crAgentVRange} _) cInfo pqEncryption subMode = do withInvLock c (strEncode cReqUri) "joinConnAsync" $ do aVRange <- asks $ ($ pqEncryption) . smpAgentVRange . config @@ -584,7 +587,7 @@ allowConnectionAsync' c corrId connId confId ownConnInfo = enqueueCommand c corrId connId (Just server) $ AClientCommand $ APC SAEConn $ LET confId ownConnInfo _ -> throwError $ CMD PROHIBITED -acceptContactAsync' :: AgentMonad m => AgentClient -> ACorrId -> Bool -> InvitationId -> ConnInfo -> CR.PQEncryption -> SubscriptionMode -> m ConnId +acceptContactAsync' :: AgentMonad m => AgentClient -> ACorrId -> Bool -> InvitationId -> ConnInfo -> PQEncryption -> SubscriptionMode -> m ConnId acceptContactAsync' c corrId enableNtfs invId ownConnInfo pqEnc subMode = do Invitation {contactConnId, connReq} <- withStore c (`getInvitation` invId) withStore c (`getConn` contactConnId) >>= \case @@ -678,7 +681,7 @@ newRcvConnSrv c userId connId enableNtfs cMode clientData pqInitKeys subMode srv withStore' c $ \db -> createRatchetX3dhKeys db connId pk1 pk2 pKem pure (connId, CRInvitationUri crData $ toVersionRangeT e2eRcvParams e2eVRange) -joinConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> CR.PQEncryption -> SubscriptionMode -> m ConnId +joinConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQEncryption -> SubscriptionMode -> m ConnId joinConn c userId connId enableNtfs cReq cInfo pqEnc subMode = do srv <- case cReq of CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _ -> @@ -686,7 +689,7 @@ joinConn c userId connId enableNtfs cReq cInfo pqEnc subMode = do _ -> getSMPServer c userId joinConnSrv c userId connId enableNtfs cReq cInfo pqEnc subMode srv -startJoinInvitation :: AgentMonad m => UserId -> ConnId -> Bool -> ConnectionRequestUri 'CMInvitation -> CR.PQEncryption -> m (Compatible VersionSMPA, ConnData, NewSndQueue, CR.Ratchet 'C.X448, CR.SndE2ERatchetParams 'C.X448) +startJoinInvitation :: AgentMonad m => UserId -> ConnId -> Bool -> ConnectionRequestUri 'CMInvitation -> PQEncryption -> m (Compatible VersionSMPA, ConnData, NewSndQueue, CR.Ratchet 'C.X448, CR.SndE2ERatchetParams 'C.X448) startJoinInvitation userId connId enableNtfs (CRInvitationUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)} e2eRcvParamsUri) pqEncryption = do AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config let e2eVRange = e2eEncryptVRange pqEncryption @@ -706,7 +709,7 @@ startJoinInvitation userId connId enableNtfs (CRInvitationUri ConnReqUriData {cr pure (aVersion, cData, q, rc, e2eSndParams) _ -> throwError $ AGENT A_VERSION -joinConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> CR.PQEncryption -> SubscriptionMode -> SMPServerWithAuth -> m ConnId +joinConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQEncryption -> SubscriptionMode -> SMPServerWithAuth -> m ConnId joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqEnc subMode srv = withInvLock c (strEncode inv) "joinConnSrv" $ do (aVersion, cData, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv pqEnc @@ -716,7 +719,7 @@ joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqEnc subMod liftIO $ createRatchet db connId' rc pure r let cData' = (cData :: ConnData) {connId = connId'} - tryError (confirmQueue aVersion c cData' sq srv cInfo (Just e2eSndParams) (Just pqEnc) subMode) >>= \case + tryError (confirmQueue aVersion c cData' sq srv cInfo (Just e2eSndParams) pqEnc subMode) >>= \case Right _ -> pure connId' Left e -> do -- possible improvement: recovery for failure on network timeout, see rfcs/2022-04-20-smp-conf-timeout-recovery.md @@ -734,7 +737,7 @@ joinConnSrv c userId connId enableNtfs (CRContactUri ConnReqUriData {crAgentVRan pure connId' _ -> throwError $ AGENT A_VERSION -joinConnSrvAsync :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> CR.PQEncryption -> SubscriptionMode -> SMPServerWithAuth -> m () +joinConnSrvAsync :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQEncryption -> SubscriptionMode -> SMPServerWithAuth -> m () joinConnSrvAsync c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqEnc subMode srv = do (_aVersion, cData, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv pqEnc q' <- withStore c $ \db -> runExceptT $ do @@ -772,7 +775,7 @@ allowConnection' c connId confId ownConnInfo = withConnLock c connId "allowConne _ -> throwError $ CMD PROHIBITED -- | Accept contact (ACPT command) in Reader monad -acceptContact' :: AgentMonad m => AgentClient -> ConnId -> Bool -> InvitationId -> ConnInfo -> CR.PQEncryption -> SubscriptionMode -> m ConnId +acceptContact' :: AgentMonad m => AgentClient -> ConnId -> Bool -> InvitationId -> ConnInfo -> PQEncryption -> SubscriptionMode -> m ConnId acceptContact' c connId enableNtfs invId ownConnInfo pqEnc subMode = withConnLock c connId "acceptContact" $ do Invitation {contactConnId, connReq} <- withStore c (`getInvitation` invId) withStore c (`getConn` contactConnId) >>= \case @@ -913,29 +916,36 @@ getNotificationMessage' c nonce encNtfInfo = do Nothing -> SMP.notification msgFlags -- | Send message to the connection (SEND command) in Reader monad -sendMessage' :: forall m. AgentMonad m => AgentClient -> ConnId -> CR.PQEncryption -> MsgFlags -> MsgBody -> m (AgentMsgId, CR.PQEncryption) +sendMessage' :: forall m. AgentMonad m => AgentClient -> ConnId -> PQEncryption -> MsgFlags -> MsgBody -> m (AgentMsgId, PQEncryption) sendMessage' c connId pqEnc msgFlags msg = liftEither . runIdentity =<< sendMessagesB' c (Identity (Right (connId, pqEnc, msgFlags, msg))) -- | Send multiple messages to different connections (SEND command) in Reader monad -sendMessages' :: forall m. AgentMonad' m => AgentClient -> [MsgReq] -> m [Either AgentErrorType (AgentMsgId, CR.PQEncryption)] +sendMessages' :: forall m. AgentMonad' m => AgentClient -> [MsgReq] -> m [Either AgentErrorType (AgentMsgId, PQEncryption)] sendMessages' c = sendMessagesB' c . map Right -sendMessagesB' :: forall m t. (AgentMonad' m, Traversable t) => AgentClient -> t (Either AgentErrorType MsgReq) -> m (t (Either AgentErrorType (AgentMsgId, CR.PQEncryption))) +sendMessagesB' :: forall m t. (AgentMonad' m, Traversable t) => AgentClient -> t (Either AgentErrorType MsgReq) -> m (t (Either AgentErrorType (AgentMsgId, PQEncryption))) sendMessagesB' c reqs = withConnLocks c connIds "sendMessages" $ do reqs' <- withStoreBatch c (\db -> fmap (bindRight $ \req@(connId, _, _, _) -> bimap storeError (req,) <$> getConn db connId) reqs) - let reqs'' = fmap (>>= prepareConn) reqs' + let (toEnable, reqs'') = mapAccumL prepareConn [] reqs' + void $ withStoreBatch' c $ \db -> map (enableConnPQEncryption db) toEnable enqueueMessagesB c reqs'' where - prepareConn :: (MsgReq, SomeConn) -> Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe CR.PQEncryption, MsgFlags, AMessage) - prepareConn ((_, pqEnc, msgFlags, msg), SomeConn _ conn) = case conn of + prepareConn :: [ConnId] -> Either AgentErrorType (MsgReq, SomeConn) -> ([ConnId], Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage)) + prepareConn acc (Left e) = (acc, Left e) + prepareConn acc (Right ((_, pqEnc, msgFlags, msg), SomeConn _ conn)) = case conn of DuplexConnection cData _ sqs -> prepareMsg cData sqs SndConnection cData sq -> prepareMsg cData [sq] - _ -> Left $ CONN SIMPLEX + _ -> (acc, Left $ CONN SIMPLEX) where - prepareMsg :: ConnData -> NonEmpty SndQueue -> Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe CR.PQEncryption, MsgFlags, AMessage) - prepareMsg cData sqs - | ratchetSyncSendProhibited cData = Left $ CMD PROHIBITED - | otherwise = Right (cData, sqs, Just pqEnc, msgFlags, A_MSG msg) + prepareMsg :: ConnData -> NonEmpty SndQueue -> ([ConnId], Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage)) + prepareMsg cData@ConnData {connId, pqEncryption} sqs + | ratchetSyncSendProhibited cData = (acc, Left $ CMD PROHIBITED) + -- connection is only updated if PQ encryption was disabled, and now it has to be enabled. + -- support for PQ encryption (small message envelopes) will not be disabled when message is sent. + | pqEnc == PQEncOn && pqEncryption == PQEncOff = + let cData' = cData {pqEncryption = pqEnc} :: ConnData + in (connId : acc, Right (cData', sqs, Just pqEnc, msgFlags, A_MSG msg)) + | otherwise = (acc, Right (cData, sqs, Just pqEnc, msgFlags, A_MSG msg)) connIds = map (\(connId, _, _, _) -> connId) $ rights $ toList reqs -- / async command processing v v v @@ -1007,7 +1017,7 @@ runCommandProcessing c@AgentClient {subQ} server_ Worker {doWork} = do _ -> throwError $ INTERNAL $ "incorrect connection type " <> show (internalCmdTag cmd) ICDuplexSecure _rId senderKey -> withServer' . tryWithLock "ICDuplexSecure" . withDuplexConn $ \(DuplexConnection cData (rq :| _) (sq :| _)) -> do secure rq senderKey - void $ enqueueMessage c cData sq Nothing SMP.MsgFlags {notification = True} HELLO + void $ enqueueMessage c cData sq SMP.MsgFlags {notification = True} HELLO -- ICDeleteConn is no longer used, but it can be present in old client databases ICDeleteConn -> withStore' c (`deleteCommand` cmdId) ICDeleteRcvQueue rId -> withServer $ \srv -> tryWithLock "ICDeleteRcvQueue" $ do @@ -1022,7 +1032,7 @@ runCommandProcessing c@AgentClient {subQ} server_ Worker {doWork} = do Just rq1 -> when (status == Confirmed) $ do secureQueue c rq' senderKey withStore' c $ \db -> setRcvQueueStatus db rq' Secured - void . enqueueMessages c cData sqs Nothing SMP.noMsgFlags $ QUSE [((server, sndId), True)] + void . enqueueMessages c cData sqs SMP.noMsgFlags $ QUSE [((server, sndId), True)] rq1' <- withStore' c $ \db -> setRcvSwitchStatus db rq1 $ Just RSSendingQUSE let rqs' = updatedQs rq1' rqs conn' = DuplexConnection cData rqs' sqs @@ -1085,16 +1095,16 @@ runCommandProcessing c@AgentClient {subQ} server_ Worker {doWork} = do notify cmd = atomically $ writeTBQueue subQ (corrId, connId, APC (sAEntity @e) cmd) -- ^ ^ ^ async command processing / -enqueueMessages :: AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> Maybe CR.PQEncryption -> MsgFlags -> AMessage -> m (AgentMsgId, CR.PQEncryption) -enqueueMessages c cData sqs pqEnc_ msgFlags aMessage = do +enqueueMessages :: AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> MsgFlags -> AMessage -> m (AgentMsgId, PQEncryption) +enqueueMessages c cData sqs msgFlags aMessage = do when (ratchetSyncSendProhibited cData) $ throwError $ INTERNAL "enqueueMessages: ratchet is not synchronized" - enqueueMessages' c cData sqs pqEnc_ msgFlags aMessage + enqueueMessages' c cData sqs msgFlags aMessage -enqueueMessages' :: AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> Maybe CR.PQEncryption -> MsgFlags -> AMessage -> m (AgentMsgId, CR.PQEncryption) -enqueueMessages' c cData sqs pqEnc_ msgFlags aMessage = - liftEither . runIdentity =<< enqueueMessagesB c (Identity (Right (cData, sqs, pqEnc_, msgFlags, aMessage))) +enqueueMessages' :: AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> MsgFlags -> AMessage -> m (AgentMsgId, CR.PQEncryption) +enqueueMessages' c cData sqs msgFlags aMessage = + liftEither . runIdentity =<< enqueueMessagesB c (Identity (Right (cData, sqs, Nothing, msgFlags, aMessage))) -enqueueMessagesB :: (AgentMonad' m, Traversable t) => AgentClient -> t (Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe CR.PQEncryption, MsgFlags, AMessage)) -> m (t (Either AgentErrorType (AgentMsgId, CR.PQEncryption))) +enqueueMessagesB :: (AgentMonad' m, Traversable t) => AgentClient -> t (Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage)) -> m (t (Either AgentErrorType (AgentMsgId, PQEncryption))) enqueueMessagesB c reqs = do reqs' <- enqueueMessageB c reqs enqueueSavedMessageB c $ mapMaybe snd $ rights $ toList reqs' @@ -1103,12 +1113,12 @@ enqueueMessagesB c reqs = do isActiveSndQ :: SndQueue -> Bool isActiveSndQ SndQueue {status} = status == Secured || status == Active -enqueueMessage :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> Maybe CR.PQEncryption -> MsgFlags -> AMessage -> m (AgentMsgId, CR.PQEncryption) -enqueueMessage c cData sq pqEnc_ msgFlags aMessage = - liftEither . fmap fst . runIdentity =<< enqueueMessageB c (Identity (Right (cData, [sq], pqEnc_, msgFlags, aMessage))) +enqueueMessage :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> MsgFlags -> AMessage -> m (AgentMsgId, PQEncryption) +enqueueMessage c cData sq msgFlags aMessage = + liftEither . fmap fst . runIdentity =<< enqueueMessageB c (Identity (Right (cData, [sq], Nothing, msgFlags, aMessage))) -- this function is used only for sending messages in batch, it returns the list of successes to enqueue additional deliveries -enqueueMessageB :: forall m t. (AgentMonad' m, Traversable t) => AgentClient -> t (Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe CR.PQEncryption, MsgFlags, AMessage)) -> m (t (Either AgentErrorType ((AgentMsgId, CR.PQEncryption), Maybe (ConnData, [SndQueue], AgentMsgId)))) +enqueueMessageB :: forall m t. (AgentMonad' m, Traversable t) => AgentClient -> t (Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage)) -> m (t (Either AgentErrorType ((AgentMsgId, PQEncryption), Maybe (ConnData, [SndQueue], AgentMsgId)))) enqueueMessageB c reqs = do getAVRange <- asks $ smpAgentVRange . config reqMids <- withStoreBatch c $ \db -> fmap (bindRight $ storeSentMsg db getAVRange) reqs @@ -1117,22 +1127,23 @@ enqueueMessageB c reqs = do let sqs' = filter isActiveSndQ sqs pure $ Right ((msgId, pqSecr), if null sqs' then Nothing else Just (cData, sqs', msgId)) where - storeSentMsg :: DB.Connection -> (CR.PQEncryption -> VersionRangeSMPA) -> (ConnData, NonEmpty SndQueue, Maybe CR.PQEncryption, MsgFlags, AMessage) -> IO (Either AgentErrorType ((ConnData, NonEmpty SndQueue, Maybe CR.PQEncryption, MsgFlags, AMessage), InternalId, CR.PQEncryption)) - storeSentMsg db getAVRange req@(ConnData {connId, connAgentVersion = v}, sq :| _, pqEnc_, msgFlags, aMessage) = fmap (first storeError) $ runExceptT $ do + storeSentMsg :: DB.Connection -> (PQEncryption -> VersionRangeSMPA) -> (ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage) -> IO (Either AgentErrorType ((ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage), InternalId, PQEncryption)) + storeSentMsg db getAVRange req@(cData@ConnData {connId, pqEncryption}, sq :| _, pqEnc_, msgFlags, aMessage) = fmap (first storeError) $ runExceptT $ do internalTs <- liftIO getCurrentTime (internalId, internalSndId, prevMsgHash) <- liftIO $ updateSndIds db connId let privHeader = APrivHeader (unSndId internalSndId) prevMsgHash agentMsg = AgentMessage privHeader aMessage agentMsgStr = smpEncode agentMsg internalHash = C.sha256Hash agentMsgStr - (encAgentMessage, pqEncryption) <- agentRatchetEncrypt db connId agentMsgStr (e2eEncUserMsgLength v) pqEnc_ - let agentVersion = maxVersion . getAVRange $ fromMaybe CR.PQEncOff pqEnc_ + (encAgentMessage, pqEnc) <- agentRatchetEncrypt db cData agentMsgStr e2eEncUserMsgLength pqEnc_ + -- agent version range is determined by the connection suppport of PQ encryption, that is may be enabled when message is sent + let agentVersion = maxVersion $ getAVRange pqEncryption msgBody = smpEncode $ AgentMsgEnvelope {agentVersion, encAgentMessage} msgType = agentMessageType agentMsg - msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgFlags, msgBody, pqEncryption, internalHash, prevMsgHash} + msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgFlags, msgBody, pqEncryption = pqEnc, internalHash, prevMsgHash} liftIO $ createSndMsg db connId msgData liftIO $ createSndMsgDelivery db connId sq internalId - pure (req, internalId, pqEncryption) + pure (req, internalId, pqEnc) enqueueSavedMessage :: AgentMonad' m => AgentClient -> ConnData -> AgentMsgId -> SndQueue -> m () enqueueSavedMessage c cData msgId sq = enqueueSavedMessageB c $ Identity (cData, [sq], msgId) @@ -1344,7 +1355,7 @@ ackMessage' c connId msgId rcptInfo_ = withConnLock c connId "ackMessage" $ do when (connAgentVersion >= deliveryRcptsSMPAgentVersion) $ do let RcvMsg {msgMeta = MsgMeta {sndMsgId}, internalHash} = msg rcpt = A_RCVD [AMessageReceipt {agentMsgId = sndMsgId, msgHash = internalHash, rcptInfo}] - void $ enqueueMessages c cData sqs Nothing SMP.MsgFlags {notification = False} rcpt + void $ enqueueMessages c cData sqs SMP.MsgFlags {notification = False} rcpt Nothing -> case (msgType, msgReceipt) of -- only remove sent message if receipt hash was Ok, both to debug and for future redundancy (AM_A_RCVD_, Just MsgReceipt {agentMsgId = sndMsgId, msgRcptStatus = MROk}) -> @@ -1374,7 +1385,7 @@ switchDuplexConnection c (DuplexConnection cData@ConnData {connId, userId} rqs s let rq' = (q :: NewRcvQueue) {primary = True, dbReplaceQueueId = Just dbQueueId} rq'' <- withStore c $ \db -> addConnRcvQueue db connId rq' addSubscription c rq'' - void . enqueueMessages c cData sqs Nothing SMP.noMsgFlags $ QADD [(qUri, Just (server, sndId))] + void . enqueueMessages c cData sqs SMP.noMsgFlags $ QADD [(qUri, Just (server, sndId))] rq1 <- withStore' c $ \db -> setRcvSwitchStatus db rq $ Just RSSendingQADD let rqs' = updatedQs rq1 rqs <> [rq''] pure . connectionStats $ DuplexConnection cData rqs' sqs @@ -1403,16 +1414,20 @@ abortConnectionSwitch' c connId = _ -> throwError $ CMD PROHIBITED _ -> throwError $ CMD PROHIBITED -synchronizeRatchet' :: AgentMonad m => AgentClient -> ConnId -> CR.PQEncryption -> Bool -> m ConnectionStats +synchronizeRatchet' :: AgentMonad m => AgentClient -> ConnId -> PQEncryption -> Bool -> m ConnectionStats synchronizeRatchet' c connId pqEnc force = withConnLock c connId "synchronizeRatchet" $ do withStore c (`getConn` connId) >>= \case SomeConn _ (DuplexConnection cData@ConnData {pqEncryption} rqs sqs) | ratchetSyncAllowed cData || force -> do -- check queues are not switching? - cData' <- if pqEncryption == pqEnc then pure cData else withStore' c $ \db -> setConnPQEncryption db cData pqEnc + pqEnc' <- + if pqEnc == PQEncOn && pqEncryption == PQEncOff + then PQEncOn <$ withStore' c (`enableConnPQEncryption` connId) + else pure pqEncryption + let cData' = cData {pqEncryption = pqEnc'} :: ConnData AgentConfig {e2eEncryptVRange} <- asks config g <- asks random - (pk1, pk2, pKem, e2eParams) <- liftIO $ CR.generateRcvE2EParams g (maxVersion $ e2eEncryptVRange pqEnc) pqEnc + (pk1, pk2, pKem, e2eParams) <- liftIO $ CR.generateRcvE2EParams g (maxVersion $ e2eEncryptVRange pqEnc') pqEnc' enqueueRatchetKeyMsgs c cData' sqs e2eParams withStore' c $ \db -> do setConnRatchetSync db connId RSStarted @@ -1948,7 +1963,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, notify (MSGNTF $ SMP.rcvMessageMeta srvMsgId msg') where queueDrained = case conn of - DuplexConnection _ _ sqs -> void $ enqueueMessages c cData sqs Nothing SMP.noMsgFlags $ QCONT (sndAddress rq) + DuplexConnection _ _ sqs -> void $ enqueueMessages c cData sqs SMP.noMsgFlags $ QCONT (sndAddress rq) _ -> pure () processClientMsg srvTs msgFlags msgBody = do clientMsg@SMP.ClientMsgEnvelope {cmHeader = SMP.PubHeader phVer e2ePubKey_} <- @@ -2200,7 +2215,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, enqueueDuplexHello :: SndQueue -> m () enqueueDuplexHello sq = do let cData' = toConnData conn' - void $ enqueueMessage c cData' sq Nothing SMP.MsgFlags {notification = True} HELLO + void $ enqueueMessage c cData' sq SMP.MsgFlags {notification = True} HELLO continueSending :: SMP.MsgId -> (SMPServer, SMP.SenderId) -> Connection 'CDuplex -> m () continueSending srvMsgId addr (DuplexConnection _ _ sqs) = @@ -2257,7 +2272,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, (Just sndPubKey, Just dhPublicKey) -> do logServer "<--" c srv rId $ "MSG :" <> logSecret srvMsgId <> " " <> logSecret (senderId queueAddress) let sqInfo' = (sqInfo :: SMPQueueInfo) {queueAddress = queueAddress {dhPublicKey}} - void . enqueueMessages c cData' sqs Nothing SMP.noMsgFlags $ QKEY [(sqInfo', sndPubKey)] + void . enqueueMessages c cData' sqs SMP.noMsgFlags $ QKEY [(sqInfo', sndPubKey)] sq1 <- withStore' c $ \db -> setSndSwitchStatus db sq $ Just SSSendingQKEY let sqs'' = updatedQs sq1 sqs' <> [sq2] conn' = DuplexConnection cData' rqs sqs'' @@ -2302,7 +2317,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, withStore' c $ \db -> setSndQueueStatus db sq' Secured let sq'' = (sq' :: SndQueue) {status = Secured} -- sending QTEST to the new queue only, the old one will be removed if sent successfully - void $ enqueueMessages c cData' [sq''] Nothing SMP.noMsgFlags $ QTEST [addr] + void $ enqueueMessages c cData' [sq''] SMP.noMsgFlags $ QTEST [addr] sq1' <- withStore' c $ \db -> setSndSwitchStatus db sq1 $ Just SSSendingQTEST let sqs' = updatedQs sq1' sqs conn' = DuplexConnection cData' rqs sqs' @@ -2318,7 +2333,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, let CR.Ratchet {rcSnd} = rcPrev -- if ratchet was initialized as receiving, it means EREADY wasn't sent on key negotiation when (isNothing rcSnd) . void $ - enqueueMessages' c cData' sqs Nothing SMP.MsgFlags {notification = True} (EREADY lastExternalSndId) + enqueueMessages' c cData' sqs SMP.MsgFlags {notification = True} (EREADY lastExternalSndId) smpInvitation :: SMP.MsgId -> Connection c -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m () smpInvitation srvMsgId conn' connReq@(CRInvitationUri crData _) cInfo = do @@ -2401,7 +2416,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, (_, rcDHRs) <- atomically . C.generateKeyPair =<< asks random rcParams <- liftEitherWith cryptoError $ CR.pqX3dhSnd pk1 pk2 (CR.APRKP CR.SRKSProposed <$> pKem) e2eOtherPartyParams recreateRatchet $ CR.initSndRatchet rcVs k2Rcv rcDHRs rcParams - void . enqueueMessages' c cData' sqs Nothing SMP.MsgFlags {notification = True} $ EREADY lastExternalSndId + void . enqueueMessages' c cData' sqs SMP.MsgFlags {notification = True} $ EREADY lastExternalSndId checkMsgIntegrity :: PrevExternalSndId -> ExternalSndId -> PrevRcvMsgHash -> ByteString -> MsgIntegrity checkMsgIntegrity prevExtSndId extSndId internalPrevMsgHash receivedPrevMsgHash @@ -2428,22 +2443,22 @@ switchStatusError q expected actual = <> (", actual=" <> show actual) connectReplyQueues :: AgentMonad m => AgentClient -> ConnData -> ConnInfo -> NonEmpty SMPQueueInfo -> m () -connectReplyQueues c cData@ConnData {userId, connId} ownConnInfo (qInfo :| _) = do +connectReplyQueues c cData@ConnData {userId, connId, pqEncryption} ownConnInfo (qInfo :| _) = do clientVRange <- asks $ smpClientVRange . config case qInfo `proveCompatible` clientVRange of Nothing -> throwError $ AGENT A_VERSION Just qInfo' -> do sq <- newSndQueue userId connId qInfo' sq' <- withStore c $ \db -> upgradeRcvConnToDuplex db connId sq - enqueueConfirmation c cData sq' ownConnInfo Nothing Nothing + enqueueConfirmation c cData sq' ownConnInfo Nothing pqEncryption -confirmQueueAsync :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> CR.PQEncryption -> SubscriptionMode -> m () +confirmQueueAsync :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> PQEncryption -> SubscriptionMode -> m () confirmQueueAsync c cData sq srv connInfo e2eEncryption_ pqEnc subMode = do - storeConfirmation c cData sq e2eEncryption_ (Just pqEnc) =<< mkAgentConfirmation c cData sq srv connInfo subMode + storeConfirmation c cData sq e2eEncryption_ pqEnc =<< mkAgentConfirmation c cData sq srv connInfo subMode submitPendingMsg c cData sq -confirmQueue :: forall m. AgentMonad m => Compatible VersionSMPA -> AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> Maybe CR.PQEncryption -> SubscriptionMode -> m () -confirmQueue (Compatible agentVersion) c cData@ConnData {connId, connAgentVersion = v} sq srv connInfo e2eEncryption_ pqEnc_ subMode = do +confirmQueue :: forall m. AgentMonad m => Compatible VersionSMPA -> AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> PQEncryption -> SubscriptionMode -> m () +confirmQueue (Compatible agentVersion) c cData@ConnData {connId} sq srv connInfo e2eEncryption_ pqEnc subMode = do msg <- mkConfirmation =<< mkAgentConfirmation c cData sq srv connInfo subMode sendConfirmation c sq msg withStore' c $ \db -> setSndQueueStatus db sq Confirmed @@ -2451,7 +2466,7 @@ confirmQueue (Compatible agentVersion) c cData@ConnData {connId, connAgentVersio mkConfirmation :: AgentMessage -> m MsgBody mkConfirmation aMessage = withStore c $ \db -> runExceptT $ do void . liftIO $ updateSndIds db connId - (encConnInfo, _) <- agentRatchetEncrypt db connId (smpEncode aMessage) (e2eEncConnInfoLength v) pqEnc_ + (encConnInfo, _) <- agentRatchetEncrypt db cData (smpEncode aMessage) e2eEncConnInfoLength (Just pqEnc) pure . smpEncode $ AgentConfirmation {agentVersion, e2eEncryption_, encConnInfo} mkAgentConfirmation :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> SubscriptionMode -> m AgentMessage @@ -2459,18 +2474,18 @@ mkAgentConfirmation c cData sq srv connInfo subMode = do qInfo <- createReplyQueue c cData sq subMode srv pure $ AgentConnInfoReply (qInfo :| []) connInfo -enqueueConfirmation :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> Maybe CR.PQEncryption -> m () -enqueueConfirmation c cData sq connInfo e2eEncryption_ pqEnc_ = do - storeConfirmation c cData sq e2eEncryption_ pqEnc_ $ AgentConnInfo connInfo +enqueueConfirmation :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> PQEncryption -> m () +enqueueConfirmation c cData sq connInfo e2eEncryption_ pqEnc = do + storeConfirmation c cData sq e2eEncryption_ pqEnc $ AgentConnInfo connInfo submitPendingMsg c cData sq -storeConfirmation :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> Maybe CR.PQEncryption -> AgentMessage -> m () -storeConfirmation c ConnData {connId, connAgentVersion = v} sq e2eEncryption_ pqEnc_ agentMsg = withStore c $ \db -> runExceptT $ do +storeConfirmation :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> PQEncryption -> AgentMessage -> m () +storeConfirmation c cData@ConnData {connId, connAgentVersion = v} sq e2eEncryption_ pqEnc agentMsg = withStore c $ \db -> runExceptT $ do internalTs <- liftIO getCurrentTime (internalId, internalSndId, prevMsgHash) <- liftIO $ updateSndIds db connId let agentMsgStr = smpEncode agentMsg internalHash = C.sha256Hash agentMsgStr - (encConnInfo, pqEncryption) <- agentRatchetEncrypt db connId agentMsgStr (e2eEncConnInfoLength v) pqEnc_ + (encConnInfo, pqEncryption) <- agentRatchetEncrypt db cData agentMsgStr e2eEncConnInfoLength (Just pqEnc) let msgBody = smpEncode $ AgentConfirmation {agentVersion = v, e2eEncryption_, encConnInfo} msgType = agentMessageType agentMsg msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgBody, pqEncryption, msgFlags = SMP.MsgFlags {notification = True}, internalHash, prevMsgHash} @@ -2504,20 +2519,21 @@ enqueueRatchetKey c cData@ConnData {connId, pqEncryption} sq e2eEncryption = do pure internalId -- encoded AgentMessage -> encoded EncAgentMessage -agentRatchetEncrypt :: DB.Connection -> ConnId -> ByteString -> Int -> Maybe CR.PQEncryption -> ExceptT StoreError IO (ByteString, CR.PQEncryption) -agentRatchetEncrypt db connId msg paddedLen pqEnc_ = do +agentRatchetEncrypt :: DB.Connection -> ConnData -> ByteString -> (PQEncryption -> Int) -> Maybe PQEncryption -> ExceptT StoreError IO (ByteString, PQEncryption) +agentRatchetEncrypt db ConnData {connId, pqEncryption} msg getPaddedLen pqEnc_ = do rc <- ExceptT $ getRatchet db connId + let paddedLen = getPaddedLen pqEncryption (encMsg, rc') <- liftE (SEAgentError . cryptoError) $ CR.rcEncrypt rc paddedLen msg pqEnc_ liftIO $ updateRatchet db connId rc' CR.SMDNoChange pure (encMsg, CR.rcSndKEM rc') -- encoded EncAgentMessage -> encoded AgentMessage -agentRatchetDecrypt :: TVar ChaChaDRG -> DB.Connection -> ConnId -> ByteString -> ExceptT StoreError IO (ByteString, CR.PQEncryption) +agentRatchetDecrypt :: TVar ChaChaDRG -> DB.Connection -> ConnId -> ByteString -> ExceptT StoreError IO (ByteString, PQEncryption) agentRatchetDecrypt g db connId encAgentMsg = do rc <- ExceptT $ getRatchet db connId agentRatchetDecrypt' g db connId rc encAgentMsg -agentRatchetDecrypt' :: TVar ChaChaDRG -> DB.Connection -> ConnId -> CR.RatchetX448 -> ByteString -> ExceptT StoreError IO (ByteString, CR.PQEncryption) +agentRatchetDecrypt' :: TVar ChaChaDRG -> DB.Connection -> ConnId -> CR.RatchetX448 -> ByteString -> ExceptT StoreError IO (ByteString, PQEncryption) agentRatchetDecrypt' g db connId rc encAgentMsg = do skipped <- liftIO $ getSkippedMsgKeys db connId (agentMsgBody_, rc', skippedDiff) <- liftE (SEAgentError . cryptoError) $ CR.rcDecrypt g rc skipped encAgentMsg diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index af271531b..d3008e790 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -282,17 +282,17 @@ supportedSMPAgentVRange pq = -- it is shorter to allow all handshake headers, -- including E2E (double-ratchet) parameters and -- signing key of the sender for the server -e2eEncConnInfoLength :: VersionSMPA -> Int -e2eEncConnInfoLength v +e2eEncConnInfoLength :: PQEncryption -> Int +e2eEncConnInfoLength = \case -- reduced by 3700 (roughly the increase of message ratchet header size + key and ciphertext in reply link) - | v >= pqdrSMPAgentVersion = 11148 - | otherwise = 14848 + PQEncOn -> 11148 + PQEncOff -> 14848 -e2eEncUserMsgLength :: VersionSMPA -> Int -e2eEncUserMsgLength v +e2eEncUserMsgLength :: PQEncryption -> Int +e2eEncUserMsgLength = \case -- reduced by 2200 (roughly the increase of message ratchet header size) - | v >= pqdrSMPAgentVersion = 13656 - | otherwise = 15856 + PQEncOn -> 13656 + PQEncOff -> 15856 -- | Raw (unparsed) SMP agent protocol transmission. type ARawTransmission = (ByteString, ByteString, ByteString) diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 78660ce6c..f04bc7904 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -58,7 +58,7 @@ module Simplex.Messaging.Agent.Store.SQLite getConnData, setConnDeleted, setConnAgentVersion, - setConnPQEncryption, + enableConnPQEncryption, getDeletedConnIds, getDeletedWaitingDeliveryConnIds, setConnRatchetSync, @@ -1952,10 +1952,9 @@ setConnAgentVersion :: DB.Connection -> ConnId -> VersionSMPA -> IO () setConnAgentVersion db connId aVersion = DB.execute db "UPDATE connections SET smp_agent_version = ? WHERE conn_id = ?" (aVersion, connId) -setConnPQEncryption :: DB.Connection -> ConnData -> CR.PQEncryption -> IO ConnData -setConnPQEncryption db cData@ConnData {connId} pqEnc = do - DB.execute db "UPDATE connections SET pq_encryption = ? WHERE conn_id = ?" (pqEnc, connId) - pure (cData :: ConnData) {pqEncryption = pqEnc} +enableConnPQEncryption :: DB.Connection -> ConnId -> IO () +enableConnPQEncryption db connId = + DB.execute db "UPDATE connections SET pq_encryption = ? WHERE conn_id = ?" (CR.PQEncOn, connId) getDeletedConnIds :: DB.Connection -> IO [ConnId] getDeletedConnIds db = map fromOnly <$> DB.query db "SELECT conn_id FROM connections WHERE deleted = ?" (Only True) diff --git a/src/Simplex/Messaging/Crypto/Ratchet.hs b/src/Simplex/Messaging/Crypto/Ratchet.hs index 9d4b919ca..b0292f9b5 100644 --- a/src/Simplex/Messaging/Crypto/Ratchet.hs +++ b/src/Simplex/Messaging/Crypto/Ratchet.hs @@ -466,6 +466,7 @@ data Ratchet a = Ratchet rcAD :: Str, rcDHRs :: PrivateKey a, rcKEM :: Maybe RatchetKEM, + rcSupportKEM :: PQEncryption, -- defines header size, can only be enabled once rcEnableKEM :: PQEncryption, -- will enable KEM on the next ratchet step rcSndKEM :: PQEncryption, -- used KEM hybrid secret for sending ratchet rcRcvKEM :: PQEncryption, -- used KEM hybrid secret for receiving ratchet @@ -596,12 +597,14 @@ initSndRatchet :: initSndRatchet rcVersion rcDHRr rcDHRs (RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted}, rcPQRs_) = do -- state.RK, state.CKs, state.NHKs = KDF_RK_HE(SK, DH(state.DHRs, state.DHRr) || state.PQRss) let (rcRK, rcCKs, rcNHKs) = rootKdf ratchetKey rcDHRr rcDHRs (rcPQRss <$> kemAccepted) + pqEnc = PQEncryption $ isJust rcPQRs_ in Ratchet { rcVersion, rcAD = assocData, rcDHRs, rcKEM = (`RatchetKEM` kemAccepted) <$> rcPQRs_, - rcEnableKEM = PQEncryption $ isJust rcPQRs_, + rcSupportKEM = pqEnc, + rcEnableKEM = pqEnc, rcSndKEM = PQEncryption $ isJust kemAccepted, rcRcvKEM = PQEncOff, rcRK, @@ -622,7 +625,7 @@ initSndRatchet rcVersion rcDHRr rcDHRs (RatchetInitParams {assocData, ratchetKey -- as part of the connection request and random salt was received from the sender. initRcvRatchet :: forall a. (AlgorithmI a, DhAlgorithm a) => RatchetVersions -> PrivateKey a -> (RatchetInitParams, Maybe KEMKeyPair) -> PQEncryption -> Ratchet a -initRcvRatchet rcVersion rcDHRs (RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted}, rcPQRs_) rcEnableKEM = +initRcvRatchet rcVersion rcDHRs (RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted}, rcPQRs_) pqEnc = Ratchet { rcVersion, rcAD = assocData, @@ -633,7 +636,8 @@ initRcvRatchet rcVersion rcDHRs (RatchetInitParams {assocData, ratchetKey, sndHK -- state.PQRss = None -- state.PQRct = None rcKEM = (`RatchetKEM` kemAccepted) <$> rcPQRs_, - rcEnableKEM, + rcSupportKEM = pqEnc, + rcEnableKEM = pqEnc, rcSndKEM = PQEncOff, rcRcvKEM = PQEncOff, rcRK = ratchetKey, @@ -662,15 +666,15 @@ data MsgHeader a = MsgHeader -- 69 = 2 (original size) + 2 + 1+56 (Curve448) + 4 + 4 -- TODO PQ this must be version-dependent -- TODO this is the exact size, some reserve should be added -paddedHeaderLen :: VersionE2E -> Int -paddedHeaderLen v - | v >= pqRatchetE2EEncryptVersion = 2284 - | otherwise = 88 +paddedHeaderLen :: PQEncryption -> Int +paddedHeaderLen = \case + PQEncOn -> 2288 + PQEncOff -> 88 -- only used in tests to validate correct padding -- (2 bytes - version size, 1 byte - header size, not to have it fixed or version-dependent) -fullHeaderLen :: VersionE2E -> Int -fullHeaderLen v = 2 + 1 + paddedHeaderLen v + authTagSize + ivSize @AES256 +fullHeaderLen :: PQEncryption -> Int +fullHeaderLen pq = 2 + 1 + paddedHeaderLen pq + authTagSize + ivSize @AES256 -- pass the current version, as MsgHeader only includes the max supported version that can be different from the current encodeMsgHeader :: AlgorithmI a => VersionE2E -> MsgHeader a -> ByteString @@ -698,13 +702,27 @@ data EncMessageHeader = EncMessageHeader -- this encoding depends on version in EncMessageHeader because it is "current" ratchet version instance Encoding EncMessageHeader where smpEncode EncMessageHeader {ehVersion, ehIV, ehAuthTag, ehBody} - | ehVersion >= pqRatchetE2EEncryptVersion = smpEncode (ehVersion, ehIV, ehAuthTag, Large ehBody) - | otherwise = smpEncode (ehVersion, ehIV, ehAuthTag, ehBody) + = smpEncode (ehVersion, ehIV, ehAuthTag) <> encodeLarge ehVersion ehBody smpP = do (ehVersion, ehIV, ehAuthTag) <- smpP - ehBody <- if ehVersion >= pqRatchetE2EEncryptVersion then unLarge <$> smpP else smpP + ehBody <- largeP pure EncMessageHeader {ehVersion, ehIV, ehAuthTag, ehBody} +-- the encoder always uses 2-byte lengths for the new version, even for short headers without PQ keys. +encodeLarge :: VersionE2E -> ByteString -> ByteString +encodeLarge v s + -- the condition for length is not necessary, it's here as a fallback. + | v >= pqRatchetE2EEncryptVersion || B.length s > 255 = smpEncode $ Large s + | otherwise = smpEncode s + +-- This parser relies on the fact that header cannot be shorter than 32 bytes (it is ~69 bytes without PQ KEM), +-- therefore if the first byte is less or equal to 31 (x1F), then we have 2 byte-length limited to 8191. +-- This allows upgrading the current version in one message. +largeP :: Parser ByteString +largeP = do + len1 <- peekWord8' + if len1 < 32 then unLarge <$> smpP else smpP + -- the header is length-prefixed to parse it as string and use as part of associated data for authenticated encryption data EncRatchetMessage = EncRatchetMessage { emHeader :: ByteString, @@ -712,19 +730,13 @@ data EncRatchetMessage = EncRatchetMessage emBody :: ByteString } --- the encoder always uses 2-byte lengths for the new version, even for short headers without PQ keys. encodeEncRatchetMessage :: VersionE2E -> EncRatchetMessage -> ByteString encodeEncRatchetMessage v EncRatchetMessage {emHeader, emBody, emAuthTag} - | v >= pqRatchetE2EEncryptVersion = smpEncode (Large emHeader, emAuthTag, Tail emBody) - | otherwise = smpEncode (emHeader, emAuthTag, Tail emBody) + = encodeLarge v emHeader <> smpEncode (emAuthTag, Tail emBody) --- This parser relies on the fact that header cannot be shorter than 32 bytes (it is ~69 bytes without PQ KEM), --- therefore if the first byte is less or equal to 31 (x1F), then we have 2 byte-length limited to 8191. --- This allows upgrading the current version in one message. encRatchetMessageP :: Parser EncRatchetMessage encRatchetMessageP = do - len1 <- peekWord8' - emHeader <- if len1 < 32 then unLarge <$> smpP else smpP + emHeader <- largeP (emAuthTag, Tail emBody) <- smpP pure EncRatchetMessage {emHeader, emBody, emAuthTag} @@ -801,30 +813,42 @@ joinContactInitialKeys = \case rcEncrypt :: AlgorithmI a => Ratchet a -> Int -> ByteString -> Maybe PQEncryption -> ExceptT CryptoError IO (ByteString, Ratchet a) rcEncrypt Ratchet {rcSnd = Nothing} _ _ _ = throwE CERatchetState -rcEncrypt rc@Ratchet {rcSnd = Just sr@SndRatchet {rcCKs, rcHKs}, rcDHRs, rcKEM, rcNs, rcPN, rcAD = Str rcAD, rcVersion} paddedMsgLen msg pqEnc_ = do +rcEncrypt rc@Ratchet {rcSnd = Just sr@SndRatchet {rcCKs, rcHKs}, rcDHRs, rcKEM, rcNs, rcPN, rcAD = Str rcAD, rcSupportKEM, rcEnableKEM, rcVersion} paddedMsgLen msg pqEnc_ = do -- state.CKs, mk = KDF_CK(state.CKs) let (ck', mk, iv, ehIV) = chainKdf rcCKs + v = current rcVersion + -- PQ encryption can be enabled or disabled + rcEnableKEM' = fromMaybe rcEnableKEM pqEnc_ + -- support for PQ encryption (and therefore large headers/small envelopes) can only be enabled, it cannot be disabled + rcSupportKEM' = PQEncryption $ enablePQ rcSupportKEM || enablePQ rcEnableKEM' -- enc_header = HENCRYPT(state.HKs, header) - let v = current rcVersion - (ehAuthTag, ehBody) <- encryptAEAD rcHKs ehIV (paddedHeaderLen v) rcAD (msgHeader v) + (ehAuthTag, ehBody) <- encryptAEAD rcHKs ehIV (paddedHeaderLen rcSupportKEM') rcAD (msgHeader v) -- return enc_header, ENCRYPT(mk, plaintext, CONCAT(AD, enc_header)) let emHeader = smpEncode EncMessageHeader {ehVersion = v, ehBody, ehAuthTag, ehIV} (emAuthTag, emBody) <- encryptAEAD mk iv paddedMsgLen (rcAD <> emHeader) msg let msg' = encodeEncRatchetMessage v EncRatchetMessage {emHeader, emBody, emAuthTag} - -- state.Ns += 1 -- TODO v5.8 remove comments below -- Note that maxSupported will not downgrade here below current. -- TODO v5.7 remove comments below - -- It will downgrade when decrypting the message when the current version downgrades to remove support for PQ encryption. - -- TODO v5.8 replace `max v currentE2EEncryptVersion` with `v` (to allow downgrade when app downgraded) - rc' = rc {rcSnd = Just sr {rcCKs = ck'}, rcNs = rcNs + 1, rcVersion = rcVersion {maxSupported = max v currentE2EEncryptVersion}} + -- TODO PQ It will downgrade when decrypting the message when the current version downgrades to remove support for PQ encryption. + -- TODO v5.8 possibly, replace `max v currentE2EEncryptVersion` with `v` (to allow downgrade when app downgraded)? + -- + -- state.Ns += 1 + rc' = + rc + { rcSnd = Just sr {rcCKs = ck'}, + rcNs = rcNs + 1, + rcSupportKEM = rcSupportKEM', + rcEnableKEM = rcEnableKEM', + rcVersion = rcVersion {maxSupported = max v currentE2EEncryptVersion} + } rc'' = case pqEnc_ of Nothing -> rc' -- This sets max version to support PQ encryption. -- Current version upgrade happens when peer decrypts the message. -- TODO v5.7 remove version upgrade here, as it's already upgraded above - Just PQEncOn -> rc' {rcEnableKEM = PQEncOn, rcVersion = rcVersion {maxSupported = pqRatchetE2EEncryptVersion}} - Just PQEncOff -> rc' {rcEnableKEM = PQEncOff, rcKEM = (\rck -> rck {rcKEMs = Nothing}) <$> rcKEM} + Just PQEncOn -> rc' {rcVersion = rcVersion {maxSupported = max v pqRatchetE2EEncryptVersion}} + Just PQEncOff -> rc' {rcKEM = (\rck -> rck {rcKEMs = Nothing}) <$> rcKEM} pure (msg', rc'') where -- header = HEADER_PQ2( @@ -870,7 +894,6 @@ rcDecrypt :: ByteString -> ExceptT CryptoError IO (DecryptResult a) rcDecrypt g rc@Ratchet {rcRcv, rcAD = Str rcAD, rcVersion} rcMKSkipped msg' = do - -- TODO PQ versioning should change encMsg@EncRatchetMessage {emHeader} <- parseE CryptoHeaderError encRatchetMessageP msg' encHdr <- parseE CryptoHeaderError smpP emHeader -- plaintext = TrySkippedMessageKeysHE(state, enc_header, cipher-text, AD) @@ -909,12 +932,14 @@ rcDecrypt g rc@Ratchet {rcRcv, rcAD = Str rcAD, rcVersion} rcMKSkipped msg' = do where upgradedRatchet :: Ratchet a upgradedRatchet - | msgMaxVersion > current rcVersion = rc {rcVersion = rcVersion {current = min msgMaxVersion $ maxSupported rcVersion}} + | msgMaxVersion > current = rc {rcVersion = rcVersion {current = max current $ min msgMaxVersion maxSupported}} | otherwise = rc + where + RVersions {current, maxSupported} = rcVersion smkDiff :: SkippedMsgKeys -> SkippedMsgDiff smkDiff smks = if M.null smks then SMDNoChange else SMDAdd smks ratchetStep :: Ratchet a -> MsgHeader a -> ExceptT CryptoError IO (Ratchet a) - ratchetStep rc'@Ratchet {rcDHRs, rcRK, rcNHKs, rcNHKr} MsgHeader {msgDHRs, msgKEM} = do + ratchetStep rc'@Ratchet {rcDHRs, rcRK, rcNHKs, rcNHKr, rcSupportKEM} MsgHeader {msgDHRs, msgKEM} = do (kemSS, kemSS', rcKEM') <- pqRatchetStep rc' msgKEM -- state.DHRs = GENERATE_DH() (_, rcDHRs') <- atomically $ generateKeyPair @a g @@ -924,11 +949,13 @@ rcDecrypt g rc@Ratchet {rcRcv, rcAD = Str rcAD, rcVersion} rcMKSkipped msg' = do (rcRK'', rcCKs', rcNHKs') = rootKdf rcRK' msgDHRs rcDHRs' kemSS' sndKEM = isJust kemSS' rcvKEM = isJust kemSS + enableKEM = sndKEM || rcvKEM || isJust rcKEM' pure rc' { rcDHRs = rcDHRs', rcKEM = rcKEM', - rcEnableKEM = PQEncryption $ sndKEM || rcvKEM, + rcSupportKEM = PQEncryption $ enablePQ rcSupportKEM || enableKEM, + rcEnableKEM = PQEncryption enableKEM, rcSndKEM = PQEncryption sndKEM, rcRcvKEM = PQEncryption rcvKEM, rcRK = rcRK'', @@ -945,7 +972,7 @@ rcDecrypt g rc@Ratchet {rcRcv, rcAD = Str rcAD, rcVersion} rcMKSkipped msg' = do -- received message does not have KEM in header, -- but the user enabled KEM when sending previous message Nothing -> case rcKEM of - Nothing | pqEnc -> do + Nothing | pqEnc && current rv >= pqRatchetE2EEncryptVersion -> do rcPQRs <- liftIO $ sntrup761Keypair g pure (Nothing, Nothing, Just RatchetKEM {rcPQRs, rcKEMs = Nothing}) _ -> pure (Nothing, Nothing, Nothing) diff --git a/tests/AgentTests/DoubleRatchetTests.hs b/tests/AgentTests/DoubleRatchetTests.hs index 0e4b9079e..5c5241849 100644 --- a/tests/AgentTests/DoubleRatchetTests.hs +++ b/tests/AgentTests/DoubleRatchetTests.hs @@ -89,10 +89,16 @@ testAlgs test = test C.SX25519 >> test C.SX448 paddedMsgLen :: Int paddedMsgLen = 100 -fullMsgLen :: VersionE2E -> Int -fullMsgLen v = headerLenLength + fullHeaderLen v + C.authTagSize + paddedMsgLen +fullMsgLen :: Ratchet a -> Int +fullMsgLen Ratchet {rcSupportKEM} = headerLenLength + fullHeaderLen rcSupportKEM + C.authTagSize + paddedMsgLen where - headerLenLength = if v < pqRatchetE2EEncryptVersion then 1 else 3 -- two bytes are added because of two Large used in new encoding + -- v = current rcVersion + headerLenLength = case rcSupportKEM of + PQEncOn -> 3 -- two bytes are added because of two Large used in new encoding + PQEncOff -> 1 + -- TODO PQ below should work too + -- | v >= pqRatchetE2EEncryptVersion = 3 + -- | otherwise = 1 testMessageHeader :: forall a. AlgorithmI a => VersionE2E -> C.SAlgorithm a -> Expectation testMessageHeader v _ = do @@ -302,8 +308,10 @@ testEnableKEM alice bob _ _ _ = do (alice, "accepting KEM") \#>! bob (alice, "KEM not enabled yet here too") \#>! bob (bob, "KEM is still not enabled") \#>! alice - (alice, "now KEM is enabled") !#> bob - (bob, "now KEM is enabled for both sides") !#> alice + (alice, "KEM still not enabled 2") \#>! bob + (bob, "now KEM is enabled") !#> alice + (alice, "now KEM is enabled for both sides") !#> bob + (bob, "Still enabled for both sides") !#> alice (alice, "disabling KEM") !#>\ bob (bob, "KEM not disabled yet") !#> alice (alice, "KEM disabled") \#> bob @@ -318,12 +326,20 @@ testEnableKEMStrict alice bob _ _ _ = do (alice, "accepting KEM") \#>! bob (alice, "KEM not enabled yet here too") \#>! bob (bob, "KEM is still not enabled") \#>! alice - (alice, "now KEM is enabled") !#>! bob - (bob, "now KEM is enabled for both sides") !#>! alice + (alice, "KEM still not enabled 2") \#>! bob + (bob, "now KEM is enabled") !#>! alice + (alice, "now KEM is enabled for both sides") !#>! bob + (bob, "Still enabled for both sides") !#>! alice (alice, "disabling KEM") !#>\ bob (bob, "KEM not disabled yet") !#>! alice (alice, "KEM disabled") \#>\ bob (bob, "KEM disabled on both sides") \#>! alice + (alice, "KEM still disabled 1") \#>\ bob + (bob, "KEM still disabled 2") \#>! alice + (alice, "KEM still disabled 3") \#>\ bob + (bob, "KEM still disabled 4") \#>! alice + (alice, "KEM still disabled 5") \#>\ bob + (bob, "KEM still disabled 6") \#>! alice testKeyJSON :: forall a. AlgorithmI a => C.SAlgorithm a -> IO () testKeyJSON _ = do @@ -571,13 +587,13 @@ testRatchetVersions pq = in RVersions v v encrypt_ :: AlgorithmI a => Maybe PQEncryption -> (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError (ByteString, Ratchet a, SkippedMsgDiff)) -encrypt_ enableKem (_, rc, _) msg = +encrypt_ pqEnc_ (_, rc, _) msg = -- print msg >> - runExceptT (rcEncrypt rc paddedMsgLen msg enableKem) + runExceptT (rcEncrypt rc paddedMsgLen msg pqEnc_) >>= either (pure . Left) checkLength where checkLength (msg', rc') = do - B.length msg' `shouldBe` fullMsgLen (current $ rcVersion rc) + B.length msg' `shouldBe` fullMsgLen rc' pure $ Right (msg', rc', SMDNoChange) decrypt_ :: (AlgorithmI a, DhAlgorithm a) => (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError (Either CryptoError ByteString, Ratchet a, SkippedMsgDiff)) From 4ffb6a348a06cd87ec7d456bca14e155c1b6310d Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin Date: Wed, 6 Mar 2024 21:28:03 +0000 Subject: [PATCH 11/30] pqdr: use different newtypes for supporting and enabling PQ encryption in connections (#1031) * pqdr: use different newtypes for supporting and enabling PQ encryption in connections * rename field, fix test * refactor --- src/Simplex/Messaging/Agent.hs | 167 +++++++++--------- src/Simplex/Messaging/Agent/Env/SQLite.hs | 6 +- src/Simplex/Messaging/Agent/Protocol.hs | 36 ++-- src/Simplex/Messaging/Agent/Store.hs | 4 +- src/Simplex/Messaging/Agent/Store/SQLite.hs | 34 ++-- .../Migrations/M20240225_ratchet_kem.hs | 4 +- .../Store/SQLite/Migrations/agent_schema.sql | 2 +- src/Simplex/Messaging/Crypto/Ratchet.hs | 116 ++++++++---- tests/AgentTests.hs | 62 +++---- tests/AgentTests/ConnectionRequestTests.hs | 4 +- tests/AgentTests/DoubleRatchetTests.hs | 45 ++--- tests/AgentTests/FunctionalAPITests.hs | 113 ++++++------ tests/AgentTests/SQLiteTests.hs | 6 +- 13 files changed, 324 insertions(+), 275 deletions(-) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index d0a232131..204647ef6 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -160,7 +160,7 @@ import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations import Simplex.Messaging.Client (ProtocolClient (..), ServerTransmission) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.File (CryptoFile, CryptoFileArgs) -import Simplex.Messaging.Crypto.Ratchet (PQEncryption, pattern PQEncOn, pattern PQEncOff) +import Simplex.Messaging.Crypto.Ratchet (PQEncryption, PQSupport, pattern PQEncOn, pattern PQEncOff, pattern PQSupportOn, pattern PQSupportOff) import qualified Simplex.Messaging.Crypto.Ratchet as CR import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String @@ -225,7 +225,7 @@ createConnectionAsync :: forall m c. (AgentErrorMonad m, ConnectionModeI c) => A createConnectionAsync c userId aCorrId enableNtfs = withAgentEnv c .:. newConnAsync c userId aCorrId enableNtfs -- | Join SMP agent connection (JOIN command) asynchronously, synchronous response is new connection id -joinConnectionAsync :: AgentErrorMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQEncryption -> SubscriptionMode -> m ConnId +joinConnectionAsync :: AgentErrorMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> m ConnId joinConnectionAsync c userId aCorrId enableNtfs = withAgentEnv c .:: joinConnAsync c userId aCorrId enableNtfs -- | Allow connection to continue after CONF notification (LET command), no synchronous response @@ -233,7 +233,7 @@ allowConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> allowConnectionAsync c = withAgentEnv c .:: allowConnectionAsync' c -- | Accept contact after REQ notification (ACPT command) asynchronously, synchronous response is new connection id -acceptContactAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> Bool -> ConfirmationId -> ConnInfo -> PQEncryption -> SubscriptionMode -> m ConnId +acceptContactAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> Bool -> ConfirmationId -> ConnInfo -> PQSupport -> SubscriptionMode -> m ConnId acceptContactAsync c aCorrId enableNtfs = withAgentEnv c .:: acceptContactAsync' c aCorrId enableNtfs -- | Acknowledge message (ACK command) asynchronously, no synchronous response @@ -257,7 +257,7 @@ createConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> SConne createConnection c userId enableNtfs = withAgentEnv c .:: newConn c userId "" enableNtfs -- | Join SMP agent connection (JOIN command) -joinConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQEncryption -> SubscriptionMode -> m ConnId +joinConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> m ConnId joinConnection c userId enableNtfs = withAgentEnv c .:: joinConn c userId "" enableNtfs -- | Allow connection to continue after CONF notification (LET command) @@ -265,7 +265,7 @@ allowConnection :: AgentErrorMonad m => AgentClient -> ConnId -> ConfirmationId allowConnection c = withAgentEnv c .:. allowConnection' c -- | Accept contact after REQ notification (ACPT command) -acceptContact :: AgentErrorMonad m => AgentClient -> Bool -> ConfirmationId -> ConnInfo -> PQEncryption -> SubscriptionMode -> m ConnId +acceptContact :: AgentErrorMonad m => AgentClient -> Bool -> ConfirmationId -> ConnInfo -> PQSupport -> SubscriptionMode -> m ConnId acceptContact c enableNtfs = withAgentEnv c .:: acceptContact' c "" enableNtfs -- | Reject contact (RJCT command) @@ -319,7 +319,7 @@ abortConnectionSwitch :: AgentErrorMonad m => AgentClient -> ConnId -> m Connect abortConnectionSwitch c = withAgentEnv c . abortConnectionSwitch' c -- | Re-synchronize connection ratchet keys -synchronizeRatchet :: AgentErrorMonad m => AgentClient -> ConnId -> PQEncryption -> Bool -> m ConnectionStats +synchronizeRatchet :: AgentErrorMonad m => AgentClient -> ConnId -> PQSupport -> Bool -> m ConnectionStats synchronizeRatchet c = withAgentEnv c .:. synchronizeRatchet' c -- | Suspend SMP agent connection (OFF command) @@ -558,23 +558,23 @@ newConnAsync c userId corrId enableNtfs cMode pqInitKeys subMode = do enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn $ NEW enableNtfs (ACM cMode) pqInitKeys subMode pure connId -newConnNoQueues :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> PQEncryption -> m ConnId -newConnNoQueues c userId connId enableNtfs cMode pqEncryption = do +newConnNoQueues :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> PQSupport -> m ConnId +newConnNoQueues c userId connId enableNtfs cMode pqSupport = do g <- asks random - connAgentVersion <- asks $ maxVersion . ($ pqEncryption) . smpAgentVRange . config - let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk, pqEncryption} + connAgentVersion <- asks $ maxVersion . ($ pqSupport) . smpAgentVRange . config + let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk, pqSupport} withStore c $ \db -> createNewConn db g cData cMode -joinConnAsync :: AgentMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQEncryption -> SubscriptionMode -> m ConnId -joinConnAsync c userId corrId enableNtfs cReqUri@(CRInvitationUri ConnReqUriData {crAgentVRange} _) cInfo pqEncryption subMode = do +joinConnAsync :: AgentMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> m ConnId +joinConnAsync c userId corrId enableNtfs cReqUri@(CRInvitationUri ConnReqUriData {crAgentVRange} _) cInfo pqSupport subMode = do withInvLock c (strEncode cReqUri) "joinConnAsync" $ do - aVRange <- asks $ ($ pqEncryption) . smpAgentVRange . config + aVRange <- asks $ ($ pqSupport) . smpAgentVRange . config case crAgentVRange `compatibleVersion` aVRange of Just (Compatible connAgentVersion) -> do g <- asks random - let cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk, pqEncryption} + let cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk, pqSupport} connId <- withStore c $ \db -> createNewConn db g cData SCMInvitation - enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn $ JOIN enableNtfs (ACR sConnectionMode cReqUri) pqEncryption subMode cInfo + enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn $ JOIN enableNtfs (ACR sConnectionMode cReqUri) pqSupport subMode cInfo pure connId _ -> throwError $ AGENT A_VERSION joinConnAsync _c _userId _corrId _enableNtfs (CRContactUri _) _subMode _cInfo _pqEncryption = @@ -587,13 +587,13 @@ allowConnectionAsync' c corrId connId confId ownConnInfo = enqueueCommand c corrId connId (Just server) $ AClientCommand $ APC SAEConn $ LET confId ownConnInfo _ -> throwError $ CMD PROHIBITED -acceptContactAsync' :: AgentMonad m => AgentClient -> ACorrId -> Bool -> InvitationId -> ConnInfo -> PQEncryption -> SubscriptionMode -> m ConnId -acceptContactAsync' c corrId enableNtfs invId ownConnInfo pqEnc subMode = do +acceptContactAsync' :: AgentMonad m => AgentClient -> ACorrId -> Bool -> InvitationId -> ConnInfo -> PQSupport -> SubscriptionMode -> m ConnId +acceptContactAsync' c corrId enableNtfs invId ownConnInfo pqSupport subMode = do Invitation {contactConnId, connReq} <- withStore c (`getInvitation` invId) withStore c (`getConn` contactConnId) >>= \case SomeConn _ (ContactConnection ConnData {userId} _) -> do withStore' c $ \db -> acceptInvitation db invId ownConnInfo - joinConnAsync c userId corrId enableNtfs connReq ownConnInfo pqEnc subMode `catchAgentError` \err -> do + joinConnAsync c userId corrId enableNtfs connReq ownConnInfo pqSupport subMode `catchAgentError` \err -> do withStore' c (`unacceptInvitation` invId) throwError err _ -> throwError $ CMD PROHIBITED @@ -681,45 +681,45 @@ newRcvConnSrv c userId connId enableNtfs cMode clientData pqInitKeys subMode srv withStore' c $ \db -> createRatchetX3dhKeys db connId pk1 pk2 pKem pure (connId, CRInvitationUri crData $ toVersionRangeT e2eRcvParams e2eVRange) -joinConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQEncryption -> SubscriptionMode -> m ConnId -joinConn c userId connId enableNtfs cReq cInfo pqEnc subMode = do +joinConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> m ConnId +joinConn c userId connId enableNtfs cReq cInfo pqSupport subMode = do srv <- case cReq of CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _ -> getNextServer c userId [qServer q] _ -> getSMPServer c userId - joinConnSrv c userId connId enableNtfs cReq cInfo pqEnc subMode srv + joinConnSrv c userId connId enableNtfs cReq cInfo pqSupport subMode srv -startJoinInvitation :: AgentMonad m => UserId -> ConnId -> Bool -> ConnectionRequestUri 'CMInvitation -> PQEncryption -> m (Compatible VersionSMPA, ConnData, NewSndQueue, CR.Ratchet 'C.X448, CR.SndE2ERatchetParams 'C.X448) -startJoinInvitation userId connId enableNtfs (CRInvitationUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)} e2eRcvParamsUri) pqEncryption = do +startJoinInvitation :: AgentMonad m => UserId -> ConnId -> Bool -> ConnectionRequestUri 'CMInvitation -> PQSupport -> m (Compatible VersionSMPA, ConnData, NewSndQueue, CR.Ratchet 'C.X448, CR.SndE2ERatchetParams 'C.X448) +startJoinInvitation userId connId enableNtfs (CRInvitationUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)} e2eRcvParamsUri) pqSupport = do AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config - let e2eVRange = e2eEncryptVRange pqEncryption + let e2eVRange = e2eEncryptVRange pqSupport case ( qUri `compatibleVersion` smpClientVRange, e2eRcvParamsUri `compatibleVersion` e2eVRange, - crAgentVRange `compatibleVersion` smpAgentVRange pqEncryption + crAgentVRange `compatibleVersion` smpAgentVRange pqSupport ) of (Just qInfo, Just (Compatible e2eRcvParams@(CR.E2ERatchetParams v _ rcDHRr kem_)), Just aVersion@(Compatible connAgentVersion)) -> do g <- asks random - (pk1, pk2, pKem, e2eSndParams) <- liftIO $ CR.generateSndE2EParams g v (CR.replyKEM_ pqEncryption kem_) + (pk1, pk2, pKem, e2eSndParams) <- liftIO $ CR.generateSndE2EParams g v (CR.replyKEM_ kem_ pqSupport) (_, rcDHRs) <- atomically $ C.generateKeyPair g rcParams <- liftEitherWith cryptoError $ CR.pqX3dhSnd pk1 pk2 pKem e2eRcvParams let rcVs = CR.RVersions {current = v, maxSupported = maxVersion e2eVRange} rc = CR.initSndRatchet rcVs rcDHRr rcDHRs rcParams q <- newSndQueue userId "" qInfo - let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk, pqEncryption} + let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk, pqSupport} pure (aVersion, cData, q, rc, e2eSndParams) _ -> throwError $ AGENT A_VERSION -joinConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQEncryption -> SubscriptionMode -> SMPServerWithAuth -> m ConnId -joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqEnc subMode srv = +joinConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> SMPServerWithAuth -> m ConnId +joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqSupport subMode srv = withInvLock c (strEncode inv) "joinConnSrv" $ do - (aVersion, cData, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv pqEnc + (aVersion, cData, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv pqSupport g <- asks random (connId', sq) <- withStore c $ \db -> runExceptT $ do r@(connId', _) <- ExceptT $ createSndConn db g cData q liftIO $ createRatchet db connId' rc pure r let cData' = (cData :: ConnData) {connId = connId'} - tryError (confirmQueue aVersion c cData' sq srv cInfo (Just e2eSndParams) pqEnc subMode) >>= \case + tryError (confirmQueue aVersion c cData' sq srv cInfo (Just e2eSndParams) subMode) >>= \case Right _ -> pure connId' Left e -> do -- possible improvement: recovery for failure on network timeout, see rfcs/2022-04-20-smp-conf-timeout-recovery.md @@ -737,14 +737,14 @@ joinConnSrv c userId connId enableNtfs (CRContactUri ConnReqUriData {crAgentVRan pure connId' _ -> throwError $ AGENT A_VERSION -joinConnSrvAsync :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQEncryption -> SubscriptionMode -> SMPServerWithAuth -> m () -joinConnSrvAsync c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqEnc subMode srv = do - (_aVersion, cData, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv pqEnc +joinConnSrvAsync :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> SMPServerWithAuth -> m () +joinConnSrvAsync c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqSupport subMode srv = do + (_aVersion, cData, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv pqSupport q' <- withStore c $ \db -> runExceptT $ do liftIO $ createRatchet db connId rc ExceptT $ updateNewConnSnd db connId q - confirmQueueAsync c cData q' srv cInfo (Just e2eSndParams) pqEnc subMode -joinConnSrvAsync _c _userId _connId _enableNtfs (CRContactUri _) _cInfo _subMode _pqEnc _srv = do + confirmQueueAsync c cData q' srv cInfo (Just e2eSndParams) subMode +joinConnSrvAsync _c _userId _connId _enableNtfs (CRContactUri _) _cInfo _subMode _pqSupport _srv = do throwError $ CMD PROHIBITED createReplyQueue :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> SubscriptionMode -> SMPServerWithAuth -> m SMPQueueInfo @@ -775,13 +775,13 @@ allowConnection' c connId confId ownConnInfo = withConnLock c connId "allowConne _ -> throwError $ CMD PROHIBITED -- | Accept contact (ACPT command) in Reader monad -acceptContact' :: AgentMonad m => AgentClient -> ConnId -> Bool -> InvitationId -> ConnInfo -> PQEncryption -> SubscriptionMode -> m ConnId -acceptContact' c connId enableNtfs invId ownConnInfo pqEnc subMode = withConnLock c connId "acceptContact" $ do +acceptContact' :: AgentMonad m => AgentClient -> ConnId -> Bool -> InvitationId -> ConnInfo -> PQSupport -> SubscriptionMode -> m ConnId +acceptContact' c connId enableNtfs invId ownConnInfo pqSupport subMode = withConnLock c connId "acceptContact" $ do Invitation {contactConnId, connReq} <- withStore c (`getInvitation` invId) withStore c (`getConn` contactConnId) >>= \case SomeConn _ (ContactConnection ConnData {userId} _) -> do withStore' c $ \db -> acceptInvitation db invId ownConnInfo - joinConn c userId connId enableNtfs connReq ownConnInfo pqEnc subMode `catchAgentError` \err -> do + joinConn c userId connId enableNtfs connReq ownConnInfo pqSupport subMode `catchAgentError` \err -> do withStore' c (`unacceptInvitation` invId) throwError err _ -> throwError $ CMD PROHIBITED @@ -927,7 +927,7 @@ sendMessagesB' :: forall m t. (AgentMonad' m, Traversable t) => AgentClient -> t sendMessagesB' c reqs = withConnLocks c connIds "sendMessages" $ do reqs' <- withStoreBatch c (\db -> fmap (bindRight $ \req@(connId, _, _, _) -> bimap storeError (req,) <$> getConn db connId) reqs) let (toEnable, reqs'') = mapAccumL prepareConn [] reqs' - void $ withStoreBatch' c $ \db -> map (enableConnPQEncryption db) toEnable + void $ withStoreBatch' c $ \db -> map (\connId -> setConnPQSupport db connId PQSupportOn) toEnable enqueueMessagesB c reqs'' where prepareConn :: [ConnId] -> Either AgentErrorType (MsgReq, SomeConn) -> ([ConnId], Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage)) @@ -938,12 +938,12 @@ sendMessagesB' c reqs = withConnLocks c connIds "sendMessages" $ do _ -> (acc, Left $ CONN SIMPLEX) where prepareMsg :: ConnData -> NonEmpty SndQueue -> ([ConnId], Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage)) - prepareMsg cData@ConnData {connId, pqEncryption} sqs + prepareMsg cData@ConnData {connId, pqSupport} sqs | ratchetSyncSendProhibited cData = (acc, Left $ CMD PROHIBITED) -- connection is only updated if PQ encryption was disabled, and now it has to be enabled. -- support for PQ encryption (small message envelopes) will not be disabled when message is sent. - | pqEnc == PQEncOn && pqEncryption == PQEncOff = - let cData' = cData {pqEncryption = pqEnc} :: ConnData + | pqEnc == PQEncOn && pqSupport == PQSupportOff = + let cData' = cData {pqSupport = PQSupportOn} :: ConnData in (connId : acc, Right (cData', sqs, Just pqEnc, msgFlags, A_MSG msg)) | otherwise = (acc, Right (cData, sqs, Just pqEnc, msgFlags, A_MSG msg)) connIds = map (\(connId, _, _, _) -> connId) $ rights $ toList reqs @@ -1127,8 +1127,8 @@ enqueueMessageB c reqs = do let sqs' = filter isActiveSndQ sqs pure $ Right ((msgId, pqSecr), if null sqs' then Nothing else Just (cData, sqs', msgId)) where - storeSentMsg :: DB.Connection -> (PQEncryption -> VersionRangeSMPA) -> (ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage) -> IO (Either AgentErrorType ((ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage), InternalId, PQEncryption)) - storeSentMsg db getAVRange req@(cData@ConnData {connId, pqEncryption}, sq :| _, pqEnc_, msgFlags, aMessage) = fmap (first storeError) $ runExceptT $ do + storeSentMsg :: DB.Connection -> (PQSupport -> VersionRangeSMPA) -> (ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage) -> IO (Either AgentErrorType ((ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage), InternalId, PQEncryption)) + storeSentMsg db getAVRange req@(cData@ConnData {connId, pqSupport}, sq :| _, pqEnc_, msgFlags, aMessage) = fmap (first storeError) $ runExceptT $ do internalTs <- liftIO getCurrentTime (internalId, internalSndId, prevMsgHash) <- liftIO $ updateSndIds db connId let privHeader = APrivHeader (unSndId internalSndId) prevMsgHash @@ -1137,7 +1137,7 @@ enqueueMessageB c reqs = do internalHash = C.sha256Hash agentMsgStr (encAgentMessage, pqEnc) <- agentRatchetEncrypt db cData agentMsgStr e2eEncUserMsgLength pqEnc_ -- agent version range is determined by the connection suppport of PQ encryption, that is may be enabled when message is sent - let agentVersion = maxVersion $ getAVRange pqEncryption + let agentVersion = maxVersion $ getAVRange pqSupport msgBody = smpEncode $ AgentMsgEnvelope {agentVersion, encAgentMessage} msgType = agentMessageType agentMsg msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgFlags, msgBody, pqEncryption = pqEnc, internalHash, prevMsgHash} @@ -1414,20 +1414,17 @@ abortConnectionSwitch' c connId = _ -> throwError $ CMD PROHIBITED _ -> throwError $ CMD PROHIBITED -synchronizeRatchet' :: AgentMonad m => AgentClient -> ConnId -> PQEncryption -> Bool -> m ConnectionStats -synchronizeRatchet' c connId pqEnc force = withConnLock c connId "synchronizeRatchet" $ do +synchronizeRatchet' :: AgentMonad m => AgentClient -> ConnId -> PQSupport -> Bool -> m ConnectionStats +synchronizeRatchet' c connId pqSupport' force = withConnLock c connId "synchronizeRatchet" $ do withStore c (`getConn` connId) >>= \case - SomeConn _ (DuplexConnection cData@ConnData {pqEncryption} rqs sqs) + SomeConn _ (DuplexConnection cData@ConnData {pqSupport} rqs sqs) | ratchetSyncAllowed cData || force -> do -- check queues are not switching? - pqEnc' <- - if pqEnc == PQEncOn && pqEncryption == PQEncOff - then PQEncOn <$ withStore' c (`enableConnPQEncryption` connId) - else pure pqEncryption - let cData' = cData {pqEncryption = pqEnc'} :: ConnData + when (pqSupport' /= pqSupport) $ withStore' c $ \db -> setConnPQSupport db connId pqSupport' + let cData' = cData {pqSupport = pqSupport'} :: ConnData AgentConfig {e2eEncryptVRange} <- asks config g <- asks random - (pk1, pk2, pKem, e2eParams) <- liftIO $ CR.generateRcvE2EParams g (maxVersion $ e2eEncryptVRange pqEnc') pqEnc' + (pk1, pk2, pKem, e2eParams) <- liftIO $ CR.generateRcvE2EParams g (maxVersion $ e2eEncryptVRange pqSupport') pqSupport' enqueueRatchetKeyMsgs c cData' sqs e2eParams withStore' c $ \db -> do setConnRatchetSync db connId RSStarted @@ -2084,8 +2081,8 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, _ -> prohibited >> ack _ -> prohibited >> ack updateConnVersion :: Connection c -> ConnData -> VersionSMPA -> m (Connection c) - updateConnVersion conn' cData'@ConnData {pqEncryption} msgAgentVersion = do - aVRange <- asks $ ($ pqEncryption) . smpAgentVRange . config + updateConnVersion conn' cData'@ConnData {pqSupport} msgAgentVersion = do + aVRange <- asks $ ($ pqSupport) . smpAgentVRange . config let msgAVRange = fromMaybe (versionToRange msgAgentVersion) $ safeVersionRange (minVersion aVRange) msgAgentVersion case msgAVRange `compatibleVersion` aVRange of Just (Compatible av) @@ -2146,14 +2143,13 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, parseMessage :: Encoding a => ByteString -> m a parseMessage = liftEither . parse smpP (AGENT A_MESSAGE) - -- TODO PQ make sure pqEncryption in conn' is set correctly smpConfirmation :: SMP.MsgId -> Connection c -> C.APublicAuthKey -> C.PublicKeyX25519 -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> ByteString -> VersionSMPC -> VersionSMPA -> m () smpConfirmation srvMsgId conn' senderKey e2ePubKey e2eEncryption encConnInfo smpClientVersion agentVersion = do logServer "<--" c srv rId $ "MSG :" <> logSecret srvMsgId AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config - let ConnData {pqEncryption} = toConnData conn' - aVRange = smpAgentVRange pqEncryption - e2eVRange = e2eEncryptVRange pqEncryption + let ConnData {pqSupport} = toConnData conn' + aVRange = smpAgentVRange pqSupport + e2eVRange = e2eEncryptVRange pqSupport unless (agentVersion `isCompatible` aVRange && smpClientVersion `isCompatible` smpClientVRange) (throwError $ AGENT A_VERSION) @@ -2166,7 +2162,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, rcParams <- liftError cryptoError $ CR.pqX3dhRcv pk1 rcDHRs pKem e2eSndParams -- TODO PQ combine isCompatible check and construction in one call let rcVs = CR.RVersions {current = e2eVersion, maxSupported = maxVersion e2eVRange} - rc = CR.initRcvRatchet rcVs rcDHRs rcParams pqEncryption + rc = CR.initRcvRatchet rcVs rcDHRs rcParams pqSupport g <- asks random (agentMsgBody_, rc', skipped) <- liftError cryptoError $ CR.rcDecrypt g rc M.empty encConnInfo case (agentMsgBody_, skipped) of @@ -2354,10 +2350,10 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, -- TODO PQ make sure pqEncryption is set correctly here newRatchetKey :: CR.RcvE2ERatchetParams 'C.X448 -> Connection 'CDuplex -> m () - newRatchetKey e2eOtherPartyParams@(CR.E2ERatchetParams e2eVersion k1Rcv k2Rcv _) conn'@(DuplexConnection cData'@ConnData {lastExternalSndId, pqEncryption} _ sqs) = + newRatchetKey e2eOtherPartyParams@(CR.E2ERatchetParams e2eVersion k1Rcv k2Rcv _) conn'@(DuplexConnection cData'@ConnData {lastExternalSndId, pqSupport} _ sqs) = unlessM ratchetExists $ do AgentConfig {e2eEncryptVRange} <- asks config - let connE2EVRange = e2eEncryptVRange pqEncryption + let connE2EVRange = e2eEncryptVRange pqSupport unless (e2eVersion `isCompatible` connE2EVRange) (throwError $ AGENT A_VERSION) keys <- getSendRatchetKeys -- TODO PQ combine with `isCompatible` check above @@ -2388,7 +2384,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, where sendReplyKey = do g <- asks random - (pk1, pk2, pKem, e2eParams) <- liftIO $ CR.generateRcvE2EParams g e2eVersion pqEncryption + (pk1, pk2, pKem, e2eParams) <- liftIO $ CR.generateRcvE2EParams g e2eVersion pqSupport enqueueRatchetKeyMsgs c cData' sqs e2eParams pure (pk1, pk2, pKem) notifyRatchetSyncError = do @@ -2411,7 +2407,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, initRatchet rcVs (pk1, pk2, pKem) | rkHash (C.publicKey pk1) (C.publicKey pk2) <= rkHashRcv = do rcParams <- liftError cryptoError $ CR.pqX3dhRcv pk1 pk2 pKem e2eOtherPartyParams - recreateRatchet $ CR.initRcvRatchet rcVs pk2 rcParams pqEncryption + recreateRatchet $ CR.initRcvRatchet rcVs pk2 rcParams pqSupport | otherwise = do (_, rcDHRs) <- atomically . C.generateKeyPair =<< asks random rcParams <- liftEitherWith cryptoError $ CR.pqX3dhSnd pk1 pk2 (CR.APRKP CR.SRKSProposed <$> pKem) e2eOtherPartyParams @@ -2443,22 +2439,22 @@ switchStatusError q expected actual = <> (", actual=" <> show actual) connectReplyQueues :: AgentMonad m => AgentClient -> ConnData -> ConnInfo -> NonEmpty SMPQueueInfo -> m () -connectReplyQueues c cData@ConnData {userId, connId, pqEncryption} ownConnInfo (qInfo :| _) = do +connectReplyQueues c cData@ConnData {userId, connId} ownConnInfo (qInfo :| _) = do clientVRange <- asks $ smpClientVRange . config case qInfo `proveCompatible` clientVRange of Nothing -> throwError $ AGENT A_VERSION Just qInfo' -> do sq <- newSndQueue userId connId qInfo' sq' <- withStore c $ \db -> upgradeRcvConnToDuplex db connId sq - enqueueConfirmation c cData sq' ownConnInfo Nothing pqEncryption + enqueueConfirmation c cData sq' ownConnInfo Nothing -confirmQueueAsync :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> PQEncryption -> SubscriptionMode -> m () -confirmQueueAsync c cData sq srv connInfo e2eEncryption_ pqEnc subMode = do - storeConfirmation c cData sq e2eEncryption_ pqEnc =<< mkAgentConfirmation c cData sq srv connInfo subMode +confirmQueueAsync :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> SubscriptionMode -> m () +confirmQueueAsync c cData sq srv connInfo e2eEncryption_ subMode = do + storeConfirmation c cData sq e2eEncryption_ =<< mkAgentConfirmation c cData sq srv connInfo subMode submitPendingMsg c cData sq -confirmQueue :: forall m. AgentMonad m => Compatible VersionSMPA -> AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> PQEncryption -> SubscriptionMode -> m () -confirmQueue (Compatible agentVersion) c cData@ConnData {connId} sq srv connInfo e2eEncryption_ pqEnc subMode = do +confirmQueue :: forall m. AgentMonad m => Compatible VersionSMPA -> AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> SubscriptionMode -> m () +confirmQueue (Compatible agentVersion) c cData@ConnData {connId, pqSupport} sq srv connInfo e2eEncryption_ subMode = do msg <- mkConfirmation =<< mkAgentConfirmation c cData sq srv connInfo subMode sendConfirmation c sq msg withStore' c $ \db -> setSndQueueStatus db sq Confirmed @@ -2466,6 +2462,7 @@ confirmQueue (Compatible agentVersion) c cData@ConnData {connId} sq srv connInfo mkConfirmation :: AgentMessage -> m MsgBody mkConfirmation aMessage = withStore c $ \db -> runExceptT $ do void . liftIO $ updateSndIds db connId + let pqEnc = CR.pqSupportToEnc pqSupport (encConnInfo, _) <- agentRatchetEncrypt db cData (smpEncode aMessage) e2eEncConnInfoLength (Just pqEnc) pure . smpEncode $ AgentConfirmation {agentVersion, e2eEncryption_, encConnInfo} @@ -2474,17 +2471,18 @@ mkAgentConfirmation c cData sq srv connInfo subMode = do qInfo <- createReplyQueue c cData sq subMode srv pure $ AgentConnInfoReply (qInfo :| []) connInfo -enqueueConfirmation :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> PQEncryption -> m () -enqueueConfirmation c cData sq connInfo e2eEncryption_ pqEnc = do - storeConfirmation c cData sq e2eEncryption_ pqEnc $ AgentConnInfo connInfo +enqueueConfirmation :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> m () +enqueueConfirmation c cData sq connInfo e2eEncryption_ = do + storeConfirmation c cData sq e2eEncryption_ $ AgentConnInfo connInfo submitPendingMsg c cData sq -storeConfirmation :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> PQEncryption -> AgentMessage -> m () -storeConfirmation c cData@ConnData {connId, connAgentVersion = v} sq e2eEncryption_ pqEnc agentMsg = withStore c $ \db -> runExceptT $ do +storeConfirmation :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> AgentMessage -> m () +storeConfirmation c cData@ConnData {connId, pqSupport, connAgentVersion = v} sq e2eEncryption_ agentMsg = withStore c $ \db -> runExceptT $ do internalTs <- liftIO getCurrentTime (internalId, internalSndId, prevMsgHash) <- liftIO $ updateSndIds db connId let agentMsgStr = smpEncode agentMsg internalHash = C.sha256Hash agentMsgStr + pqEnc = CR.pqSupportToEnc pqSupport (encConnInfo, pqEncryption) <- agentRatchetEncrypt db cData agentMsgStr e2eEncConnInfoLength (Just pqEnc) let msgBody = smpEncode $ AgentConfirmation {agentVersion = v, e2eEncryption_, encConnInfo} msgType = agentMessageType agentMsg @@ -2498,8 +2496,8 @@ enqueueRatchetKeyMsgs c cData (sq :| sqs) e2eEncryption = do mapM_ (enqueueSavedMessage c cData msgId) $ filter isActiveSndQ sqs enqueueRatchetKey :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> CR.RcvE2ERatchetParams 'C.X448 -> m AgentMsgId -enqueueRatchetKey c cData@ConnData {connId, pqEncryption} sq e2eEncryption = do - aVRange <- asks $ ($ pqEncryption) . smpAgentVRange . config +enqueueRatchetKey c cData@ConnData {connId, pqSupport} sq e2eEncryption = do + aVRange <- asks $ ($ pqSupport) . smpAgentVRange . config msgId <- storeRatchetKey $ maxVersion aVRange submitPendingMsg c cData sq pure $ unId msgId @@ -2513,16 +2511,17 @@ enqueueRatchetKey c cData@ConnData {connId, pqEncryption} sq e2eEncryption = do internalHash = C.sha256Hash agentMsgStr let msgBody = smpEncode $ AgentRatchetKey {agentVersion, e2eEncryption, info = agentMsgStr} msgType = agentMessageType agentMsg - msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgBody, pqEncryption, msgFlags = SMP.MsgFlags {notification = True}, internalHash, prevMsgHash} + -- this message is e2e encrypted with queue key, not with double ratchet + msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgBody, pqEncryption = PQEncOff, msgFlags = SMP.MsgFlags {notification = True}, internalHash, prevMsgHash} liftIO $ createSndMsg db connId msgData liftIO $ createSndMsgDelivery db connId sq internalId pure internalId -- encoded AgentMessage -> encoded EncAgentMessage -agentRatchetEncrypt :: DB.Connection -> ConnData -> ByteString -> (PQEncryption -> Int) -> Maybe PQEncryption -> ExceptT StoreError IO (ByteString, PQEncryption) -agentRatchetEncrypt db ConnData {connId, pqEncryption} msg getPaddedLen pqEnc_ = do +agentRatchetEncrypt :: DB.Connection -> ConnData -> ByteString -> (PQSupport -> Int) -> Maybe PQEncryption -> ExceptT StoreError IO (ByteString, PQEncryption) +agentRatchetEncrypt db ConnData {connId, pqSupport} msg getPaddedLen pqEnc_ = do rc <- ExceptT $ getRatchet db connId - let paddedLen = getPaddedLen pqEncryption + let paddedLen = getPaddedLen pqSupport (encMsg, rc') <- liftE (SEAgentError . cryptoError) $ CR.rcEncrypt rc paddedLen msg pqEnc_ liftIO $ updateRatchet db connId rc' CR.SMDNoChange pure (encMsg, CR.rcSndKEM rc') diff --git a/src/Simplex/Messaging/Agent/Env/SQLite.hs b/src/Simplex/Messaging/Agent/Env/SQLite.hs index a292e5db6..20a378a45 100644 --- a/src/Simplex/Messaging/Agent/Env/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Env/SQLite.hs @@ -56,7 +56,7 @@ import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations import Simplex.Messaging.Client import Simplex.Messaging.Client.Agent () import qualified Simplex.Messaging.Crypto as C -import Simplex.Messaging.Crypto.Ratchet (PQEncryption, VersionRangeE2E, supportedE2EEncryptVRange) +import Simplex.Messaging.Crypto.Ratchet (PQSupport, VersionRangeE2E, supportedE2EEncryptVRange) import Simplex.Messaging.Notifications.Client (defaultNTFClientConfig) import Simplex.Messaging.Notifications.Transport (NTFVersion) import Simplex.Messaging.Notifications.Types @@ -116,8 +116,8 @@ data AgentConfig = AgentConfig caCertificateFile :: FilePath, privateKeyFile :: FilePath, certificateFile :: FilePath, - e2eEncryptVRange :: PQEncryption -> VersionRangeE2E, - smpAgentVRange :: PQEncryption -> VersionRangeSMPA, + e2eEncryptVRange :: PQSupport -> VersionRangeE2E, + smpAgentVRange :: PQSupport -> VersionRangeSMPA, smpClientVRange :: VersionRangeSMPC } diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index d3008e790..02aa5e260 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -192,7 +192,9 @@ import Simplex.Messaging.Crypto.Ratchet ( InitialKeys (..), PQEncryption (..), pattern PQEncOff, - pattern PQEncOn, + PQSupport, + pattern PQSupportOn, + pattern PQSupportOff, RcvE2ERatchetParams, RcvE2ERatchetParamsUri, SndE2ERatchetParams @@ -272,27 +274,27 @@ pqdrSMPAgentVersion = VersionSMPA 5 currentSMPAgentVersion :: VersionSMPA currentSMPAgentVersion = VersionSMPA 4 --- TODO v5.7 remove dependency of version range on whether PQ encryption is used -supportedSMPAgentVRange :: PQEncryption -> VersionRangeSMPA +-- TODO v5.7 remove dependency of version range on whether PQ support is needed +supportedSMPAgentVRange :: PQSupport -> VersionRangeSMPA supportedSMPAgentVRange pq = mkVersionRange duplexHandshakeSMPAgentVersion $ case pq of - PQEncOn -> pqdrSMPAgentVersion - PQEncOff -> currentSMPAgentVersion + PQSupportOn -> pqdrSMPAgentVersion + PQSupportOff -> currentSMPAgentVersion -- it is shorter to allow all handshake headers, -- including E2E (double-ratchet) parameters and -- signing key of the sender for the server -e2eEncConnInfoLength :: PQEncryption -> Int +e2eEncConnInfoLength :: PQSupport -> Int e2eEncConnInfoLength = \case -- reduced by 3700 (roughly the increase of message ratchet header size + key and ciphertext in reply link) - PQEncOn -> 11148 - PQEncOff -> 14848 + PQSupportOn -> 11148 + PQSupportOff -> 14848 -e2eEncUserMsgLength :: PQEncryption -> Int +e2eEncUserMsgLength :: PQSupport -> Int e2eEncUserMsgLength = \case -- reduced by 2200 (roughly the increase of message ratchet header size) - PQEncOn -> 13656 - PQEncOff -> 15856 + PQSupportOn -> 13656 + PQSupportOff -> 15856 -- | Raw (unparsed) SMP agent protocol transmission. type ARawTransmission = (ByteString, ByteString, ByteString) @@ -371,11 +373,11 @@ type ConnInfo = ByteString data ACommand (p :: AParty) (e :: AEntity) where NEW :: Bool -> AConnectionMode -> InitialKeys -> SubscriptionMode -> ACommand Client AEConn -- response INV INV :: AConnectionRequestUri -> ACommand Agent AEConn - JOIN :: Bool -> AConnectionRequestUri -> PQEncryption -> SubscriptionMode -> ConnInfo -> ACommand Client AEConn -- response OK + JOIN :: Bool -> AConnectionRequestUri -> PQSupport -> SubscriptionMode -> ConnInfo -> ACommand Client AEConn -- response OK CONF :: ConfirmationId -> [SMPServer] -> ConnInfo -> ACommand Agent AEConn -- ConnInfo is from sender, [SMPServer] will be empty only in v1 handshake LET :: ConfirmationId -> ConnInfo -> ACommand Client AEConn -- ConnInfo is from client REQ :: InvitationId -> NonEmpty SMPServer -> ConnInfo -> ACommand Agent AEConn -- ConnInfo is from sender - ACPT :: InvitationId -> PQEncryption -> ConnInfo -> ACommand Client AEConn -- ConnInfo is from client + ACPT :: InvitationId -> PQSupport -> ConnInfo -> ACommand Client AEConn -- ConnInfo is from client RJCT :: InvitationId -> ACommand Client AEConn INFO :: ConnInfo -> ACommand Agent AEConn CON :: PQEncryption -> ACommand Agent AEConn -- notification that connection is established @@ -1732,9 +1734,9 @@ commandP binaryP = ACmdTag SClient e cmd -> ACmd SClient e <$> case cmd of NEW_ -> s (NEW <$> strP_ <*> strP_ <*> pqIKP <*> (strP <|> pure SMP.SMSubscribe)) - JOIN_ -> s (JOIN <$> strP_ <*> strP_ <*> pqEncP <*> (strP_ <|> pure SMP.SMSubscribe) <*> binaryP) + JOIN_ -> s (JOIN <$> strP_ <*> strP_ <*> pqSupP <*> (strP_ <|> pure SMP.SMSubscribe) <*> binaryP) LET_ -> s (LET <$> A.takeTill (== ' ') <* A.space <*> binaryP) - ACPT_ -> s (ACPT <$> A.takeTill (== ' ') <* A.space <*> pqEncP <*> binaryP) + ACPT_ -> s (ACPT <$> A.takeTill (== ' ') <* A.space <*> pqSupP <*> binaryP) RJCT_ -> s (RJCT <$> A.takeByteString) SUB_ -> pure SUB SEND_ -> s (SEND <$> pqEncP <*> smpP <* A.space <*> binaryP) @@ -1781,7 +1783,9 @@ commandP binaryP = s :: Parser a -> Parser a s p = A.space *> p pqIKP :: Parser InitialKeys - pqIKP = strP_ <|> pure (IKNoPQ PQEncOff) + pqIKP = strP_ <|> pure (IKNoPQ PQSupportOff) + pqSupP :: Parser PQSupport + pqSupP = strP_ <|> pure PQSupportOff pqEncP :: Parser PQEncryption pqEncP = strP_ <|> pure PQEncOff connections :: Parser [ConnId] diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index 971b38905..ce76d5c89 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -30,7 +30,7 @@ import Data.Type.Equality import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Agent.RetryInterval (RI2State) import qualified Simplex.Messaging.Crypto as C -import Simplex.Messaging.Crypto.Ratchet (RatchetX448, PQEncryption) +import Simplex.Messaging.Crypto.Ratchet (RatchetX448, PQEncryption, PQSupport) import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol ( MsgBody, @@ -309,7 +309,7 @@ data ConnData = ConnData lastExternalSndId :: PrevExternalSndId, deleted :: Bool, ratchetSyncState :: RatchetSyncState, - pqEncryption :: PQEncryption + pqSupport :: PQSupport } deriving (Eq, Show) diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index f04bc7904..33051d234 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -58,7 +58,7 @@ module Simplex.Messaging.Agent.Store.SQLite getConnData, setConnDeleted, setConnAgentVersion, - enableConnPQEncryption, + setConnPQSupport, getDeletedConnIds, getDeletedWaitingDeliveryConnIds, setConnRatchetSync, @@ -268,7 +268,7 @@ import Simplex.Messaging.Agent.Store.SQLite.Migrations (DownMigration (..), MTRE import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.File (CryptoFile (..), CryptoFileArgs (..)) -import Simplex.Messaging.Crypto.Ratchet (RatchetX448, SkippedMsgDiff (..), SkippedMsgKeys) +import Simplex.Messaging.Crypto.Ratchet (RatchetX448, SkippedMsgDiff (..), SkippedMsgKeys, PQEncryption (..), PQSupport (..)) import qualified Simplex.Messaging.Crypto.Ratchet as CR import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String @@ -576,14 +576,14 @@ createSndConn db gVar cData q@SndQueue {server} = insertSndQueue_ db connId q serverKeyHash_ createConnRecord :: DB.Connection -> ConnId -> ConnData -> SConnectionMode c -> IO () -createConnRecord db connId ConnData {userId, connAgentVersion, enableNtfs, pqEncryption} cMode = +createConnRecord db connId ConnData {userId, connAgentVersion, enableNtfs, pqSupport} cMode = DB.execute db [sql| INSERT INTO connections - (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, pq_encryption, duplex_handshake) VALUES (?,?,?,?,?,?,?) + (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, pq_support, duplex_handshake) VALUES (?,?,?,?,?,?,?) |] - (userId, connId, cMode, connAgentVersion, enableNtfs, pqEncryption, True) + (userId, connId, cMode, connAgentVersion, enableNtfs, pqSupport, True) checkConfirmedSndQueueExists_ :: DB.Connection -> NewSndQueue -> IO Bool checkConfirmedSndQueueExists_ db SndQueue {server, sndId} = do @@ -1032,7 +1032,7 @@ getPendingQueueMsg db connId SndQueue {dbQueueId} = |] (connId, msgId) err = SEInternal $ "msg delivery " <> bshow msgId <> " returned []" - pendingMsgData :: (AgentMessageType, Maybe MsgFlags, MsgBody, CR.PQEncryption, InternalTs, Maybe Int64, Maybe Int64) -> PendingMsgData + pendingMsgData :: (AgentMessageType, Maybe MsgFlags, MsgBody, PQEncryption, InternalTs, Maybe Int64, Maybe Int64) -> PendingMsgData pendingMsgData (msgType, msgFlags_, msgBody, pqEncryption, internalTs, riSlow_, riFast_) = let msgFlags = fromMaybe SMP.noMsgFlags msgFlags_ msgRetryState = RI2State <$> riSlow_ <*> riFast_ @@ -1130,7 +1130,7 @@ getLastMsg db connId msgId = |] (connId, msgId) -toRcvMsg :: (Int64, InternalTs, BrokerId, BrokerTs) :. (AgentMsgId, MsgIntegrity, MsgHash, AgentMessageType, MsgBody, CR.PQEncryption, Maybe AgentMsgId, Maybe MsgReceiptStatus, Bool) -> RcvMsg +toRcvMsg :: (Int64, InternalTs, BrokerId, BrokerTs) :. (AgentMsgId, MsgIntegrity, MsgHash, AgentMessageType, MsgBody, PQEncryption, Maybe AgentMsgId, Maybe MsgReceiptStatus, Bool) -> RcvMsg toRcvMsg ((agentMsgId, internalTs, brokerId, brokerTs) :. (sndMsgId, integrity, internalHash, msgType, msgBody, pqEncryption, rcptInternalId_, rcptStatus_, userAck)) = let msgMeta = MsgMeta {recipient = (agentMsgId, internalTs), broker = (brokerId, brokerTs), sndMsgId, integrity, pqEncryption} msgReceipt = MsgReceipt <$> rcptInternalId_ <*> rcptStatus_ @@ -1776,9 +1776,13 @@ instance ToField (Version v) where toField (Version v) = toField v instance FromField (Version v) where fromField f = Version <$> fromField f -instance ToField CR.PQEncryption where toField (CR.PQEncryption pqEnc) = toField pqEnc +instance ToField PQEncryption where toField (PQEncryption pqEnc) = toField pqEnc -instance FromField CR.PQEncryption where fromField f = CR.PQEncryption <$> fromField f +instance FromField PQEncryption where fromField f = PQEncryption <$> fromField f + +instance ToField PQSupport where toField (PQSupport pqEnc) = toField pqEnc + +instance FromField PQSupport where fromField f = PQSupport <$> fromField f listToEither :: e -> [a] -> Either e a listToEither _ (x : _) = Right x @@ -1931,14 +1935,14 @@ getConnData db connId' = [sql| SELECT user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, - last_external_snd_msg_id, deleted, ratchet_sync_state, pq_encryption + last_external_snd_msg_id, deleted, ratchet_sync_state, pq_support FROM connections WHERE conn_id = ? |] (Only connId') where - cData (userId, connId, cMode, connAgentVersion, enableNtfs_, lastExternalSndId, deleted, ratchetSyncState, pqEncryption) = - (ConnData {userId, connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, lastExternalSndId, deleted, ratchetSyncState, pqEncryption}, cMode) + cData (userId, connId, cMode, connAgentVersion, enableNtfs_, lastExternalSndId, deleted, ratchetSyncState, pqSupport) = + (ConnData {userId, connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, lastExternalSndId, deleted, ratchetSyncState, pqSupport}, cMode) setConnDeleted :: DB.Connection -> Bool -> ConnId -> IO () setConnDeleted db waitDelivery connId @@ -1952,9 +1956,9 @@ setConnAgentVersion :: DB.Connection -> ConnId -> VersionSMPA -> IO () setConnAgentVersion db connId aVersion = DB.execute db "UPDATE connections SET smp_agent_version = ? WHERE conn_id = ?" (aVersion, connId) -enableConnPQEncryption :: DB.Connection -> ConnId -> IO () -enableConnPQEncryption db connId = - DB.execute db "UPDATE connections SET pq_encryption = ? WHERE conn_id = ?" (CR.PQEncOn, connId) +setConnPQSupport :: DB.Connection -> ConnId -> PQSupport -> IO () +setConnPQSupport db connId pqSupport = + DB.execute db "UPDATE connections SET pq_support = ? WHERE conn_id = ?" (pqSupport, connId) getDeletedConnIds :: DB.Connection -> IO [ConnId] getDeletedConnIds db = map fromOnly <$> DB.query db "SELECT conn_id FROM connections WHERE deleted = ?" (Only True) diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20240225_ratchet_kem.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20240225_ratchet_kem.hs index 07ba0f135..1e8a8db4d 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20240225_ratchet_kem.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20240225_ratchet_kem.hs @@ -9,7 +9,7 @@ m20240225_ratchet_kem :: Query m20240225_ratchet_kem = [sql| ALTER TABLE ratchets ADD COLUMN pq_priv_kem BLOB; -ALTER TABLE connections ADD COLUMN pq_encryption INTEGER NOT NULL DEFAULT 0; +ALTER TABLE connections ADD COLUMN pq_support INTEGER NOT NULL DEFAULT 0; ALTER TABLE messages ADD COLUMN pq_encryption INTEGER NOT NULL DEFAULT 0; |] @@ -17,6 +17,6 @@ down_m20240225_ratchet_kem :: Query down_m20240225_ratchet_kem = [sql| ALTER TABLE ratchets DROP COLUMN pq_priv_kem; -ALTER TABLE connections DROP COLUMN pq_encryption; +ALTER TABLE connections DROP COLUMN pq_support; ALTER TABLE messages DROP COLUMN pq_encryption; |] diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql index 850199cbb..0818be904 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql @@ -28,7 +28,7 @@ CREATE TABLE connections( REFERENCES users ON DELETE CASCADE, ratchet_sync_state TEXT NOT NULL DEFAULT 'ok', deleted_at_wait_delivery TEXT, - pq_encryption INTEGER NOT NULL DEFAULT 0 + pq_support INTEGER NOT NULL DEFAULT 0 ) WITHOUT ROWID; CREATE TABLE rcv_queues( host TEXT NOT NULL, diff --git a/src/Simplex/Messaging/Crypto/Ratchet.hs b/src/Simplex/Messaging/Crypto/Ratchet.hs index b0292f9b5..38ada0f01 100644 --- a/src/Simplex/Messaging/Crypto/Ratchet.hs +++ b/src/Simplex/Messaging/Crypto/Ratchet.hs @@ -28,6 +28,9 @@ module Simplex.Messaging.Crypto.Ratchet PQEncryption (..), pattern PQEncOn, pattern PQEncOff, + PQSupport (..), + pattern PQSupportOn, + pattern PQSupportOff, AUseKEM (..), RatchetKEMState (..), SRatchetKEMState (..), @@ -53,6 +56,8 @@ module Simplex.Messaging.Crypto.Ratchet connPQEncryption, joinContactInitialKeys, replyKEM_, + pqSupportToEnc, + pqEncToSupport, pqX3dhSnd, pqX3dhRcv, initSndRatchet, @@ -143,11 +148,11 @@ currentE2EEncryptVersion :: VersionE2E currentE2EEncryptVersion = VersionE2E 2 -- TODO v5.7 remove dependency of version range on whether PQ encryption is used -supportedE2EEncryptVRange :: PQEncryption -> VersionRangeE2E +supportedE2EEncryptVRange :: PQSupport -> VersionRangeE2E supportedE2EEncryptVRange pq = mkVersionRange kdfX3DHE2EEncryptVersion $ case pq of - PQEncOn -> pqRatchetE2EEncryptVersion - PQEncOff -> currentE2EEncryptVersion + PQSupportOn -> pqRatchetE2EEncryptVersion + PQSupportOff -> currentE2EEncryptVersion data RatchetKEMState = RKSProposed -- only KEM encapsulation key @@ -385,14 +390,13 @@ generateE2EParams g v useKEM_ = do _ -> pure Nothing -- used by party initiating connection, Bob in double-ratchet spec -generateRcvE2EParams :: (AlgorithmI a, DhAlgorithm a) => TVar ChaChaDRG -> VersionE2E -> PQEncryption -> IO (PrivateKey a, PrivateKey a, Maybe (PrivRKEMParams 'RKSProposed), E2ERatchetParams 'RKSProposed a) +generateRcvE2EParams :: (AlgorithmI a, DhAlgorithm a) => TVar ChaChaDRG -> VersionE2E -> PQSupport -> IO (PrivateKey a, PrivateKey a, Maybe (PrivRKEMParams 'RKSProposed), E2ERatchetParams 'RKSProposed a) generateRcvE2EParams g v = generateE2EParams g v . proposeKEM_ where - proposeKEM_ :: PQEncryption -> Maybe (UseKEM 'RKSProposed) + proposeKEM_ :: PQSupport -> Maybe (UseKEM 'RKSProposed) proposeKEM_ = \case - PQEncOn -> Just ProposeKEM - PQEncOff -> Nothing - + PQSupportOn -> Just ProposeKEM + PQSupportOff -> Nothing -- used by party accepting connection, Alice in double-ratchet spec generateSndE2EParams :: forall a. (AlgorithmI a, DhAlgorithm a) => TVar ChaChaDRG -> VersionE2E -> Maybe AUseKEM -> IO (PrivateKey a, PrivateKey a, Maybe APrivRKEMParams, AE2ERatchetParams a) @@ -466,7 +470,7 @@ data Ratchet a = Ratchet rcAD :: Str, rcDHRs :: PrivateKey a, rcKEM :: Maybe RatchetKEM, - rcSupportKEM :: PQEncryption, -- defines header size, can only be enabled once + rcSupportKEM :: PQSupport, -- defines header size, can only be enabled once rcEnableKEM :: PQEncryption, -- will enable KEM on the next ratchet step rcSndKEM :: PQEncryption, -- used KEM hybrid secret for sending ratchet rcRcvKEM :: PQEncryption, -- used KEM hybrid secret for receiving ratchet @@ -597,14 +601,14 @@ initSndRatchet :: initSndRatchet rcVersion rcDHRr rcDHRs (RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted}, rcPQRs_) = do -- state.RK, state.CKs, state.NHKs = KDF_RK_HE(SK, DH(state.DHRs, state.DHRr) || state.PQRss) let (rcRK, rcCKs, rcNHKs) = rootKdf ratchetKey rcDHRr rcDHRs (rcPQRss <$> kemAccepted) - pqEnc = PQEncryption $ isJust rcPQRs_ + pqOn = isJust rcPQRs_ in Ratchet { rcVersion, rcAD = assocData, rcDHRs, rcKEM = (`RatchetKEM` kemAccepted) <$> rcPQRs_, - rcSupportKEM = pqEnc, - rcEnableKEM = pqEnc, + rcSupportKEM = PQSupport pqOn, + rcEnableKEM = PQEncryption pqOn, rcSndKEM = PQEncryption $ isJust kemAccepted, rcRcvKEM = PQEncOff, rcRK, @@ -624,8 +628,8 @@ initSndRatchet rcVersion rcDHRr rcDHRs (RatchetInitParams {assocData, ratchetKey -- Please note that the public part of rcDHRs was sent to the sender -- as part of the connection request and random salt was received from the sender. initRcvRatchet :: - forall a. (AlgorithmI a, DhAlgorithm a) => RatchetVersions -> PrivateKey a -> (RatchetInitParams, Maybe KEMKeyPair) -> PQEncryption -> Ratchet a -initRcvRatchet rcVersion rcDHRs (RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted}, rcPQRs_) pqEnc = + forall a. (AlgorithmI a, DhAlgorithm a) => RatchetVersions -> PrivateKey a -> (RatchetInitParams, Maybe KEMKeyPair) -> PQSupport -> Ratchet a +initRcvRatchet rcVersion rcDHRs (RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted}, rcPQRs_) pqSupport = Ratchet { rcVersion, rcAD = assocData, @@ -636,8 +640,8 @@ initRcvRatchet rcVersion rcDHRs (RatchetInitParams {assocData, ratchetKey, sndHK -- state.PQRss = None -- state.PQRct = None rcKEM = (`RatchetKEM` kemAccepted) <$> rcPQRs_, - rcSupportKEM = pqEnc, - rcEnableKEM = pqEnc, + rcSupportKEM = pqSupport, + rcEnableKEM = pqSupportToEnc pqSupport, rcSndKEM = PQEncOff, rcRcvKEM = PQEncOff, rcRK = ratchetKey, @@ -666,14 +670,14 @@ data MsgHeader a = MsgHeader -- 69 = 2 (original size) + 2 + 1+56 (Curve448) + 4 + 4 -- TODO PQ this must be version-dependent -- TODO this is the exact size, some reserve should be added -paddedHeaderLen :: PQEncryption -> Int +paddedHeaderLen :: PQSupport -> Int paddedHeaderLen = \case - PQEncOn -> 2288 - PQEncOff -> 88 + PQSupportOn -> 2288 + PQSupportOff -> 88 -- only used in tests to validate correct padding -- (2 bytes - version size, 1 byte - header size, not to have it fixed or version-dependent) -fullHeaderLen :: PQEncryption -> Int +fullHeaderLen :: PQSupport -> Int fullHeaderLen pq = 2 + 1 + paddedHeaderLen pq + authTagSize + ivSize @AES256 -- pass the current version, as MsgHeader only includes the max supported version that can be different from the current @@ -759,12 +763,40 @@ instance FromJSON PQEncryption where parseJSON v = PQEncryption <$> parseJSON v omittedField = Just PQEncOff -replyKEM_ :: PQEncryption -> Maybe (RKEMParams 'RKSProposed) -> Maybe AUseKEM -replyKEM_ pqEnc kem_ = case pqEnc of - PQEncOn -> Just $ case kem_ of +newtype PQSupport = PQSupport {supportPQ :: Bool} + deriving (Eq, Show) + +pattern PQSupportOn :: PQSupport +pattern PQSupportOn = PQSupport True + +pattern PQSupportOff :: PQSupport +pattern PQSupportOff = PQSupport False + +{-# COMPLETE PQSupportOn, PQSupportOff #-} + +instance ToJSON PQSupport where + toEncoding (PQSupport pq) = toEncoding pq + toJSON (PQSupport pq) = toJSON pq + +instance FromJSON PQSupport where + parseJSON v = PQSupport <$> parseJSON v + omittedField = Just PQSupportOff + +pqSupportToEnc :: PQSupport -> PQEncryption +pqSupportToEnc (PQSupport pq) = PQEncryption pq + +pqEncToSupport :: PQEncryption -> PQSupport +pqEncToSupport (PQEncryption pq) = PQSupport pq + +supportOrEnc :: PQSupport -> PQEncryption -> PQSupport +supportOrEnc (PQSupport sup) (PQEncryption enc) = PQSupport $ sup || enc + +replyKEM_ :: Maybe (RKEMParams 'RKSProposed) -> PQSupport -> Maybe AUseKEM +replyKEM_ kem_ = \case + PQSupportOn -> Just $ case kem_ of Just (RKParamsProposed k) -> AUseKEM SRKSAccepted $ AcceptKEM k Nothing -> AUseKEM SRKSProposed ProposeKEM - PQEncOff -> Nothing + PQSupportOff -> Nothing instance StrEncoding PQEncryption where strEncode pqMode @@ -778,14 +810,20 @@ instance StrEncoding PQEncryption where where pq = pure . PQEncryption -data InitialKeys = IKUsePQ | IKNoPQ PQEncryption +instance StrEncoding PQSupport where + strEncode = strEncode . pqSupportToEnc + {-# INLINE strEncode #-} + strP = pqEncToSupport <$> strP + {-# INLINE strP #-} + +data InitialKeys = IKUsePQ | IKNoPQ PQSupport deriving (Eq, Show) pattern IKPQOn :: InitialKeys -pattern IKPQOn = IKNoPQ PQEncOn +pattern IKPQOn = IKNoPQ PQSupportOn pattern IKPQOff :: InitialKeys -pattern IKPQOff = IKNoPQ PQEncOff +pattern IKPQOff = IKNoPQ PQSupportOff instance StrEncoding InitialKeys where strEncode = \case @@ -794,22 +832,22 @@ instance StrEncoding InitialKeys where strP = IKNoPQ <$> strP <|> "pq=invitation" $> IKUsePQ -- determines whether PQ key should be included in invitation link -initialPQEncryption :: InitialKeys -> PQEncryption +initialPQEncryption :: InitialKeys -> PQSupport initialPQEncryption = \case - IKUsePQ -> PQEncOn - IKNoPQ _ -> PQEncOff -- default + IKUsePQ -> PQSupportOn + IKNoPQ _ -> PQSupportOff -- default -- determines whether PQ encryption should be used in connection -connPQEncryption :: InitialKeys -> PQEncryption +connPQEncryption :: InitialKeys -> PQSupport connPQEncryption = \case - IKUsePQ -> PQEncOn + IKUsePQ -> PQSupportOn IKNoPQ pq -> pq -- default for creating connection is IKNoPQ PQEncOn -- determines whether PQ key should be included in invitation link sent to contact address -joinContactInitialKeys :: PQEncryption -> InitialKeys +joinContactInitialKeys :: PQSupport -> InitialKeys joinContactInitialKeys = \case - PQEncOn -> IKUsePQ -- default - PQEncOff -> IKNoPQ PQEncOff + PQSupportOn -> IKUsePQ -- default + PQSupportOff -> IKNoPQ PQSupportOff rcEncrypt :: AlgorithmI a => Ratchet a -> Int -> ByteString -> Maybe PQEncryption -> ExceptT CryptoError IO (ByteString, Ratchet a) rcEncrypt Ratchet {rcSnd = Nothing} _ _ _ = throwE CERatchetState @@ -820,7 +858,7 @@ rcEncrypt rc@Ratchet {rcSnd = Just sr@SndRatchet {rcCKs, rcHKs}, rcDHRs, rcKEM, -- PQ encryption can be enabled or disabled rcEnableKEM' = fromMaybe rcEnableKEM pqEnc_ -- support for PQ encryption (and therefore large headers/small envelopes) can only be enabled, it cannot be disabled - rcSupportKEM' = PQEncryption $ enablePQ rcSupportKEM || enablePQ rcEnableKEM' + rcSupportKEM' = rcSupportKEM `supportOrEnc` rcEnableKEM' -- enc_header = HENCRYPT(state.HKs, header) (ehAuthTag, ehBody) <- encryptAEAD rcHKs ehIV (paddedHeaderLen rcSupportKEM') rcAD (msgHeader v) -- return enc_header, ENCRYPT(mk, plaintext, CONCAT(AD, enc_header)) @@ -949,13 +987,13 @@ rcDecrypt g rc@Ratchet {rcRcv, rcAD = Str rcAD, rcVersion} rcMKSkipped msg' = do (rcRK'', rcCKs', rcNHKs') = rootKdf rcRK' msgDHRs rcDHRs' kemSS' sndKEM = isJust kemSS' rcvKEM = isJust kemSS - enableKEM = sndKEM || rcvKEM || isJust rcKEM' + rcEnableKEM' = PQEncryption $ sndKEM || rcvKEM || isJust rcKEM' pure rc' { rcDHRs = rcDHRs', rcKEM = rcKEM', - rcSupportKEM = PQEncryption $ enablePQ rcSupportKEM || enableKEM, - rcEnableKEM = PQEncryption enableKEM, + rcSupportKEM = rcSupportKEM `supportOrEnc` rcEnableKEM', + rcEnableKEM = rcEnableKEM', rcSndKEM = PQEncryption sndKEM, rcRcvKEM = PQEncryption rcvKEM, rcRK = rcRK'', diff --git a/tests/AgentTests.hs b/tests/AgentTests.hs index 6ff9db523..bb91725f8 100644 --- a/tests/AgentTests.hs +++ b/tests/AgentTests.hs @@ -29,7 +29,7 @@ import SMPAgentClient import SMPClient (testKeyHash, testPort, testPort2, testStoreLogFile, withSmpServer, withSmpServerStoreLogOn) import Simplex.Messaging.Agent.Protocol hiding (MID) import qualified Simplex.Messaging.Agent.Protocol as A -import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), PQEncryption (..), pattern IKPQOn, pattern IKPQOff, pattern PQEncOn, pattern PQEncOff) +import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), PQEncryption (..), PQSupport (..), pattern IKPQOn, pattern IKPQOff, pattern PQEncOn, pattern PQSupportOn, pattern PQSupportOff) import qualified Simplex.Messaging.Crypto.Ratchet as CR import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol (ErrorType (..)) @@ -173,7 +173,7 @@ type PQMatrix2 c = HasCallStack => TProxy c -> (HasCallStack => (c -> c -> IO ()) -> Expectation) -> - (HasCallStack => (c, InitialKeys) -> (c, PQEncryption) -> IO ()) -> + (HasCallStack => (c, InitialKeys) -> (c, PQSupport) -> IO ()) -> Spec pqMatrix2 :: PQMatrix2 c @@ -184,34 +184,34 @@ pqMatrix2NoInv = pqMatrix2_ False pqMatrix2_ :: Bool -> PQMatrix2 c pqMatrix2_ pqInv _ smpTest test = do - it "dh/dh handshake" $ smpTest $ \a b -> test (a, IKPQOff) (b, PQEncOff) - it "dh/pq handshake" $ smpTest $ \a b -> test (a, IKPQOff) (b, PQEncOn) - it "pq/dh handshake" $ smpTest $ \a b -> test (a, IKPQOn) (b, PQEncOff) - it "pq/pq handshake" $ smpTest $ \a b -> test (a, IKPQOn) (b, PQEncOn) + it "dh/dh handshake" $ smpTest $ \a b -> test (a, IKPQOff) (b, PQSupportOff) + it "dh/pq handshake" $ smpTest $ \a b -> test (a, IKPQOff) (b, PQSupportOn) + it "pq/dh handshake" $ smpTest $ \a b -> test (a, IKPQOn) (b, PQSupportOff) + it "pq/pq handshake" $ smpTest $ \a b -> test (a, IKPQOn) (b, PQSupportOn) when pqInv $ do - it "pq-inv/dh handshake" $ smpTest $ \a b -> test (a, IKUsePQ) (b, PQEncOff) - it "pq-inv/pq handshake" $ smpTest $ \a b -> test (a, IKUsePQ) (b, PQEncOn) + it "pq-inv/dh handshake" $ smpTest $ \a b -> test (a, IKUsePQ) (b, PQSupportOff) + it "pq-inv/pq handshake" $ smpTest $ \a b -> test (a, IKUsePQ) (b, PQSupportOn) pqMatrix3 :: HasCallStack => TProxy c -> (HasCallStack => (c -> c -> c -> IO ()) -> Expectation) -> - (HasCallStack => (c, InitialKeys) -> (c, PQEncryption) -> (c, PQEncryption) -> IO ()) -> + (HasCallStack => (c, InitialKeys) -> (c, PQSupport) -> (c, PQSupport) -> IO ()) -> Spec pqMatrix3 _ smpTest test = do - it "dh" $ smpTest $ \a b c -> test (a, IKPQOff) (b, PQEncOff) (c, PQEncOff) - it "dh/dh/pq" $ smpTest $ \a b c -> test (a, IKPQOff) (b, PQEncOff) (c, PQEncOn) - it "dh/pq/dh" $ smpTest $ \a b c -> test (a, IKPQOff) (b, PQEncOn) (c, PQEncOff) - it "dh/pq/pq" $ smpTest $ \a b c -> test (a, IKPQOff) (b, PQEncOn) (c, PQEncOn) - it "pq/dh/dh" $ smpTest $ \a b c -> test (a, IKPQOn) (b, PQEncOff) (c, PQEncOff) - it "pq/dh/pq" $ smpTest $ \a b c -> test (a, IKPQOn) (b, PQEncOff) (c, PQEncOn) - it "pq/pq/dh" $ smpTest $ \a b c -> test (a, IKPQOn) (b, PQEncOn) (c, PQEncOff) - it "pq" $ smpTest $ \a b c -> test (a, IKPQOn) (b, PQEncOn) (c, PQEncOn) + it "dh" $ smpTest $ \a b c -> test (a, IKPQOff) (b, PQSupportOff) (c, PQSupportOff) + it "dh/dh/pq" $ smpTest $ \a b c -> test (a, IKPQOff) (b, PQSupportOff) (c, PQSupportOn) + it "dh/pq/dh" $ smpTest $ \a b c -> test (a, IKPQOff) (b, PQSupportOn) (c, PQSupportOff) + it "dh/pq/pq" $ smpTest $ \a b c -> test (a, IKPQOff) (b, PQSupportOn) (c, PQSupportOn) + it "pq/dh/dh" $ smpTest $ \a b c -> test (a, IKPQOn) (b, PQSupportOff) (c, PQSupportOff) + it "pq/dh/pq" $ smpTest $ \a b c -> test (a, IKPQOn) (b, PQSupportOff) (c, PQSupportOn) + it "pq/pq/dh" $ smpTest $ \a b c -> test (a, IKPQOn) (b, PQSupportOn) (c, PQSupportOff) + it "pq" $ smpTest $ \a b c -> test (a, IKPQOn) (b, PQSupportOn) (c, PQSupportOn) testDuplexConnection :: (HasCallStack, Transport c) => TProxy c -> c -> c -> IO () -testDuplexConnection _ alice bob = testDuplexConnection' (alice, IKPQOn) (bob, PQEncOn) +testDuplexConnection _ alice bob = testDuplexConnection' (alice, IKPQOn) (bob, PQSupportOn) -testDuplexConnection' :: (HasCallStack, Transport c) => (c, InitialKeys) -> (c, PQEncryption) -> IO () +testDuplexConnection' :: (HasCallStack, Transport c) => (c, InitialKeys) -> (c, PQSupport) -> IO () testDuplexConnection' (alice, aPQ) (bob, bPQ) = do let pq = pqConnectionMode aPQ bPQ ("1", "bob", Right (INV cReq)) <- alice #: ("1", "bob", "NEW T INV" <> pqConnModeStr aPQ <> " subscribe") @@ -246,9 +246,9 @@ testDuplexConnection' (alice, aPQ) (bob, bPQ) = do alice #:# "nothing else should be delivered to alice" testDuplexConnRandomIds :: (HasCallStack, Transport c) => TProxy c -> c -> c -> IO () -testDuplexConnRandomIds _ alice bob = testDuplexConnRandomIds' (alice, IKPQOn) (bob, PQEncOn) +testDuplexConnRandomIds _ alice bob = testDuplexConnRandomIds' (alice, IKPQOn) (bob, PQSupportOn) -testDuplexConnRandomIds' :: (HasCallStack, Transport c) => (c, InitialKeys) -> (c, PQEncryption) -> IO () +testDuplexConnRandomIds' :: (HasCallStack, Transport c) => (c, InitialKeys) -> (c, PQSupport) -> IO () testDuplexConnRandomIds' (alice, aPQ) (bob, bPQ) = do let pq = pqConnectionMode aPQ bPQ ("1", bobConn, Right (INV cReq)) <- alice #: ("1", "", "NEW T INV" <> pqConnModeStr aPQ <> " subscribe") @@ -282,7 +282,7 @@ testDuplexConnRandomIds' (alice, aPQ) (bob, bPQ) = do alice #: ("6", bobConn, "DEL") #> ("6", bobConn, OK) alice #:# "nothing else should be delivered to alice" -testContactConnection :: Transport c => (c, InitialKeys) -> (c, PQEncryption) -> (c, PQEncryption) -> IO () +testContactConnection :: Transport c => (c, InitialKeys) -> (c, PQSupport) -> (c, PQSupport) -> IO () testContactConnection (alice, aPQ) (bob, bPQ) (tom, tPQ) = do ("1", "alice_contact", Right (INV cReq)) <- alice #: ("1", "alice_contact", "NEW T CON" <> pqConnModeStr aPQ <> " subscribe") let cReq' = strEncode cReq @@ -316,7 +316,7 @@ testContactConnection (alice, aPQ) (bob, bPQ) (tom, tPQ) = do tom <#= \case ("", "alice", Msg' 4 pq' "hi there") -> pq' == atPQ; _ -> False tom #: ("23", "alice", "ACK 4") #> ("23", "alice", OK) -testContactConnRandomIds :: Transport c => (c, InitialKeys) -> (c, PQEncryption) -> IO () +testContactConnRandomIds :: Transport c => (c, InitialKeys) -> (c, PQSupport) -> IO () testContactConnRandomIds (alice, aPQ) (bob, bPQ) = do let pq = pqConnectionMode aPQ bPQ ("1", aliceContact, Right (INV cReq)) <- alice #: ("1", "", "NEW T CON" <> pqConnModeStr aPQ <> " subscribe") @@ -380,7 +380,7 @@ testSubscrNotification t (server, _) client = do withSmpServer (ATransport t) $ client <# ("", "conn1", ERR (SMP AUTH)) -- this new server does not have the queue -testMsgDeliveryServerRestart :: forall c. Transport c => (c, InitialKeys) -> (c, PQEncryption) -> IO () +testMsgDeliveryServerRestart :: forall c. Transport c => (c, InitialKeys) -> (c, PQSupport) -> IO () testMsgDeliveryServerRestart (alice, aPQ) (bob, bPQ) = do let pq = pqConnectionMode aPQ bPQ withServer $ do @@ -547,9 +547,9 @@ testResumeDeliveryQuotaExceeded _ alice bob = do bob #: ("5", "alice", "ACK 9") #> ("5", "alice", OK) connect :: Transport c => (c, ByteString) -> (c, ByteString) -> IO () -connect (h1, name1) (h2, name2) = connect' (h1, name1, IKPQOn) (h2, name2, PQEncOn) +connect (h1, name1) (h2, name2) = connect' (h1, name1, IKPQOn) (h2, name2, PQSupportOn) -connect' :: forall c. Transport c => (c, ByteString, InitialKeys) -> (c, ByteString, PQEncryption) -> IO () +connect' :: forall c. Transport c => (c, ByteString, InitialKeys) -> (c, ByteString, PQSupport) -> IO () connect' (h1, name1, pqMode1) (h2, name2, pqMode2) = do ("c1", _, Right (INV cReq)) <- h1 #: ("c1", name2, "NEW T INV" <> pqConnModeStr pqMode1 <> " subscribe") let cReq' = strEncode cReq @@ -561,15 +561,15 @@ connect' (h1, name1, pqMode1) (h2, name2, pqMode2) = do h2 <# ("", name1, CON pq) h1 <# ("", name2, CON pq) -pqConnectionMode :: InitialKeys -> PQEncryption -> PQEncryption -pqConnectionMode pqMode1 pqMode2 = PQEncryption $ enablePQ (CR.connPQEncryption pqMode1) && enablePQ pqMode2 +pqConnectionMode :: InitialKeys -> PQSupport -> PQEncryption +pqConnectionMode pqMode1 pqMode2 = PQEncryption $ supportPQ (CR.connPQEncryption pqMode1) && supportPQ pqMode2 -enableKEMStr :: PQEncryption -> ByteString -enableKEMStr PQEncOn = " " <> strEncode PQEncOn +enableKEMStr :: PQSupport -> ByteString +enableKEMStr PQSupportOn = " " <> strEncode PQSupportOn enableKEMStr _ = "" pqConnModeStr :: InitialKeys -> ByteString -pqConnModeStr (IKNoPQ PQEncOff) = "" +pqConnModeStr (IKNoPQ PQSupportOff) = "" pqConnModeStr pq = " " <> strEncode pq sendMessage :: Transport c => (c, ConnId) -> (c, ConnId) -> ByteString -> IO () diff --git a/tests/AgentTests/ConnectionRequestTests.hs b/tests/AgentTests/ConnectionRequestTests.hs index 8228ae4cd..f5e689e96 100644 --- a/tests/AgentTests/ConnectionRequestTests.hs +++ b/tests/AgentTests/ConnectionRequestTests.hs @@ -81,7 +81,7 @@ testE2ERatchetParams :: RcvE2ERatchetParamsUri 'C.X448 testE2ERatchetParams = E2ERatchetParamsUri (mkVersionRange (VersionE2E 1) (VersionE2E 1)) testDhPubKey testDhPubKey Nothing testE2ERatchetParams12 :: RcvE2ERatchetParamsUri 'C.X448 -testE2ERatchetParams12 = E2ERatchetParamsUri (supportedE2EEncryptVRange PQEncOn) testDhPubKey testDhPubKey Nothing +testE2ERatchetParams12 = E2ERatchetParamsUri (supportedE2EEncryptVRange PQSupportOn) testDhPubKey testDhPubKey Nothing connectionRequest :: AConnectionRequestUri connectionRequest = @@ -95,7 +95,7 @@ connectionRequestCurrentRange :: AConnectionRequestUri connectionRequestCurrentRange = ACR SCMInvitation $ CRInvitationUri - connReqData {crAgentVRange = supportedSMPAgentVRange PQEncOn, crSmpQueues = [queueV1, queueV1]} + connReqData {crAgentVRange = supportedSMPAgentVRange PQSupportOn, crSmpQueues = [queueV1, queueV1]} testE2ERatchetParams12 connectionRequestClientDataEmpty :: AConnectionRequestUri diff --git a/tests/AgentTests/DoubleRatchetTests.hs b/tests/AgentTests/DoubleRatchetTests.hs index 5c5241849..5b73be893 100644 --- a/tests/AgentTests/DoubleRatchetTests.hs +++ b/tests/AgentTests/DoubleRatchetTests.hs @@ -94,8 +94,8 @@ fullMsgLen Ratchet {rcSupportKEM} = headerLenLength + fullHeaderLen rcSupportKEM where -- v = current rcVersion headerLenLength = case rcSupportKEM of - PQEncOn -> 3 -- two bytes are added because of two Large used in new encoding - PQEncOff -> 1 + PQSupportOn -> 3 -- two bytes are added because of two Large used in new encoding + PQSupportOff -> 1 -- TODO PQ below should work too -- | v >= pqRatchetE2EEncryptVersion = 3 -- | otherwise = 1 @@ -371,6 +371,7 @@ testDecodeV2RatchetJSON :: IO () testDecodeV2RatchetJSON = do let v2RatchetJSON = "{\"rcVersion\":[2,2],\"rcAD\":\"2GEJrq48TmQse6NR16I-hrI0tSySZQ57E_g46nDceAPRAiF6j0drq26RTE7be6X7uiB4RaGJGf4QRXzcYuVtWw==\",\"rcDHRs\":\"TUM0Q0FRQXdCUVlESzJWdUJDSUVJRkNYbUxtSHQ3SUNfeHpGTi1Qb3ZqTVQ3S2p6XzZlZlBjOG9fRFY2RWxKOQ==\",\"rcRK\":\"BOX2X7YW5qDSp2XknY_lqacSrtDqQNPvS6iJlZIs3G0=\",\"rcNs\":0,\"rcNr\":0,\"rcPN\":0,\"rcNHKs\":\"IMouSkXUvzT_mo0WM-pqEUK09-HTLk9WOTCFQglyQxU=\",\"rcNHKr\":\"g-tus1clYPV0rGlzkf5a959tUqDYQVZ1FpcPeXdKwxI=\"}" Right (r :: Ratchet X25519) <- pure $ J.eitherDecodeStrict' v2RatchetJSON + rcSupportKEM r `shouldBe` PQSupportOff rcEnableKEM r `shouldBe` PQEncOff rcSndKEM r `shouldBe` PQEncOff rcRcvKEM r `shouldBe` PQEncOff @@ -386,7 +387,7 @@ testX3dh _ = do g <- C.newRandom let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion (pkBob1, pkBob2, Nothing, AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g v Nothing - (pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQEncOff + (pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQSupportOff let paramsBob = pqX3dhSnd pkBob1 pkBob2 Nothing e2eAlice paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 Nothing e2eBob paramsAlice `shouldBe` paramsBob @@ -395,7 +396,7 @@ testX3dhV1 :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO () testX3dhV1 _ = do g <- C.newRandom (pkBob1, pkBob2, Nothing, AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g (VersionE2E 1) Nothing - (pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams @a g (VersionE2E 1) PQEncOff + (pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams @a g (VersionE2E 1) PQSupportOff let paramsBob = pqX3dhSnd pkBob1 pkBob2 Nothing e2eAlice paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 Nothing e2eBob paramsAlice `shouldBe` paramsBob @@ -405,7 +406,7 @@ testPqX3dhProposeInReply _ = do g <- C.newRandom let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion -- initiate (no KEM) - (pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQEncOff + (pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQSupportOff -- propose KEM in reply (pkBob1, pkBob2, pKemBob_@(Just _), AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g v (Just $ AUseKEM SRKSProposed ProposeKEM) Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 pKemBob_ e2eAlice @@ -417,7 +418,7 @@ testPqX3dhProposeAccept _ = do g <- C.newRandom let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion -- initiate (propose KEM) - (pkAlice1, pkAlice2, pKemAlice_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQEncOn + (pkAlice1, pkAlice2, pKemAlice_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQSupportOn E2ERatchetParams _ _ _ (Just (RKParamsProposed aliceKem)) <- pure e2eAlice -- accept KEM (pkBob1, pkBob2, pKemBob_@(Just _), AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g v (Just $ AUseKEM SRKSAccepted $ AcceptKEM aliceKem) @@ -430,7 +431,7 @@ testPqX3dhProposeReject _ = do g <- C.newRandom let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion -- initiate (propose KEM) - (pkAlice1, pkAlice2, pKemAlice_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQEncOn + (pkAlice1, pkAlice2, pKemAlice_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQSupportOn E2ERatchetParams _ _ _ (Just (RKParamsProposed _)) <- pure e2eAlice -- reject KEM (pkBob1, pkBob2, Nothing, AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g v Nothing @@ -443,7 +444,7 @@ testPqX3dhAcceptWithoutProposalError _ = do g <- C.newRandom let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion -- initiate (no KEM) - (pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQEncOff + (pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQSupportOff E2ERatchetParams _ _ _ Nothing <- pure e2eAlice -- incorrectly accept KEM -- we don't have key in proposal, so we just generate it @@ -457,7 +458,7 @@ testPqX3dhProposeAgain _ = do g <- C.newRandom let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion -- initiate (propose KEM) - (pkAlice1, pkAlice2, pKemAlice_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQEncOn + (pkAlice1, pkAlice2, pKemAlice_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQSupportOn E2ERatchetParams _ _ _ (Just (RKParamsProposed _)) <- pure e2eAlice -- propose KEM again in reply - this is not an error (pkBob1, pkBob2, pKemBob_@(Just _), AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g v (Just $ AUseKEM SRKSProposed ProposeKEM) @@ -520,13 +521,13 @@ initRatchets = do g <- C.newRandom let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion (pkBob1, pkBob2, _pKemParams@Nothing, AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams g v Nothing - (pkAlice1, pkAlice2, _pKem@Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams g v PQEncOff + (pkAlice1, pkAlice2, _pKem@Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams g v PQSupportOff Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 Nothing e2eAlice Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 Nothing e2eBob (_, pkBob3) <- atomically $ C.generateKeyPair g - let vs = testRatchetVersions PQEncOff + let vs = testRatchetVersions PQSupportOff bob = initSndRatchet vs (C.publicKey pkAlice2) pkBob3 paramsBob - alice = initRcvRatchet vs pkAlice2 paramsAlice PQEncOff + alice = initRcvRatchet vs pkAlice2 paramsAlice PQSupportOff pure (alice, bob, encrypt' noSndKEM, decrypt' noRcvKEM, (\#>)) initRatchetsKEMProposed :: forall a. (AlgorithmI a, DhAlgorithm a) => IO (Ratchet a, Ratchet a, Encrypt a, Decrypt a, EncryptDecryptSpec a) @@ -534,16 +535,16 @@ initRatchetsKEMProposed = do g <- C.newRandom let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion -- initiate (no KEM) - (pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams g v PQEncOff + (pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams g v PQSupportOff -- propose KEM in reply let useKem = AUseKEM SRKSProposed ProposeKEM (pkBob1, pkBob2, pKemParams_@(Just _), AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams g v (Just useKem) Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 pKemParams_ e2eAlice Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 Nothing e2eBob (_, pkBob3) <- atomically $ C.generateKeyPair g - let vs = testRatchetVersions PQEncOn + let vs = testRatchetVersions PQSupportOn bob = initSndRatchet vs (C.publicKey pkAlice2) pkBob3 paramsBob - alice = initRcvRatchet vs pkAlice2 paramsAlice PQEncOn + alice = initRcvRatchet vs pkAlice2 paramsAlice PQSupportOn pure (alice, bob, encrypt' hasSndKEM, decrypt' hasRcvKEM, (!#>)) initRatchetsKEMAccepted :: forall a. (AlgorithmI a, DhAlgorithm a) => IO (Ratchet a, Ratchet a, Encrypt a, Decrypt a, EncryptDecryptSpec a) @@ -551,7 +552,7 @@ initRatchetsKEMAccepted = do g <- C.newRandom let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion -- initiate (propose) - (pkAlice1, pkAlice2, pKem_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams g v PQEncOn + (pkAlice1, pkAlice2, pKem_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams g v PQSupportOn E2ERatchetParams _ _ _ (Just (RKParamsProposed aliceKem)) <- pure e2eAlice -- accept let useKem = AUseKEM SRKSAccepted (AcceptKEM aliceKem) @@ -559,9 +560,9 @@ initRatchetsKEMAccepted = do Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 pKemParams_ e2eAlice Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 pKem_ e2eBob (_, pkBob3) <- atomically $ C.generateKeyPair g - let vs = testRatchetVersions PQEncOn + let vs = testRatchetVersions PQSupportOn bob = initSndRatchet vs (C.publicKey pkAlice2) pkBob3 paramsBob - alice = initRcvRatchet vs pkAlice2 paramsAlice PQEncOn + alice = initRcvRatchet vs pkAlice2 paramsAlice PQSupportOn pure (alice, bob, encrypt' hasSndKEM, decrypt' hasRcvKEM, (!#>)) initRatchetsKEMProposedAgain :: forall a. (AlgorithmI a, DhAlgorithm a) => IO (Ratchet a, Ratchet a, Encrypt a, Decrypt a, EncryptDecryptSpec a) @@ -569,19 +570,19 @@ initRatchetsKEMProposedAgain = do g <- C.newRandom let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion -- initiate (propose KEM) - (pkAlice1, pkAlice2, pKem_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams g v PQEncOn + (pkAlice1, pkAlice2, pKem_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams g v PQSupportOn -- propose KEM again in reply let useKem = AUseKEM SRKSProposed ProposeKEM (pkBob1, pkBob2, pKemParams_@(Just _), AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams g v (Just useKem) Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 pKemParams_ e2eAlice Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 pKem_ e2eBob (_, pkBob3) <- atomically $ C.generateKeyPair g - let vs = testRatchetVersions PQEncOn + let vs = testRatchetVersions PQSupportOn bob = initSndRatchet vs (C.publicKey pkAlice2) pkBob3 paramsBob - alice = initRcvRatchet vs pkAlice2 paramsAlice PQEncOn + alice = initRcvRatchet vs pkAlice2 paramsAlice PQSupportOn pure (alice, bob, encrypt' hasSndKEM, decrypt' hasRcvKEM, (!#>)) -testRatchetVersions :: PQEncryption -> RatchetVersions +testRatchetVersions :: PQSupport -> RatchetVersions testRatchetVersions pq = let v = maxVersion $ supportedE2EEncryptVRange pq in RVersions v v diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index bc328bbad..706f3994f 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -72,7 +72,7 @@ import Simplex.Messaging.Agent.Store.SQLite (MigrationConfirmation (..), SQLiteS import Simplex.Messaging.Agent.Store.SQLite.Common (withTransaction') import Simplex.Messaging.Client (NetworkConfig (..), ProtocolClientConfig (..), TransportSessionMode (TSMEntity, TSMUser), defaultSMPClientConfig) import qualified Simplex.Messaging.Crypto as C -import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), PQEncryption (..), pattern PQEncOn, pattern PQEncOff) +import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), PQEncryption (..), PQSupport (..), pattern PQEncOn, pattern PQEncOff, pattern PQSupportOn, pattern PQSupportOff) import qualified Simplex.Messaging.Crypto.Ratchet as CR import Simplex.Messaging.Encoding.String import Simplex.Messaging.Notifications.Transport (NTFVersion, pattern VersionNTF, authBatchCmdsNTFVersion) @@ -173,9 +173,9 @@ agentCfgVPrev :: AgentConfig agentCfgVPrev = agentCfg { sndAuthAlg = C.AuthAlg C.SEd25519, - smpAgentVRange = \_ -> prevRange $ smpAgentVRange agentCfg PQEncOff, + smpAgentVRange = \_ -> prevRange $ smpAgentVRange agentCfg PQSupportOff, smpClientVRange = prevRange $ smpClientVRange agentCfg, - e2eEncryptVRange = \_ -> prevRange $ e2eEncryptVRange agentCfg PQEncOff, + e2eEncryptVRange = \_ -> prevRange $ e2eEncryptVRange agentCfg PQSupportOff, smpCfg = smpCfgVPrev } @@ -188,7 +188,7 @@ agentCfgV7 = } agentCfgRatchetVPrev :: AgentConfig -agentCfgRatchetVPrev = agentCfg {e2eEncryptVRange = \_ -> prevRange $ e2eEncryptVRange agentCfg PQEncOff} +agentCfgRatchetVPrev = agentCfg {e2eEncryptVRange = \_ -> prevRange $ e2eEncryptVRange agentCfg PQSupportOff} prevRange :: VersionRange v -> VersionRange v prevRange vr = vr {maxVersion = max (minVersion vr) (prevVersion $ maxVersion vr)} @@ -224,10 +224,10 @@ inAnyOrder g rs = do expected r rp = rp r createConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SubscriptionMode -> m (ConnId, ConnectionRequestUri c) -createConnection c userId enableNtfs cMode clientData = A.createConnection c userId enableNtfs cMode clientData (IKNoPQ PQEncOn) +createConnection c userId enableNtfs cMode clientData = A.createConnection c userId enableNtfs cMode clientData (IKNoPQ PQSupportOn) joinConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> m ConnId -joinConnection c userId enableNtfs cReq connInfo = A.joinConnection c userId enableNtfs cReq connInfo PQEncOn +joinConnection c userId enableNtfs cReq connInfo = A.joinConnection c userId enableNtfs cReq connInfo PQSupportOn sendMessage :: AgentErrorMonad m => AgentClient -> ConnId -> SMP.MsgFlags -> MsgBody -> m AgentMsgId sendMessage c connId msgFlags msgBody = do @@ -427,24 +427,24 @@ canCreateQueue allowNew (srvAuth, srvVersion) (clntAuth, clntVersion) = in allowNew && (isNothing srvAuth || (srvVersion >= v && clntVersion >= v && srvAuth == clntAuth)) -- TODO PQ test next version with PQ -testMatrix2 :: ATransport -> (PQEncryption -> AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec +testMatrix2 :: ATransport -> (PQSupport -> AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec testMatrix2 t runTest = do - it "v7" $ withSmpServerV7 t $ runTestCfg2 agentCfgV7 agentCfgV7 3 $ runTest PQEncOn - it "v7 to current" $ withSmpServerV7 t $ runTestCfg2 agentCfgV7 agentCfg 3 $ runTest PQEncOn - it "current to v7" $ withSmpServerV7 t $ runTestCfg2 agentCfg agentCfgV7 3 $ runTest PQEncOn - it "current with v7 server" $ withSmpServerV7 t $ runTestCfg2 agentCfg agentCfg 3 $ runTest PQEncOn - it "current" $ withSmpServer t $ runTestCfg2 agentCfg agentCfg 3 $ runTest PQEncOn - it "prev" $ withSmpServer t $ runTestCfg2 agentCfgVPrev agentCfgVPrev 3 $ runTest PQEncOff - it "prev to current" $ withSmpServer t $ runTestCfg2 agentCfgVPrev agentCfg 3 $ runTest PQEncOff - it "current to prev" $ withSmpServer t $ runTestCfg2 agentCfg agentCfgVPrev 3 $ runTest PQEncOff + it "v7" $ withSmpServerV7 t $ runTestCfg2 agentCfgV7 agentCfgV7 3 $ runTest PQSupportOn + it "v7 to current" $ withSmpServerV7 t $ runTestCfg2 agentCfgV7 agentCfg 3 $ runTest PQSupportOn + it "current to v7" $ withSmpServerV7 t $ runTestCfg2 agentCfg agentCfgV7 3 $ runTest PQSupportOn + it "current with v7 server" $ withSmpServerV7 t $ runTestCfg2 agentCfg agentCfg 3 $ runTest PQSupportOn + it "current" $ withSmpServer t $ runTestCfg2 agentCfg agentCfg 3 $ runTest PQSupportOn + it "prev" $ withSmpServer t $ runTestCfg2 agentCfgVPrev agentCfgVPrev 3 $ runTest PQSupportOff + it "prev to current" $ withSmpServer t $ runTestCfg2 agentCfgVPrev agentCfg 3 $ runTest PQSupportOff + it "current to prev" $ withSmpServer t $ runTestCfg2 agentCfg agentCfgVPrev 3 $ runTest PQSupportOff -- TODO PQ test next version with PQ -testRatchetMatrix2 :: ATransport -> (PQEncryption -> AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec +testRatchetMatrix2 :: ATransport -> (PQSupport -> AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec testRatchetMatrix2 t runTest = do - it "ratchet current" $ withSmpServer t $ runTestCfg2 agentCfg agentCfg 3 $ runTest PQEncOn - it "ratchet prev" $ withSmpServer t $ runTestCfg2 agentCfgRatchetVPrev agentCfgRatchetVPrev 3 $ runTest PQEncOff - it "ratchets prev to current" $ withSmpServer t $ runTestCfg2 agentCfgRatchetVPrev agentCfg 3 $ runTest PQEncOff - it "ratchets current to prev" $ withSmpServer t $ runTestCfg2 agentCfg agentCfgRatchetVPrev 3 $ runTest PQEncOff + it "ratchet current" $ withSmpServer t $ runTestCfg2 agentCfg agentCfg 3 $ runTest PQSupportOn + it "ratchet prev" $ withSmpServer t $ runTestCfg2 agentCfgRatchetVPrev agentCfgRatchetVPrev 3 $ runTest PQSupportOff + it "ratchets prev to current" $ withSmpServer t $ runTestCfg2 agentCfgRatchetVPrev agentCfg 3 $ runTest PQSupportOff + it "ratchets current to prev" $ withSmpServer t $ runTestCfg2 agentCfg agentCfgRatchetVPrev 3 $ runTest PQSupportOff testServerMatrix2 :: ATransport -> (InitialAgentServers -> IO ()) -> Spec testServerMatrix2 t runTest = do @@ -466,13 +466,14 @@ withAgentClientsCfg2 aCfg bCfg runTest = do withAgentClients2 :: (AgentClient -> AgentClient -> IO ()) -> IO () withAgentClients2 = withAgentClientsCfg2 agentCfg agentCfg -runAgentClientTest :: HasCallStack => PQEncryption -> AgentClient -> AgentClient -> AgentMsgId -> IO () -runAgentClientTest pqEnc alice@AgentClient {} bob baseId = +runAgentClientTest :: HasCallStack => PQSupport -> AgentClient -> AgentClient -> AgentMsgId -> IO () +runAgentClientTest pqSupport alice@AgentClient {} bob baseId = runRight_ $ do - (bobId, qInfo) <- A.createConnection alice 1 True SCMInvitation Nothing (IKNoPQ pqEnc) SMSubscribe - aliceId <- A.joinConnection bob 1 True qInfo "bob's connInfo" pqEnc SMSubscribe + (bobId, qInfo) <- A.createConnection alice 1 True SCMInvitation Nothing (IKNoPQ pqSupport) SMSubscribe + aliceId <- A.joinConnection bob 1 True qInfo "bob's connInfo" pqSupport SMSubscribe ("", _, CONF confId _ "bob's connInfo") <- get alice allowConnection alice bobId confId "alice's connInfo" + let pqEnc = CR.pqSupportToEnc pqSupport get alice ##> ("", bobId, A.CON pqEnc) get bob ##> ("", aliceId, INFO "alice's connInfo") get bob ##> ("", aliceId, A.CON pqEnc) @@ -527,15 +528,16 @@ testAgentClient3 = do get c =##> \case ("", connId, Msg "c5") -> connId == aIdForC; _ -> False ackMessage c aIdForC 5 Nothing -runAgentClientContactTest :: HasCallStack => PQEncryption -> AgentClient -> AgentClient -> AgentMsgId -> IO () -runAgentClientContactTest pqEnc alice bob baseId = +runAgentClientContactTest :: HasCallStack => PQSupport -> AgentClient -> AgentClient -> AgentMsgId -> IO () +runAgentClientContactTest pqSupport alice bob baseId = runRight_ $ do - (_, qInfo) <- A.createConnection alice 1 True SCMContact Nothing (IKNoPQ pqEnc) SMSubscribe - aliceId <- A.joinConnection bob 1 True qInfo "bob's connInfo" pqEnc SMSubscribe + (_, qInfo) <- A.createConnection alice 1 True SCMContact Nothing (IKNoPQ pqSupport) SMSubscribe + aliceId <- A.joinConnection bob 1 True qInfo "bob's connInfo" pqSupport SMSubscribe ("", _, REQ invId _ "bob's connInfo") <- get alice - bobId <- acceptContact alice True invId "alice's connInfo" PQEncOn SMSubscribe + bobId <- acceptContact alice True invId "alice's connInfo" PQSupportOn SMSubscribe ("", _, CONF confId _ "alice's connInfo") <- get bob allowConnection bob aliceId confId "bob's connInfo" + let pqEnc = CR.pqSupportToEnc pqSupport get alice ##> ("", bobId, INFO "bob's connInfo") get alice ##> ("", bobId, A.CON pqEnc) get bob ##> ("", aliceId, A.CON pqEnc) @@ -690,7 +692,7 @@ testIncreaseConnAgentVersion t = do bob <- getSMPAgentClient' 2 agentCfg {smpAgentVRange = \_ -> mkVersionRange 1 2} initAgentServers testDB2 withSmpServerStoreMsgLogOn t testPort $ \_ -> do (aliceId, bobId) <- runRight $ do - (aliceId, bobId) <- makeConnection_ PQEncOff alice bob + (aliceId, bobId) <- makeConnection_ PQSupportOff alice bob exchangeGreetingsMsgId_ PQEncOff 4 alice bobId bob aliceId checkVersion alice bobId 2 checkVersion bob aliceId 2 @@ -751,7 +753,7 @@ testIncreaseConnAgentVersionMaxCompatible t = do bob <- getSMPAgentClient' 2 agentCfg {smpAgentVRange = \_ -> mkVersionRange 1 2} initAgentServers testDB2 withSmpServerStoreMsgLogOn t testPort $ \_ -> do (aliceId, bobId) <- runRight $ do - (aliceId, bobId) <- makeConnection_ PQEncOff alice bob + (aliceId, bobId) <- makeConnection_ PQSupportOff alice bob exchangeGreetingsMsgId_ PQEncOff 4 alice bobId bob aliceId checkVersion alice bobId 2 checkVersion bob aliceId 2 @@ -779,7 +781,7 @@ testIncreaseConnAgentVersionStartDifferentVersion t = do bob <- getSMPAgentClient' 2 agentCfg {smpAgentVRange = \_ -> mkVersionRange 1 3} initAgentServers testDB2 withSmpServerStoreMsgLogOn t testPort $ \_ -> do (aliceId, bobId) <- runRight $ do - (aliceId, bobId) <- makeConnection_ PQEncOff alice bob + (aliceId, bobId) <- makeConnection_ PQSupportOff alice bob exchangeGreetingsMsgId_ PQEncOff 4 alice bobId bob aliceId checkVersion alice bobId 2 checkVersion bob aliceId 2 @@ -1029,7 +1031,7 @@ testRatchetSync t = withAgentClients2 $ \alice bob -> withSmpServerStoreMsgLogOn t testPort $ \_ -> do (aliceId, bobId, bob2) <- setupDesynchronizedRatchet alice bob runRight $ do - ConnectionStats {ratchetSyncState} <- synchronizeRatchet bob2 aliceId PQEncOn False + ConnectionStats {ratchetSyncState} <- synchronizeRatchet bob2 aliceId PQSupportOn False liftIO $ ratchetSyncState `shouldBe` RSStarted get alice =##> ratchetSyncP bobId RSAgreed get bob2 =##> ratchetSyncP aliceId RSAgreed @@ -1073,7 +1075,7 @@ setupDesynchronizedRatchet alice bob = do runRight_ $ do subscribeConnection bob2 aliceId - Left A.CMD {cmdErr = PROHIBITED} <- runExceptT $ synchronizeRatchet bob2 aliceId PQEncOn False + Left A.CMD {cmdErr = PROHIBITED} <- runExceptT $ synchronizeRatchet bob2 aliceId PQSupportOn False 8 <- sendMessage alice bobId SMP.noMsgFlags "hello 5" get alice ##> ("", bobId, SENT 8) @@ -1104,7 +1106,7 @@ testRatchetSyncServerOffline t = withAgentClients2 $ \alice bob -> do ("", "", DOWN _ _) <- nGet alice ("", "", DOWN _ _) <- nGet bob2 - ConnectionStats {ratchetSyncState} <- runRight $ synchronizeRatchet bob2 aliceId PQEncOn False + ConnectionStats {ratchetSyncState} <- runRight $ synchronizeRatchet bob2 aliceId PQSupportOn False liftIO $ ratchetSyncState `shouldBe` RSStarted withSmpServerStoreMsgLogOn t testPort $ \_ -> do @@ -1134,7 +1136,7 @@ testRatchetSyncClientRestart t = do setupDesynchronizedRatchet alice bob ("", "", DOWN _ _) <- nGet alice ("", "", DOWN _ _) <- nGet bob2 - ConnectionStats {ratchetSyncState} <- runRight $ synchronizeRatchet bob2 aliceId PQEncOn False + ConnectionStats {ratchetSyncState} <- runRight $ synchronizeRatchet bob2 aliceId PQSupportOn False liftIO $ ratchetSyncState `shouldBe` RSStarted disconnectAgentClient bob2 bob3 <- getSMPAgentClient' 3 agentCfg initAgentServers testDB2 @@ -1161,7 +1163,7 @@ testRatchetSyncSuspendForeground t = do ("", "", DOWN _ _) <- nGet alice ("", "", DOWN _ _) <- nGet bob2 - ConnectionStats {ratchetSyncState} <- runRight $ synchronizeRatchet bob2 aliceId PQEncOn False + ConnectionStats {ratchetSyncState} <- runRight $ synchronizeRatchet bob2 aliceId PQSupportOn False liftIO $ ratchetSyncState `shouldBe` RSStarted suspendAgent bob2 0 @@ -1195,10 +1197,10 @@ testRatchetSyncSimultaneous t = do ("", "", DOWN _ _) <- nGet alice ("", "", DOWN _ _) <- nGet bob2 - ConnectionStats {ratchetSyncState = bRSS} <- runRight $ synchronizeRatchet bob2 aliceId PQEncOn False + ConnectionStats {ratchetSyncState = bRSS} <- runRight $ synchronizeRatchet bob2 aliceId PQSupportOn False liftIO $ bRSS `shouldBe` RSStarted - ConnectionStats {ratchetSyncState = aRSS} <- runRight $ synchronizeRatchet alice bobId PQEncOn True + ConnectionStats {ratchetSyncState = aRSS} <- runRight $ synchronizeRatchet alice bobId PQSupportOn True liftIO $ aRSS `shouldBe` RSStarted withSmpServerStoreMsgLogOn t testPort $ \_ -> do @@ -1253,20 +1255,21 @@ testOnlyCreatePull = withAgentClients2 $ \alice bob -> runRight_ $ do pure r makeConnection :: AgentClient -> AgentClient -> ExceptT AgentErrorType IO (ConnId, ConnId) -makeConnection = makeConnection_ PQEncOn +makeConnection = makeConnection_ PQSupportOn -makeConnection_ :: PQEncryption -> AgentClient -> AgentClient -> ExceptT AgentErrorType IO (ConnId, ConnId) +makeConnection_ :: PQSupport -> AgentClient -> AgentClient -> ExceptT AgentErrorType IO (ConnId, ConnId) makeConnection_ pqEnc alice bob = makeConnectionForUsers_ pqEnc alice 1 bob 1 makeConnectionForUsers :: AgentClient -> UserId -> AgentClient -> UserId -> ExceptT AgentErrorType IO (ConnId, ConnId) -makeConnectionForUsers = makeConnectionForUsers_ PQEncOn +makeConnectionForUsers = makeConnectionForUsers_ PQSupportOn -makeConnectionForUsers_ :: PQEncryption -> AgentClient -> UserId -> AgentClient -> UserId -> ExceptT AgentErrorType IO (ConnId, ConnId) -makeConnectionForUsers_ pqEnc alice aliceUserId bob bobUserId = do - (bobId, qInfo) <- A.createConnection alice aliceUserId True SCMInvitation Nothing (CR.IKNoPQ pqEnc) SMSubscribe - aliceId <- A.joinConnection bob bobUserId True qInfo "bob's connInfo" pqEnc SMSubscribe +makeConnectionForUsers_ :: PQSupport -> AgentClient -> UserId -> AgentClient -> UserId -> ExceptT AgentErrorType IO (ConnId, ConnId) +makeConnectionForUsers_ pqSupport alice aliceUserId bob bobUserId = do + (bobId, qInfo) <- A.createConnection alice aliceUserId True SCMInvitation Nothing (CR.IKNoPQ pqSupport) SMSubscribe + aliceId <- A.joinConnection bob bobUserId True qInfo "bob's connInfo" pqSupport SMSubscribe ("", _, CONF confId _ "bob's connInfo") <- get alice allowConnection alice bobId confId "alice's connInfo" + let pqEnc = CR.pqSupportToEnc pqSupport get alice ##> ("", bobId, A.CON pqEnc) get bob ##> ("", aliceId, INFO "alice's connInfo") get bob ##> ("", aliceId, A.CON pqEnc) @@ -1392,7 +1395,7 @@ testBatchedSubscriptions nCreate nDel t = do a <- getSMPAgentClient' 1 agentCfg initAgentServers2 testDB b <- getSMPAgentClient' 2 agentCfg initAgentServers2 testDB2 conns <- runServers $ do - conns <- replicateM (nCreate :: Int) $ makeConnection_ PQEncOff a b + conns <- replicateM (nCreate :: Int) $ makeConnection_ PQSupportOff a b forM_ conns $ \(aId, bId) -> exchangeGreetings_ PQEncOff a bId b aId let (aIds', bIds') = unzip $ take nDel conns delete a bIds' @@ -1456,10 +1459,10 @@ testBatchedSubscriptions nCreate nDel t = do testAsyncCommands :: IO () testAsyncCommands = withAgentClients2 $ \alice bob -> runRight_ $ do - bobId <- createConnectionAsync alice 1 "1" True SCMInvitation (IKNoPQ PQEncOn) SMSubscribe + bobId <- createConnectionAsync alice 1 "1" True SCMInvitation (IKNoPQ PQSupportOn) SMSubscribe ("1", bobId', INV (ACR _ qInfo)) <- get alice liftIO $ bobId' `shouldBe` bobId - aliceId <- joinConnectionAsync bob 1 "2" True qInfo "bob's connInfo" PQEncOn SMSubscribe + aliceId <- joinConnectionAsync bob 1 "2" True qInfo "bob's connInfo" PQSupportOn SMSubscribe ("2", aliceId', OK) <- get bob liftIO $ aliceId' `shouldBe` aliceId ("", _, CONF confId _ "bob's connInfo") <- get alice @@ -1506,7 +1509,7 @@ testAsyncCommands = testAsyncCommandsRestore :: ATransport -> IO () testAsyncCommandsRestore t = do alice <- getSMPAgentClient' 1 agentCfg initAgentServers testDB - bobId <- runRight $ createConnectionAsync alice 1 "1" True SCMInvitation (IKNoPQ PQEncOn) SMSubscribe + bobId <- runRight $ createConnectionAsync alice 1 "1" True SCMInvitation (IKNoPQ PQSupportOn) SMSubscribe liftIO $ noMessages alice "alice doesn't receive INV because server is down" disconnectAgentClient alice alice' <- liftIO $ getSMPAgentClient' 2 agentCfg initAgentServers testDB @@ -1523,7 +1526,7 @@ testAcceptContactAsync = (_, qInfo) <- createConnection alice 1 True SCMContact Nothing SMSubscribe aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe ("", _, REQ invId _ "bob's connInfo") <- get alice - bobId <- acceptContactAsync alice "1" True invId "alice's connInfo" PQEncOn SMSubscribe + bobId <- acceptContactAsync alice "1" True invId "alice's connInfo" PQSupportOn SMSubscribe get alice =##> \case ("1", c, OK) -> c == bobId; _ -> False ("", _, CONF confId _ "alice's connInfo") <- get bob allowConnection bob aliceId confId "bob's connInfo" @@ -1809,10 +1812,10 @@ testJoinConnectionAsyncReplyError t = do a <- getSMPAgentClient' 1 agentCfg initAgentServers testDB b <- getSMPAgentClient' 2 agentCfg initAgentServersSrv2 testDB2 (aId, bId) <- withSmpServerStoreLogOn t testPort $ \_ -> runRight $ do - bId <- createConnectionAsync a 1 "1" True SCMInvitation (IKNoPQ PQEncOn) SMSubscribe + bId <- createConnectionAsync a 1 "1" True SCMInvitation (IKNoPQ PQSupportOn) SMSubscribe ("1", bId', INV (ACR _ qInfo)) <- get a liftIO $ bId' `shouldBe` bId - aId <- joinConnectionAsync b 1 "2" True qInfo "bob's connInfo" PQEncOn SMSubscribe + aId <- joinConnectionAsync b 1 "2" True qInfo "bob's connInfo" PQSupportOn SMSubscribe liftIO $ threadDelay 500000 ConnectionStats {rcvQueuesInfo = [], sndQueuesInfo = [SndQueueInfo {}]} <- getConnectionServers b aId pure (aId, bId) @@ -2353,7 +2356,7 @@ testDeliveryReceiptsVersion t = do b <- getSMPAgentClient' 2 agentCfg {smpAgentVRange = \_ -> mkVersionRange 1 3} initAgentServers testDB2 withSmpServerStoreMsgLogOn t testPort $ \_ -> do (aId, bId) <- runRight $ do - (aId, bId) <- makeConnection_ PQEncOff a b + (aId, bId) <- makeConnection_ PQSupportOff a b checkVersion a bId 3 checkVersion b aId 3 (4, _) <- A.sendMessage a bId PQEncOff SMP.noMsgFlags "hello" @@ -2392,7 +2395,7 @@ testDeliveryReceiptsVersion t = do get b' =##> \case ("", c, Rcvd 10) -> c == aId; _ -> False ackMessage b' aId 11 Nothing -- TODO PQ this part hangs when waiting for Rcvd, because connection tries to upgrade to PQ encryption. - -- replacing 2 PQEncOn with PQEncOff above prevents hanging. + -- replacing 2 PQSupportOn with PQEncOff above prevents hanging. -- (12, _) <- A.sendMessage a' bId PQEncOn SMP.noMsgFlags "hello 2" -- get a' ##> ("", bId, SENT 12) -- get b' =##> \case ("", c, Msg' 12 PQEncOff "hello 2") -> c == aId; _ -> False diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index af91dac42..4bac4fb83 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -45,7 +45,7 @@ import Simplex.Messaging.Agent.Store.SQLite.Common (withTransaction') import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations import qualified Simplex.Messaging.Crypto as C -import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), pattern PQEncOn) +import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), pattern PQSupportOn) import qualified Simplex.Messaging.Crypto.Ratchet as CR import Simplex.Messaging.Crypto.File (CryptoFile (..)) import Simplex.Messaging.Encoding.String (StrEncoding (..)) @@ -190,7 +190,7 @@ cData1 = lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk, - pqEncryption = CR.PQEncOn + pqSupport = CR.PQSupportOn } testPrivateAuthKey :: C.APrivateAuthKey @@ -662,7 +662,7 @@ testGetPendingServerCommand st = do Right (Just PendingCommand {corrId = corrId'}) <- getPendingServerCommand db (Just smpServer1) corrId' `shouldBe` "4" where - command = AClientCommand $ APC SAEConn $ NEW True (ACM SCMInvitation) (IKNoPQ PQEncOn) SMSubscribe + command = AClientCommand $ APC SAEConn $ NEW True (ACM SCMInvitation) (IKNoPQ PQSupportOn) SMSubscribe corruptCmd :: DB.Connection -> ByteString -> ConnId -> IO () corruptCmd db corrId connId = DB.execute db "UPDATE commands SET command = cast('bad' as blob) WHERE conn_id = ? AND corr_id = ?" (connId, corrId) From 07fa75ec498958fdd702a3292ba579f86259cab5 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin Date: Thu, 7 Mar 2024 08:35:40 +0000 Subject: [PATCH 12/30] pqdr: agent api to confirm PQ encryption support during connection handshake, fix incorrect PQ support (#1032) * pqdr: agent api to confirm PQ encryption support during connection handshake * fix CONF, tests * fix REQ, tests * remove unused --- src/Simplex/Messaging/Agent.hs | 97 +++++++++++++++++-------- src/Simplex/Messaging/Agent/Protocol.hs | 34 ++++----- tests/AgentTests.hs | 53 +++++++++----- tests/AgentTests/FunctionalAPITests.hs | 33 +++++++-- tests/AgentTests/NotificationTests.hs | 4 +- 5 files changed, 144 insertions(+), 77 deletions(-) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 204647ef6..bb335d95c 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -46,6 +46,7 @@ module Simplex.Messaging.Agent withInvLock, createUser, deleteUser, + connRequestPQSupport, createConnectionAsync, joinConnectionAsync, allowConnectionAsync, @@ -160,7 +161,7 @@ import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations import Simplex.Messaging.Client (ProtocolClient (..), ServerTransmission) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.File (CryptoFile, CryptoFileArgs) -import Simplex.Messaging.Crypto.Ratchet (PQEncryption, PQSupport, pattern PQEncOn, pattern PQEncOff, pattern PQSupportOn, pattern PQSupportOff) +import Simplex.Messaging.Crypto.Ratchet (PQEncryption, PQSupport (..), pattern PQEncOn, pattern PQEncOff, pattern PQSupportOn, pattern PQSupportOff) import qualified Simplex.Messaging.Crypto.Ratchet as CR import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String @@ -566,17 +567,17 @@ newConnNoQueues c userId connId enableNtfs cMode pqSupport = do withStore c $ \db -> createNewConn db g cData cMode joinConnAsync :: AgentMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> m ConnId -joinConnAsync c userId corrId enableNtfs cReqUri@(CRInvitationUri ConnReqUriData {crAgentVRange} _) cInfo pqSupport subMode = do +joinConnAsync c userId corrId enableNtfs cReqUri@CRInvitationUri {} cInfo pqSup subMode = do withInvLock c (strEncode cReqUri) "joinConnAsync" $ do - aVRange <- asks $ ($ pqSupport) . smpAgentVRange . config - case crAgentVRange `compatibleVersion` aVRange of - Just (Compatible connAgentVersion) -> do + compatibleInvitationUri cReqUri pqSup >>= \case + Just (_, Compatible (CR.E2ERatchetParams v _ _ _), Compatible connAgentVersion) -> do g <- asks random - let cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk, pqSupport} + let pqSupport = versionPQSupport_ pqSup connAgentVersion (Just v) + cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk, pqSupport} connId <- withStore c $ \db -> createNewConn db g cData SCMInvitation enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn $ JOIN enableNtfs (ACR sConnectionMode cReqUri) pqSupport subMode cInfo pure connId - _ -> throwError $ AGENT A_VERSION + Nothing -> throwError $ AGENT A_VERSION joinConnAsync _c _userId _corrId _enableNtfs (CRContactUri _) _subMode _cInfo _pqEncryption = throwError $ CMD PROHIBITED @@ -690,29 +691,56 @@ joinConn c userId connId enableNtfs cReq cInfo pqSupport subMode = do joinConnSrv c userId connId enableNtfs cReq cInfo pqSupport subMode srv startJoinInvitation :: AgentMonad m => UserId -> ConnId -> Bool -> ConnectionRequestUri 'CMInvitation -> PQSupport -> m (Compatible VersionSMPA, ConnData, NewSndQueue, CR.Ratchet 'C.X448, CR.SndE2ERatchetParams 'C.X448) -startJoinInvitation userId connId enableNtfs (CRInvitationUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)} e2eRcvParamsUri) pqSupport = do - AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config - let e2eVRange = e2eEncryptVRange pqSupport - case ( qUri `compatibleVersion` smpClientVRange, - e2eRcvParamsUri `compatibleVersion` e2eVRange, - crAgentVRange `compatibleVersion` smpAgentVRange pqSupport - ) of - (Just qInfo, Just (Compatible e2eRcvParams@(CR.E2ERatchetParams v _ rcDHRr kem_)), Just aVersion@(Compatible connAgentVersion)) -> do +startJoinInvitation userId connId enableNtfs cReqUri pqSup = + compatibleInvitationUri cReqUri pqSup >>= \case + Just (qInfo, (Compatible e2eRcvParams@(CR.E2ERatchetParams v _ rcDHRr kem_)), aVersion@(Compatible connAgentVersion)) -> do g <- asks random + let pqSupport = versionPQSupport_ pqSup connAgentVersion (Just v) (pk1, pk2, pKem, e2eSndParams) <- liftIO $ CR.generateSndE2EParams g v (CR.replyKEM_ kem_ pqSupport) (_, rcDHRs) <- atomically $ C.generateKeyPair g rcParams <- liftEitherWith cryptoError $ CR.pqX3dhSnd pk1 pk2 pKem e2eRcvParams - let rcVs = CR.RVersions {current = v, maxSupported = maxVersion e2eVRange} + maxSupported <- asks $ maxVersion . ($ pqSup) . e2eEncryptVRange . config + let rcVs = CR.RVersions {current = v, maxSupported} rc = CR.initSndRatchet rcVs rcDHRr rcDHRs rcParams q <- newSndQueue userId "" qInfo let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk, pqSupport} pure (aVersion, cData, q, rc, e2eSndParams) - _ -> throwError $ AGENT A_VERSION + Nothing -> throwError $ AGENT A_VERSION + +connRequestPQSupport :: AgentMonad' m => PQSupport -> ConnectionRequestUri c -> m (Maybe PQSupport) +connRequestPQSupport pqSup cReq = case cReq of + CRInvitationUri {} -> invPQSupported <$$> compatibleInvitationUri cReq pqSup + where + invPQSupported (_, Compatible (CR.E2ERatchetParams e2eV _ _ _), Compatible agentV) = versionPQSupport_ pqSup agentV (Just e2eV) + CRContactUri {} -> ctPQSupported <$$> compatibleContactUri cReq pqSup + where + ctPQSupported (_, Compatible agentV) = versionPQSupport_ pqSup agentV Nothing + +compatibleInvitationUri :: AgentMonad' m => ConnectionRequestUri 'CMInvitation -> PQSupport -> m (Maybe (Compatible SMPQueueInfo, Compatible (CR.RcvE2ERatchetParams 'C.X448), Compatible VersionSMPA)) +compatibleInvitationUri (CRInvitationUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)} e2eRcvParamsUri) pqSup = do + AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config + pure $ + (,,) + <$> (qUri `compatibleVersion` smpClientVRange) + <*> (e2eRcvParamsUri `compatibleVersion` e2eEncryptVRange pqSup) + <*> (crAgentVRange `compatibleVersion` smpAgentVRange pqSup) + +compatibleContactUri :: AgentMonad' m => ConnectionRequestUri 'CMContact -> PQSupport -> m (Maybe (Compatible SMPQueueInfo, Compatible VersionSMPA)) +compatibleContactUri (CRContactUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)}) pqSup = do + AgentConfig {smpClientVRange, smpAgentVRange} <- asks config + pure $ + (,) + <$> (qUri `compatibleVersion` smpClientVRange) + <*> (crAgentVRange `compatibleVersion` smpAgentVRange pqSup) + +versionPQSupport_ :: PQSupport -> VersionSMPA -> Maybe CR.VersionE2E -> PQSupport +versionPQSupport_ (PQSupport sup) agentV e2eV_ = + PQSupport $ sup && pqdrSMPAgentVersion <= agentV && maybe True (CR.pqRatchetE2EEncryptVersion <=) e2eV_ joinConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> SMPServerWithAuth -> m ConnId -joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqSupport subMode srv = +joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqSup subMode srv = withInvLock c (strEncode inv) "joinConnSrv" $ do - (aVersion, cData, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv pqSupport + (aVersion, cData, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv pqSup g <- asks random (connId', sq) <- withStore c $ \db -> runExceptT $ do r@(connId', _) <- ExceptT $ createSndConn db g cData q @@ -725,17 +753,13 @@ joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqSupport su -- possible improvement: recovery for failure on network timeout, see rfcs/2022-04-20-smp-conf-timeout-recovery.md void $ withStore' c $ \db -> deleteConn db Nothing connId' throwError e -joinConnSrv c userId connId enableNtfs (CRContactUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)}) cInfo pqEnc subMode srv = do - aVRange <- asks $ ($ pqEnc) . smpAgentVRange . config - clientVRange <- asks $ smpClientVRange . config - case ( qUri `compatibleVersion` clientVRange, - crAgentVRange `compatibleVersion` aVRange - ) of - (Just qInfo, Just vrsn) -> do - (connId', cReq) <- newConnSrv c userId connId enableNtfs SCMInvitation Nothing (CR.joinContactInitialKeys pqEnc) subMode srv +joinConnSrv c userId connId enableNtfs cReqUri@CRContactUri {} cInfo pqSup subMode srv = + compatibleContactUri cReqUri pqSup >>= \case + Just (qInfo, vrsn) -> do + (connId', cReq) <- newConnSrv c userId connId enableNtfs SCMInvitation Nothing (CR.joinContactInitialKeys pqSup) subMode srv sendInvitation c userId qInfo vrsn cReq cInfo pure connId' - _ -> throwError $ AGENT A_VERSION + Nothing -> throwError $ AGENT A_VERSION joinConnSrvAsync :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> SMPServerWithAuth -> m () joinConnSrvAsync c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqSupport subMode srv = do @@ -2162,7 +2186,8 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, rcParams <- liftError cryptoError $ CR.pqX3dhRcv pk1 rcDHRs pKem e2eSndParams -- TODO PQ combine isCompatible check and construction in one call let rcVs = CR.RVersions {current = e2eVersion, maxSupported = maxVersion e2eVRange} - rc = CR.initRcvRatchet rcVs rcDHRs rcParams pqSupport + pqSupport' = versionPQSupport_ pqSupport agentVersion (Just e2eVersion) + rc = CR.initRcvRatchet rcVs rcDHRs rcParams pqSupport' g <- asks random (agentMsgBody_, rc', skipped) <- liftError cryptoError $ CR.rcDecrypt g rc M.empty encConnInfo case (agentMsgBody_, skipped) of @@ -2176,16 +2201,17 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, let newConfirmation = NewConfirmation {connId, senderConf, ratchetState = rc'} confId <- withStore c $ \db -> do setConnAgentVersion db connId agentVersion + when (pqSupport /= pqSupport') $ setConnPQSupport db connId pqSupport' createConfirmation db g newConfirmation let srvs = map qServer $ smpReplyQueues senderConf - notify $ CONF confId srvs connInfo + notify $ CONF confId pqSupport' srvs connInfo _ -> prohibited -- party accepting connection (DuplexConnection _ (RcvQueue {smpClientVersion = v'} :| _) _, Nothing) -> do g <- asks random withStore c (\db -> runExceptT $ agentRatchetDecrypt g db connId encConnInfo) >>= parseMessage . fst >>= \case AgentConnInfo connInfo -> do - notify $ INFO connInfo + notify $ INFO pqSupport connInfo let dhSecret = C.dh' e2ePubKey e2ePrivKey withStore' c $ \db -> setRcvQueueConfirmedE2E db rq dhSecret $ min v' smpClientVersion enqueueCmd $ ICDuplexSecure rId senderKey @@ -2336,12 +2362,19 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, logServer "<--" c srv rId $ "MSG :" <> logSecret srvMsgId case conn' of ContactConnection {} -> do + -- show connection request even if invitaion via contact address is not compatible. + -- in case invitation not compatible, assume there is no PQ encryption support. + pqSupport <- maybe PQSupportOff pqSupported <$> compatibleInvitationUri connReq PQSupportOn + liftIO $ print pqSupport g <- asks random let newInv = NewInvitation {contactConnId = connId, connReq, recipientConnInfo = cInfo} invId <- withStore c $ \db -> createInvitation db g newInv let srvs = L.map qServer $ crSmpQueues crData - notify $ REQ invId srvs cInfo + notify $ REQ invId pqSupport srvs cInfo _ -> prohibited + where + pqSupported (_, Compatible (CR.E2ERatchetParams v _ _ _), Compatible agentVersion) = + versionPQSupport_ PQSupportOn agentVersion (Just v) qDuplex :: Connection c -> String -> (Connection 'CDuplex -> m ()) -> m () qDuplex conn' name action = case conn' of diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index 02aa5e260..465b19a01 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -374,12 +374,12 @@ data ACommand (p :: AParty) (e :: AEntity) where NEW :: Bool -> AConnectionMode -> InitialKeys -> SubscriptionMode -> ACommand Client AEConn -- response INV INV :: AConnectionRequestUri -> ACommand Agent AEConn JOIN :: Bool -> AConnectionRequestUri -> PQSupport -> SubscriptionMode -> ConnInfo -> ACommand Client AEConn -- response OK - CONF :: ConfirmationId -> [SMPServer] -> ConnInfo -> ACommand Agent AEConn -- ConnInfo is from sender, [SMPServer] will be empty only in v1 handshake + CONF :: ConfirmationId -> PQSupport -> [SMPServer] -> ConnInfo -> ACommand Agent AEConn -- ConnInfo is from sender, [SMPServer] will be empty only in v1 handshake LET :: ConfirmationId -> ConnInfo -> ACommand Client AEConn -- ConnInfo is from client - REQ :: InvitationId -> NonEmpty SMPServer -> ConnInfo -> ACommand Agent AEConn -- ConnInfo is from sender + REQ :: InvitationId -> PQSupport -> NonEmpty SMPServer -> ConnInfo -> ACommand Agent AEConn -- ConnInfo is from sender ACPT :: InvitationId -> PQSupport -> ConnInfo -> ACommand Client AEConn -- ConnInfo is from client RJCT :: InvitationId -> ACommand Client AEConn - INFO :: ConnInfo -> ACommand Agent AEConn + INFO :: PQSupport -> ConnInfo -> ACommand Agent AEConn CON :: PQEncryption -> ACommand Agent AEConn -- notification that connection is established SUB :: ACommand Client AEConn END :: ACommand Agent AEConn @@ -1748,9 +1748,9 @@ commandP binaryP = ACmdTag SAgent e cmd -> ACmd SAgent e <$> case cmd of INV_ -> s (INV <$> strP) - CONF_ -> s (CONF <$> A.takeTill (== ' ') <* A.space <*> strListP <* A.space <*> binaryP) - REQ_ -> s (REQ <$> A.takeTill (== ' ') <* A.space <*> strP_ <*> binaryP) - INFO_ -> s (INFO <$> binaryP) + CONF_ -> s (CONF <$> A.takeTill (== ' ') <* A.space <*> pqSupP <*> strListP <* A.space <*> binaryP) + REQ_ -> s (REQ <$> A.takeTill (== ' ') <* A.space <*> pqSupP <*> strP_ <*> binaryP) + INFO_ -> s (INFO <$> pqSupP <*> binaryP) CON_ -> s (CON <$> strP) END_ -> pure END CONNECT_ -> s (CONNECT <$> strP_ <*> strP) @@ -1805,13 +1805,13 @@ serializeCommand :: ACommand p e -> ByteString serializeCommand = \case NEW ntfs cMode pqIK subMode -> s (NEW_, ntfs, cMode, pqIK, subMode) INV cReq -> s (INV_, cReq) - JOIN ntfs cReq pqEnc subMode cInfo -> s (JOIN_, ntfs, cReq, pqEnc, subMode, Str $ serializeBinary cInfo) - CONF confId srvs cInfo -> B.unwords [s CONF_, confId, strEncodeList srvs, serializeBinary cInfo] + JOIN ntfs cReq pqSup subMode cInfo -> s (JOIN_, ntfs, cReq, pqSup, subMode, Str $ serializeBinary cInfo) + CONF confId pqSup srvs cInfo -> B.unwords [s CONF_, confId, s pqSup, strEncodeList srvs, serializeBinary cInfo] LET confId cInfo -> B.unwords [s LET_, confId, serializeBinary cInfo] - REQ invId srvs cInfo -> B.unwords [s REQ_, invId, s srvs, serializeBinary cInfo] - ACPT invId pqEnc cInfo -> B.unwords [s ACPT_, invId, s pqEnc, serializeBinary cInfo] + REQ invId pqSup srvs cInfo -> B.unwords [s REQ_, invId, s pqSup, s srvs, serializeBinary cInfo] + ACPT invId pqSup cInfo -> B.unwords [s ACPT_, invId, s pqSup, serializeBinary cInfo] RJCT invId -> B.unwords [s RJCT_, invId] - INFO cInfo -> B.unwords [s INFO_, serializeBinary cInfo] + INFO pqSup cInfo -> B.unwords [s INFO_, s pqSup, serializeBinary cInfo] SUB -> s SUB_ END -> s END_ CONNECT p h -> s (CONNECT_, p, h) @@ -1910,14 +1910,14 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody cmdWithMsgBody :: APartyCmd p -> m (Either AgentErrorType (APartyCmd p)) cmdWithMsgBody (APC e cmd) = APC e <$$> case cmd of - SEND kem msgFlags body -> SEND kem msgFlags <$$> getBody body + SEND pqEnc msgFlags body -> SEND pqEnc msgFlags <$$> getBody body MSG msgMeta msgFlags body -> MSG msgMeta msgFlags <$$> getBody body - JOIN ntfs qUri kem subMode cInfo -> JOIN ntfs qUri kem subMode <$$> getBody cInfo - CONF confId srvs cInfo -> CONF confId srvs <$$> getBody cInfo + JOIN ntfs qUri pqSup subMode cInfo -> JOIN ntfs qUri pqSup subMode <$$> getBody cInfo + CONF confId pqSup srvs cInfo -> CONF confId pqSup srvs <$$> getBody cInfo LET confId cInfo -> LET confId <$$> getBody cInfo - REQ invId srvs cInfo -> REQ invId srvs <$$> getBody cInfo - ACPT invId kem cInfo -> ACPT invId kem <$$> getBody cInfo - INFO cInfo -> INFO <$$> getBody cInfo + REQ invId pqSup srvs cInfo -> REQ invId pqSup srvs <$$> getBody cInfo + ACPT invId pqSup cInfo -> ACPT invId pqSup <$$> getBody cInfo + INFO pqSup cInfo -> INFO pqSup <$$> getBody cInfo _ -> pure $ Right cmd getBody :: ByteString -> m (Either AgentErrorType ByteString) diff --git a/tests/AgentTests.hs b/tests/AgentTests.hs index bb91725f8..34719e803 100644 --- a/tests/AgentTests.hs +++ b/tests/AgentTests.hs @@ -27,7 +27,7 @@ import GHC.Stack (withFrozenCallStack) import Network.HTTP.Types (urlEncode) import SMPAgentClient import SMPClient (testKeyHash, testPort, testPort2, testStoreLogFile, withSmpServer, withSmpServerStoreLogOn) -import Simplex.Messaging.Agent.Protocol hiding (MID) +import Simplex.Messaging.Agent.Protocol hiding (MID, CONF, INFO, REQ) import qualified Simplex.Messaging.Agent.Protocol as A import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), PQEncryption (..), PQSupport (..), pattern IKPQOn, pattern IKPQOff, pattern PQEncOn, pattern PQSupportOn, pattern PQSupportOff) import qualified Simplex.Messaging.Crypto.Ratchet as CR @@ -214,12 +214,14 @@ testDuplexConnection _ alice bob = testDuplexConnection' (alice, IKPQOn) (bob, P testDuplexConnection' :: (HasCallStack, Transport c) => (c, InitialKeys) -> (c, PQSupport) -> IO () testDuplexConnection' (alice, aPQ) (bob, bPQ) = do let pq = pqConnectionMode aPQ bPQ + pqSup = CR.pqEncToSupport pq ("1", "bob", Right (INV cReq)) <- alice #: ("1", "bob", "NEW T INV" <> pqConnModeStr aPQ <> " subscribe") let cReq' = strEncode cReq bob #: ("11", "alice", "JOIN T " <> cReq' <> enableKEMStr bPQ <> " subscribe 14\nbob's connInfo") #> ("11", "alice", OK) - ("", "bob", Right (CONF confId _ "bob's connInfo")) <- (alice <#:) + ("", "bob", Right (A.CONF confId pqSup' _ "bob's connInfo")) <- (alice <#:) + pqSup' `shouldBe` pqSup alice #: ("2", "bob", "LET " <> confId <> " 16\nalice's connInfo") #> ("2", "bob", OK) - bob <# ("", "alice", INFO "alice's connInfo") + bob <# ("", "alice", A.INFO pqSup "alice's connInfo") bob <# ("", "alice", CON pq) alice <# ("", "bob", CON pq) -- message IDs 1 to 3 get assigned to control messages, so first MSG is assigned ID 4 @@ -251,13 +253,15 @@ testDuplexConnRandomIds _ alice bob = testDuplexConnRandomIds' (alice, IKPQOn) ( testDuplexConnRandomIds' :: (HasCallStack, Transport c) => (c, InitialKeys) -> (c, PQSupport) -> IO () testDuplexConnRandomIds' (alice, aPQ) (bob, bPQ) = do let pq = pqConnectionMode aPQ bPQ + pqSup = CR.pqEncToSupport pq ("1", bobConn, Right (INV cReq)) <- alice #: ("1", "", "NEW T INV" <> pqConnModeStr aPQ <> " subscribe") let cReq' = strEncode cReq ("11", aliceConn, Right OK) <- bob #: ("11", "", "JOIN T " <> cReq' <> enableKEMStr bPQ <> " subscribe 14\nbob's connInfo") - ("", bobConn', Right (CONF confId _ "bob's connInfo")) <- (alice <#:) + ("", bobConn', Right (A.CONF confId pqSup' _ "bob's connInfo")) <- (alice <#:) + pqSup' `shouldBe` pqSup bobConn' `shouldBe` bobConn alice #: ("2", bobConn, "LET " <> confId <> " 16\nalice's connInfo") =#> \case ("2", c, OK) -> c == bobConn; _ -> False - bob <# ("", aliceConn, INFO "alice's connInfo") + bob <# ("", aliceConn, A.INFO pqSup "alice's connInfo") bob <# ("", aliceConn, CON pq) alice <# ("", bobConn, CON pq) alice #: ("2", bobConn, "SEND F :hello") #> ("2", bobConn, A.MID 4 pq) @@ -287,14 +291,17 @@ testContactConnection (alice, aPQ) (bob, bPQ) (tom, tPQ) = do ("1", "alice_contact", Right (INV cReq)) <- alice #: ("1", "alice_contact", "NEW T CON" <> pqConnModeStr aPQ <> " subscribe") let cReq' = strEncode cReq abPQ = pqConnectionMode aPQ bPQ + abPQSup = CR.pqEncToSupport abPQ aPQMode = CR.connPQEncryption aPQ bob #: ("11", "alice", "JOIN T " <> cReq' <> enableKEMStr bPQ <> " subscribe 14\nbob's connInfo") #> ("11", "alice", OK) - ("", "alice_contact", Right (REQ aInvId _ "bob's connInfo")) <- (alice <#:) + ("", "alice_contact", Right (A.REQ aInvId pqSup' _ "bob's connInfo")) <- (alice <#:) + pqSup' `shouldBe` bPQ alice #: ("2", "bob", "ACPT " <> aInvId <> enableKEMStr aPQMode <> " 16\nalice's connInfo") #> ("2", "bob", OK) - ("", "alice", Right (CONF bConfId _ "alice's connInfo")) <- (bob <#:) + ("", "alice", Right (A.CONF bConfId pqSup'' _ "alice's connInfo")) <- (bob <#:) + pqSup'' `shouldBe` abPQSup bob #: ("12", "alice", "LET " <> bConfId <> " 16\nbob's connInfo 2") #> ("12", "alice", OK) - alice <# ("", "bob", INFO "bob's connInfo 2") + alice <# ("", "bob", A.INFO abPQSup "bob's connInfo 2") alice <# ("", "bob", CON abPQ) bob <# ("", "alice", CON abPQ) alice #: ("3", "bob", "SEND F :hi") #> ("3", "bob", A.MID 4 abPQ) @@ -303,12 +310,15 @@ testContactConnection (alice, aPQ) (bob, bPQ) (tom, tPQ) = do bob #: ("13", "alice", "ACK 4") #> ("13", "alice", OK) let atPQ = pqConnectionMode aPQ tPQ + atPQSup = CR.pqEncToSupport atPQ tom #: ("21", "alice", "JOIN T " <> cReq' <> enableKEMStr tPQ <> " subscribe 14\ntom's connInfo") #> ("21", "alice", OK) - ("", "alice_contact", Right (REQ aInvId' _ "tom's connInfo")) <- (alice <#:) + ("", "alice_contact", Right (A.REQ aInvId' pqSup3 _ "tom's connInfo")) <- (alice <#:) + pqSup3 `shouldBe` tPQ alice #: ("4", "tom", "ACPT " <> aInvId' <> enableKEMStr aPQMode <> " 16\nalice's connInfo") #> ("4", "tom", OK) - ("", "alice", Right (CONF tConfId _ "alice's connInfo")) <- (tom <#:) + ("", "alice", Right (A.CONF tConfId pqSup4 _ "alice's connInfo")) <- (tom <#:) + pqSup4 `shouldBe` atPQSup tom #: ("22", "alice", "LET " <> tConfId <> " 16\ntom's connInfo 2") #> ("22", "alice", OK) - alice <# ("", "tom", INFO "tom's connInfo 2") + alice <# ("", "tom", A.INFO atPQSup "tom's connInfo 2") alice <# ("", "tom", CON atPQ) tom <# ("", "alice", CON atPQ) alice #: ("5", "tom", "SEND F :hi there") #> ("5", "tom", A.MID 4 atPQ) @@ -319,19 +329,22 @@ testContactConnection (alice, aPQ) (bob, bPQ) (tom, tPQ) = do testContactConnRandomIds :: Transport c => (c, InitialKeys) -> (c, PQSupport) -> IO () testContactConnRandomIds (alice, aPQ) (bob, bPQ) = do let pq = pqConnectionMode aPQ bPQ + pqSup = CR.pqEncToSupport pq ("1", aliceContact, Right (INV cReq)) <- alice #: ("1", "", "NEW T CON" <> pqConnModeStr aPQ <> " subscribe") let cReq' = strEncode cReq ("11", aliceConn, Right OK) <- bob #: ("11", "", "JOIN T " <> cReq' <> enableKEMStr bPQ <> " subscribe 14\nbob's connInfo") - ("", aliceContact', Right (REQ aInvId _ "bob's connInfo")) <- (alice <#:) + ("", aliceContact', Right (A.REQ aInvId pqSup' _ "bob's connInfo")) <- (alice <#:) + pqSup' `shouldBe` bPQ aliceContact' `shouldBe` aliceContact ("2", bobConn, Right OK) <- alice #: ("2", "", "ACPT " <> aInvId <> enableKEMStr (CR.connPQEncryption aPQ) <> " 16\nalice's connInfo") - ("", aliceConn', Right (CONF bConfId _ "alice's connInfo")) <- (bob <#:) + ("", aliceConn', Right (A.CONF bConfId pqSup'' _ "alice's connInfo")) <- (bob <#:) + pqSup'' `shouldBe` pqSup aliceConn' `shouldBe` aliceConn bob #: ("12", aliceConn, "LET " <> bConfId <> " 16\nbob's connInfo 2") #> ("12", aliceConn, OK) - alice <# ("", bobConn, INFO "bob's connInfo 2") + alice <# ("", bobConn, A.INFO pqSup "bob's connInfo 2") alice <# ("", bobConn, CON pq) bob <# ("", aliceConn, CON pq) @@ -345,7 +358,7 @@ testRejectContactRequest _ alice bob = do ("1", "a_contact", Right (INV cReq)) <- alice #: ("1", "a_contact", "NEW T CON subscribe") let cReq' = strEncode cReq bob #: ("11", "alice", "JOIN T " <> cReq' <> " subscribe 10\nbob's info") #> ("11", "alice", OK) - ("", "a_contact", Right (REQ aInvId _ "bob's info")) <- (alice <#:) + ("", "a_contact", Right (A.REQ aInvId PQSupportOff _ "bob's info")) <- (alice <#:) -- RJCT must use correct contact connection alice #: ("2a", "bob", "RJCT " <> aInvId) #> ("2a", "bob", ERR $ CONN NOT_FOUND) alice #: ("2b", "a_contact", "RJCT " <> aInvId) #> ("2b", "a_contact", OK) @@ -486,7 +499,7 @@ testConcurrentMsgDelivery _ alice bob = do ("1", "bob2", Right (INV cReq)) <- alice #: ("1", "bob2", "NEW T INV subscribe") let cReq' = strEncode cReq bob #: ("11", "alice2", "JOIN T " <> cReq' <> " subscribe 14\nbob's connInfo") #> ("11", "alice2", OK) - ("", "bob2", Right (CONF _confId _ "bob's connInfo")) <- (alice <#:) + ("", "bob2", Right (A.CONF _confId PQSupportOff _ "bob's connInfo")) <- (alice <#:) -- below commands would be needed to accept bob's connection, but alice does not -- alice #: ("2", "bob", "LET " <> _confId <> " 16\nalice's connInfo") #> ("2", "bob", OK) -- bob <# ("", "alice", INFO "alice's connInfo") @@ -553,11 +566,13 @@ connect' :: forall c. Transport c => (c, ByteString, InitialKeys) -> (c, ByteStr connect' (h1, name1, pqMode1) (h2, name2, pqMode2) = do ("c1", _, Right (INV cReq)) <- h1 #: ("c1", name2, "NEW T INV" <> pqConnModeStr pqMode1 <> " subscribe") let cReq' = strEncode cReq + pq = pqConnectionMode pqMode1 pqMode2 + pqSup = CR.pqEncToSupport pq h2 #: ("c2", name1, "JOIN T " <> cReq' <> enableKEMStr pqMode2 <> " subscribe 5\ninfo2") #> ("c2", name1, OK) - ("", _, Right (CONF connId _ "info2")) <- (h1 <#:) + ("", _, Right (A.CONF connId pqSup' _ "info2")) <- (h1 <#:) + pqSup' `shouldBe` pqSup h1 #: ("c3", name2, "LET " <> connId <> " 5\ninfo1") #> ("c3", name2, OK) - h2 <# ("", name1, INFO "info1") - let pq = pqConnectionMode pqMode1 pqMode2 + h2 <# ("", name1, A.INFO pqSup "info1") h2 <# ("", name1, CON pq) h1 <# ("", name2, CON pq) diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 706f3994f..424a681e2 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -36,6 +36,9 @@ module AgentTests.FunctionalAPITests (##>), (=##>), pattern CON, + pattern CONF, + pattern INFO, + pattern REQ, pattern Msg, pattern Msg', agentCfgV7, @@ -52,6 +55,7 @@ import qualified Data.ByteString.Char8 as B import Data.Either (isRight) import Data.Int (Int64) import Data.List (nub) +import Data.List.NonEmpty (NonEmpty) import qualified Data.Map as M import Data.Maybe (isJust, isNothing) import qualified Data.Set as S @@ -66,7 +70,7 @@ import Simplex.Messaging.Agent hiding (createConnection, joinConnection, sendMes import qualified Simplex.Messaging.Agent as A import Simplex.Messaging.Agent.Client (ProtocolTestFailure (..), ProtocolTestStep (..)) import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..), InitialAgentServers (..), createAgentStore) -import Simplex.Messaging.Agent.Protocol hiding (CON) +import Simplex.Messaging.Agent.Protocol hiding (CON, CONF, INFO, REQ) import qualified Simplex.Messaging.Agent.Protocol as A import Simplex.Messaging.Agent.Store.SQLite (MigrationConfirmation (..), SQLiteStore (dbNew)) import Simplex.Messaging.Agent.Store.SQLite.Common (withTransaction') @@ -144,6 +148,15 @@ pGet c = do DISCONNECT {} -> pGet c _ -> pure t +pattern CONF :: ConfirmationId -> [SMPServer] -> ConnInfo -> ACommand 'Agent e +pattern CONF conId srvs connInfo <- A.CONF conId PQSupportOn srvs connInfo + +pattern INFO :: ConnInfo -> ACommand 'Agent 'AEConn +pattern INFO connInfo = A.INFO PQSupportOn connInfo + +pattern REQ :: InvitationId -> NonEmpty SMPServer -> ConnInfo -> ACommand 'Agent e +pattern REQ invId srvs connInfo <- A.REQ invId PQSupportOn srvs connInfo + pattern CON :: ACommand 'Agent 'AEConn pattern CON = A.CON PQEncOn @@ -471,11 +484,12 @@ runAgentClientTest pqSupport alice@AgentClient {} bob baseId = runRight_ $ do (bobId, qInfo) <- A.createConnection alice 1 True SCMInvitation Nothing (IKNoPQ pqSupport) SMSubscribe aliceId <- A.joinConnection bob 1 True qInfo "bob's connInfo" pqSupport SMSubscribe - ("", _, CONF confId _ "bob's connInfo") <- get alice + ("", _, A.CONF confId pqSup' _ "bob's connInfo") <- get alice + liftIO $ pqSup' `shouldBe` pqSupport allowConnection alice bobId confId "alice's connInfo" let pqEnc = CR.pqSupportToEnc pqSupport get alice ##> ("", bobId, A.CON pqEnc) - get bob ##> ("", aliceId, INFO "alice's connInfo") + get bob ##> ("", aliceId, A.INFO pqSupport "alice's connInfo") get bob ##> ("", aliceId, A.CON pqEnc) -- message IDs 1 to 3 (or 1 to 4 in v1) get assigned to control messages, so first MSG is assigned ID 4 1 <- msgId <$> A.sendMessage alice bobId pqEnc SMP.noMsgFlags "hello" @@ -533,12 +547,14 @@ runAgentClientContactTest pqSupport alice bob baseId = runRight_ $ do (_, qInfo) <- A.createConnection alice 1 True SCMContact Nothing (IKNoPQ pqSupport) SMSubscribe aliceId <- A.joinConnection bob 1 True qInfo "bob's connInfo" pqSupport SMSubscribe - ("", _, REQ invId _ "bob's connInfo") <- get alice + ("", _, A.REQ invId pqSup' _ "bob's connInfo") <- get alice + liftIO $ pqSup' `shouldBe` pqSupport bobId <- acceptContact alice True invId "alice's connInfo" PQSupportOn SMSubscribe - ("", _, CONF confId _ "alice's connInfo") <- get bob + ("", _, A.CONF confId pqSup'' _ "alice's connInfo") <- get bob + liftIO $ pqSup'' `shouldBe` pqSupport allowConnection bob aliceId confId "bob's connInfo" let pqEnc = CR.pqSupportToEnc pqSupport - get alice ##> ("", bobId, INFO "bob's connInfo") + get alice ##> ("", bobId, A.INFO pqSupport "bob's connInfo") get alice ##> ("", bobId, A.CON pqEnc) get bob ##> ("", aliceId, A.CON pqEnc) -- message IDs 1 to 3 (or 1 to 4 in v1) get assigned to control messages, so first MSG is assigned ID 4 @@ -1267,11 +1283,12 @@ makeConnectionForUsers_ :: PQSupport -> AgentClient -> UserId -> AgentClient -> makeConnectionForUsers_ pqSupport alice aliceUserId bob bobUserId = do (bobId, qInfo) <- A.createConnection alice aliceUserId True SCMInvitation Nothing (CR.IKNoPQ pqSupport) SMSubscribe aliceId <- A.joinConnection bob bobUserId True qInfo "bob's connInfo" pqSupport SMSubscribe - ("", _, CONF confId _ "bob's connInfo") <- get alice + ("", _, A.CONF confId pqSup' _ "bob's connInfo") <- get alice + liftIO $ pqSup' `shouldBe` pqSupport allowConnection alice bobId confId "alice's connInfo" let pqEnc = CR.pqSupportToEnc pqSupport get alice ##> ("", bobId, A.CON pqEnc) - get bob ##> ("", aliceId, INFO "alice's connInfo") + get bob ##> ("", aliceId, A.INFO pqSupport "alice's connInfo") get bob ##> ("", aliceId, A.CON pqEnc) pure (aliceId, bobId) diff --git a/tests/AgentTests/NotificationTests.hs b/tests/AgentTests/NotificationTests.hs index f815fb808..d8354efed 100644 --- a/tests/AgentTests/NotificationTests.hs +++ b/tests/AgentTests/NotificationTests.hs @@ -30,6 +30,8 @@ import AgentTests.FunctionalAPITests (##>), (=##>), pattern CON, + pattern CONF, + pattern INFO, pattern Msg, ) import Control.Concurrent (ThreadId, killThread, threadDelay) @@ -50,7 +52,7 @@ import SMPClient (cfg, cfgV7, testPort, testPort2, testStoreLogFile2, withSmpSer import Simplex.Messaging.Agent hiding (createConnection, joinConnection, sendMessage) import Simplex.Messaging.Agent.Client (ProtocolTestFailure (..), ProtocolTestStep (..), withStore') import Simplex.Messaging.Agent.Env.SQLite (AgentConfig, Env (..), InitialAgentServers) -import Simplex.Messaging.Agent.Protocol hiding (CON) +import Simplex.Messaging.Agent.Protocol hiding (CON, CONF, INFO) import Simplex.Messaging.Agent.Store.SQLite (getSavedNtfToken) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding.String From 11288866f90bafb0892701b0e0679eddb030b5df Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin Date: Thu, 7 Mar 2024 12:41:10 +0000 Subject: [PATCH 13/30] pqdr: refactor --- src/Simplex/Messaging/Agent.hs | 18 +++++++++--------- src/Simplex/Messaging/Agent/Store/SQLite.hs | 8 -------- src/Simplex/Messaging/Crypto/Ratchet.hs | 21 +++++++++++++++++---- 3 files changed, 26 insertions(+), 21 deletions(-) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index bb335d95c..9e27ca8f1 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -572,7 +572,7 @@ joinConnAsync c userId corrId enableNtfs cReqUri@CRInvitationUri {} cInfo pqSup compatibleInvitationUri cReqUri pqSup >>= \case Just (_, Compatible (CR.E2ERatchetParams v _ _ _), Compatible connAgentVersion) -> do g <- asks random - let pqSupport = versionPQSupport_ pqSup connAgentVersion (Just v) + let pqSupport = pqSup `CR.pqSupportAnd` versionPQSupport_ connAgentVersion (Just v) cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk, pqSupport} connId <- withStore c $ \db -> createNewConn db g cData SCMInvitation enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn $ JOIN enableNtfs (ACR sConnectionMode cReqUri) pqSupport subMode cInfo @@ -695,7 +695,7 @@ startJoinInvitation userId connId enableNtfs cReqUri pqSup = compatibleInvitationUri cReqUri pqSup >>= \case Just (qInfo, (Compatible e2eRcvParams@(CR.E2ERatchetParams v _ rcDHRr kem_)), aVersion@(Compatible connAgentVersion)) -> do g <- asks random - let pqSupport = versionPQSupport_ pqSup connAgentVersion (Just v) + let pqSupport = pqSup `CR.pqSupportAnd` versionPQSupport_ connAgentVersion (Just v) (pk1, pk2, pKem, e2eSndParams) <- liftIO $ CR.generateSndE2EParams g v (CR.replyKEM_ kem_ pqSupport) (_, rcDHRs) <- atomically $ C.generateKeyPair g rcParams <- liftEitherWith cryptoError $ CR.pqX3dhSnd pk1 pk2 pKem e2eRcvParams @@ -711,10 +711,10 @@ connRequestPQSupport :: AgentMonad' m => PQSupport -> ConnectionRequestUri c -> connRequestPQSupport pqSup cReq = case cReq of CRInvitationUri {} -> invPQSupported <$$> compatibleInvitationUri cReq pqSup where - invPQSupported (_, Compatible (CR.E2ERatchetParams e2eV _ _ _), Compatible agentV) = versionPQSupport_ pqSup agentV (Just e2eV) + invPQSupported (_, Compatible (CR.E2ERatchetParams e2eV _ _ _), Compatible agentV) = pqSup `CR.pqSupportAnd` versionPQSupport_ agentV (Just e2eV) CRContactUri {} -> ctPQSupported <$$> compatibleContactUri cReq pqSup where - ctPQSupported (_, Compatible agentV) = versionPQSupport_ pqSup agentV Nothing + ctPQSupported (_, Compatible agentV) = pqSup `CR.pqSupportAnd` versionPQSupport_ agentV Nothing compatibleInvitationUri :: AgentMonad' m => ConnectionRequestUri 'CMInvitation -> PQSupport -> m (Maybe (Compatible SMPQueueInfo, Compatible (CR.RcvE2ERatchetParams 'C.X448), Compatible VersionSMPA)) compatibleInvitationUri (CRInvitationUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)} e2eRcvParamsUri) pqSup = do @@ -733,9 +733,9 @@ compatibleContactUri (CRContactUri ConnReqUriData {crAgentVRange, crSmpQueues = <$> (qUri `compatibleVersion` smpClientVRange) <*> (crAgentVRange `compatibleVersion` smpAgentVRange pqSup) -versionPQSupport_ :: PQSupport -> VersionSMPA -> Maybe CR.VersionE2E -> PQSupport -versionPQSupport_ (PQSupport sup) agentV e2eV_ = - PQSupport $ sup && pqdrSMPAgentVersion <= agentV && maybe True (CR.pqRatchetE2EEncryptVersion <=) e2eV_ +versionPQSupport_ :: VersionSMPA -> Maybe CR.VersionE2E -> PQSupport +versionPQSupport_ agentV e2eV_ = + PQSupport $ pqdrSMPAgentVersion <= agentV && maybe True (CR.pqRatchetE2EEncryptVersion <=) e2eV_ joinConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> SMPServerWithAuth -> m ConnId joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqSup subMode srv = @@ -2186,7 +2186,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, rcParams <- liftError cryptoError $ CR.pqX3dhRcv pk1 rcDHRs pKem e2eSndParams -- TODO PQ combine isCompatible check and construction in one call let rcVs = CR.RVersions {current = e2eVersion, maxSupported = maxVersion e2eVRange} - pqSupport' = versionPQSupport_ pqSupport agentVersion (Just e2eVersion) + pqSupport' = pqSupport `CR.pqSupportAnd` versionPQSupport_ agentVersion (Just e2eVersion) rc = CR.initRcvRatchet rcVs rcDHRs rcParams pqSupport' g <- asks random (agentMsgBody_, rc', skipped) <- liftError cryptoError $ CR.rcDecrypt g rc M.empty encConnInfo @@ -2374,7 +2374,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, _ -> prohibited where pqSupported (_, Compatible (CR.E2ERatchetParams v _ _ _), Compatible agentVersion) = - versionPQSupport_ PQSupportOn agentVersion (Just v) + PQSupportOn `CR.pqSupportAnd` versionPQSupport_ agentVersion (Just v) qDuplex :: Connection c -> String -> (Connection 'CDuplex -> m ()) -> m () qDuplex conn' name action = case conn' of diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 33051d234..2f6707c5a 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -1776,14 +1776,6 @@ instance ToField (Version v) where toField (Version v) = toField v instance FromField (Version v) where fromField f = Version <$> fromField f -instance ToField PQEncryption where toField (PQEncryption pqEnc) = toField pqEnc - -instance FromField PQEncryption where fromField f = PQEncryption <$> fromField f - -instance ToField PQSupport where toField (PQSupport pqEnc) = toField pqEnc - -instance FromField PQSupport where fromField f = PQSupport <$> fromField f - listToEither :: e -> [a] -> Either e a listToEither _ (x : _) = Right x listToEither e _ = Left e diff --git a/src/Simplex/Messaging/Crypto/Ratchet.hs b/src/Simplex/Messaging/Crypto/Ratchet.hs index 38ada0f01..47ca1c0f2 100644 --- a/src/Simplex/Messaging/Crypto/Ratchet.hs +++ b/src/Simplex/Messaging/Crypto/Ratchet.hs @@ -58,6 +58,8 @@ module Simplex.Messaging.Crypto.Ratchet replyKEM_, pqSupportToEnc, pqEncToSupport, + pqSupportAnd, + pqSupportOrEnc, pqX3dhSnd, pqX3dhRcv, initSndRatchet, @@ -788,8 +790,11 @@ pqSupportToEnc (PQSupport pq) = PQEncryption pq pqEncToSupport :: PQEncryption -> PQSupport pqEncToSupport (PQEncryption pq) = PQSupport pq -supportOrEnc :: PQSupport -> PQEncryption -> PQSupport -supportOrEnc (PQSupport sup) (PQEncryption enc) = PQSupport $ sup || enc +pqSupportAnd :: PQSupport -> PQSupport -> PQSupport +pqSupportAnd (PQSupport s1) (PQSupport s2) = PQSupport $ s1 && s2 + +pqSupportOrEnc :: PQSupport -> PQEncryption -> PQSupport +pqSupportOrEnc (PQSupport sup) (PQEncryption enc) = PQSupport $ sup || enc replyKEM_ :: Maybe (RKEMParams 'RKSProposed) -> PQSupport -> Maybe AUseKEM replyKEM_ kem_ = \case @@ -858,7 +863,7 @@ rcEncrypt rc@Ratchet {rcSnd = Just sr@SndRatchet {rcCKs, rcHKs}, rcDHRs, rcKEM, -- PQ encryption can be enabled or disabled rcEnableKEM' = fromMaybe rcEnableKEM pqEnc_ -- support for PQ encryption (and therefore large headers/small envelopes) can only be enabled, it cannot be disabled - rcSupportKEM' = rcSupportKEM `supportOrEnc` rcEnableKEM' + rcSupportKEM' = rcSupportKEM `pqSupportOrEnc` rcEnableKEM' -- enc_header = HENCRYPT(state.HKs, header) (ehAuthTag, ehBody) <- encryptAEAD rcHKs ehIV (paddedHeaderLen rcSupportKEM') rcAD (msgHeader v) -- return enc_header, ENCRYPT(mk, plaintext, CONCAT(AD, enc_header)) @@ -992,7 +997,7 @@ rcDecrypt g rc@Ratchet {rcRcv, rcAD = Str rcAD, rcVersion} rcMKSkipped msg' = do rc' { rcDHRs = rcDHRs', rcKEM = rcKEM', - rcSupportKEM = rcSupportKEM `supportOrEnc` rcEnableKEM', + rcSupportKEM = rcSupportKEM `pqSupportOrEnc` rcEnableKEM', rcEnableKEM = rcEnableKEM', rcSndKEM = PQEncryption sndKEM, rcRcvKEM = PQEncryption rcvKEM, @@ -1138,3 +1143,11 @@ instance AlgorithmI a => FromJSON (Ratchet a) where instance AlgorithmI a => ToField (Ratchet a) where toField = toField . LB.toStrict . J.encode instance (AlgorithmI a, Typeable a) => FromField (Ratchet a) where fromField = blobFieldDecoder J.eitherDecodeStrict' + +instance ToField PQEncryption where toField (PQEncryption pqEnc) = toField pqEnc + +instance FromField PQEncryption where fromField f = PQEncryption <$> fromField f + +instance ToField PQSupport where toField (PQSupport pqEnc) = toField pqEnc + +instance FromField PQSupport where fromField f = PQSupport <$> fromField f From dd22b5c823eb38310bdb51703b7941752877d27b Mon Sep 17 00:00:00 2001 From: Alexander Bondarenko <486682+dpwiz@users.noreply.github.com> Date: Thu, 7 Mar 2024 16:17:32 +0200 Subject: [PATCH 14/30] core: tweak compress api (#1029) * convert compress to a wrapper with passthrough fallback * add compress1 for non-batched compression * use original size as upper bound for scratch * refactor --------- Co-authored-by: Evgeny Poberezkin --- src/Simplex/Messaging/Compression.hs | 30 +++++++++++++++++----------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/src/Simplex/Messaging/Compression.hs b/src/Simplex/Messaging/Compression.hs index c6664a179..fec9f8151 100644 --- a/src/Simplex/Messaging/Compression.hs +++ b/src/Simplex/Messaging/Compression.hs @@ -5,9 +5,12 @@ module Simplex.Messaging.Compression where import qualified Codec.Compression.Zstd.FFI as Z import Control.Monad (forM) +import Control.Monad.Except +import Control.Monad.IO.Class import Data.ByteString (ByteString) import qualified Data.ByteString as B import qualified Data.ByteString.Unsafe as B +import Data.Either (fromRight) import Data.List.NonEmpty (NonEmpty) import Foreign import Foreign.C.Types @@ -23,7 +26,7 @@ data Compressed -- | Messages below this length are not encoded to avoid compression overhead. maxLengthPassthrough :: Int -maxLengthPassthrough = 181 -- Sampled from real client data. Messages with length >=181 rapidly gain compression ratio. +maxLengthPassthrough = 180 -- Sampled from real client data. Messages with length > 180 rapidly gain compression ratio. instance Encoding Compressed where smpEncode = \case @@ -35,23 +38,26 @@ instance Encoding Compressed where '1' -> Compressed <$> smpP x -> fail $ "unknown Compressed tag: " <> show x -type CompressCtx = (Ptr Z.CCtx, Ptr CChar, Int) +type CompressCtx = (Ptr Z.CCtx, Ptr CChar, CSize) -withCompressCtx :: Int -> (CompressCtx -> IO a) -> IO a +withCompressCtx :: CSize -> (CompressCtx -> IO a) -> IO a withCompressCtx scratchSize action = bracket Z.createCCtx Z.freeCCtx $ \cctx -> - allocaBytes scratchSize $ \scratchPtr -> + allocaBytes (fromIntegral scratchSize) $ \scratchPtr -> action (cctx, scratchPtr, scratchSize) -compress :: CompressCtx -> ByteString -> IO (Either String Compressed) -compress (cctx, scratchPtr, scratchSize) bs - | B.length bs < maxLengthPassthrough = pure . Right $ Passthrough bs +-- | Compress bytes, falling back to Passthrough in case of some internal error. +compress :: CompressCtx -> ByteString -> IO Compressed +compress ctx bs = fromRight (Passthrough bs) <$> compress_ ctx bs + +compress_ :: CompressCtx -> ByteString -> IO (Either String Compressed) +compress_ (cctx, scratchPtr, scratchSize) bs + | B.length bs <= maxLengthPassthrough = pure . Right $ Passthrough bs | otherwise = - B.unsafeUseAsCStringLen bs $ \(sourcePtr, sourceSize) -> do - res <- Z.checkError $ Z.compressCCtx cctx scratchPtr (fromIntegral scratchSize) sourcePtr (fromIntegral sourceSize) 3 - case res of - Left e -> pure $ Left e -- should not happen, unless input buffer is too short - Right dstSize -> Right . Compressed . Large <$> B.packCStringLen (scratchPtr, fromIntegral dstSize) + B.unsafeUseAsCStringLen bs $ \(sourcePtr, sourceSize) -> runExceptT $ do + -- should not fail, unless input buffer is too short + dstSize <- ExceptT $ Z.checkError $ Z.compressCCtx cctx scratchPtr scratchSize sourcePtr (fromIntegral sourceSize) 3 + liftIO $ Compressed . Large <$> B.packCStringLen (scratchPtr, fromIntegral dstSize) type DecompressCtx = (Ptr Z.DCtx, Ptr CChar, CSize) From 5e23fa6cfc60c5efd561f9131a9528b9ccb9782d Mon Sep 17 00:00:00 2001 From: spaced4ndy <8711996+spaced4ndy@users.noreply.github.com> Date: Thu, 7 Mar 2024 19:44:48 +0400 Subject: [PATCH 15/30] agent pq: connRequestPQSupport api (#1034) --- src/Simplex/Messaging/Agent.hs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 9e27ca8f1..fd7bc7e24 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -221,6 +221,9 @@ createUser c = withAgentEnv c .: createUser' c deleteUser :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> m () deleteUser c = withAgentEnv c .: deleteUser' c +connRequestPQSupport :: (MonadUnliftIO m) => AgentClient -> PQSupport -> ConnectionRequestUri c -> m (Maybe PQSupport) +connRequestPQSupport c = withAgentEnv c .: connRequestPQSupport' c + -- | Create SMP agent connection (NEW command) asynchronously, synchronous response is new connection id createConnectionAsync :: forall m c. (AgentErrorMonad m, ConnectionModeI c) => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> CR.InitialKeys -> SubscriptionMode -> m ConnId createConnectionAsync c userId aCorrId enableNtfs = withAgentEnv c .:. newConnAsync c userId aCorrId enableNtfs @@ -707,8 +710,8 @@ startJoinInvitation userId connId enableNtfs cReqUri pqSup = pure (aVersion, cData, q, rc, e2eSndParams) Nothing -> throwError $ AGENT A_VERSION -connRequestPQSupport :: AgentMonad' m => PQSupport -> ConnectionRequestUri c -> m (Maybe PQSupport) -connRequestPQSupport pqSup cReq = case cReq of +connRequestPQSupport' :: AgentMonad' m => AgentClient -> PQSupport -> ConnectionRequestUri c -> m (Maybe PQSupport) +connRequestPQSupport' _c pqSup cReq = case cReq of CRInvitationUri {} -> invPQSupported <$$> compatibleInvitationUri cReq pqSup where invPQSupported (_, Compatible (CR.E2ERatchetParams e2eV _ _ _), Compatible agentV) = pqSup `CR.pqSupportAnd` versionPQSupport_ agentV (Just e2eV) From 8ff4c628b5b3c35d997535283b4b3de15bb15bd7 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin Date: Fri, 8 Mar 2024 08:28:15 +0000 Subject: [PATCH 16/30] pqdr: make envelope sizes dependent on version, test enabling PQ (#1035) --- src/Simplex/Messaging/Agent.hs | 11 ++-- src/Simplex/Messaging/Agent/Protocol.hs | 16 ++--- src/Simplex/Messaging/Crypto/Ratchet.hs | 67 ++++++++++----------- tests/AgentTests/DoubleRatchetTests.hs | 27 ++++----- tests/AgentTests/FunctionalAPITests.hs | 77 ++++++++++++++++++++++++- 5 files changed, 131 insertions(+), 67 deletions(-) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index fd7bc7e24..d0cb7d2fa 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -699,7 +699,7 @@ startJoinInvitation userId connId enableNtfs cReqUri pqSup = Just (qInfo, (Compatible e2eRcvParams@(CR.E2ERatchetParams v _ rcDHRr kem_)), aVersion@(Compatible connAgentVersion)) -> do g <- asks random let pqSupport = pqSup `CR.pqSupportAnd` versionPQSupport_ connAgentVersion (Just v) - (pk1, pk2, pKem, e2eSndParams) <- liftIO $ CR.generateSndE2EParams g v (CR.replyKEM_ kem_ pqSupport) + (pk1, pk2, pKem, e2eSndParams) <- liftIO $ CR.generateSndE2EParams g v (CR.replyKEM_ v kem_ pqSupport) (_, rcDHRs) <- atomically $ C.generateKeyPair g rcParams <- liftEitherWith cryptoError $ CR.pqX3dhSnd pk1 pk2 pKem e2eRcvParams maxSupported <- asks $ maxVersion . ($ pqSup) . e2eEncryptVRange . config @@ -759,7 +759,7 @@ joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqSup subMod joinConnSrv c userId connId enableNtfs cReqUri@CRContactUri {} cInfo pqSup subMode srv = compatibleContactUri cReqUri pqSup >>= \case Just (qInfo, vrsn) -> do - (connId', cReq) <- newConnSrv c userId connId enableNtfs SCMInvitation Nothing (CR.joinContactInitialKeys pqSup) subMode srv + (connId', cReq) <- newConnSrv c userId connId enableNtfs SCMInvitation Nothing (CR.IKNoPQ pqSup) subMode srv sendInvitation c userId qInfo vrsn cReq cInfo pure connId' Nothing -> throwError $ AGENT A_VERSION @@ -2368,7 +2368,6 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, -- show connection request even if invitaion via contact address is not compatible. -- in case invitation not compatible, assume there is no PQ encryption support. pqSupport <- maybe PQSupportOff pqSupported <$> compatibleInvitationUri connReq PQSupportOn - liftIO $ print pqSupport g <- asks random let newInv = NewInvitation {contactConnId = connId, connReq, recipientConnInfo = cInfo} invId <- withStore c $ \db -> createInvitation db g newInv @@ -2554,10 +2553,10 @@ enqueueRatchetKey c cData@ConnData {connId, pqSupport} sq e2eEncryption = do pure internalId -- encoded AgentMessage -> encoded EncAgentMessage -agentRatchetEncrypt :: DB.Connection -> ConnData -> ByteString -> (PQSupport -> Int) -> Maybe PQEncryption -> ExceptT StoreError IO (ByteString, PQEncryption) -agentRatchetEncrypt db ConnData {connId, pqSupport} msg getPaddedLen pqEnc_ = do +agentRatchetEncrypt :: DB.Connection -> ConnData -> ByteString -> (VersionSMPA -> PQSupport -> Int) -> Maybe PQEncryption -> ExceptT StoreError IO (ByteString, PQEncryption) +agentRatchetEncrypt db ConnData {connId, connAgentVersion = v, pqSupport} msg getPaddedLen pqEnc_ = do rc <- ExceptT $ getRatchet db connId - let paddedLen = getPaddedLen pqSupport + let paddedLen = getPaddedLen v pqSupport (encMsg, rc') <- liftE (SEAgentError . cryptoError) $ CR.rcEncrypt rc paddedLen msg pqEnc_ liftIO $ updateRatchet db connId rc' CR.SMDNoChange pure (encMsg, CR.rcSndKEM rc') diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index 465b19a01..0c6bfd1ba 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -284,17 +284,17 @@ supportedSMPAgentVRange pq = -- it is shorter to allow all handshake headers, -- including E2E (double-ratchet) parameters and -- signing key of the sender for the server -e2eEncConnInfoLength :: PQSupport -> Int -e2eEncConnInfoLength = \case +e2eEncConnInfoLength :: VersionSMPA -> PQSupport -> Int +e2eEncConnInfoLength v = \case -- reduced by 3700 (roughly the increase of message ratchet header size + key and ciphertext in reply link) - PQSupportOn -> 11148 - PQSupportOff -> 14848 + PQSupportOn | v >= pqdrSMPAgentVersion -> 11148 + _ -> 14848 -e2eEncUserMsgLength :: PQSupport -> Int -e2eEncUserMsgLength = \case +e2eEncUserMsgLength :: VersionSMPA -> PQSupport -> Int +e2eEncUserMsgLength v = \case -- reduced by 2200 (roughly the increase of message ratchet header size) - PQSupportOn -> 13656 - PQSupportOff -> 15856 + PQSupportOn | v >= pqdrSMPAgentVersion -> 13656 + _ -> 15856 -- | Raw (unparsed) SMP agent protocol transmission. type ARawTransmission = (ByteString, ByteString, ByteString) diff --git a/src/Simplex/Messaging/Crypto/Ratchet.hs b/src/Simplex/Messaging/Crypto/Ratchet.hs index 47ca1c0f2..534fa7c07 100644 --- a/src/Simplex/Messaging/Crypto/Ratchet.hs +++ b/src/Simplex/Messaging/Crypto/Ratchet.hs @@ -54,12 +54,11 @@ module Simplex.Messaging.Crypto.Ratchet generateSndE2EParams, initialPQEncryption, connPQEncryption, - joinContactInitialKeys, replyKEM_, pqSupportToEnc, pqEncToSupport, pqSupportAnd, - pqSupportOrEnc, + pqEnableSupport, pqX3dhSnd, pqX3dhRcv, initSndRatchet, @@ -672,15 +671,15 @@ data MsgHeader a = MsgHeader -- 69 = 2 (original size) + 2 + 1+56 (Curve448) + 4 + 4 -- TODO PQ this must be version-dependent -- TODO this is the exact size, some reserve should be added -paddedHeaderLen :: PQSupport -> Int -paddedHeaderLen = \case - PQSupportOn -> 2288 - PQSupportOff -> 88 +paddedHeaderLen :: VersionE2E -> PQSupport -> Int +paddedHeaderLen v = \case + PQSupportOn | v >= pqRatchetE2EEncryptVersion -> 2288 + _ -> 88 -- only used in tests to validate correct padding -- (2 bytes - version size, 1 byte - header size, not to have it fixed or version-dependent) -fullHeaderLen :: PQSupport -> Int -fullHeaderLen pq = 2 + 1 + paddedHeaderLen pq + authTagSize + ivSize @AES256 +fullHeaderLen :: VersionE2E -> PQSupport -> Int +fullHeaderLen v pq = 2 + 1 + paddedHeaderLen v pq + authTagSize + ivSize @AES256 -- pass the current version, as MsgHeader only includes the max supported version that can be different from the current encodeMsgHeader :: AlgorithmI a => VersionE2E -> MsgHeader a -> ByteString @@ -718,7 +717,8 @@ instance Encoding EncMessageHeader where encodeLarge :: VersionE2E -> ByteString -> ByteString encodeLarge v s -- the condition for length is not necessary, it's here as a fallback. - | v >= pqRatchetE2EEncryptVersion || B.length s > 255 = smpEncode $ Large s + -- | v >= pqRatchetE2EEncryptVersion || B.length s > 255 = smpEncode $ Large s + | v >= pqRatchetE2EEncryptVersion = smpEncode $ Large s | otherwise = smpEncode s -- This parser relies on the fact that header cannot be shorter than 32 bytes (it is ~69 bytes without PQ KEM), @@ -793,15 +793,15 @@ pqEncToSupport (PQEncryption pq) = PQSupport pq pqSupportAnd :: PQSupport -> PQSupport -> PQSupport pqSupportAnd (PQSupport s1) (PQSupport s2) = PQSupport $ s1 && s2 -pqSupportOrEnc :: PQSupport -> PQEncryption -> PQSupport -pqSupportOrEnc (PQSupport sup) (PQEncryption enc) = PQSupport $ sup || enc +pqEnableSupport :: VersionE2E -> PQSupport -> PQEncryption -> PQSupport +pqEnableSupport v (PQSupport sup) (PQEncryption enc) = PQSupport $ sup || (v >= pqRatchetE2EEncryptVersion && enc) -replyKEM_ :: Maybe (RKEMParams 'RKSProposed) -> PQSupport -> Maybe AUseKEM -replyKEM_ kem_ = \case - PQSupportOn -> Just $ case kem_ of +replyKEM_ :: VersionE2E -> Maybe (RKEMParams 'RKSProposed) -> PQSupport -> Maybe AUseKEM +replyKEM_ v kem_ = \case + PQSupportOn | v >= pqRatchetE2EEncryptVersion -> Just $ case kem_ of Just (RKParamsProposed k) -> AUseKEM SRKSAccepted $ AcceptKEM k Nothing -> AUseKEM SRKSProposed ProposeKEM - PQSupportOff -> Nothing + _ -> Nothing instance StrEncoding PQEncryption where strEncode pqMode @@ -848,12 +848,6 @@ connPQEncryption = \case IKUsePQ -> PQSupportOn IKNoPQ pq -> pq -- default for creating connection is IKNoPQ PQEncOn --- determines whether PQ key should be included in invitation link sent to contact address -joinContactInitialKeys :: PQSupport -> InitialKeys -joinContactInitialKeys = \case - PQSupportOn -> IKUsePQ -- default - PQSupportOff -> IKNoPQ PQSupportOff - rcEncrypt :: AlgorithmI a => Ratchet a -> Int -> ByteString -> Maybe PQEncryption -> ExceptT CryptoError IO (ByteString, Ratchet a) rcEncrypt Ratchet {rcSnd = Nothing} _ _ _ = throwE CERatchetState rcEncrypt rc@Ratchet {rcSnd = Just sr@SndRatchet {rcCKs, rcHKs}, rcDHRs, rcKEM, rcNs, rcPN, rcAD = Str rcAD, rcSupportKEM, rcEnableKEM, rcVersion} paddedMsgLen msg pqEnc_ = do @@ -863,9 +857,14 @@ rcEncrypt rc@Ratchet {rcSnd = Just sr@SndRatchet {rcCKs, rcHKs}, rcDHRs, rcKEM, -- PQ encryption can be enabled or disabled rcEnableKEM' = fromMaybe rcEnableKEM pqEnc_ -- support for PQ encryption (and therefore large headers/small envelopes) can only be enabled, it cannot be disabled - rcSupportKEM' = rcSupportKEM `pqSupportOrEnc` rcEnableKEM' + rcSupportKEM' = pqEnableSupport v rcSupportKEM rcEnableKEM' + -- This sets max version to support PQ encryption. + -- Current version upgrade happens when peer decrypts the message. + -- TODO v5.7 remove version upgrade here, as it's already upgraded above + maxSupported' = max currentE2EEncryptVersion $ if pqEnc_ == Just PQEncOn then pqRatchetE2EEncryptVersion else v + rcVersion' = rcVersion {maxSupported = maxSupported'} -- enc_header = HENCRYPT(state.HKs, header) - (ehAuthTag, ehBody) <- encryptAEAD rcHKs ehIV (paddedHeaderLen rcSupportKEM') rcAD (msgHeader v) + (ehAuthTag, ehBody) <- encryptAEAD rcHKs ehIV (paddedHeaderLen v rcSupportKEM') rcAD (msgHeader v maxSupported') -- return enc_header, ENCRYPT(mk, plaintext, CONCAT(AD, enc_header)) let emHeader = smpEncode EncMessageHeader {ehVersion = v, ehBody, ehAuthTag, ehIV} (emAuthTag, emBody) <- encryptAEAD mk iv paddedMsgLen (rcAD <> emHeader) msg @@ -883,16 +882,10 @@ rcEncrypt rc@Ratchet {rcSnd = Just sr@SndRatchet {rcCKs, rcHKs}, rcDHRs, rcKEM, rcNs = rcNs + 1, rcSupportKEM = rcSupportKEM', rcEnableKEM = rcEnableKEM', - rcVersion = rcVersion {maxSupported = max v currentE2EEncryptVersion} + rcVersion = rcVersion', + rcKEM = if pqEnc_ == Just PQEncOff then (\rck -> rck {rcKEMs = Nothing}) <$> rcKEM else rcKEM } - rc'' = case pqEnc_ of - Nothing -> rc' - -- This sets max version to support PQ encryption. - -- Current version upgrade happens when peer decrypts the message. - -- TODO v5.7 remove version upgrade here, as it's already upgraded above - Just PQEncOn -> rc' {rcVersion = rcVersion {maxSupported = max v pqRatchetE2EEncryptVersion}} - Just PQEncOff -> rc' {rcKEM = (\rck -> rck {rcKEMs = Nothing}) <$> rcKEM} - pure (msg', rc'') + pure (msg', rc') where -- header = HEADER_PQ2( -- dh = state.DHRs.public, @@ -901,11 +894,11 @@ rcEncrypt rc@Ratchet {rcSnd = Just sr@SndRatchet {rcCKs, rcHKs}, rcDHRs, rcKEM, -- pn = state.PN, -- n = state.Ns -- ) - msgHeader v = + msgHeader v maxSupported' = encodeMsgHeader - v + v MsgHeader - { msgMaxVersion = maxSupported rcVersion, + { msgMaxVersion = maxSupported', msgDHRs = publicKey rcDHRs, msgKEM = msgKEMParams <$> rcKEM, msgPN = rcPN, @@ -982,7 +975,7 @@ rcDecrypt g rc@Ratchet {rcRcv, rcAD = Str rcAD, rcVersion} rcMKSkipped msg' = do smkDiff :: SkippedMsgKeys -> SkippedMsgDiff smkDiff smks = if M.null smks then SMDNoChange else SMDAdd smks ratchetStep :: Ratchet a -> MsgHeader a -> ExceptT CryptoError IO (Ratchet a) - ratchetStep rc'@Ratchet {rcDHRs, rcRK, rcNHKs, rcNHKr, rcSupportKEM} MsgHeader {msgDHRs, msgKEM} = do + ratchetStep rc'@Ratchet {rcDHRs, rcRK, rcNHKs, rcNHKr, rcSupportKEM, rcVersion = rv} MsgHeader {msgDHRs, msgKEM} = do (kemSS, kemSS', rcKEM') <- pqRatchetStep rc' msgKEM -- state.DHRs = GENERATE_DH() (_, rcDHRs') <- atomically $ generateKeyPair @a g @@ -997,7 +990,7 @@ rcDecrypt g rc@Ratchet {rcRcv, rcAD = Str rcAD, rcVersion} rcMKSkipped msg' = do rc' { rcDHRs = rcDHRs', rcKEM = rcKEM', - rcSupportKEM = rcSupportKEM `pqSupportOrEnc` rcEnableKEM', + rcSupportKEM = pqEnableSupport (current rv) rcSupportKEM rcEnableKEM', rcEnableKEM = rcEnableKEM', rcSndKEM = PQEncryption sndKEM, rcRcvKEM = PQEncryption rcvKEM, diff --git a/tests/AgentTests/DoubleRatchetTests.hs b/tests/AgentTests/DoubleRatchetTests.hs index 5b73be893..9490fe2e2 100644 --- a/tests/AgentTests/DoubleRatchetTests.hs +++ b/tests/AgentTests/DoubleRatchetTests.hs @@ -90,15 +90,12 @@ paddedMsgLen :: Int paddedMsgLen = 100 fullMsgLen :: Ratchet a -> Int -fullMsgLen Ratchet {rcSupportKEM} = headerLenLength + fullHeaderLen rcSupportKEM + C.authTagSize + paddedMsgLen +fullMsgLen Ratchet {rcSupportKEM, rcVersion} = headerLenLength + fullHeaderLen v rcSupportKEM + C.authTagSize + paddedMsgLen where - -- v = current rcVersion + v = current rcVersion headerLenLength = case rcSupportKEM of - PQSupportOn -> 3 -- two bytes are added because of two Large used in new encoding - PQSupportOff -> 1 - -- TODO PQ below should work too - -- | v >= pqRatchetE2EEncryptVersion = 3 - -- | otherwise = 1 + PQSupportOn | v >= pqRatchetE2EEncryptVersion -> 3 -- two bytes are added because of two Large used in new encoding + _ -> 1 testMessageHeader :: forall a. AlgorithmI a => VersionE2E -> C.SAlgorithm a -> Expectation testMessageHeader v _ = do @@ -308,10 +305,10 @@ testEnableKEM alice bob _ _ _ = do (alice, "accepting KEM") \#>! bob (alice, "KEM not enabled yet here too") \#>! bob (bob, "KEM is still not enabled") \#>! alice - (alice, "KEM still not enabled 2") \#>! bob - (bob, "now KEM is enabled") !#> alice - (alice, "now KEM is enabled for both sides") !#> bob - (bob, "Still enabled for both sides") !#> alice + (alice, "now KEM is enabled") !#>! bob + (bob, "now KEM is enabled for both sides") !#> alice + (alice, "still enabled for both sides") !#> bob + (bob, "still enabled for both sides 2") !#> alice (alice, "disabling KEM") !#>\ bob (bob, "KEM not disabled yet") !#> alice (alice, "KEM disabled") \#> bob @@ -326,10 +323,10 @@ testEnableKEMStrict alice bob _ _ _ = do (alice, "accepting KEM") \#>! bob (alice, "KEM not enabled yet here too") \#>! bob (bob, "KEM is still not enabled") \#>! alice - (alice, "KEM still not enabled 2") \#>! bob - (bob, "now KEM is enabled") !#>! alice - (alice, "now KEM is enabled for both sides") !#>! bob - (bob, "Still enabled for both sides") !#>! alice + (alice, "now KEM is enabled") !#>! bob + (bob, "now KEM is enabled for both sides") !#>! alice + (alice, "still enabled for both sides") !#>! bob + (bob, "still enabled for both sides 2") !#>! alice (alice, "disabling KEM") !#>\ bob (bob, "KEM not disabled yet") !#>! alice (alice, "KEM disabled") \#>\ bob diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 424a681e2..95f3a3db2 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -164,11 +164,14 @@ pattern Msg :: MsgBody -> ACommand 'Agent e pattern Msg msgBody <- MSG MsgMeta {integrity = MsgOk, pqEncryption = PQEncOn} _ msgBody pattern Msg' :: AgentMsgId -> PQEncryption -> MsgBody -> ACommand 'Agent e -pattern Msg' aMsgId pqEncryption msgBody <- MSG MsgMeta {integrity = MsgOk, recipient = (aMsgId, _), pqEncryption} _ msgBody +pattern Msg' aMsgId pq msgBody <- MSG MsgMeta {integrity = MsgOk, recipient = (aMsgId, _), pqEncryption = pq} _ msgBody pattern MsgErr :: AgentMsgId -> MsgErrorType -> MsgBody -> ACommand 'Agent 'AEConn pattern MsgErr msgId err msgBody <- MSG MsgMeta {recipient = (msgId, _), integrity = MsgError err} _ msgBody +pattern MsgErr' :: AgentMsgId -> MsgErrorType -> PQEncryption -> MsgBody -> ACommand 'Agent 'AEConn +pattern MsgErr' msgId err pq msgBody <- MSG MsgMeta {recipient = (msgId, _), integrity = MsgError err, pqEncryption = pq} _ msgBody + pattern Rcvd :: AgentMsgId -> ACommand 'Agent 'AEConn pattern Rcvd agentMsgId <- RCVD MsgMeta {integrity = MsgOk} [MsgReceipt {agentMsgId, msgRcptStatus = MROk}] @@ -256,6 +259,8 @@ functionalAPITests t = do withSmpServer t testServerMultipleIdentities it "should connect with two peers" $ withSmpServer t testAgentClient3 + it "should establish connection without PQ encryption and enable it" $ + withSmpServer t testEnablePQEncryption describe "Establishing duplex connection v2, different Ratchet versions" $ testRatchetMatrix2 t runAgentClientTest describe "Establish duplex connection via contact address" $ @@ -516,6 +521,76 @@ runAgentClientTest pqSupport alice@AgentClient {} bob baseId = where msgId = subtract baseId . fst +testEnablePQEncryption :: HasCallStack => IO () +testEnablePQEncryption = do + ca <- getSMPAgentClient' 1 agentCfg initAgentServers testDB + cb <- getSMPAgentClient' 2 agentCfg initAgentServers testDB2 + g <- C.newRandom + runRight_ $ do + (aId, bId) <- makeConnection_ PQSupportOff ca cb + let a = (ca, aId) + b = (cb, bId) + (a, 4, "msg 1") \#>\ b + (b, 5, "msg 2") \#>\ a + -- 45 bytes is used by agent message envelope inside double ratchet message envelope + let largeMsg g' pqEnc = atomically $ C.randomBytes (e2eEncUserMsgLength pqdrSMPAgentVersion pqEnc - 45) g' + lrg <- largeMsg g PQSupportOff + (a, 6, lrg) \#>\ b + (b, 7, lrg) \#>\ a + -- enabling PQ encryption + (a, 8, lrg) \#>! b + (b, 9, lrg) \#>! a + -- switched to smaller envelopes (before reporting PQ encryption enabled) + sml <- largeMsg g PQSupportOn + -- fail because of message size + Left (A.CMD LARGE) <- tryError $ A.sendMessage ca bId PQEncOn SMP.noMsgFlags lrg + (11, PQEncOff) <- A.sendMessage ca bId PQEncOn SMP.noMsgFlags sml + get ca =##> \case ("", connId, SENT 11) -> connId == bId; _ -> False + get cb =##> \case ("", connId, MsgErr' 10 MsgSkipped {} PQEncOff msg') -> connId == aId && msg' == sml; _ -> False + ackMessage cb aId 10 Nothing + -- -- fail in reply to sync IDss + Left (A.CMD LARGE) <- tryError $ A.sendMessage cb aId PQEncOn SMP.noMsgFlags lrg + (12, PQEncOn) <- A.sendMessage cb aId PQEncOn SMP.noMsgFlags sml + get cb =##> \case ("", connId, SENT 12) -> connId == aId; _ -> False + get ca =##> \case ("", connId, MsgErr' 12 MsgSkipped {} PQEncOn msg') -> connId == bId && msg' == sml; _ -> False + ackMessage ca bId 12 Nothing + -- PQ encryption now enabled + (a, 13, sml) !#>! b + (b, 14, sml) !#>! a + -- disabling PQ encryption + (a, 15, sml) !#>\ b + (b, 16, sml) !#>\ a + (a, 17, sml) \#>\ b + (b, 18, sml) \#>\ a + -- enabling PQ encryption again + (a, 19, sml) \#>! b + (b, 20, sml) \#>! a + (a, 21, sml) \#>! b + (b, 22, sml) !#>! a + (a, 23, sml) !#>! b + -- disabling PQ encryption again + (b, 24, sml) !#>\ a + (a, 25, sml) !#>\ b + (b, 26, sml) \#>\ a + (a, 27, sml) \#>\ b + -- PQ encryption is now disabled, but support remained enabled, so we still cannot send larger messages + Left (A.CMD LARGE) <- tryError $ A.sendMessage ca bId PQEncOff SMP.noMsgFlags (sml <> "123456") + Left (A.CMD LARGE) <- tryError $ A.sendMessage cb aId PQEncOff SMP.noMsgFlags (sml <> "123456") + pure () + where + (\#>\) = PQEncOff `sndRcv` PQEncOff + (\#>!) = PQEncOff `sndRcv` PQEncOn + (!#>!) = PQEncOn `sndRcv` PQEncOn + (!#>\) = PQEncOn `sndRcv` PQEncOff + +sndRcv :: PQEncryption -> PQEncryption -> ((AgentClient, ConnId), AgentMsgId, MsgBody) -> (AgentClient, ConnId) -> ExceptT AgentErrorType IO () +sndRcv pqEnc pqEnc' ((c1, id1), mId, msg) (c2, id2) = do + r <- A.sendMessage c1 id2 pqEnc' SMP.noMsgFlags msg + liftIO $ r `shouldBe` (mId, pqEnc) + get c1 =##> \case ("", connId, SENT mId') -> connId == id2 && mId' == mId; _ -> False + get c2 =##> \case ("", connId, Msg' mId' pq msg') -> connId == id1 && mId' == mId && msg' == msg && pq == pqEnc; _ -> False + ackMessage c2 id1 mId Nothing + testAgentClient3 :: HasCallStack => IO () testAgentClient3 = do a <- getSMPAgentClient' 1 agentCfg initAgentServers testDB From 03d73f442f85537ec3489ffcd339dabc8a3ffb07 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin Date: Fri, 8 Mar 2024 10:01:58 +0000 Subject: [PATCH 17/30] JSON instance for VersionRange (#1036) --- src/Simplex/Messaging/Version.hs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/Simplex/Messaging/Version.hs b/src/Simplex/Messaging/Version.hs index 78d290687..21c7c3ac3 100644 --- a/src/Simplex/Messaging/Version.hs +++ b/src/Simplex/Messaging/Version.hs @@ -4,6 +4,7 @@ {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeFamilies #-} module Simplex.Messaging.Version @@ -26,9 +27,11 @@ module Simplex.Messaging.Version where import Control.Applicative (optional) +import qualified Data.Aeson.TH as J import qualified Data.Attoparsec.ByteString.Char8 as A import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String +import Simplex.Messaging.Parsers (defaultJSON) import Simplex.Messaging.Version.Internal (Version (..)) pattern VersionRange :: Version v -> Version v -> VersionRange v @@ -120,3 +123,5 @@ compatibleVersion x vr = mkCompatibleIf :: a -> Bool -> Maybe (Compatible a) x `mkCompatibleIf` cond = if cond then Just $ Compatible_ x else Nothing + +$(J.deriveJSON defaultJSON ''VersionRange) From b4e55146b8a910add95d0756734ca5ba3f0850fc Mon Sep 17 00:00:00 2001 From: Alexander Bondarenko <486682+dpwiz@users.noreply.github.com> Date: Fri, 8 Mar 2024 12:58:45 +0200 Subject: [PATCH 18/30] core: fix VersionRange JSON instances (#1037) --- src/Simplex/Messaging/Version.hs | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/src/Simplex/Messaging/Version.hs b/src/Simplex/Messaging/Version.hs index 21c7c3ac3..25f7368d1 100644 --- a/src/Simplex/Messaging/Version.hs +++ b/src/Simplex/Messaging/Version.hs @@ -2,9 +2,9 @@ {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeFamilies #-} module Simplex.Messaging.Version @@ -27,11 +27,13 @@ module Simplex.Messaging.Version where import Control.Applicative (optional) -import qualified Data.Aeson.TH as J +import qualified Data.Aeson as J +import qualified Data.Aeson.Encoding as JE +import Data.Aeson.Types ((.:), (.=)) +import qualified Data.Aeson.Types as JT import qualified Data.Attoparsec.ByteString.Char8 as A import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Parsers (defaultJSON) import Simplex.Messaging.Version.Internal (Version (..)) pattern VersionRange :: Version v -> Version v -> VersionRange v @@ -45,6 +47,18 @@ data VersionRange v = VRange } deriving (Eq, Show) +instance J.FromJSON (VersionRange v) where + parseJSON (J.Object v) = do + minVersion <- v .: "minVersion" + maxVersion <- v .: "maxVersion" + pure VRange {minVersion, maxVersion} + parseJSON invalid = + JT.prependFailure "bad VersionRange, " (JT.typeMismatch "Object" invalid) + +instance J.ToJSON (VersionRange v) where + toEncoding VRange {minVersion, maxVersion} = JE.pairs $ ("minVersion" .= minVersion) <> ("maxVersion" .= maxVersion) + toJSON VRange {minVersion, maxVersion} = J.object ["minVersion" .= minVersion, "maxVersion" .= maxVersion] + class VersionScope v -- | construct valid version range, to be used in constants @@ -123,5 +137,3 @@ compatibleVersion x vr = mkCompatibleIf :: a -> Bool -> Maybe (Compatible a) x `mkCompatibleIf` cond = if cond then Just $ Compatible_ x else Nothing - -$(J.deriveJSON defaultJSON ''VersionRange) From 8cdd49b91256aee56427f8b8e351cf415045e9c7 Mon Sep 17 00:00:00 2001 From: Alexander Bondarenko <486682+dpwiz@users.noreply.github.com> Date: Fri, 8 Mar 2024 15:43:33 +0200 Subject: [PATCH 19/30] core: restore Eq instances (#1038) * core: restore Eq instances * remove duplicates from tests --- src/Simplex/Messaging/Agent/Protocol.hs | 28 +++++++++++++++++++++- src/Simplex/Messaging/Crypto/Ratchet.hs | 4 +++- src/Simplex/Messaging/Protocol.hs | 5 +++- tests/AgentTests/ConnectionRequestTests.hs | 13 ---------- tests/AgentTests/DoubleRatchetTests.hs | 2 -- tests/AgentTests/FunctionalAPITests.hs | 12 ++-------- 6 files changed, 36 insertions(+), 28 deletions(-) diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index 0c6bfd1ba..660951bfc 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -171,7 +171,7 @@ import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L import Data.Map (Map) import qualified Data.Map as M -import Data.Maybe (fromMaybe) +import Data.Maybe (fromMaybe, isJust) import Data.Text (Text) import qualified Data.Text as T import Data.Text.Encoding (decodeLatin1, encodeUtf8) @@ -365,6 +365,11 @@ deriving instance Show ACmd data APartyCmd p = forall e. AEntityI e => APC (SAEntity e) (ACommand p e) +instance Eq (APartyCmd p) where + APC e cmd == APC e' cmd' = case testEquality e e' of + Just Refl -> cmd == cmd' + Nothing -> False + deriving instance Show (APartyCmd p) type ConnInfo = ByteString @@ -417,12 +422,19 @@ data ACommand (p :: AParty) (e :: AEntity) where SFDONE :: ValidFileDescription 'FSender -> [ValidFileDescription 'FRecipient] -> ACommand Agent AESndFile SFERR :: AgentErrorType -> ACommand Agent AESndFile +deriving instance Eq (ACommand p e) + deriving instance Show (ACommand p e) data ACmdTag = forall p e. (APartyI p, AEntityI e) => ACmdTag (SAParty p) (SAEntity e) (ACommandTag p e) data APartyCmdTag p = forall e. AEntityI e => APCT (SAEntity e) (ACommandTag p e) +instance Eq (APartyCmdTag p) where + APCT e cmd == APCT e' cmd' = case testEquality e e' of + Just Refl -> cmd == cmd' + Nothing -> False + deriving instance Show (APartyCmdTag p) data ACommandTag (p :: AParty) (e :: AEntity) where @@ -472,6 +484,8 @@ data ACommandTag (p :: AParty) (e :: AEntity) where SFDONE_ :: ACommandTag Agent AESndFile SFERR_ :: ACommandTag Agent AESndFile +deriving instance Eq (ACommandTag p e) + deriving instance Show (ACommandTag p e) aPartyCmdTag :: APartyCmd p -> APartyCmdTag p @@ -755,6 +769,8 @@ data SConnectionMode (m :: ConnectionMode) where SCMInvitation :: SConnectionMode CMInvitation SCMContact :: SConnectionMode CMContact +deriving instance Eq (SConnectionMode m) + deriving instance Show (SConnectionMode m) instance TestEquality SConnectionMode where @@ -764,6 +780,9 @@ instance TestEquality SConnectionMode where data AConnectionMode = forall m. ConnectionModeI m => ACM (SConnectionMode m) +instance Eq AConnectionMode where + ACM m == ACM m' = isJust $ testEquality m m' + cmInvitation :: AConnectionMode cmInvitation = ACM SCMInvitation @@ -1356,10 +1375,17 @@ data ConnectionRequestUri (m :: ConnectionMode) where -- they are passed in AgentInvitation message CRContactUri :: ConnReqUriData -> ConnectionRequestUri CMContact +deriving instance Eq (ConnectionRequestUri m) + deriving instance Show (ConnectionRequestUri m) data AConnectionRequestUri = forall m. ConnectionModeI m => ACR (SConnectionMode m) (ConnectionRequestUri m) +instance Eq AConnectionRequestUri where + ACR m cr == ACR m' cr' = case testEquality m m' of + Just Refl -> cr == cr' + _ -> False + deriving instance Show AConnectionRequestUri data ConnReqUriData = ConnReqUriData diff --git a/src/Simplex/Messaging/Crypto/Ratchet.hs b/src/Simplex/Messaging/Crypto/Ratchet.hs index 534fa7c07..4b15e5766 100644 --- a/src/Simplex/Messaging/Crypto/Ratchet.hs +++ b/src/Simplex/Messaging/Crypto/Ratchet.hs @@ -190,6 +190,8 @@ data RKEMParams (s :: RatchetKEMState) where RKParamsProposed :: KEMPublicKey -> RKEMParams 'RKSProposed RKParamsAccepted :: KEMCiphertext -> KEMPublicKey -> RKEMParams 'RKSAccepted +deriving instance Eq (RKEMParams s) + deriving instance Show (RKEMParams s) data ARKEMParams = forall s. RatchetKEMStateI s => ARKP (SRatchetKEMState s) (RKEMParams s) @@ -275,7 +277,7 @@ type RcvE2ERatchetParamsUri a = E2ERatchetParamsUri 'RKSProposed a data E2ERatchetParamsUri (s :: RatchetKEMState) (a :: Algorithm) = E2ERatchetParamsUri VersionRangeE2E (PublicKey a) (PublicKey a) (Maybe (RKEMParams s)) - deriving (Show) + deriving (Eq, Show) data AE2ERatchetParamsUri (a :: Algorithm) = forall s. diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 3aef08622..2c593fc6f 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -183,7 +183,7 @@ import Data.Functor (($>)) import Data.Kind import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L -import Data.Maybe (isNothing) +import Data.Maybe (isJust, isNothing) import Data.String import Data.Time.Clock.System (SystemTime (..)) import Data.Type.Equality @@ -766,6 +766,9 @@ deriving instance Show (SProtocolType p) data AProtocolType = forall p. ProtocolTypeI p => AProtocolType (SProtocolType p) +instance Eq AProtocolType where + AProtocolType p == AProtocolType p' = isJust $ testEquality p p' + deriving instance Show AProtocolType instance TestEquality SProtocolType where diff --git a/tests/AgentTests/ConnectionRequestTests.hs b/tests/AgentTests/ConnectionRequestTests.hs index f5e689e96..7ab234887 100644 --- a/tests/AgentTests/ConnectionRequestTests.hs +++ b/tests/AgentTests/ConnectionRequestTests.hs @@ -3,7 +3,6 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE PatternSynonyms #-} {-# OPTIONS_GHC -Wno-orphans #-} {-# OPTIONS_GHC -fno-warn-ambiguous-fields #-} @@ -11,7 +10,6 @@ module AgentTests.ConnectionRequestTests where import Data.ByteString (ByteString) -import Data.Type.Equality import Network.HTTP.Types (urlEncode) import Simplex.Messaging.Agent.Protocol import qualified Simplex.Messaging.Crypto as C @@ -22,17 +20,6 @@ import Simplex.Messaging.ServiceScheme (ServiceScheme (..)) import Simplex.Messaging.Version import Test.Hspec -deriving instance Eq (ConnectionRequestUri m) - -deriving instance Eq (E2ERatchetParamsUri s a) - -deriving instance Eq (RKEMParams s) - -instance Eq AConnectionRequestUri where - ACR m cr == ACR m' cr' = case testEquality m m' of - Just Refl -> cr == cr' - _ -> False - uri :: String uri = "smp.simplex.im" diff --git a/tests/AgentTests/DoubleRatchetTests.hs b/tests/AgentTests/DoubleRatchetTests.hs index 9490fe2e2..c64f5fb96 100644 --- a/tests/AgentTests/DoubleRatchetTests.hs +++ b/tests/AgentTests/DoubleRatchetTests.hs @@ -160,8 +160,6 @@ deriving instance Eq RatchetInitParams deriving instance Eq RatchetKey -deriving instance Eq (RKEMParams s) - instance Eq ARKEMParams where (ARKP s ps) == (ARKP s' ps') = case testEquality s s' of Just Refl -> ps == ps' diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 95f3a3db2..9c0355852 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -57,7 +57,7 @@ import Data.Int (Int64) import Data.List (nub) import Data.List.NonEmpty (NonEmpty) import qualified Data.Map as M -import Data.Maybe (isJust, isNothing) +import Data.Maybe (isNothing) import qualified Data.Set as S import Data.Time.Clock (diffUTCTime, getCurrentTime) import Data.Time.Clock.System (SystemTime (..), getSystemTime) @@ -80,7 +80,7 @@ import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), PQEncryption (..), PQ import qualified Simplex.Messaging.Crypto.Ratchet as CR import Simplex.Messaging.Encoding.String import Simplex.Messaging.Notifications.Transport (NTFVersion, pattern VersionNTF, authBatchCmdsNTFVersion) -import Simplex.Messaging.Protocol (AProtocolType (..), BasicAuth, ErrorType (..), MsgBody, ProtocolServer (..), SubscriptionMode (..), supportedSMPClientVRange) +import Simplex.Messaging.Protocol (BasicAuth, ErrorType (..), MsgBody, ProtocolServer (..), SubscriptionMode (..), supportedSMPClientVRange) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Server.Env.STM (ServerConfig (..)) import Simplex.Messaging.Server.Expiration @@ -96,14 +96,6 @@ import XFTPClient (testXFTPServer) type AEntityTransmission e = (ACorrId, ConnId, ACommand 'Agent e) -deriving instance Eq (ACommand p e) - -instance Eq AConnectionMode where - ACM m == ACM m' = isJust $ testEquality m m' - -instance Eq AProtocolType where - AProtocolType p == AProtocolType p' = isJust $ testEquality p p' - -- deriving instance Eq (ValidFileDescription p) (##>) :: (HasCallStack, MonadUnliftIO m) => m (AEntityTransmission e) -> AEntityTransmission e -> m () From dab55e0a9b03577f643af7922afa061801d82ed5 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin Date: Fri, 8 Mar 2024 23:13:21 +0000 Subject: [PATCH 20/30] pqdr: return agent version from connection request version & PQ support check api --- src/Simplex/Messaging/Agent.hs | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index d0cb7d2fa..df79181e9 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -221,9 +221,6 @@ createUser c = withAgentEnv c .: createUser' c deleteUser :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> m () deleteUser c = withAgentEnv c .: deleteUser' c -connRequestPQSupport :: (MonadUnliftIO m) => AgentClient -> PQSupport -> ConnectionRequestUri c -> m (Maybe PQSupport) -connRequestPQSupport c = withAgentEnv c .: connRequestPQSupport' c - -- | Create SMP agent connection (NEW command) asynchronously, synchronous response is new connection id createConnectionAsync :: forall m c. (AgentErrorMonad m, ConnectionModeI c) => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> CR.InitialKeys -> SubscriptionMode -> m ConnId createConnectionAsync c userId aCorrId enableNtfs = withAgentEnv c .:. newConnAsync c userId aCorrId enableNtfs @@ -710,14 +707,14 @@ startJoinInvitation userId connId enableNtfs cReqUri pqSup = pure (aVersion, cData, q, rc, e2eSndParams) Nothing -> throwError $ AGENT A_VERSION -connRequestPQSupport' :: AgentMonad' m => AgentClient -> PQSupport -> ConnectionRequestUri c -> m (Maybe PQSupport) -connRequestPQSupport' _c pqSup cReq = case cReq of +connRequestPQSupport :: MonadUnliftIO m => AgentClient -> PQSupport -> ConnectionRequestUri c -> m (Maybe (VersionSMPA, PQSupport)) +connRequestPQSupport c pqSup cReq = withAgentEnv c $ case cReq of CRInvitationUri {} -> invPQSupported <$$> compatibleInvitationUri cReq pqSup where - invPQSupported (_, Compatible (CR.E2ERatchetParams e2eV _ _ _), Compatible agentV) = pqSup `CR.pqSupportAnd` versionPQSupport_ agentV (Just e2eV) + invPQSupported (_, Compatible (CR.E2ERatchetParams e2eV _ _ _), Compatible agentV) = (agentV, pqSup `CR.pqSupportAnd` versionPQSupport_ agentV (Just e2eV)) CRContactUri {} -> ctPQSupported <$$> compatibleContactUri cReq pqSup where - ctPQSupported (_, Compatible agentV) = pqSup `CR.pqSupportAnd` versionPQSupport_ agentV Nothing + ctPQSupported (_, Compatible agentV) = (agentV, pqSup `CR.pqSupportAnd` versionPQSupport_ agentV Nothing) compatibleInvitationUri :: AgentMonad' m => ConnectionRequestUri 'CMInvitation -> PQSupport -> m (Maybe (Compatible SMPQueueInfo, Compatible (CR.RcvE2ERatchetParams 'C.X448), Compatible VersionSMPA)) compatibleInvitationUri (CRInvitationUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)} e2eRcvParamsUri) pqSup = do @@ -737,8 +734,7 @@ compatibleContactUri (CRContactUri ConnReqUriData {crAgentVRange, crSmpQueues = <*> (crAgentVRange `compatibleVersion` smpAgentVRange pqSup) versionPQSupport_ :: VersionSMPA -> Maybe CR.VersionE2E -> PQSupport -versionPQSupport_ agentV e2eV_ = - PQSupport $ pqdrSMPAgentVersion <= agentV && maybe True (CR.pqRatchetE2EEncryptVersion <=) e2eV_ +versionPQSupport_ agentV e2eV_ = PQSupport $ agentV >= pqdrSMPAgentVersion && maybe True (>= CR.pqRatchetE2EEncryptVersion) e2eV_ joinConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> SMPServerWithAuth -> m ConnId joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqSup subMode srv = From 054b6edb14c323d4e83c6aea6c22deaffd681147 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin Date: Sun, 10 Mar 2024 11:27:31 +0000 Subject: [PATCH 21/30] pqdr: clean up (#1039) --- src/Simplex/Messaging/Agent.hs | 9 +++------ src/Simplex/Messaging/Crypto/Ratchet.hs | 22 ++++++++-------------- tests/AgentTests/DoubleRatchetTests.hs | 2 +- 3 files changed, 12 insertions(+), 21 deletions(-) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index df79181e9..08678140e 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -700,7 +700,7 @@ startJoinInvitation userId connId enableNtfs cReqUri pqSup = (_, rcDHRs) <- atomically $ C.generateKeyPair g rcParams <- liftEitherWith cryptoError $ CR.pqX3dhSnd pk1 pk2 pKem e2eRcvParams maxSupported <- asks $ maxVersion . ($ pqSup) . e2eEncryptVRange . config - let rcVs = CR.RVersions {current = v, maxSupported} + let rcVs = CR.RatchetVersions {current = v, maxSupported} rc = CR.initSndRatchet rcVs rcDHRr rcDHRs rcParams q <- newSndQueue userId "" qInfo let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk, pqSupport} @@ -2183,8 +2183,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, unless (e2eVersion `isCompatible` e2eVRange) (throwError $ AGENT A_VERSION) (pk1, rcDHRs, pKem) <- withStore c (`getRatchetX3dhKeys` connId) rcParams <- liftError cryptoError $ CR.pqX3dhRcv pk1 rcDHRs pKem e2eSndParams - -- TODO PQ combine isCompatible check and construction in one call - let rcVs = CR.RVersions {current = e2eVersion, maxSupported = maxVersion e2eVRange} + let rcVs = CR.RatchetVersions {current = e2eVersion, maxSupported = maxVersion e2eVRange} pqSupport' = pqSupport `CR.pqSupportAnd` versionPQSupport_ agentVersion (Just e2eVersion) rc = CR.initRcvRatchet rcVs rcDHRs rcParams pqSupport' g <- asks random @@ -2379,7 +2378,6 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, DuplexConnection {} -> action conn' _ -> qError $ name <> ": message must be sent to duplex connection" - -- TODO PQ make sure pqEncryption is set correctly here newRatchetKey :: CR.RcvE2ERatchetParams 'C.X448 -> Connection 'CDuplex -> m () newRatchetKey e2eOtherPartyParams@(CR.E2ERatchetParams e2eVersion k1Rcv k2Rcv _) conn'@(DuplexConnection cData'@ConnData {lastExternalSndId, pqSupport} _ sqs) = unlessM ratchetExists $ do @@ -2387,8 +2385,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, let connE2EVRange = e2eEncryptVRange pqSupport unless (e2eVersion `isCompatible` connE2EVRange) (throwError $ AGENT A_VERSION) keys <- getSendRatchetKeys - -- TODO PQ combine with `isCompatible` check above - let rcVs = CR.RVersions {current = e2eVersion, maxSupported = maxVersion connE2EVRange} + let rcVs = CR.RatchetVersions {current = e2eVersion, maxSupported = maxVersion connE2EVRange} initRatchet rcVs keys notifyAgreed where diff --git a/src/Simplex/Messaging/Crypto/Ratchet.hs b/src/Simplex/Messaging/Crypto/Ratchet.hs index 4b15e5766..a3251c26c 100644 --- a/src/Simplex/Messaging/Crypto/Ratchet.hs +++ b/src/Simplex/Messaging/Crypto/Ratchet.hs @@ -488,7 +488,7 @@ data Ratchet a = Ratchet } deriving (Show) -data RatchetVersions = RVersions +data RatchetVersions = RatchetVersions { current :: VersionE2E, maxSupported :: VersionE2E } @@ -496,8 +496,8 @@ data RatchetVersions = RVersions instance ToJSON RatchetVersions where -- TODO v5.7 or v5.8 change to the default record encoding - toJSON (RVersions v1 v2) = toJSON (v1, v2) - toEncoding (RVersions v1 v2) = toEncoding (v1, v2) + toJSON (RatchetVersions v1 v2) = toJSON (v1, v2) + toEncoding (RatchetVersions v1 v2) = toEncoding (v1, v2) instance FromJSON RatchetVersions where -- TODO v6.0 replace with the default record parser @@ -509,7 +509,7 @@ instance FromJSON RatchetVersions where toRV (v1, v2) = maybe (fail "bad version range") (pure . ratchetVersions) $ safeVersionRange v1 v2 ratchetVersions :: VersionRangeE2E -> RatchetVersions -ratchetVersions (VersionRange v1 v2) = RVersions {current = v1, maxSupported = v2} +ratchetVersions (VersionRange v1 v2) = RatchetVersions {current = v1, maxSupported = v2} data SndRatchet a = SndRatchet { rcDHRr :: PublicKey a, @@ -671,7 +671,6 @@ data MsgHeader a = MsgHeader -- to allow extension without increasing the size, the actual header length is: -- 69 = 2 (original size) + 2 + 1+56 (Curve448) + 4 + 4 --- TODO PQ this must be version-dependent -- TODO this is the exact size, some reserve should be added paddedHeaderLen :: VersionE2E -> PQSupport -> Int paddedHeaderLen v = \case @@ -679,7 +678,7 @@ paddedHeaderLen v = \case _ -> 88 -- only used in tests to validate correct padding --- (2 bytes - version size, 1 byte - header size, not to have it fixed or version-dependent) +-- (2 bytes - version size, 1 byte - header size) fullHeaderLen :: VersionE2E -> PQSupport -> Int fullHeaderLen v pq = 2 + 1 + paddedHeaderLen v pq + authTagSize + ivSize @AES256 @@ -862,7 +861,8 @@ rcEncrypt rc@Ratchet {rcSnd = Just sr@SndRatchet {rcCKs, rcHKs}, rcDHRs, rcKEM, rcSupportKEM' = pqEnableSupport v rcSupportKEM rcEnableKEM' -- This sets max version to support PQ encryption. -- Current version upgrade happens when peer decrypts the message. - -- TODO v5.7 remove version upgrade here, as it's already upgraded above + -- TODO note that maxSupported will not downgrade here below current (v). + -- TODO PQ currentE2EEncryptVersion should be passed via config maxSupported' = max currentE2EEncryptVersion $ if pqEnc_ == Just PQEncOn then pqRatchetE2EEncryptVersion else v rcVersion' = rcVersion {maxSupported = maxSupported'} -- enc_header = HENCRYPT(state.HKs, header) @@ -871,12 +871,6 @@ rcEncrypt rc@Ratchet {rcSnd = Just sr@SndRatchet {rcCKs, rcHKs}, rcDHRs, rcKEM, let emHeader = smpEncode EncMessageHeader {ehVersion = v, ehBody, ehAuthTag, ehIV} (emAuthTag, emBody) <- encryptAEAD mk iv paddedMsgLen (rcAD <> emHeader) msg let msg' = encodeEncRatchetMessage v EncRatchetMessage {emHeader, emBody, emAuthTag} - -- TODO v5.8 remove comments below - -- Note that maxSupported will not downgrade here below current. - -- TODO v5.7 remove comments below - -- TODO PQ It will downgrade when decrypting the message when the current version downgrades to remove support for PQ encryption. - -- TODO v5.8 possibly, replace `max v currentE2EEncryptVersion` with `v` (to allow downgrade when app downgraded)? - -- -- state.Ns += 1 rc' = rc @@ -973,7 +967,7 @@ rcDecrypt g rc@Ratchet {rcRcv, rcAD = Str rcAD, rcVersion} rcMKSkipped msg' = do | msgMaxVersion > current = rc {rcVersion = rcVersion {current = max current $ min msgMaxVersion maxSupported}} | otherwise = rc where - RVersions {current, maxSupported} = rcVersion + RatchetVersions {current, maxSupported} = rcVersion smkDiff :: SkippedMsgKeys -> SkippedMsgDiff smkDiff smks = if M.null smks then SMDNoChange else SMDAdd smks ratchetStep :: Ratchet a -> MsgHeader a -> ExceptT CryptoError IO (Ratchet a) diff --git a/tests/AgentTests/DoubleRatchetTests.hs b/tests/AgentTests/DoubleRatchetTests.hs index c64f5fb96..e6057fc02 100644 --- a/tests/AgentTests/DoubleRatchetTests.hs +++ b/tests/AgentTests/DoubleRatchetTests.hs @@ -580,7 +580,7 @@ initRatchetsKEMProposedAgain = do testRatchetVersions :: PQSupport -> RatchetVersions testRatchetVersions pq = let v = maxVersion $ supportedE2EEncryptVRange pq - in RVersions v v + in RatchetVersions v v encrypt_ :: AlgorithmI a => Maybe PQEncryption -> (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError (ByteString, Ratchet a, SkippedMsgDiff)) encrypt_ pqEnc_ (_, rc, _) msg = From 851ed2d02e2a78c15893ad8bc9c5a4d917eb6a35 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin Date: Sun, 10 Mar 2024 13:29:03 +0000 Subject: [PATCH 22/30] pqdr: more tests, pass e2e version to rcEncrypt from config (#1040) * pqdr: more tests, pass e2e version to rcEncrypt from config * fix --- src/Simplex/Messaging/Agent.hs | 60 ++++++++++++++----------- src/Simplex/Messaging/Agent/Protocol.hs | 2 + src/Simplex/Messaging/Crypto/Ratchet.hs | 23 ++++------ tests/AgentTests/DoubleRatchetTests.hs | 4 +- tests/AgentTests/FunctionalAPITests.hs | 23 +++++----- 5 files changed, 59 insertions(+), 53 deletions(-) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 08678140e..56432b947 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -1143,24 +1143,26 @@ enqueueMessage c cData sq msgFlags aMessage = -- this function is used only for sending messages in batch, it returns the list of successes to enqueue additional deliveries enqueueMessageB :: forall m t. (AgentMonad' m, Traversable t) => AgentClient -> t (Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage)) -> m (t (Either AgentErrorType ((AgentMsgId, PQEncryption), Maybe (ConnData, [SndQueue], AgentMsgId)))) enqueueMessageB c reqs = do - getAVRange <- asks $ smpAgentVRange . config - reqMids <- withStoreBatch c $ \db -> fmap (bindRight $ storeSentMsg db getAVRange) reqs + cfg <- asks config + reqMids <- withStoreBatch c $ \db -> fmap (bindRight $ storeSentMsg db cfg) reqs forME reqMids $ \((cData, sq :| sqs, _, _, _), InternalId msgId, pqSecr) -> do submitPendingMsg c cData sq let sqs' = filter isActiveSndQ sqs pure $ Right ((msgId, pqSecr), if null sqs' then Nothing else Just (cData, sqs', msgId)) where - storeSentMsg :: DB.Connection -> (PQSupport -> VersionRangeSMPA) -> (ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage) -> IO (Either AgentErrorType ((ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage), InternalId, PQEncryption)) - storeSentMsg db getAVRange req@(cData@ConnData {connId, pqSupport}, sq :| _, pqEnc_, msgFlags, aMessage) = fmap (first storeError) $ runExceptT $ do + storeSentMsg :: DB.Connection -> AgentConfig -> (ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage) -> IO (Either AgentErrorType ((ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage), InternalId, PQEncryption)) + storeSentMsg db cfg req@(cData@ConnData {connId, pqSupport}, sq :| _, pqEnc_, msgFlags, aMessage) = fmap (first storeError) $ runExceptT $ do + let AgentConfig {smpAgentVRange, e2eEncryptVRange} = cfg internalTs <- liftIO getCurrentTime (internalId, internalSndId, prevMsgHash) <- liftIO $ updateSndIds db connId let privHeader = APrivHeader (unSndId internalSndId) prevMsgHash agentMsg = AgentMessage privHeader aMessage agentMsgStr = smpEncode agentMsg internalHash = C.sha256Hash agentMsgStr - (encAgentMessage, pqEnc) <- agentRatchetEncrypt db cData agentMsgStr e2eEncUserMsgLength pqEnc_ + currentE2EVersion = maxVersion $ e2eEncryptVRange PQSupportOff + (encAgentMessage, pqEnc) <- agentRatchetEncrypt db cData agentMsgStr e2eEncUserMsgLength pqEnc_ currentE2EVersion -- agent version range is determined by the connection suppport of PQ encryption, that is may be enabled when message is sent - let agentVersion = maxVersion $ getAVRange pqSupport + let agentVersion = maxVersion $ smpAgentVRange pqSupport msgBody = smpEncode $ AgentMsgEnvelope {agentVersion, encAgentMessage} msgType = agentMessageType agentMsg msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgFlags, msgBody, pqEncryption = pqEnc, internalHash, prevMsgHash} @@ -2488,11 +2490,14 @@ confirmQueue (Compatible agentVersion) c cData@ConnData {connId, pqSupport} sq s withStore' c $ \db -> setSndQueueStatus db sq Confirmed where mkConfirmation :: AgentMessage -> m MsgBody - mkConfirmation aMessage = withStore c $ \db -> runExceptT $ do - void . liftIO $ updateSndIds db connId - let pqEnc = CR.pqSupportToEnc pqSupport - (encConnInfo, _) <- agentRatchetEncrypt db cData (smpEncode aMessage) e2eEncConnInfoLength (Just pqEnc) - pure . smpEncode $ AgentConfirmation {agentVersion, e2eEncryption_, encConnInfo} + mkConfirmation aMessage = do + -- the version to be used when PQSupport is disabled + currentE2EVersion <- asks $ maxVersion . ($ PQSupportOff) . e2eEncryptVRange . config + withStore c $ \db -> runExceptT $ do + void . liftIO $ updateSndIds db connId + let pqEnc = CR.pqSupportToEnc pqSupport + (encConnInfo, _) <- agentRatchetEncrypt db cData (smpEncode aMessage) e2eEncConnInfoLength (Just pqEnc) currentE2EVersion + pure . smpEncode $ AgentConfirmation {agentVersion, e2eEncryption_, encConnInfo} mkAgentConfirmation :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> SubscriptionMode -> m AgentMessage mkAgentConfirmation c cData sq srv connInfo subMode = do @@ -2505,18 +2510,21 @@ enqueueConfirmation c cData sq connInfo e2eEncryption_ = do submitPendingMsg c cData sq storeConfirmation :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> AgentMessage -> m () -storeConfirmation c cData@ConnData {connId, pqSupport, connAgentVersion = v} sq e2eEncryption_ agentMsg = withStore c $ \db -> runExceptT $ do - internalTs <- liftIO getCurrentTime - (internalId, internalSndId, prevMsgHash) <- liftIO $ updateSndIds db connId - let agentMsgStr = smpEncode agentMsg - internalHash = C.sha256Hash agentMsgStr - pqEnc = CR.pqSupportToEnc pqSupport - (encConnInfo, pqEncryption) <- agentRatchetEncrypt db cData agentMsgStr e2eEncConnInfoLength (Just pqEnc) - let msgBody = smpEncode $ AgentConfirmation {agentVersion = v, e2eEncryption_, encConnInfo} - msgType = agentMessageType agentMsg - msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgBody, pqEncryption, msgFlags = SMP.MsgFlags {notification = True}, internalHash, prevMsgHash} - liftIO $ createSndMsg db connId msgData - liftIO $ createSndMsgDelivery db connId sq internalId +storeConfirmation c cData@ConnData {connId, pqSupport, connAgentVersion = v} sq e2eEncryption_ agentMsg = do + -- the version to be used when PQSupport is disabled + currentE2EVersion <- asks $ maxVersion . ($ PQSupportOff) . e2eEncryptVRange . config + withStore c $ \db -> runExceptT $ do + internalTs <- liftIO getCurrentTime + (internalId, internalSndId, prevMsgHash) <- liftIO $ updateSndIds db connId + let agentMsgStr = smpEncode agentMsg + internalHash = C.sha256Hash agentMsgStr + pqEnc = CR.pqSupportToEnc pqSupport + (encConnInfo, pqEncryption) <- agentRatchetEncrypt db cData agentMsgStr e2eEncConnInfoLength (Just pqEnc) currentE2EVersion + let msgBody = smpEncode $ AgentConfirmation {agentVersion = v, e2eEncryption_, encConnInfo} + msgType = agentMessageType agentMsg + msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgBody, pqEncryption, msgFlags = SMP.MsgFlags {notification = True}, internalHash, prevMsgHash} + liftIO $ createSndMsg db connId msgData + liftIO $ createSndMsgDelivery db connId sq internalId enqueueRatchetKeyMsgs :: forall m. AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> CR.RcvE2ERatchetParams 'C.X448 -> m () enqueueRatchetKeyMsgs c cData (sq :| sqs) e2eEncryption = do @@ -2546,11 +2554,11 @@ enqueueRatchetKey c cData@ConnData {connId, pqSupport} sq e2eEncryption = do pure internalId -- encoded AgentMessage -> encoded EncAgentMessage -agentRatchetEncrypt :: DB.Connection -> ConnData -> ByteString -> (VersionSMPA -> PQSupport -> Int) -> Maybe PQEncryption -> ExceptT StoreError IO (ByteString, PQEncryption) -agentRatchetEncrypt db ConnData {connId, connAgentVersion = v, pqSupport} msg getPaddedLen pqEnc_ = do +agentRatchetEncrypt :: DB.Connection -> ConnData -> ByteString -> (VersionSMPA -> PQSupport -> Int) -> Maybe PQEncryption -> CR.VersionE2E -> ExceptT StoreError IO (ByteString, PQEncryption) +agentRatchetEncrypt db ConnData {connId, connAgentVersion = v, pqSupport} msg getPaddedLen pqEnc_ currentE2EVersion = do rc <- ExceptT $ getRatchet db connId let paddedLen = getPaddedLen v pqSupport - (encMsg, rc') <- liftE (SEAgentError . cryptoError) $ CR.rcEncrypt rc paddedLen msg pqEnc_ + (encMsg, rc') <- liftE (SEAgentError . cryptoError) $ CR.rcEncrypt rc paddedLen msg pqEnc_ currentE2EVersion liftIO $ updateRatchet db connId rc' CR.SMDNoChange pure (encMsg, CR.rcSndKEM rc') diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index 660951bfc..1005b7195 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -37,9 +37,11 @@ module Simplex.Messaging.Agent.Protocol VersionSMPA, VersionRangeSMPA, pattern VersionSMPA, + duplexHandshakeSMPAgentVersion, ratchetSyncSMPAgentVersion, deliveryRcptsSMPAgentVersion, pqdrSMPAgentVersion, + currentSMPAgentVersion, supportedSMPAgentVRange, e2eEncConnInfoLength, e2eEncUserMsgLength, diff --git a/src/Simplex/Messaging/Crypto/Ratchet.hs b/src/Simplex/Messaging/Crypto/Ratchet.hs index a3251c26c..a0a1d4d10 100644 --- a/src/Simplex/Messaging/Crypto/Ratchet.hs +++ b/src/Simplex/Messaging/Crypto/Ratchet.hs @@ -76,7 +76,6 @@ module Simplex.Messaging.Crypto.Ratchet RatchetKEM (..), RatchetKEMAccepted (..), RatchetKey (..), - ratchetVersions, fullHeaderLen, applySMDiff, encodeMsgHeader, @@ -496,20 +495,17 @@ data RatchetVersions = RatchetVersions instance ToJSON RatchetVersions where -- TODO v5.7 or v5.8 change to the default record encoding - toJSON (RatchetVersions v1 v2) = toJSON (v1, v2) - toEncoding (RatchetVersions v1 v2) = toEncoding (v1, v2) + toJSON RatchetVersions {current, maxSupported} = toJSON (current, maxSupported) + toEncoding RatchetVersions {current, maxSupported} = toEncoding (current, maxSupported) instance FromJSON RatchetVersions where - -- TODO v6.0 replace with the default record parser + -- TODO v5.7 or v5.8 replace comment below with "tuple for backward" -- this parser supports JSON record encoding for forward compatibility - parseJSON v = (tupleP <|> recordP v) >>= toRV + parseJSON v = toRV <$> (tupleP <|> recordP v) where tupleP = parseJSON v recordP = J.withObject "RatchetVersions" $ \o -> (,) <$> o J..: "current" <*> o J..: "maxSupported" - toRV (v1, v2) = maybe (fail "bad version range") (pure . ratchetVersions) $ safeVersionRange v1 v2 - -ratchetVersions :: VersionRangeE2E -> RatchetVersions -ratchetVersions (VersionRange v1 v2) = RatchetVersions {current = v1, maxSupported = v2} + toRV (current, maxSupported) = RatchetVersions {current, maxSupported} data SndRatchet a = SndRatchet { rcDHRr :: PublicKey a, @@ -849,9 +845,9 @@ connPQEncryption = \case IKUsePQ -> PQSupportOn IKNoPQ pq -> pq -- default for creating connection is IKNoPQ PQEncOn -rcEncrypt :: AlgorithmI a => Ratchet a -> Int -> ByteString -> Maybe PQEncryption -> ExceptT CryptoError IO (ByteString, Ratchet a) -rcEncrypt Ratchet {rcSnd = Nothing} _ _ _ = throwE CERatchetState -rcEncrypt rc@Ratchet {rcSnd = Just sr@SndRatchet {rcCKs, rcHKs}, rcDHRs, rcKEM, rcNs, rcPN, rcAD = Str rcAD, rcSupportKEM, rcEnableKEM, rcVersion} paddedMsgLen msg pqEnc_ = do +rcEncrypt :: AlgorithmI a => Ratchet a -> Int -> ByteString -> Maybe PQEncryption -> VersionE2E -> ExceptT CryptoError IO (ByteString, Ratchet a) +rcEncrypt Ratchet {rcSnd = Nothing} _ _ _ _ = throwE CERatchetState +rcEncrypt rc@Ratchet {rcSnd = Just sr@SndRatchet {rcCKs, rcHKs}, rcDHRs, rcKEM, rcNs, rcPN, rcAD = Str rcAD, rcSupportKEM, rcEnableKEM, rcVersion} paddedMsgLen msg pqEnc_ supportedE2EVersion = do -- state.CKs, mk = KDF_CK(state.CKs) let (ck', mk, iv, ehIV) = chainKdf rcCKs v = current rcVersion @@ -862,8 +858,7 @@ rcEncrypt rc@Ratchet {rcSnd = Just sr@SndRatchet {rcCKs, rcHKs}, rcDHRs, rcKEM, -- This sets max version to support PQ encryption. -- Current version upgrade happens when peer decrypts the message. -- TODO note that maxSupported will not downgrade here below current (v). - -- TODO PQ currentE2EEncryptVersion should be passed via config - maxSupported' = max currentE2EEncryptVersion $ if pqEnc_ == Just PQEncOn then pqRatchetE2EEncryptVersion else v + maxSupported' = max supportedE2EVersion $ if pqEnc_ == Just PQEncOn then pqRatchetE2EEncryptVersion else v rcVersion' = rcVersion {maxSupported = maxSupported'} -- enc_header = HENCRYPT(state.HKs, header) (ehAuthTag, ehBody) <- encryptAEAD rcHKs ehIV (paddedHeaderLen v rcSupportKEM') rcAD (msgHeader v maxSupported') diff --git a/tests/AgentTests/DoubleRatchetTests.hs b/tests/AgentTests/DoubleRatchetTests.hs index e6057fc02..f95f07029 100644 --- a/tests/AgentTests/DoubleRatchetTests.hs +++ b/tests/AgentTests/DoubleRatchetTests.hs @@ -358,7 +358,7 @@ testVersionJSON = do testDecodeRV $ (1 :: Int, 2 :: Int) testDecodeRV $ J.object ["current" .= (1 :: Int), "maxSupported" .= (2 :: Int)] where - rv v1 v2 = ratchetVersions $ mkVersionRange (VersionE2E v1) (VersionE2E v2) + rv v1 v2 = RatchetVersions (VersionE2E v1) (VersionE2E v2) testDecodeRV :: ToJSON a => a -> Expectation testDecodeRV a = J.eitherDecode' (J.encode a) `shouldBe` Right (rv 1 2) @@ -585,7 +585,7 @@ testRatchetVersions pq = encrypt_ :: AlgorithmI a => Maybe PQEncryption -> (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError (ByteString, Ratchet a, SkippedMsgDiff)) encrypt_ pqEnc_ (_, rc, _) msg = -- print msg >> - runExceptT (rcEncrypt rc paddedMsgLen msg pqEnc_) + runExceptT (rcEncrypt rc paddedMsgLen msg pqEnc_ currentE2EEncryptVersion) >>= either (pure . Left) checkLength where checkLength (msg', rc') = do diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 9c0355852..e17f44df3 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -176,7 +176,6 @@ smpCfgV7 = (smpCfg agentCfg) {serverVRange = V.mkVersionRange batchCmdsSMPVersio ntfCfgV2 :: ProtocolClientConfig NTFVersion ntfCfgV2 = (smpCfg agentCfg) {serverVRange = V.mkVersionRange (VersionNTF 1) authBatchCmdsNTFVersion} --- TODO PQ test next version with PQ agentCfgVPrev :: AgentConfig agentCfgVPrev = agentCfg @@ -187,10 +186,13 @@ agentCfgVPrev = smpCfg = smpCfgVPrev } +-- agent config for the next client version agentCfgV7 :: AgentConfig agentCfgV7 = agentCfg { sndAuthAlg = C.AuthAlg C.SX25519, + smpAgentVRange = \_ -> V.mkVersionRange duplexHandshakeSMPAgentVersion $ max pqdrSMPAgentVersion currentSMPAgentVersion, + e2eEncryptVRange = \_ -> V.mkVersionRange CR.kdfX3DHE2EEncryptVersion $ max CR.pqRatchetE2EEncryptVersion CR.currentE2EEncryptVersion, smpCfg = smpCfgV7, ntfCfg = ntfCfgV2 } @@ -436,7 +438,6 @@ canCreateQueue allowNew (srvAuth, srvVersion) (clntAuth, clntVersion) = let v = basicAuthSMPVersion in allowNew && (isNothing srvAuth || (srvVersion >= v && clntVersion >= v && srvAuth == clntAuth)) --- TODO PQ test next version with PQ testMatrix2 :: ATransport -> (PQSupport -> AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec testMatrix2 t runTest = do it "v7" $ withSmpServerV7 t $ runTestCfg2 agentCfgV7 agentCfgV7 3 $ runTest PQSupportOn @@ -448,9 +449,11 @@ testMatrix2 t runTest = do it "prev to current" $ withSmpServer t $ runTestCfg2 agentCfgVPrev agentCfg 3 $ runTest PQSupportOff it "current to prev" $ withSmpServer t $ runTestCfg2 agentCfg agentCfgVPrev 3 $ runTest PQSupportOff --- TODO PQ test next version with PQ testRatchetMatrix2 :: ATransport -> (PQSupport -> AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec testRatchetMatrix2 t runTest = do + it "ratchet next" $ withSmpServerV7 t $ runTestCfg2 agentCfgV7 agentCfgV7 3 $ runTest PQSupportOn + it "ratchet next to current" $ withSmpServerV7 t $ runTestCfg2 agentCfgV7 agentCfg 3 $ runTest PQSupportOn + it "ratchet current to next" $ withSmpServerV7 t $ runTestCfg2 agentCfg agentCfgV7 3 $ runTest PQSupportOn it "ratchet current" $ withSmpServer t $ runTestCfg2 agentCfg agentCfg 3 $ runTest PQSupportOn it "ratchet prev" $ withSmpServer t $ runTestCfg2 agentCfgRatchetVPrev agentCfgRatchetVPrev 3 $ runTest PQSupportOff it "ratchets prev to current" $ withSmpServer t $ runTestCfg2 agentCfgRatchetVPrev agentCfg 3 $ runTest PQSupportOff @@ -2478,14 +2481,12 @@ testDeliveryReceiptsVersion t = do ackMessage a' bId 10 $ Just "" get b' =##> \case ("", c, Rcvd 10) -> c == aId; _ -> False ackMessage b' aId 11 Nothing - -- TODO PQ this part hangs when waiting for Rcvd, because connection tries to upgrade to PQ encryption. - -- replacing 2 PQSupportOn with PQEncOff above prevents hanging. - -- (12, _) <- A.sendMessage a' bId PQEncOn SMP.noMsgFlags "hello 2" - -- get a' ##> ("", bId, SENT 12) - -- get b' =##> \case ("", c, Msg' 12 PQEncOff "hello 2") -> c == aId; _ -> False - -- ackMessage b' aId 12 $ Just "" - -- get a' =##> \case ("", c, Rcvd 12) -> c == bId; _ -> False - -- ackMessage a' bId 13 Nothing + (12, _) <- A.sendMessage a' bId PQEncOn SMP.noMsgFlags "hello 2" + get a' ##> ("", bId, SENT 12) + get b' =##> \case ("", c, Msg' 12 PQEncOff "hello 2") -> c == aId; _ -> False + ackMessage b' aId 12 $ Just "" + get a' =##> \case ("", c, Rcvd 12) -> c == bId; _ -> False + ackMessage a' bId 13 Nothing disconnectAgentClient a' disconnectAgentClient b' From b4c90781bba8cca3a8f7bea9e0c2b6707ff923af Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin Date: Sun, 10 Mar 2024 17:53:57 +0000 Subject: [PATCH 23/30] pqdr: update envelope sizes --- src/Simplex/Messaging/Agent/Protocol.hs | 8 ++++---- src/Simplex/Messaging/Crypto/Ratchet.hs | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index 1005b7195..2c06e0279 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -288,14 +288,14 @@ supportedSMPAgentVRange pq = -- signing key of the sender for the server e2eEncConnInfoLength :: VersionSMPA -> PQSupport -> Int e2eEncConnInfoLength v = \case - -- reduced by 3700 (roughly the increase of message ratchet header size + key and ciphertext in reply link) - PQSupportOn | v >= pqdrSMPAgentVersion -> 11148 + -- reduced by 3726 (roughly the increase of message ratchet header size + key and ciphertext in reply link) + PQSupportOn | v >= pqdrSMPAgentVersion -> 11122 _ -> 14848 e2eEncUserMsgLength :: VersionSMPA -> PQSupport -> Int e2eEncUserMsgLength v = \case - -- reduced by 2200 (roughly the increase of message ratchet header size) - PQSupportOn | v >= pqdrSMPAgentVersion -> 13656 + -- reduced by 2222 (the increase of message ratchet header size) + PQSupportOn | v >= pqdrSMPAgentVersion -> 13634 _ -> 15856 -- | Raw (unparsed) SMP agent protocol transmission. diff --git a/src/Simplex/Messaging/Crypto/Ratchet.hs b/src/Simplex/Messaging/Crypto/Ratchet.hs index a0a1d4d10..068f62776 100644 --- a/src/Simplex/Messaging/Crypto/Ratchet.hs +++ b/src/Simplex/Messaging/Crypto/Ratchet.hs @@ -667,10 +667,10 @@ data MsgHeader a = MsgHeader -- to allow extension without increasing the size, the actual header length is: -- 69 = 2 (original size) + 2 + 1+56 (Curve448) + 4 + 4 --- TODO this is the exact size, some reserve should be added +-- The exact size is 2288, added reserve paddedHeaderLen :: VersionE2E -> PQSupport -> Int paddedHeaderLen v = \case - PQSupportOn | v >= pqRatchetE2EEncryptVersion -> 2288 + PQSupportOn | v >= pqRatchetE2EEncryptVersion -> 2310 _ -> 88 -- only used in tests to validate correct padding From 78eb4f764fd52385a8687d2605a0e6edc1808431 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin Date: Sun, 10 Mar 2024 19:41:06 +0000 Subject: [PATCH 24/30] v5.6.0-beta.1 --- package.yaml | 2 +- simplexmq.cabal | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/package.yaml b/package.yaml index 76de72ac6..8b20e8033 100644 --- a/package.yaml +++ b/package.yaml @@ -1,5 +1,5 @@ name: simplexmq -version: 5.6.0.0 +version: 5.6.0.1 synopsis: SimpleXMQ message broker description: | This package includes <./docs/Simplex-Messaging-Server.html server>, diff --git a/simplexmq.cabal b/simplexmq.cabal index 35f916cc8..eedbcc313 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -5,7 +5,7 @@ cabal-version: 1.12 -- see: https://github.com/sol/hpack name: simplexmq -version: 5.6.0.0 +version: 5.6.0.1 synopsis: SimpleXMQ message broker description: This package includes <./docs/Simplex-Messaging-Server.html server>, <./docs/Simplex-Messaging-Client.html client> and From 2cad0cb2016382f951442ca0d1f08f0adb24cbd1 Mon Sep 17 00:00:00 2001 From: Alexander Bondarenko <486682+dpwiz@users.noreply.github.com> Date: Mon, 11 Mar 2024 21:06:53 +0200 Subject: [PATCH 25/30] core: check ACK handling with return type (#1041) * core: check ACK handling with return type * fix ratchet sync * add SQL Locked to dbBusyLoop * rename --------- Co-authored-by: Evgeny Poberezkin --- src/Simplex/Messaging/Agent.hs | 38 +++++++++++-------- .../Messaging/Agent/Store/SQLite/Common.hs | 3 +- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 56432b947..5b77fcac0 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -1962,6 +1962,8 @@ cleanupManager c@AgentClient {subQ} = do notify :: forall e. AEntityI e => EntityId -> ACommand 'Agent e -> ExceptT AgentErrorType m () notify entId cmd = atomically $ writeTBQueue subQ ("", entId, APC (sAEntity @e) cmd) +data ACKd = ACKd | ACKPending + -- | make sure to ACK or throw in each message processing branch -- it cannot be finally, unfortunately, as sometimes it needs to be ACK+DEL processSMPTransmission :: forall m. AgentMonad m => AgentClient -> ServerTransmission SMPVersion BrokerMsg -> m () @@ -1976,13 +1978,14 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, cData@ConnData {userId, connId, connAgentVersion, ratchetSyncState = rss} = withConnLock c connId "processSMP" $ case cmd of SMP.MSG msg@SMP.RcvMessage {msgId = srvMsgId} -> - handleNotifyAck $ do + void . handleNotifyAck $ do msg' <- decryptSMPMessage rq msg - handleNotifyAck $ case msg' of + ack' <- handleNotifyAck $ case msg' of SMP.ClientRcvMsgBody {msgTs = srvTs, msgFlags, msgBody} -> processClientMsg srvTs msgFlags msgBody SMP.ClientRcvMsgQuota {} -> queueDrained >> ack whenM (atomically $ hasGetLock c rq) $ notify (MSGNTF $ SMP.rcvMessageMeta srvMsgId msg') + pure ack' where queueDrained = case conn of DuplexConnection _ _ sqs -> void $ enqueueMessages c cData sqs SMP.noMsgFlags $ QCONT (sndAddress rq) @@ -2005,8 +2008,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, decryptClientMessage e2eDh clientMsg >>= \case (SMP.PHEmpty, AgentRatchetKey {agentVersion, e2eEncryption}) -> do conn' <- updateConnVersion conn cData agentVersion - qDuplex conn' "AgentRatchetKey" $ newRatchetKey e2eEncryption - ack + qDuplex conn' "AgentRatchetKey" $ \a -> newRatchetKey e2eEncryption a >> ack (SMP.PHEmpty, AgentMsgEnvelope {agentVersion, encAgentMessage}) -> do conn' <- updateConnVersion conn cData agentVersion -- primary queue is set as Active in helloMsg, below is to set additional queues Active @@ -2033,6 +2035,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, A_MSG body -> do logServer "<--" c srv rId $ "MSG :" <> logSecret srvMsgId notify $ MSG msgMeta msgFlags body + pure ACKPending A_RCVD rcpts -> qDuplex conn'' "RCVD" $ messagesRcvd rcpts msgMeta QCONT addr -> qDuplexAckDel conn'' "QCONT" $ continueSending srvMsgId addr QADD qs -> qDuplexAckDel conn'' "QADD" $ qAddMsg srvMsgId qs @@ -2043,7 +2046,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, QTEST _ -> logServer "<--" c srv rId ("MSG :" <> logSecret srvMsgId) >> ackDel msgId EREADY _ -> qDuplexAckDel conn'' "EREADY" $ ereadyMsg rcPrev where - qDuplexAckDel :: Connection c -> String -> (Connection 'CDuplex -> m ()) -> m () + qDuplexAckDel :: Connection c -> String -> (Connection 'CDuplex -> m ()) -> m ACKd qDuplexAckDel conn'' name a = qDuplex conn'' name a >> ackDel msgId resetRatchetSync :: m (Connection c) resetRatchetSync @@ -2064,7 +2067,8 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, AgentMessage _ (A_MSG body) -> do logServer "<--" c srv rId $ "MSG :" <> logSecret srvMsgId notify $ MSG msgMeta msgFlags body - _ -> pure () + pure ACKPending + _ -> ack _ -> checkDuplicateHash e encryptedMsgHash >> ack Left (AGENT (A_CRYPTO e)) -> do exists <- withStore' c $ \db -> checkRcvMsgHashExists db connId encryptedMsgHash @@ -2117,11 +2121,11 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, pure $ updateConnection cData'' conn' | otherwise -> pure conn' Nothing -> pure conn' - ack :: m () - ack = enqueueCmd $ ICAck rId srvMsgId - ackDel :: InternalId -> m () - ackDel = enqueueCmd . ICAckDel rId srvMsgId - handleNotifyAck :: m () -> m () + ack :: m ACKd + ack = enqueueCmd (ICAck rId srvMsgId) $> ACKd + ackDel :: InternalId -> m ACKd + ackDel aId = enqueueCmd (ICAckDel rId srvMsgId aId) $> ACKd + handleNotifyAck :: m ACKd -> m ACKd handleNotifyAck m = m `catchAgentError` \e -> notify (ERR e) >> ack SMP.END -> atomically (TM.lookup tSess smpClients $>>= (tryReadTMVar . sessionVar) >>= processEND) @@ -2249,14 +2253,16 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, >>= mapM_ (\(_, retryLock) -> tryPutTMVar retryLock ()) Nothing -> qError "QCONT: queue address not found" - messagesRcvd :: NonEmpty AMessageReceipt -> MsgMeta -> Connection 'CDuplex -> m () + messagesRcvd :: NonEmpty AMessageReceipt -> MsgMeta -> Connection 'CDuplex -> m ACKd messagesRcvd rcpts msgMeta@MsgMeta {broker = (srvMsgId, _)} _ = do logServer "<--" c srv rId $ "MSG :" <> logSecret srvMsgId rs <- forM rcpts $ \rcpt -> clientReceipt rcpt `catchAgentError` \e -> notify (ERR e) $> Nothing case L.nonEmpty . catMaybes $ L.toList rs of - Just rs' -> notify $ RCVD msgMeta rs' -- client must ACK once processed - Nothing -> enqueueCmd $ ICAck rId srvMsgId + Just rs' -> notify (RCVD msgMeta rs') $> ACKPending + Nothing -> ack where + ack :: m ACKd + ack = enqueueCmd (ICAck rId srvMsgId) $> ACKd clientReceipt :: AMessageReceipt -> m (Maybe MsgReceipt) clientReceipt AMessageReceipt {agentMsgId, msgHash} = do let sndMsgId = InternalSndId agentMsgId @@ -2347,7 +2353,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, _ -> qError "QUSE: switching SndQueue not found in connection" _ -> qError "QUSE: switched queue address not found in connection" - qError :: String -> m () + qError :: String -> m a qError = throwError . AGENT . A_QUEUE ereadyMsg :: CR.RatchetX448 -> Connection 'CDuplex -> m () @@ -2375,7 +2381,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, pqSupported (_, Compatible (CR.E2ERatchetParams v _ _ _), Compatible agentVersion) = PQSupportOn `CR.pqSupportAnd` versionPQSupport_ agentVersion (Just v) - qDuplex :: Connection c -> String -> (Connection 'CDuplex -> m ()) -> m () + qDuplex :: Connection c -> String -> (Connection 'CDuplex -> m a) -> m a qDuplex conn' name action = case conn' of DuplexConnection {} -> action conn' _ -> qError $ name <> ": message must be sent to duplex connection" diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs index 9ba6cd08f..18c16cc8b 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs @@ -71,7 +71,8 @@ dbBusyLoop action = loop 500 3000000 loop :: Int -> Int -> IO a loop t tLim = action `E.catch` \(e :: SQLError) -> - if tLim > t && SQL.sqlError e == SQL.ErrorBusy + let se = SQL.sqlError e in + if tLim > t && (se == SQL.ErrorBusy || se == SQL.ErrorLocked) then do threadDelay t loop (t * 9 `div` 8) (tLim - t) From 0aa4ae72286237d066c3ce2bff355638523c7095 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin Date: Tue, 12 Mar 2024 14:31:30 +0000 Subject: [PATCH 26/30] v5.6.0-beta.2 --- package.yaml | 2 +- simplexmq.cabal | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/package.yaml b/package.yaml index 8b20e8033..7bac7a4f4 100644 --- a/package.yaml +++ b/package.yaml @@ -1,5 +1,5 @@ name: simplexmq -version: 5.6.0.1 +version: 5.6.0.2 synopsis: SimpleXMQ message broker description: | This package includes <./docs/Simplex-Messaging-Server.html server>, diff --git a/simplexmq.cabal b/simplexmq.cabal index eedbcc313..9cb5fd3be 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -5,7 +5,7 @@ cabal-version: 1.12 -- see: https://github.com/sol/hpack name: simplexmq -version: 5.6.0.1 +version: 5.6.0.2 synopsis: SimpleXMQ message broker description: This package includes <./docs/Simplex-Messaging-Server.html server>, <./docs/Simplex-Messaging-Client.html client> and From e93ea6df714b938e0dcff7098bd943f5643e1eef Mon Sep 17 00:00:00 2001 From: spaced4ndy <8711996+spaced4ndy@users.noreply.github.com> Date: Wed, 13 Mar 2024 13:33:43 +0400 Subject: [PATCH 27/30] xftp: fix sending large files (#1043) --- src/Simplex/FileTransfer/Description.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Simplex/FileTransfer/Description.hs b/src/Simplex/FileTransfer/Description.hs index 58bcb9df3..d5b5e5105 100644 --- a/src/Simplex/FileTransfer/Description.hs +++ b/src/Simplex/FileTransfer/Description.hs @@ -227,7 +227,7 @@ validateFileDescription fd@FileDescription {size, chunks} | otherwise = Right $ ValidFD fd where chunkNos = map (\FileChunk {chunkNo} -> chunkNo) chunks - chunksSize = fromIntegral . foldl' (\s FileChunk {chunkSize} -> s + unFileSize chunkSize) 0 + chunksSize = foldl' (\(s :: Int64) FileChunk {chunkSize} -> s + fromIntegral (unFileSize chunkSize)) 0 encodeFileDescription :: FileDescription p -> YAMLFileDescription encodeFileDescription FileDescription {party, size, digest, key, nonce, chunkSize, chunks, redirect} = From 293a2ca3f10232fa8a5221388344acf68643ad92 Mon Sep 17 00:00:00 2001 From: spaced4ndy <8711996+spaced4ndy@users.noreply.github.com> Date: Wed, 13 Mar 2024 13:33:59 +0400 Subject: [PATCH 28/30] agent: remove withStoreCtx (#1044) --- src/Simplex/Messaging/Agent.hs | 16 +++++++------- src/Simplex/Messaging/Agent/Client.hs | 31 ++++----------------------- 2 files changed, 12 insertions(+), 35 deletions(-) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 5b77fcac0..2b2db1dd5 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -612,9 +612,9 @@ ackMessageAsync' c corrId connId msgId rcptInfo_ = do enqueueAck :: m () enqueueAck = do let mId = InternalId msgId - RcvMsg {msgType} <- withStoreCtx "ackMessageAsync': getRcvMsg" c $ \db -> getRcvMsg db connId mId + RcvMsg {msgType} <- withStore c $ \db -> getRcvMsg db connId mId when (isJust rcptInfo_ && msgType /= AM_A_MSG_) $ throwError $ CMD PROHIBITED - (RcvQueue {server}, _) <- withStoreCtx "ackMessageAsync': setMsgUserAck" c $ \db -> setMsgUserAck db connId mId + (RcvQueue {server}, _) <- withStore c $ \db -> setMsgUserAck db connId mId enqueueCommand c corrId connId (Just server) . AClientCommand $ APC SAEConn $ ACK msgId rcptInfo_ deleteConnectionAsync' :: forall m. AgentMonad m => AgentClient -> Bool -> ConnId -> m () @@ -1367,13 +1367,13 @@ ackMessage' c connId msgId rcptInfo_ = withConnLock c connId "ackMessage" $ do ack :: m () ack = do -- the stored message was delivered via a specific queue, the rest failed to decrypt and were already acknowledged - (rq, srvMsgId) <- withStoreCtx "ackMessage': setMsgUserAck" c $ \db -> setMsgUserAck db connId $ InternalId msgId + (rq, srvMsgId) <- withStore c $ \db -> setMsgUserAck db connId $ InternalId msgId ackQueueMessage c rq srvMsgId del :: m () - del = withStoreCtx' "ackMessage': deleteMsg" c $ \db -> deleteMsg db connId $ InternalId msgId + del = withStore' c $ \db -> deleteMsg db connId $ InternalId msgId sendRcpt :: Connection 'CDuplex -> m () sendRcpt (DuplexConnection cData@ConnData {connAgentVersion} _ sqs) = do - msg@RcvMsg {msgType, msgReceipt} <- withStoreCtx "ackMessage': getRcvMsg" c $ \db -> getRcvMsg db connId $ InternalId msgId + msg@RcvMsg {msgType, msgReceipt} <- withStore c $ \db -> getRcvMsg db connId $ InternalId msgId case rcptInfo_ of Just rcptInfo -> do unless (msgType == AM_A_MSG_) $ throwError (CMD PROHIBITED) @@ -1384,7 +1384,7 @@ ackMessage' c connId msgId rcptInfo_ = withConnLock c connId "ackMessage" $ do Nothing -> case (msgType, msgReceipt) of -- only remove sent message if receipt hash was Ok, both to debug and for future redundancy (AM_A_RCVD_, Just MsgReceipt {agentMsgId = sndMsgId, msgRcptStatus = MROk}) -> - withStoreCtx' "ackMessage': deleteDeliveredSndMsg" c $ \db -> deleteDeliveredSndMsg db connId $ InternalId sndMsgId + withStore' c $ \db -> deleteDeliveredSndMsg db connId $ InternalId sndMsgId _ -> pure () switchConnection' :: AgentMonad m => AgentClient -> ConnId -> m ConnectionStats @@ -2059,7 +2059,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, | otherwise = pure conn' Right _ -> prohibited >> ack Left e@(AGENT A_DUPLICATE) -> do - withStoreCtx' "processSMP: getLastMsg" c (\db -> getLastMsg db connId srvMsgId) >>= \case + withStore' c (\db -> getLastMsg db connId srvMsgId) >>= \case Just RcvMsg {internalId, msgMeta, msgBody = agentMsgBody, userAck} | userAck -> ackDel internalId | otherwise -> do @@ -2266,7 +2266,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, clientReceipt :: AMessageReceipt -> m (Maybe MsgReceipt) clientReceipt AMessageReceipt {agentMsgId, msgHash} = do let sndMsgId = InternalSndId agentMsgId - SndMsg {internalId = InternalId msgId, msgType, internalHash, msgReceipt} <- withStoreCtx "messagesRcvd: getSndMsgViaRcpt" c $ \db -> getSndMsgViaRcpt db connId sndMsgId + SndMsg {internalId = InternalId msgId, msgType, internalHash, msgReceipt} <- withStore c $ \db -> getSndMsgViaRcpt db connId sndMsgId if msgType /= AM_A_MSG_ then notify (ERR $ AGENT A_PROHIBITED) $> Nothing -- unexpected message type for receipt else case msgReceipt of diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index f60ddea26..8b3b87122 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -110,8 +110,6 @@ module Simplex.Messaging.Agent.Client whenSuspending, withStore, withStore', - withStoreCtx, - withStoreCtx', withStoreBatch, withStoreBatch', storeError, @@ -1457,34 +1455,13 @@ waitUntilForeground :: AgentClient -> STM () waitUntilForeground c = unlessM ((ASForeground ==) <$> readTVar (agentState c)) retry withStore' :: AgentMonad m => AgentClient -> (DB.Connection -> IO a) -> m a -withStore' = withStoreCtx_' Nothing +withStore' c action = withStore c $ fmap Right . action withStore :: AgentMonad m => AgentClient -> (DB.Connection -> IO (Either StoreError a)) -> m a -withStore = withStoreCtx_ Nothing - -withStoreCtx' :: AgentMonad m => String -> AgentClient -> (DB.Connection -> IO a) -> m a -withStoreCtx' = withStoreCtx_' . Just - -withStoreCtx :: AgentMonad m => String -> AgentClient -> (DB.Connection -> IO (Either StoreError a)) -> m a -withStoreCtx = withStoreCtx_ . Just - -withStoreCtx_' :: AgentMonad m => Maybe String -> AgentClient -> (DB.Connection -> IO a) -> m a -withStoreCtx_' ctx_ c action = withStoreCtx_ ctx_ c $ fmap Right . action - -withStoreCtx_ :: AgentMonad m => Maybe String -> AgentClient -> (DB.Connection -> IO (Either StoreError a)) -> m a -withStoreCtx_ ctx_ c action = do +withStore c action = do st <- asks store - liftEitherError storeError . agentOperationBracket c AODatabase (\_ -> pure ()) $ case ctx_ of - Nothing -> withTransaction st action `E.catch` handleInternal "" - -- uncomment to debug store performance - -- Just ctx -> do - -- t1 <- liftIO getCurrentTime - -- putStrLn $ "agent withStoreCtx start :: " <> show t1 <> " :: " <> ctx - -- r <- withTransaction st action `E.catch` handleInternal (" (" <> ctx <> ")") - -- t2 <- liftIO getCurrentTime - -- putStrLn $ "agent withStoreCtx end :: " <> show t2 <> " :: " <> ctx <> " :: duration=" <> show (diffToMilliseconds $ diffUTCTime t2 t1) - -- pure r - Just _ -> withTransaction st action `E.catch` handleInternal "" + liftEitherError storeError . agentOperationBracket c AODatabase (\_ -> pure ()) $ + withTransaction st action `E.catch` handleInternal "" where handleInternal :: String -> E.SomeException -> IO (Either StoreError a) handleInternal ctxStr e = pure . Left . SEInternal . B.pack $ show e <> ctxStr From ace09cc07dbede7e1f73c260241795077957cc72 Mon Sep 17 00:00:00 2001 From: Alexander Bondarenko <486682+dpwiz@users.noreply.github.com> Date: Fri, 15 Mar 2024 10:08:52 +0200 Subject: [PATCH 29/30] xftp: force single chunk for redirect descriptions (#1050) * xftp: force single chunk for redirect descriptions * Update src/Simplex/FileTransfer/Client/Main.hs Co-authored-by: spaced4ndy <8711996+spaced4ndy@users.noreply.github.com> --------- Co-authored-by: spaced4ndy <8711996+spaced4ndy@users.noreply.github.com> --- src/Simplex/FileTransfer/Agent.hs | 12 ++++++++---- src/Simplex/FileTransfer/Client/Main.hs | 9 ++++++++- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/Simplex/FileTransfer/Agent.hs b/src/Simplex/FileTransfer/Agent.hs index 5666b63ff..f2138d840 100644 --- a/src/Simplex/FileTransfer/Agent.hs +++ b/src/Simplex/FileTransfer/Agent.hs @@ -389,16 +389,20 @@ runXFTPSndPrepareWorker c Worker {doWork} = do where AgentConfig {xftpMaxRecipientsPerRequest = maxRecipients, messageRetryInterval = ri} = cfg encryptFileForUpload :: SndFile -> FilePath -> m (FileDigest, [(XFTPChunkSpec, FileDigest)]) - encryptFileForUpload SndFile {key, nonce, srcFile} fsEncPath = do + encryptFileForUpload SndFile {key, nonce, srcFile, redirect} fsEncPath = do let CryptoFile {filePath} = srcFile fileName = takeFileName filePath fileSize <- liftIO $ fromInteger <$> CF.getFileContentsSize srcFile when (fileSize > maxFileSizeHard) $ throwError $ INTERNAL "max file size exceeded" let fileHdr = smpEncode FileHeader {fileName, fileExtra = Nothing} fileSize' = fromIntegral (B.length fileHdr) + fileSize - chunkSizes = prepareChunkSizes $ fileSize' + fileSizeLen + authTagSize - chunkSizes' = map fromIntegral chunkSizes - encSize = sum chunkSizes' + payloadSize = fileSize' + fileSizeLen + authTagSize + chunkSizes <- case redirect of + Nothing -> pure $ prepareChunkSizes payloadSize + Just _ -> case singleChunkSize payloadSize of + Nothing -> throwError $ INTERNAL "max file size exceeded for redirect" + Just chunkSize -> pure [chunkSize] + let encSize = sum $ map fromIntegral chunkSizes void $ liftError (INTERNAL . show) $ encryptFile srcFile fileHdr key nonce fileSize' encSize fsEncPath digest <- liftIO $ LC.sha512Hash <$> LB.readFile fsEncPath let chunkSpecs = prepareChunkSpecs fsEncPath chunkSizes diff --git a/src/Simplex/FileTransfer/Client/Main.hs b/src/Simplex/FileTransfer/Client/Main.hs index d0c867b27..90348099a 100644 --- a/src/Simplex/FileTransfer/Client/Main.hs +++ b/src/Simplex/FileTransfer/Client/Main.hs @@ -16,6 +16,7 @@ module Simplex.FileTransfer.Client.Main xftpClientCLI, cliSendFile, cliSendFileOpts, + singleChunkSize, prepareChunkSizes, prepareChunkSpecs, maxFileSize, @@ -42,7 +43,7 @@ import Data.List.NonEmpty (NonEmpty (..), nonEmpty) import qualified Data.List.NonEmpty as L import Data.Map (Map) import qualified Data.Map as M -import Data.Maybe (fromMaybe) +import Data.Maybe (fromMaybe, listToMaybe) import qualified Data.Text as T import Data.Word (Word32) import GHC.Records (HasField (getField)) @@ -528,6 +529,12 @@ getFileDescription' path = getFileDescription path >>= \case AVFD fd -> either (throwError . CLIError) pure $ checkParty fd +singleChunkSize :: Int64 -> Maybe Word32 +singleChunkSize size' = + listToMaybe $ dropWhile (< chunkSize) serverChunkSizes + where + chunkSize = fromIntegral size' + prepareChunkSizes :: Int64 -> [Word32] prepareChunkSizes size' = prepareSizes size' where From ca68eca86ef92ae266a4005ab1ad57b589f83933 Mon Sep 17 00:00:00 2001 From: Alexander Bondarenko <486682+dpwiz@users.noreply.github.com> Date: Fri, 15 Mar 2024 10:18:15 +0200 Subject: [PATCH 30/30] agent: fix leak in getChunkDigest (#1051) --- src/Simplex/FileTransfer/Agent.hs | 5 +++-- src/Simplex/FileTransfer/Client/Main.hs | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/Simplex/FileTransfer/Agent.hs b/src/Simplex/FileTransfer/Agent.hs index f2138d840..2abf8e3dc 100644 --- a/src/Simplex/FileTransfer/Agent.hs +++ b/src/Simplex/FileTransfer/Agent.hs @@ -35,6 +35,7 @@ import Control.Monad.Reader import Data.Bifunctor (first) import qualified Data.ByteString.Char8 as B import qualified Data.ByteString.Lazy.Char8 as LB +import Data.Coerce (coerce) import Data.Composition ((.:)) import Data.Either (rights) import Data.Int (Int64) @@ -406,8 +407,8 @@ runXFTPSndPrepareWorker c Worker {doWork} = do void $ liftError (INTERNAL . show) $ encryptFile srcFile fileHdr key nonce fileSize' encSize fsEncPath digest <- liftIO $ LC.sha512Hash <$> LB.readFile fsEncPath let chunkSpecs = prepareChunkSpecs fsEncPath chunkSizes - chunkDigests <- map FileDigest <$> mapM (liftIO . getChunkDigest) chunkSpecs - pure (FileDigest digest, zip chunkSpecs chunkDigests) + chunkDigests <- liftIO $ mapM getChunkDigest chunkSpecs + pure (FileDigest digest, zip chunkSpecs $ coerce chunkDigests) chunkCreated :: SndFileChunk -> Bool chunkCreated SndFileChunk {replicas} = any (\SndFileChunkReplica {replicaStatus} -> replicaStatus == SFRSCreated) replicas diff --git a/src/Simplex/FileTransfer/Client/Main.hs b/src/Simplex/FileTransfer/Client/Main.hs index 90348099a..bca41cea8 100644 --- a/src/Simplex/FileTransfer/Client/Main.hs +++ b/src/Simplex/FileTransfer/Client/Main.hs @@ -414,7 +414,8 @@ getChunkDigest :: XFTPChunkSpec -> IO ByteString getChunkDigest XFTPChunkSpec {filePath = chunkPath, chunkOffset, chunkSize} = withFile chunkPath ReadMode $ \h -> do hSeek h AbsoluteSeek $ fromIntegral chunkOffset - LC.sha256Hash <$> LB.hGet h (fromIntegral chunkSize) + chunk <- LB.hGet h (fromIntegral chunkSize) + pure $! LC.sha256Hash chunk cliReceiveFile :: ReceiveOptions -> ExceptT CLIError IO () cliReceiveFile ReceiveOptions {fileDescription, filePath, retryCount, tempPath, verbose, yes} =