From a8caab810a6b8acd6ce23dba5d022c9a5d21cbf3 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Sat, 27 Aug 2022 17:35:54 +0100 Subject: [PATCH] store methods --- src/Simplex/Messaging/Agent.hs | 18 +-- src/Simplex/Messaging/Agent/Store/SQLite.hs | 127 +++++++++++++------- 2 files changed, 95 insertions(+), 50 deletions(-) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index f21399e24..3ed17ab39 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -452,24 +452,24 @@ doRcvQueueAction c cData rq@RcvQueue {rcvQueueAction} sq = where withNextRcvQueue :: AgentMonad m => (AgentClient -> ConnData -> RcvQueue -> SndQueue -> RcvQueue -> m ()) -> m () withNextRcvQueue action = do - withStore' c (`getNextRcvQueue` dbNextRcvQueueId rq) >>= \case + withStore' c (`getNextRcvQueue` rq) >>= \case Just rq' -> action c cData rq sq rq' _ -> do -- notify agent internal error pure () createNextRcvQueue :: AgentMonad m => AgentClient -> ConnData -> RcvQueue -> SndQueue -> m () -createNextRcvQueue c cData rq@RcvQueue {server, sndId} sq = do +createNextRcvQueue c cData@ConnData {connId} rq@RcvQueue {server, sndId} sq = do clientVRange <- asks $ smpClientVRange . config nextQueueUri <- - withStore' c (`getNextRcvQueue` dbNextRcvQueueId rq) >>= \case + withStore' c (`getNextRcvQueue` rq) >>= \case Just RcvQueue {server = smpServer, sndId = senderId, e2ePrivKey} -> do let queueAddress = SMPQueueAddress {smpServer, senderId, dhPublicKey = C.publicKey e2ePrivKey} pure SMPQueueUri {clientVRange, queueAddress} _ -> do srv <- getSMPServer c server (rq', qUri) <- newRcvQueue c srv clientVRange False - withStore' c $ \db -> dbCreateNextRcvQueue db rq rq' + withStore' c $ \db -> dbCreateNextRcvQueue db connId rq rq' pure qUri void $ enqueueMessage c cData sq SMP.noMsgFlags QNEW {currentAddress = (server, sndId), nextQueueUri} withStore' c $ \db -> setRcvQueueAction db rq Nothing @@ -1341,13 +1341,13 @@ processSMPTransmission c@AgentClient {smpClients, subQ} transmission@(srv, v, se case (nextQUri `compatibleVersion` clientVRange) of Just qInfo@(Compatible nextQInfo) -> do sq'@SndQueue {sndPublicKey, e2ePubKey} <- newSndQueue qInfo False - withStore' c $ \db -> dbCreateNextSndQueue db sq sq' + withStore' c $ \db -> dbCreateNextSndQueue db connId sq sq' case (sndPublicKey, e2ePubKey) of (Just nextSenderKey, Just dhPublicKey) -> do let qAddr = (queueAddress (nextQInfo :: SMPQueueInfo)) {dhPublicKey} nextQueueInfo = (nextQInfo :: SMPQueueInfo) {queueAddress = qAddr} void . enqueueMessage c cData sq SMP.noMsgFlags $ QKEYS {nextSenderKey, nextQueueInfo} - rq' <- withStore' c (`getNextRcvQueue` dbNextRcvQueueId rq) + rq' <- withStore' c (`getNextRcvQueue` rq) notify . SWITCH SPStarted $ connectionStats conn rq' (Just sq') _ -> throwError $ INTERNAL "absent sender keys" _ -> throwError $ AGENT A_VERSION @@ -1361,7 +1361,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} transmission@(srv, v, se DuplexConnection _ _ sq -> do clientVRange <- asks $ smpClientVRange . config unless (qInfo `isCompatible` clientVRange) . throwError $ AGENT A_VERSION - withStore' c (`getNextRcvQueue` dbNextRcvQueueId rq) >>= \case + withStore' c (`getNextRcvQueue` rq) >>= \case Just rq'@RcvQueue {server, sndId, e2ePrivKey = dhPrivKey, smpClientVersion = clntVer} -> do unless (smpServer == server && senderId == sndId) . throwError $ INTERNAL "incorrect queue address" let dhSecret = C.dh' dhPublicKey dhPrivKey @@ -1377,7 +1377,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} transmission@(srv, v, se rqReady :: (SMPServer, SMP.SenderId) -> m () rqReady (smpServer, senderId) = case conn of DuplexConnection _ _ sq -> - withStore' c (`getNextSndQueue` dbNextSndQueueId sq) >>= \case + withStore' c (`getNextSndQueue` sq) >>= \case Just sq'@SndQueue {server, sndId} -> do unless (smpServer == server && senderId == sndId) . throwError $ INTERNAL "incorrect queue address" void $ enqueueMessage c cData sq' SMP.noMsgFlags QTEST @@ -1398,7 +1398,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} transmission@(srv, v, se rqSwitch :: (SMPServer, SMP.SenderId) -> m () rqSwitch (smpServer, senderId) = case conn of DuplexConnection _ _ sq@SndQueue {server, sndId} -> do - withStore' c (`getNextSndQueue` dbNextSndQueueId sq) >>= \case + withStore' c (`getNextSndQueue` sq) >>= \case Just sq'@SndQueue {server = server', sndId = sndId'} -> do unless (smpServer == server' && senderId == sndId') . throwError $ INTERNAL "incorrect queue address" let qKey = (connId, server, sndId) diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 2da08521f..12a2cdef2 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -272,7 +272,7 @@ createRcvConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandsh createConn_ gVar cData $ \connId -> do upsertServer_ db server DB.execute db "INSERT INTO connections (conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?, ?, ?, ?, ?)" (connId, cMode, connAgentVersion, enableNtfs, duplexHandshake) - insertRcvQueue_ db connId q + void $ insertRcvQueue_ db connId q createSndConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SndQueue -> IO (Either StoreError ConnId) createSndConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandshake} q@SndQueue {server} = do @@ -280,7 +280,7 @@ createSndConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandsh upsertServer_ db server DB.execute db "INSERT INTO connections (conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?, ?, ?, ?, ?)" (connId, SCMInvitation, connAgentVersion, enableNtfs, duplexHandshake) -- TODO add queue ID in insertSndQueue_ - insertSndQueue_ db connId q + void $ insertSndQueue_ db connId q getRcvConn :: DB.Connection -> SMPServer -> SMP.RecipientId -> IO (Either StoreError (RcvQueue, SomeConn)) getRcvConn db ProtocolServer {host, port} rcvId = runExceptT $ do @@ -316,7 +316,7 @@ upgradeRcvConnToDuplex db connId sq@SndQueue {server} = (SomeConn _ RcvConnection {}) -> do upsertServer_ db server -- TODO save with queue ID - insertSndQueue_ db connId sq + void $ insertSndQueue_ db connId sq pure $ Right () (SomeConn c _) -> pure . Left . SEBadConnType $ connType c @@ -325,7 +325,7 @@ upgradeSndConnToDuplex db connId rq@RcvQueue {server} = getConn db connId $>>= \case SomeConn _ SndConnection {} -> do upsertServer_ db server - insertRcvQueue_ db connId rq + void $ insertRcvQueue_ db connId rq pure $ Right () SomeConn c _ -> pure . Left . SEBadConnType $ connType c @@ -393,8 +393,8 @@ setRcvQueueNtfCreds db connId clientNtfCreds = Just ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret} -> (Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret) Nothing -> (Nothing, Nothing, Nothing, Nothing) -getNextRcvQueue :: DB.Connection -> Maybe Int64 -> IO (Maybe RcvQueue) -getNextRcvQueue db = \case +getNextRcvQueue :: DB.Connection -> RcvQueue -> IO (Maybe RcvQueue) +getNextRcvQueue db RcvQueue {dbNextRcvQueueId} = case dbNextRcvQueueId of Just rqId -> maybeFirstRow toRcvQueue $ DB.query @@ -412,32 +412,69 @@ getNextRcvQueue db = \case (rqId, False) _ -> pure Nothing -getNextSndQueue :: DB.Connection -> Maybe Int64 -> IO (Maybe SndQueue) -getNextSndQueue _db _sqId_ = pure Nothing +getNextSndQueue :: DB.Connection -> SndQueue -> IO (Maybe SndQueue) +getNextSndQueue db SndQueue {dbNextSndQueueId} = case dbNextSndQueueId of + Just sqId -> + maybeFirstRow toSndQueue $ + DB.query + db + [sql| + SELECT q.host, q.port, s.key_hash, + q.snd_id, q.snd_public_key, q.snd_private_key, q.e2e_pub_key, q.e2e_dh_secret, q.status, + q.snd_queue_action, q.snd_queue_action_ts, q.curr_snd_queue, q.next_snd_queue_id, + q.smp_client_version, q.created_at, q.updated_at + FROM snd_queues q + INNER JOIN servers s ON q.host = s.host AND q.port = s.port + WHERE q.snd_queue_id = ? AND q.curr_rcv_queue = ? + |] + (sqId, False) + _ -> pure Nothing -dbCreateNextRcvQueue :: DB.Connection -> RcvQueue -> RcvQueue -> IO () -dbCreateNextRcvQueue _db _rq _nextRq = pure () +dbCreateNextRcvQueue :: DB.Connection -> ConnId -> RcvQueue -> RcvQueue -> IO () +dbCreateNextRcvQueue db connId RcvQueue {server = (SMPServer host port _), rcvId} rq' = do + rqId <- insertRcvQueue_ db connId rq' + DB.execute + db + [sql| + UPDATE rcv_queues + SET next_rcv_queue_id = ? + WHERE host = ? AND port = ? AND rcv_id = ? AND curr_rcv_queue = ? + |] + (rqId, host, port, rcvId, True) -dbCreateNextSndQueue :: DB.Connection -> SndQueue -> SndQueue -> IO () -dbCreateNextSndQueue _db _sq _nextSq = do - -- create next queue record - -- update current queue with the next queue ID - pure () +dbCreateNextSndQueue :: DB.Connection -> ConnId -> SndQueue -> SndQueue -> IO () +dbCreateNextSndQueue db connId SndQueue {server = (SMPServer host port _), sndId} sq' = do + sqId <- insertSndQueue_ db connId sq' + DB.execute + db + [sql| + UPDATE snd_queues + SET next_snd_queue_id = ? + WHERE host = ? AND port = ? AND snd_id = ? AND curr_snd_queue = ? + |] + (sqId, host, port, sndId, True) setRcvQueueAction :: DB.Connection -> RcvQueue -> Maybe RcvQueueAction -> IO () -setRcvQueueAction _db _rq _rqAction_ = pure () +setRcvQueueAction db RcvQueue {server = (SMPServer host port _), rcvId} rqAction_ = do + ts <- getCurrentTime + DB.execute + db + [sql| + UPDATE rcv_queues + SET rcv_queue_action = ?, rcv_queue_action_ts = ? + WHERE host = ? AND port = ? AND rcv_id = ? AND curr_rcv_queue = ? + |] + (rqAction_, ts, host, port, rcvId, True) switchCurrRcvQueue :: DB.Connection -> RcvQueue -> RcvQueue -> IO () -switchCurrRcvQueue _db _rq _nextRq = do - -- make a new queue a main one - -- delete old queue from the database - pure () +switchCurrRcvQueue db RcvQueue {server = (SMPServer host port _), rcvId} RcvQueue {dbNextRcvQueueId} = do + DB.execute db "DELETE FROM rcv_queues WHERE host = ? AND port = ? AND rcv_id = ? AND curr_rcv_queue = ?" (host, port, rcvId, True) + DB.execute db "UPDATE rcv_queues SET curr_rcv_queue = ? WHERE rcv_queue_id = ? AND curr_rcv_queue = ?" (True, dbNextRcvQueueId, False) switchCurrSndQueue :: DB.Connection -> SndQueue -> SndQueue -> IO () -switchCurrSndQueue _db _sq _nextSq = do - -- make new queue active - -- delete old queue from the database - pure () +switchCurrSndQueue db SndQueue {server = (SMPServer host port _), sndId} SndQueue {dbNextSndQueueId} = do + DB.execute db "DELETE FROM snd_queues WHERE host = ? AND port = ? AND snd_id = ? AND curr_snd_queue = ?" (host, port, sndId, True) + DB.execute db "UPDATE snd_queues SET curr_snd_queue = ? WHERE snd_queue_id = ? AND curr_snd_queue = ?" (True, dbNextSndQueueId, False) type SMPConfirmationRow = (SndPublicVerifyKey, C.PublicKeyX25519, ConnInfo, Maybe [SMPQueueInfo], Maybe Version) @@ -1161,29 +1198,31 @@ upsertNtfServer_ db ProtocolServer {host, port, keyHash} = do -- * createRcvConn helpers -insertRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO () +insertRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO Int64 insertRcvQueue_ db connId RcvQueue {..} = do qId <- newQueueId_ <$> DB.query_ db "SELECT rcv_queue_id FROM rcv_queues ORDER BY rcv_queue_id DESC LIMIT 1" DB.execute db [sql| INSERT INTO rcv_queues - (rcv_queue_id, host, port, rcv_id, conn_id, rcv_private_key, rcv_dh_secret, e2e_priv_key, e2e_dh_secret, snd_id, status, smp_client_version, created_at, updated_at) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?); + (rcv_queue_id, host, port, rcv_id, conn_id, rcv_private_key, rcv_dh_secret, e2e_priv_key, e2e_dh_secret, snd_id, status, curr_rcv_queue, smp_client_version, created_at, updated_at) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?); |] - ((qId, host server, port server, rcvId, connId) :. (rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status) :. (smpClientVersion, createdAt, updatedAt)) + ((qId, host server, port server, rcvId, connId) :. (rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status) :. (currRcvQueue, smpClientVersion, createdAt, updatedAt)) + pure qId -- * createSndConn helpers -insertSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> IO () +insertSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> IO Int64 insertSndQueue_ db connId SndQueue {..} = do qId <- newQueueId_ <$> DB.query_ db "SELECT snd_queue_id FROM snd_queues ORDER BY snd_queue_id DESC LIMIT 1" DB.execute db [sql| INSERT INTO snd_queues - (snd_queue_id, host, port, snd_id, conn_id, snd_public_key, snd_private_key, e2e_pub_key, e2e_dh_secret, status, smp_client_version, created_at, updated_at) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?); + (snd_queue_id, host, port, snd_id, conn_id, snd_public_key, snd_private_key, e2e_pub_key, e2e_dh_secret, status, curr_snd_queue, smp_client_version, created_at, updated_at) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?); |] - ((qId, host server, port server, sndId, connId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status) :. (smpClientVersion, createdAt, updatedAt)) + ((qId, host server, port server, sndId, connId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status) :. (currSndQueue, smpClientVersion, createdAt, updatedAt)) + pure qId newQueueId_ :: [Only (Maybe Int64)] -> Int64 newQueueId_ [] = 1 @@ -1240,8 +1279,8 @@ toNtfCreds _ = Nothing getRcvQueueByConnId_ :: DB.Connection -> ConnId -> IO (Maybe RcvQueue) getRcvQueueByConnId_ dbConn connId = - listToMaybe . map toRcvQueue - <$> DB.query + maybeFirstRow toRcvQueue $ + DB.query dbConn [sql| SELECT q.host, q.port, s.key_hash, @@ -1256,10 +1295,10 @@ getRcvQueueByConnId_ dbConn connId = (connId, True) getSndQueueByConnId_ :: DB.Connection -> ConnId -> IO (Maybe SndQueue) -getSndQueueByConnId_ dbConn connId = - listToMaybe . map sndQueue - <$> DB.query - dbConn +getSndQueueByConnId_ db connId = + maybeFirstRow toSndQueue $ + DB.query + db [sql| SELECT q.host, q.port, s.key_hash, q.snd_id, q.snd_public_key, q.snd_private_key, q.e2e_pub_key, q.e2e_dh_secret, q.status, @@ -1270,11 +1309,17 @@ getSndQueueByConnId_ dbConn connId = WHERE q.conn_id = ? AND q.curr_snd_queue = ? |] (connId, True) - where - sndQueue (srvRow :. (sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, sqAction_, sqActionTs_, currSndQueue, dbNextSndQueueId) :. (smpClientVersion, createdAt, updatedAt)) = - let server = toSMPServer srvRow - sndQueueAction = (,) <$> sqAction_ <*> sqActionTs_ - in SndQueue {server, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, sndQueueAction, currSndQueue, dbNextSndQueueId, smpClientVersion, createdAt, updatedAt} + +type SndQueueRow = + ServerRow + :. (SMP.SenderId, Maybe C.APublicVerifyKey, SMP.SndPrivateSignKey, Maybe C.PublicKeyX25519, C.DhSecretX25519, QueueStatus, Maybe SndQueueAction, Maybe UTCTime, Bool, Maybe Int64) + :. (Version, UTCTime, UTCTime) + +toSndQueue :: SndQueueRow -> SndQueue +toSndQueue (srvRow :. (sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, sqAction_, sqActionTs_, currSndQueue, dbNextSndQueueId) :. (smpClientVersion, createdAt, updatedAt)) = + let server = toSMPServer srvRow + sndQueueAction = (,) <$> sqAction_ <*> sqActionTs_ + in SndQueue {server, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, sndQueueAction, currSndQueue, dbNextSndQueueId, smpClientVersion, createdAt, updatedAt} -- * updateRcvIds helpers