diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 56a1887e2..e74d7b57d 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -305,7 +305,7 @@ newConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> SConnectionMode c -> newConn c connId enableNtfs cMode = do srv <- getSMPServer c clientVRange <- asks $ smpClientVRange . config - (rq, qUri) <- newRcvQueue c srv clientVRange + (rq, qUri) <- newRcvQueue c srv clientVRange False g <- asks idsDrg connAgentVersion <- asks $ maxVersion . smpAgentVRange . config let cData = ConnData {connId, connAgentVersion, enableNtfs, duplexHandshake = Nothing} -- connection mode is determined by the accepting agent @@ -368,7 +368,7 @@ joinConn c connId enableNtfs (CRContactUri (ConnReqUriData _ agentVRange (qUri : createReplyQueue :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> m SMPQueueInfo createReplyQueue c ConnData {connId, enableNtfs} SndQueue {smpClientVersion} = do srv <- getSMPServer c - (rq, qUri) <- newRcvQueue c srv $ versionToRange smpClientVersion + (rq, qUri) <- newRcvQueue c srv (versionToRange smpClientVersion) False let qInfo = toVersionT qUri smpClientVersion addSubscription c rq connId withStore c $ \db -> upgradeSndConnToDuplex db connId rq @@ -459,7 +459,7 @@ createNextRcvQueue c cData rq@RcvQueue {server, sndId} sq = do pure SMPQueueUri {clientVRange, queueAddress} _ -> do srv <- getSMPServer c - (rq', qUri) <- newRcvQueue c srv clientVRange + (rq', qUri) <- newRcvQueue c srv clientVRange True withStore' c $ \db -> dbCreateNextRcvQueue db rq rq' pure qUri void $ enqueueMessage c cData sq SMP.noMsgFlags QNEW {currentAddress = (server, sndId), nextQueueUri} diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 7a36097df..9043bcb40 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -471,10 +471,10 @@ protocolClientError protocolError_ = \case e@PCESignatureError {} -> INTERNAL $ show e e@PCEIOError {} -> INTERNAL $ show e -newRcvQueue :: AgentMonad m => AgentClient -> SMPServer -> VersionRange -> m (RcvQueue, SMPQueueUri) -newRcvQueue c srv vRange = +newRcvQueue :: AgentMonad m => AgentClient -> SMPServer -> VersionRange -> Bool -> m (RcvQueue, SMPQueueUri) +newRcvQueue c srv vRange next = asks (cmdSignAlg . config) >>= \case - C.SignAlg a -> newRcvQueue_ a c srv vRange + C.SignAlg a -> newRcvQueue_ a c srv vRange next newRcvQueue_ :: (C.SignatureAlgorithm a, C.AlgorithmI a, AgentMonad m) => @@ -482,8 +482,9 @@ newRcvQueue_ :: AgentClient -> SMPServer -> VersionRange -> + Bool -> m (RcvQueue, SMPQueueUri) -newRcvQueue_ a c srv vRange = do +newRcvQueue_ a c srv vRange next = do (recipientKey, rcvPrivateKey) <- liftIO $ C.generateSignatureKeyPair a (dhKey, privDhKey) <- liftIO C.generateKeyPair' (e2eDhKey, e2ePrivKey) <- liftIO C.generateKeyPair' @@ -504,6 +505,7 @@ newRcvQueue_ a c srv vRange = do sndPublicKey = Nothing, status = New, rcvQueueAction = Nothing, + nextRcvQueue = next, dbNextRcvQueueId = Nothing, clientNtfCreds = Nothing, smpClientVersion = maxVersion vRange, diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index aea0a7122..ad78e5cb3 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -61,6 +61,8 @@ data RcvQueue = RcvQueue status :: QueueStatus, -- | action to perform, to be done on connection subscription, if it fails and not reset rcvQueueAction :: Maybe (RcvQueueAction, UTCTime), + -- | True if this is the queue the connection is switching to, rather than the current queue + nextRcvQueue :: Bool, -- | database ID of the new queue created for this queue to switch to (queue rotation) dbNextRcvQueueId :: Maybe Int64, -- | credentials used in context of notifications diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 6eca35eda..fb42d77bd 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -15,6 +15,7 @@ {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -fno-warn-orphans #-} @@ -387,13 +388,13 @@ setRcvQueueNtfCreds db connId clientNtfCreds = getNextRcvQueue :: DB.Connection -> Maybe Int64 -> IO (Maybe RcvQueue) getNextRcvQueue db = \case Just rqId -> - listToMaybe . map rcvQueue - <$> DB.query + maybeFirstRow toRcvQueue $ + DB.query db [sql| SELECT q.host, q.port, s.key_hash, q.rcv_id, q.rcv_private_key, q.rcv_dh_secret, q.e2e_priv_key, q.e2e_dh_secret, q.snd_id, q.snd_key, q.status, - q.rcv_queue_action, q.rcv_queue_action_ts, q.next_rcv_queue_id, + q.rcv_queue_action, q.rcv_queue_action_ts, q.next_rcv_queue, q.next_rcv_queue_id, q.ntf_public_key, q.ntf_private_key, q.ntf_id, q.rcv_ntf_dh_secret, q.smp_client_version, q.created_at, q.updated_at FROM rcv_queues q @@ -401,15 +402,6 @@ getNextRcvQueue db = \case WHERE q.rcv_queue_id = ? AND q.next_rcv_queue = ? |] (rqId, True) - where - rcvQueue (srvRow :. (rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, sndPublicKey, status) :. (rqAction_, rqActionTs_, dbNextRcvQueueId) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) :. (smpClientVersion_, createdAt, updatedAt)) = - let server = toSMPServer srvRow - smpClientVersion = fromMaybe 1 smpClientVersion_ - rcvQueueAction = (,) <$> rqAction_ <*> rqActionTs_ - clientNtfCreds = case (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) of - (Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret) -> Just $ ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret} - _ -> Nothing - in RcvQueue {server, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, sndPublicKey, status, rcvQueueAction, dbNextRcvQueueId, smpClientVersion, clientNtfCreds, createdAt, updatedAt} _ -> pure Nothing getNextSndQueue :: DB.Connection -> Maybe Int64 -> IO (Maybe SndQueue) @@ -1182,38 +1174,58 @@ getConn :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn) getConn dbConn connId = getConnData dbConn connId >>= \case Nothing -> pure $ Left SEConnNotFound - Just (connData, cMode) -> do + Just (cData, cMode) -> do rQ <- getRcvQueueByConnId_ dbConn connId sQ <- getSndQueueByConnId_ dbConn connId pure $ case (rQ, sQ, cMode) of - (Just rcvQ, Just sndQ, CMInvitation) -> Right $ SomeConn SCDuplex (DuplexConnection connData rcvQ sndQ) - (Just rcvQ, Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection connData rcvQ) - (Nothing, Just sndQ, CMInvitation) -> Right $ SomeConn SCSnd (SndConnection connData sndQ) - (Just rcvQ, Nothing, CMContact) -> Right $ SomeConn SCContact (ContactConnection connData rcvQ) + (Just rcvQ, Just sndQ, CMInvitation) -> Right $ SomeConn SCDuplex (DuplexConnection cData rcvQ sndQ) + (Just rcvQ, Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection cData rcvQ) + (Nothing, Just sndQ, CMInvitation) -> Right $ SomeConn SCSnd (SndConnection cData sndQ) + (Just rcvQ, Nothing, CMContact) -> Right $ SomeConn SCContact (ContactConnection cData rcvQ) _ -> Left SEConnNotFound getConnData :: DB.Connection -> ConnId -> IO (Maybe (ConnData, ConnectionMode)) getConnData dbConn connId' = - connData - <$> DB.query dbConn "SELECT conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake FROM connections WHERE conn_id = ?;" (Only connId') + maybeFirstRow toConnData $ + DB.query dbConn "SELECT conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake FROM connections WHERE conn_id = ?;" (Only connId') where - connData [(connId, cMode, connAgentVersion, enableNtfs_, duplexHandshake)] = Just (ConnData {connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, duplexHandshake}, cMode) - connData _ = Nothing + toConnData (connId, cMode, connAgentVersion, enableNtfs_, duplexHandshake) = (ConnData {connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, duplexHandshake}, cMode) + +type RcvQueueRow = + ServerRow + :. (SMP.RecipientId, SMP.RcvPrivateSignKey, SMP.RcvDhSecret, C.PrivateKeyX25519, Maybe C.DhSecretX25519, SMP.SenderId, Maybe C.APublicVerifyKey, QueueStatus) + :. (Maybe RcvQueueAction, Maybe UTCTime, Bool, Maybe Int64) + :. NtfCredsRow + :. (Maybe Version, UTCTime, UTCTime) type ServerRow = (NonEmpty TransportHost, String, C.KeyHash) +type NtfCredsRow = (Maybe SMP.NtfPublicVerifyKey, Maybe SMP.NtfPrivateSignKey, Maybe SMP.NotifierId, Maybe RcvNtfDhSecret) + +toRcvQueue :: RcvQueueRow -> RcvQueue +toRcvQueue (srvRow :. (rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, sndPublicKey, status) :. (rqAction_, rqActionTs_, nextRcvQueue, dbNextRcvQueueId) :. ntfCredsRow :. (smpClientVersion_, createdAt, updatedAt)) = + let server = toSMPServer srvRow + smpClientVersion = fromMaybe 1 smpClientVersion_ + rcvQueueAction = (,) <$> rqAction_ <*> rqActionTs_ + clientNtfCreds = toNtfCreds ntfCredsRow + in RcvQueue {server, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, sndPublicKey, status, rcvQueueAction, nextRcvQueue, dbNextRcvQueueId, smpClientVersion, clientNtfCreds, createdAt, updatedAt} + toSMPServer :: ServerRow -> SMPServer toSMPServer (host, port, keyHash) = SMPServer host port keyHash +toNtfCreds :: NtfCredsRow -> Maybe ClientNtfCreds +toNtfCreds (Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret) = Just $ ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret} +toNtfCreds _ = Nothing + getRcvQueueByConnId_ :: DB.Connection -> ConnId -> IO (Maybe RcvQueue) getRcvQueueByConnId_ dbConn connId = - listToMaybe . map rcvQueue + listToMaybe . map toRcvQueue <$> DB.query dbConn [sql| SELECT q.host, q.port, s.key_hash, q.rcv_id, q.rcv_private_key, q.rcv_dh_secret, q.e2e_priv_key, q.e2e_dh_secret, q.snd_id, q.snd_key, q.status, - q.rcv_queue_action, q.rcv_queue_action_ts, q.next_rcv_queue_id, + q.rcv_queue_action, q.rcv_queue_action_ts, q.next_rcv_queue, q.next_rcv_queue_id, q.ntf_public_key, q.ntf_private_key, q.ntf_id, q.rcv_ntf_dh_secret, q.smp_client_version, q.created_at, q.updated_at FROM rcv_queues q @@ -1221,15 +1233,6 @@ getRcvQueueByConnId_ dbConn connId = WHERE q.conn_id = ? AND q.next_rcv_queue = ? |] (connId, False) - where - rcvQueue (srvRow :. (rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, sndPublicKey, status) :. (rqAction_, rqActionTs_, dbNextRcvQueueId) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) :. (smpClientVersion_, createdAt, updatedAt)) = - let server = toSMPServer srvRow - smpClientVersion = fromMaybe 1 smpClientVersion_ - rcvQueueAction = (,) <$> rqAction_ <*> rqActionTs_ - clientNtfCreds = case (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) of - (Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret) -> Just $ ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret} - _ -> Nothing - in RcvQueue {server, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, sndPublicKey, status, rcvQueueAction, dbNextRcvQueueId, smpClientVersion, clientNtfCreds, createdAt, updatedAt} getSndQueueByConnId_ :: DB.Connection -> ConnId -> IO (Maybe SndQueue) getSndQueueByConnId_ dbConn connId = diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index b2d7f126c..cd6b97255 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -166,6 +166,7 @@ rcvQueue1 = sndId = "2345", sndPublicKey = Nothing, status = New, + nextRcvQueue = False, dbNextRcvQueueId = Nothing, rcvQueueAction = Nothing, clientNtfCreds = Nothing, @@ -349,6 +350,7 @@ testUpgradeSndConnToDuplex = sndId = "4567", sndPublicKey = Nothing, status = New, + nextRcvQueue = False, dbNextRcvQueueId = Nothing, rcvQueueAction = Nothing, clientNtfCreds = Nothing,