store methods

This commit is contained in:
Evgeny Poberezkin
2022-08-27 17:35:54 +01:00
parent 4ab17dc449
commit a8caab810a
2 changed files with 95 additions and 50 deletions
+9 -9
View File
@@ -452,24 +452,24 @@ doRcvQueueAction c cData rq@RcvQueue {rcvQueueAction} sq =
where
withNextRcvQueue :: AgentMonad m => (AgentClient -> ConnData -> RcvQueue -> SndQueue -> RcvQueue -> m ()) -> m ()
withNextRcvQueue action = do
withStore' c (`getNextRcvQueue` dbNextRcvQueueId rq) >>= \case
withStore' c (`getNextRcvQueue` rq) >>= \case
Just rq' -> action c cData rq sq rq'
_ -> do
-- notify agent internal error
pure ()
createNextRcvQueue :: AgentMonad m => AgentClient -> ConnData -> RcvQueue -> SndQueue -> m ()
createNextRcvQueue c cData rq@RcvQueue {server, sndId} sq = do
createNextRcvQueue c cData@ConnData {connId} rq@RcvQueue {server, sndId} sq = do
clientVRange <- asks $ smpClientVRange . config
nextQueueUri <-
withStore' c (`getNextRcvQueue` dbNextRcvQueueId rq) >>= \case
withStore' c (`getNextRcvQueue` rq) >>= \case
Just RcvQueue {server = smpServer, sndId = senderId, e2ePrivKey} -> do
let queueAddress = SMPQueueAddress {smpServer, senderId, dhPublicKey = C.publicKey e2ePrivKey}
pure SMPQueueUri {clientVRange, queueAddress}
_ -> do
srv <- getSMPServer c server
(rq', qUri) <- newRcvQueue c srv clientVRange False
withStore' c $ \db -> dbCreateNextRcvQueue db rq rq'
withStore' c $ \db -> dbCreateNextRcvQueue db connId rq rq'
pure qUri
void $ enqueueMessage c cData sq SMP.noMsgFlags QNEW {currentAddress = (server, sndId), nextQueueUri}
withStore' c $ \db -> setRcvQueueAction db rq Nothing
@@ -1341,13 +1341,13 @@ processSMPTransmission c@AgentClient {smpClients, subQ} transmission@(srv, v, se
case (nextQUri `compatibleVersion` clientVRange) of
Just qInfo@(Compatible nextQInfo) -> do
sq'@SndQueue {sndPublicKey, e2ePubKey} <- newSndQueue qInfo False
withStore' c $ \db -> dbCreateNextSndQueue db sq sq'
withStore' c $ \db -> dbCreateNextSndQueue db connId sq sq'
case (sndPublicKey, e2ePubKey) of
(Just nextSenderKey, Just dhPublicKey) -> do
let qAddr = (queueAddress (nextQInfo :: SMPQueueInfo)) {dhPublicKey}
nextQueueInfo = (nextQInfo :: SMPQueueInfo) {queueAddress = qAddr}
void . enqueueMessage c cData sq SMP.noMsgFlags $ QKEYS {nextSenderKey, nextQueueInfo}
rq' <- withStore' c (`getNextRcvQueue` dbNextRcvQueueId rq)
rq' <- withStore' c (`getNextRcvQueue` rq)
notify . SWITCH SPStarted $ connectionStats conn rq' (Just sq')
_ -> throwError $ INTERNAL "absent sender keys"
_ -> throwError $ AGENT A_VERSION
@@ -1361,7 +1361,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} transmission@(srv, v, se
DuplexConnection _ _ sq -> do
clientVRange <- asks $ smpClientVRange . config
unless (qInfo `isCompatible` clientVRange) . throwError $ AGENT A_VERSION
withStore' c (`getNextRcvQueue` dbNextRcvQueueId rq) >>= \case
withStore' c (`getNextRcvQueue` rq) >>= \case
Just rq'@RcvQueue {server, sndId, e2ePrivKey = dhPrivKey, smpClientVersion = clntVer} -> do
unless (smpServer == server && senderId == sndId) . throwError $ INTERNAL "incorrect queue address"
let dhSecret = C.dh' dhPublicKey dhPrivKey
@@ -1377,7 +1377,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} transmission@(srv, v, se
rqReady :: (SMPServer, SMP.SenderId) -> m ()
rqReady (smpServer, senderId) = case conn of
DuplexConnection _ _ sq ->
withStore' c (`getNextSndQueue` dbNextSndQueueId sq) >>= \case
withStore' c (`getNextSndQueue` sq) >>= \case
Just sq'@SndQueue {server, sndId} -> do
unless (smpServer == server && senderId == sndId) . throwError $ INTERNAL "incorrect queue address"
void $ enqueueMessage c cData sq' SMP.noMsgFlags QTEST
@@ -1398,7 +1398,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} transmission@(srv, v, se
rqSwitch :: (SMPServer, SMP.SenderId) -> m ()
rqSwitch (smpServer, senderId) = case conn of
DuplexConnection _ _ sq@SndQueue {server, sndId} -> do
withStore' c (`getNextSndQueue` dbNextSndQueueId sq) >>= \case
withStore' c (`getNextSndQueue` sq) >>= \case
Just sq'@SndQueue {server = server', sndId = sndId'} -> do
unless (smpServer == server' && senderId == sndId') . throwError $ INTERNAL "incorrect queue address"
let qKey = (connId, server, sndId)
+86 -41
View File
@@ -272,7 +272,7 @@ createRcvConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandsh
createConn_ gVar cData $ \connId -> do
upsertServer_ db server
DB.execute db "INSERT INTO connections (conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?, ?, ?, ?, ?)" (connId, cMode, connAgentVersion, enableNtfs, duplexHandshake)
insertRcvQueue_ db connId q
void $ insertRcvQueue_ db connId q
createSndConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SndQueue -> IO (Either StoreError ConnId)
createSndConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandshake} q@SndQueue {server} = do
@@ -280,7 +280,7 @@ createSndConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandsh
upsertServer_ db server
DB.execute db "INSERT INTO connections (conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?, ?, ?, ?, ?)" (connId, SCMInvitation, connAgentVersion, enableNtfs, duplexHandshake)
-- TODO add queue ID in insertSndQueue_
insertSndQueue_ db connId q
void $ insertSndQueue_ db connId q
getRcvConn :: DB.Connection -> SMPServer -> SMP.RecipientId -> IO (Either StoreError (RcvQueue, SomeConn))
getRcvConn db ProtocolServer {host, port} rcvId = runExceptT $ do
@@ -316,7 +316,7 @@ upgradeRcvConnToDuplex db connId sq@SndQueue {server} =
(SomeConn _ RcvConnection {}) -> do
upsertServer_ db server
-- TODO save with queue ID
insertSndQueue_ db connId sq
void $ insertSndQueue_ db connId sq
pure $ Right ()
(SomeConn c _) -> pure . Left . SEBadConnType $ connType c
@@ -325,7 +325,7 @@ upgradeSndConnToDuplex db connId rq@RcvQueue {server} =
getConn db connId $>>= \case
SomeConn _ SndConnection {} -> do
upsertServer_ db server
insertRcvQueue_ db connId rq
void $ insertRcvQueue_ db connId rq
pure $ Right ()
SomeConn c _ -> pure . Left . SEBadConnType $ connType c
@@ -393,8 +393,8 @@ setRcvQueueNtfCreds db connId clientNtfCreds =
Just ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret} -> (Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret)
Nothing -> (Nothing, Nothing, Nothing, Nothing)
getNextRcvQueue :: DB.Connection -> Maybe Int64 -> IO (Maybe RcvQueue)
getNextRcvQueue db = \case
getNextRcvQueue :: DB.Connection -> RcvQueue -> IO (Maybe RcvQueue)
getNextRcvQueue db RcvQueue {dbNextRcvQueueId} = case dbNextRcvQueueId of
Just rqId ->
maybeFirstRow toRcvQueue $
DB.query
@@ -412,32 +412,69 @@ getNextRcvQueue db = \case
(rqId, False)
_ -> pure Nothing
getNextSndQueue :: DB.Connection -> Maybe Int64 -> IO (Maybe SndQueue)
getNextSndQueue _db _sqId_ = pure Nothing
getNextSndQueue :: DB.Connection -> SndQueue -> IO (Maybe SndQueue)
getNextSndQueue db SndQueue {dbNextSndQueueId} = case dbNextSndQueueId of
Just sqId ->
maybeFirstRow toSndQueue $
DB.query
db
[sql|
SELECT q.host, q.port, s.key_hash,
q.snd_id, q.snd_public_key, q.snd_private_key, q.e2e_pub_key, q.e2e_dh_secret, q.status,
q.snd_queue_action, q.snd_queue_action_ts, q.curr_snd_queue, q.next_snd_queue_id,
q.smp_client_version, q.created_at, q.updated_at
FROM snd_queues q
INNER JOIN servers s ON q.host = s.host AND q.port = s.port
WHERE q.snd_queue_id = ? AND q.curr_rcv_queue = ?
|]
(sqId, False)
_ -> pure Nothing
dbCreateNextRcvQueue :: DB.Connection -> RcvQueue -> RcvQueue -> IO ()
dbCreateNextRcvQueue _db _rq _nextRq = pure ()
dbCreateNextRcvQueue :: DB.Connection -> ConnId -> RcvQueue -> RcvQueue -> IO ()
dbCreateNextRcvQueue db connId RcvQueue {server = (SMPServer host port _), rcvId} rq' = do
rqId <- insertRcvQueue_ db connId rq'
DB.execute
db
[sql|
UPDATE rcv_queues
SET next_rcv_queue_id = ?
WHERE host = ? AND port = ? AND rcv_id = ? AND curr_rcv_queue = ?
|]
(rqId, host, port, rcvId, True)
dbCreateNextSndQueue :: DB.Connection -> SndQueue -> SndQueue -> IO ()
dbCreateNextSndQueue _db _sq _nextSq = do
-- create next queue record
-- update current queue with the next queue ID
pure ()
dbCreateNextSndQueue :: DB.Connection -> ConnId -> SndQueue -> SndQueue -> IO ()
dbCreateNextSndQueue db connId SndQueue {server = (SMPServer host port _), sndId} sq' = do
sqId <- insertSndQueue_ db connId sq'
DB.execute
db
[sql|
UPDATE snd_queues
SET next_snd_queue_id = ?
WHERE host = ? AND port = ? AND snd_id = ? AND curr_snd_queue = ?
|]
(sqId, host, port, sndId, True)
setRcvQueueAction :: DB.Connection -> RcvQueue -> Maybe RcvQueueAction -> IO ()
setRcvQueueAction _db _rq _rqAction_ = pure ()
setRcvQueueAction db RcvQueue {server = (SMPServer host port _), rcvId} rqAction_ = do
ts <- getCurrentTime
DB.execute
db
[sql|
UPDATE rcv_queues
SET rcv_queue_action = ?, rcv_queue_action_ts = ?
WHERE host = ? AND port = ? AND rcv_id = ? AND curr_rcv_queue = ?
|]
(rqAction_, ts, host, port, rcvId, True)
switchCurrRcvQueue :: DB.Connection -> RcvQueue -> RcvQueue -> IO ()
switchCurrRcvQueue _db _rq _nextRq = do
-- make a new queue a main one
-- delete old queue from the database
pure ()
switchCurrRcvQueue db RcvQueue {server = (SMPServer host port _), rcvId} RcvQueue {dbNextRcvQueueId} = do
DB.execute db "DELETE FROM rcv_queues WHERE host = ? AND port = ? AND rcv_id = ? AND curr_rcv_queue = ?" (host, port, rcvId, True)
DB.execute db "UPDATE rcv_queues SET curr_rcv_queue = ? WHERE rcv_queue_id = ? AND curr_rcv_queue = ?" (True, dbNextRcvQueueId, False)
switchCurrSndQueue :: DB.Connection -> SndQueue -> SndQueue -> IO ()
switchCurrSndQueue _db _sq _nextSq = do
-- make new queue active
-- delete old queue from the database
pure ()
switchCurrSndQueue db SndQueue {server = (SMPServer host port _), sndId} SndQueue {dbNextSndQueueId} = do
DB.execute db "DELETE FROM snd_queues WHERE host = ? AND port = ? AND snd_id = ? AND curr_snd_queue = ?" (host, port, sndId, True)
DB.execute db "UPDATE snd_queues SET curr_snd_queue = ? WHERE snd_queue_id = ? AND curr_snd_queue = ?" (True, dbNextSndQueueId, False)
type SMPConfirmationRow = (SndPublicVerifyKey, C.PublicKeyX25519, ConnInfo, Maybe [SMPQueueInfo], Maybe Version)
@@ -1161,29 +1198,31 @@ upsertNtfServer_ db ProtocolServer {host, port, keyHash} = do
-- * createRcvConn helpers
insertRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO ()
insertRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO Int64
insertRcvQueue_ db connId RcvQueue {..} = do
qId <- newQueueId_ <$> DB.query_ db "SELECT rcv_queue_id FROM rcv_queues ORDER BY rcv_queue_id DESC LIMIT 1"
DB.execute
db
[sql|
INSERT INTO rcv_queues
(rcv_queue_id, host, port, rcv_id, conn_id, rcv_private_key, rcv_dh_secret, e2e_priv_key, e2e_dh_secret, snd_id, status, smp_client_version, created_at, updated_at) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?);
(rcv_queue_id, host, port, rcv_id, conn_id, rcv_private_key, rcv_dh_secret, e2e_priv_key, e2e_dh_secret, snd_id, status, curr_rcv_queue, smp_client_version, created_at, updated_at) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?);
|]
((qId, host server, port server, rcvId, connId) :. (rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status) :. (smpClientVersion, createdAt, updatedAt))
((qId, host server, port server, rcvId, connId) :. (rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status) :. (currRcvQueue, smpClientVersion, createdAt, updatedAt))
pure qId
-- * createSndConn helpers
insertSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> IO ()
insertSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> IO Int64
insertSndQueue_ db connId SndQueue {..} = do
qId <- newQueueId_ <$> DB.query_ db "SELECT snd_queue_id FROM snd_queues ORDER BY snd_queue_id DESC LIMIT 1"
DB.execute
db
[sql|
INSERT INTO snd_queues
(snd_queue_id, host, port, snd_id, conn_id, snd_public_key, snd_private_key, e2e_pub_key, e2e_dh_secret, status, smp_client_version, created_at, updated_at) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?);
(snd_queue_id, host, port, snd_id, conn_id, snd_public_key, snd_private_key, e2e_pub_key, e2e_dh_secret, status, curr_snd_queue, smp_client_version, created_at, updated_at) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?);
|]
((qId, host server, port server, sndId, connId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status) :. (smpClientVersion, createdAt, updatedAt))
((qId, host server, port server, sndId, connId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status) :. (currSndQueue, smpClientVersion, createdAt, updatedAt))
pure qId
newQueueId_ :: [Only (Maybe Int64)] -> Int64
newQueueId_ [] = 1
@@ -1240,8 +1279,8 @@ toNtfCreds _ = Nothing
getRcvQueueByConnId_ :: DB.Connection -> ConnId -> IO (Maybe RcvQueue)
getRcvQueueByConnId_ dbConn connId =
listToMaybe . map toRcvQueue
<$> DB.query
maybeFirstRow toRcvQueue $
DB.query
dbConn
[sql|
SELECT q.host, q.port, s.key_hash,
@@ -1256,10 +1295,10 @@ getRcvQueueByConnId_ dbConn connId =
(connId, True)
getSndQueueByConnId_ :: DB.Connection -> ConnId -> IO (Maybe SndQueue)
getSndQueueByConnId_ dbConn connId =
listToMaybe . map sndQueue
<$> DB.query
dbConn
getSndQueueByConnId_ db connId =
maybeFirstRow toSndQueue $
DB.query
db
[sql|
SELECT q.host, q.port, s.key_hash,
q.snd_id, q.snd_public_key, q.snd_private_key, q.e2e_pub_key, q.e2e_dh_secret, q.status,
@@ -1270,11 +1309,17 @@ getSndQueueByConnId_ dbConn connId =
WHERE q.conn_id = ? AND q.curr_snd_queue = ?
|]
(connId, True)
where
sndQueue (srvRow :. (sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, sqAction_, sqActionTs_, currSndQueue, dbNextSndQueueId) :. (smpClientVersion, createdAt, updatedAt)) =
let server = toSMPServer srvRow
sndQueueAction = (,) <$> sqAction_ <*> sqActionTs_
in SndQueue {server, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, sndQueueAction, currSndQueue, dbNextSndQueueId, smpClientVersion, createdAt, updatedAt}
type SndQueueRow =
ServerRow
:. (SMP.SenderId, Maybe C.APublicVerifyKey, SMP.SndPrivateSignKey, Maybe C.PublicKeyX25519, C.DhSecretX25519, QueueStatus, Maybe SndQueueAction, Maybe UTCTime, Bool, Maybe Int64)
:. (Version, UTCTime, UTCTime)
toSndQueue :: SndQueueRow -> SndQueue
toSndQueue (srvRow :. (sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, sqAction_, sqActionTs_, currSndQueue, dbNextSndQueueId) :. (smpClientVersion, createdAt, updatedAt)) =
let server = toSMPServer srvRow
sndQueueAction = (,) <$> sqAction_ <*> sqActionTs_
in SndQueue {server, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, sndQueueAction, currSndQueue, dbNextSndQueueId, smpClientVersion, createdAt, updatedAt}
-- * updateRcvIds helpers