mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-05-31 07:04:17 +00:00
store methods
This commit is contained in:
@@ -452,24 +452,24 @@ doRcvQueueAction c cData rq@RcvQueue {rcvQueueAction} sq =
|
||||
where
|
||||
withNextRcvQueue :: AgentMonad m => (AgentClient -> ConnData -> RcvQueue -> SndQueue -> RcvQueue -> m ()) -> m ()
|
||||
withNextRcvQueue action = do
|
||||
withStore' c (`getNextRcvQueue` dbNextRcvQueueId rq) >>= \case
|
||||
withStore' c (`getNextRcvQueue` rq) >>= \case
|
||||
Just rq' -> action c cData rq sq rq'
|
||||
_ -> do
|
||||
-- notify agent internal error
|
||||
pure ()
|
||||
|
||||
createNextRcvQueue :: AgentMonad m => AgentClient -> ConnData -> RcvQueue -> SndQueue -> m ()
|
||||
createNextRcvQueue c cData rq@RcvQueue {server, sndId} sq = do
|
||||
createNextRcvQueue c cData@ConnData {connId} rq@RcvQueue {server, sndId} sq = do
|
||||
clientVRange <- asks $ smpClientVRange . config
|
||||
nextQueueUri <-
|
||||
withStore' c (`getNextRcvQueue` dbNextRcvQueueId rq) >>= \case
|
||||
withStore' c (`getNextRcvQueue` rq) >>= \case
|
||||
Just RcvQueue {server = smpServer, sndId = senderId, e2ePrivKey} -> do
|
||||
let queueAddress = SMPQueueAddress {smpServer, senderId, dhPublicKey = C.publicKey e2ePrivKey}
|
||||
pure SMPQueueUri {clientVRange, queueAddress}
|
||||
_ -> do
|
||||
srv <- getSMPServer c server
|
||||
(rq', qUri) <- newRcvQueue c srv clientVRange False
|
||||
withStore' c $ \db -> dbCreateNextRcvQueue db rq rq'
|
||||
withStore' c $ \db -> dbCreateNextRcvQueue db connId rq rq'
|
||||
pure qUri
|
||||
void $ enqueueMessage c cData sq SMP.noMsgFlags QNEW {currentAddress = (server, sndId), nextQueueUri}
|
||||
withStore' c $ \db -> setRcvQueueAction db rq Nothing
|
||||
@@ -1341,13 +1341,13 @@ processSMPTransmission c@AgentClient {smpClients, subQ} transmission@(srv, v, se
|
||||
case (nextQUri `compatibleVersion` clientVRange) of
|
||||
Just qInfo@(Compatible nextQInfo) -> do
|
||||
sq'@SndQueue {sndPublicKey, e2ePubKey} <- newSndQueue qInfo False
|
||||
withStore' c $ \db -> dbCreateNextSndQueue db sq sq'
|
||||
withStore' c $ \db -> dbCreateNextSndQueue db connId sq sq'
|
||||
case (sndPublicKey, e2ePubKey) of
|
||||
(Just nextSenderKey, Just dhPublicKey) -> do
|
||||
let qAddr = (queueAddress (nextQInfo :: SMPQueueInfo)) {dhPublicKey}
|
||||
nextQueueInfo = (nextQInfo :: SMPQueueInfo) {queueAddress = qAddr}
|
||||
void . enqueueMessage c cData sq SMP.noMsgFlags $ QKEYS {nextSenderKey, nextQueueInfo}
|
||||
rq' <- withStore' c (`getNextRcvQueue` dbNextRcvQueueId rq)
|
||||
rq' <- withStore' c (`getNextRcvQueue` rq)
|
||||
notify . SWITCH SPStarted $ connectionStats conn rq' (Just sq')
|
||||
_ -> throwError $ INTERNAL "absent sender keys"
|
||||
_ -> throwError $ AGENT A_VERSION
|
||||
@@ -1361,7 +1361,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} transmission@(srv, v, se
|
||||
DuplexConnection _ _ sq -> do
|
||||
clientVRange <- asks $ smpClientVRange . config
|
||||
unless (qInfo `isCompatible` clientVRange) . throwError $ AGENT A_VERSION
|
||||
withStore' c (`getNextRcvQueue` dbNextRcvQueueId rq) >>= \case
|
||||
withStore' c (`getNextRcvQueue` rq) >>= \case
|
||||
Just rq'@RcvQueue {server, sndId, e2ePrivKey = dhPrivKey, smpClientVersion = clntVer} -> do
|
||||
unless (smpServer == server && senderId == sndId) . throwError $ INTERNAL "incorrect queue address"
|
||||
let dhSecret = C.dh' dhPublicKey dhPrivKey
|
||||
@@ -1377,7 +1377,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} transmission@(srv, v, se
|
||||
rqReady :: (SMPServer, SMP.SenderId) -> m ()
|
||||
rqReady (smpServer, senderId) = case conn of
|
||||
DuplexConnection _ _ sq ->
|
||||
withStore' c (`getNextSndQueue` dbNextSndQueueId sq) >>= \case
|
||||
withStore' c (`getNextSndQueue` sq) >>= \case
|
||||
Just sq'@SndQueue {server, sndId} -> do
|
||||
unless (smpServer == server && senderId == sndId) . throwError $ INTERNAL "incorrect queue address"
|
||||
void $ enqueueMessage c cData sq' SMP.noMsgFlags QTEST
|
||||
@@ -1398,7 +1398,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} transmission@(srv, v, se
|
||||
rqSwitch :: (SMPServer, SMP.SenderId) -> m ()
|
||||
rqSwitch (smpServer, senderId) = case conn of
|
||||
DuplexConnection _ _ sq@SndQueue {server, sndId} -> do
|
||||
withStore' c (`getNextSndQueue` dbNextSndQueueId sq) >>= \case
|
||||
withStore' c (`getNextSndQueue` sq) >>= \case
|
||||
Just sq'@SndQueue {server = server', sndId = sndId'} -> do
|
||||
unless (smpServer == server' && senderId == sndId') . throwError $ INTERNAL "incorrect queue address"
|
||||
let qKey = (connId, server, sndId)
|
||||
|
||||
@@ -272,7 +272,7 @@ createRcvConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandsh
|
||||
createConn_ gVar cData $ \connId -> do
|
||||
upsertServer_ db server
|
||||
DB.execute db "INSERT INTO connections (conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?, ?, ?, ?, ?)" (connId, cMode, connAgentVersion, enableNtfs, duplexHandshake)
|
||||
insertRcvQueue_ db connId q
|
||||
void $ insertRcvQueue_ db connId q
|
||||
|
||||
createSndConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SndQueue -> IO (Either StoreError ConnId)
|
||||
createSndConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandshake} q@SndQueue {server} = do
|
||||
@@ -280,7 +280,7 @@ createSndConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandsh
|
||||
upsertServer_ db server
|
||||
DB.execute db "INSERT INTO connections (conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?, ?, ?, ?, ?)" (connId, SCMInvitation, connAgentVersion, enableNtfs, duplexHandshake)
|
||||
-- TODO add queue ID in insertSndQueue_
|
||||
insertSndQueue_ db connId q
|
||||
void $ insertSndQueue_ db connId q
|
||||
|
||||
getRcvConn :: DB.Connection -> SMPServer -> SMP.RecipientId -> IO (Either StoreError (RcvQueue, SomeConn))
|
||||
getRcvConn db ProtocolServer {host, port} rcvId = runExceptT $ do
|
||||
@@ -316,7 +316,7 @@ upgradeRcvConnToDuplex db connId sq@SndQueue {server} =
|
||||
(SomeConn _ RcvConnection {}) -> do
|
||||
upsertServer_ db server
|
||||
-- TODO save with queue ID
|
||||
insertSndQueue_ db connId sq
|
||||
void $ insertSndQueue_ db connId sq
|
||||
pure $ Right ()
|
||||
(SomeConn c _) -> pure . Left . SEBadConnType $ connType c
|
||||
|
||||
@@ -325,7 +325,7 @@ upgradeSndConnToDuplex db connId rq@RcvQueue {server} =
|
||||
getConn db connId $>>= \case
|
||||
SomeConn _ SndConnection {} -> do
|
||||
upsertServer_ db server
|
||||
insertRcvQueue_ db connId rq
|
||||
void $ insertRcvQueue_ db connId rq
|
||||
pure $ Right ()
|
||||
SomeConn c _ -> pure . Left . SEBadConnType $ connType c
|
||||
|
||||
@@ -393,8 +393,8 @@ setRcvQueueNtfCreds db connId clientNtfCreds =
|
||||
Just ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret} -> (Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret)
|
||||
Nothing -> (Nothing, Nothing, Nothing, Nothing)
|
||||
|
||||
getNextRcvQueue :: DB.Connection -> Maybe Int64 -> IO (Maybe RcvQueue)
|
||||
getNextRcvQueue db = \case
|
||||
getNextRcvQueue :: DB.Connection -> RcvQueue -> IO (Maybe RcvQueue)
|
||||
getNextRcvQueue db RcvQueue {dbNextRcvQueueId} = case dbNextRcvQueueId of
|
||||
Just rqId ->
|
||||
maybeFirstRow toRcvQueue $
|
||||
DB.query
|
||||
@@ -412,32 +412,69 @@ getNextRcvQueue db = \case
|
||||
(rqId, False)
|
||||
_ -> pure Nothing
|
||||
|
||||
getNextSndQueue :: DB.Connection -> Maybe Int64 -> IO (Maybe SndQueue)
|
||||
getNextSndQueue _db _sqId_ = pure Nothing
|
||||
getNextSndQueue :: DB.Connection -> SndQueue -> IO (Maybe SndQueue)
|
||||
getNextSndQueue db SndQueue {dbNextSndQueueId} = case dbNextSndQueueId of
|
||||
Just sqId ->
|
||||
maybeFirstRow toSndQueue $
|
||||
DB.query
|
||||
db
|
||||
[sql|
|
||||
SELECT q.host, q.port, s.key_hash,
|
||||
q.snd_id, q.snd_public_key, q.snd_private_key, q.e2e_pub_key, q.e2e_dh_secret, q.status,
|
||||
q.snd_queue_action, q.snd_queue_action_ts, q.curr_snd_queue, q.next_snd_queue_id,
|
||||
q.smp_client_version, q.created_at, q.updated_at
|
||||
FROM snd_queues q
|
||||
INNER JOIN servers s ON q.host = s.host AND q.port = s.port
|
||||
WHERE q.snd_queue_id = ? AND q.curr_rcv_queue = ?
|
||||
|]
|
||||
(sqId, False)
|
||||
_ -> pure Nothing
|
||||
|
||||
dbCreateNextRcvQueue :: DB.Connection -> RcvQueue -> RcvQueue -> IO ()
|
||||
dbCreateNextRcvQueue _db _rq _nextRq = pure ()
|
||||
dbCreateNextRcvQueue :: DB.Connection -> ConnId -> RcvQueue -> RcvQueue -> IO ()
|
||||
dbCreateNextRcvQueue db connId RcvQueue {server = (SMPServer host port _), rcvId} rq' = do
|
||||
rqId <- insertRcvQueue_ db connId rq'
|
||||
DB.execute
|
||||
db
|
||||
[sql|
|
||||
UPDATE rcv_queues
|
||||
SET next_rcv_queue_id = ?
|
||||
WHERE host = ? AND port = ? AND rcv_id = ? AND curr_rcv_queue = ?
|
||||
|]
|
||||
(rqId, host, port, rcvId, True)
|
||||
|
||||
dbCreateNextSndQueue :: DB.Connection -> SndQueue -> SndQueue -> IO ()
|
||||
dbCreateNextSndQueue _db _sq _nextSq = do
|
||||
-- create next queue record
|
||||
-- update current queue with the next queue ID
|
||||
pure ()
|
||||
dbCreateNextSndQueue :: DB.Connection -> ConnId -> SndQueue -> SndQueue -> IO ()
|
||||
dbCreateNextSndQueue db connId SndQueue {server = (SMPServer host port _), sndId} sq' = do
|
||||
sqId <- insertSndQueue_ db connId sq'
|
||||
DB.execute
|
||||
db
|
||||
[sql|
|
||||
UPDATE snd_queues
|
||||
SET next_snd_queue_id = ?
|
||||
WHERE host = ? AND port = ? AND snd_id = ? AND curr_snd_queue = ?
|
||||
|]
|
||||
(sqId, host, port, sndId, True)
|
||||
|
||||
setRcvQueueAction :: DB.Connection -> RcvQueue -> Maybe RcvQueueAction -> IO ()
|
||||
setRcvQueueAction _db _rq _rqAction_ = pure ()
|
||||
setRcvQueueAction db RcvQueue {server = (SMPServer host port _), rcvId} rqAction_ = do
|
||||
ts <- getCurrentTime
|
||||
DB.execute
|
||||
db
|
||||
[sql|
|
||||
UPDATE rcv_queues
|
||||
SET rcv_queue_action = ?, rcv_queue_action_ts = ?
|
||||
WHERE host = ? AND port = ? AND rcv_id = ? AND curr_rcv_queue = ?
|
||||
|]
|
||||
(rqAction_, ts, host, port, rcvId, True)
|
||||
|
||||
switchCurrRcvQueue :: DB.Connection -> RcvQueue -> RcvQueue -> IO ()
|
||||
switchCurrRcvQueue _db _rq _nextRq = do
|
||||
-- make a new queue a main one
|
||||
-- delete old queue from the database
|
||||
pure ()
|
||||
switchCurrRcvQueue db RcvQueue {server = (SMPServer host port _), rcvId} RcvQueue {dbNextRcvQueueId} = do
|
||||
DB.execute db "DELETE FROM rcv_queues WHERE host = ? AND port = ? AND rcv_id = ? AND curr_rcv_queue = ?" (host, port, rcvId, True)
|
||||
DB.execute db "UPDATE rcv_queues SET curr_rcv_queue = ? WHERE rcv_queue_id = ? AND curr_rcv_queue = ?" (True, dbNextRcvQueueId, False)
|
||||
|
||||
switchCurrSndQueue :: DB.Connection -> SndQueue -> SndQueue -> IO ()
|
||||
switchCurrSndQueue _db _sq _nextSq = do
|
||||
-- make new queue active
|
||||
-- delete old queue from the database
|
||||
pure ()
|
||||
switchCurrSndQueue db SndQueue {server = (SMPServer host port _), sndId} SndQueue {dbNextSndQueueId} = do
|
||||
DB.execute db "DELETE FROM snd_queues WHERE host = ? AND port = ? AND snd_id = ? AND curr_snd_queue = ?" (host, port, sndId, True)
|
||||
DB.execute db "UPDATE snd_queues SET curr_snd_queue = ? WHERE snd_queue_id = ? AND curr_snd_queue = ?" (True, dbNextSndQueueId, False)
|
||||
|
||||
type SMPConfirmationRow = (SndPublicVerifyKey, C.PublicKeyX25519, ConnInfo, Maybe [SMPQueueInfo], Maybe Version)
|
||||
|
||||
@@ -1161,29 +1198,31 @@ upsertNtfServer_ db ProtocolServer {host, port, keyHash} = do
|
||||
|
||||
-- * createRcvConn helpers
|
||||
|
||||
insertRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO ()
|
||||
insertRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO Int64
|
||||
insertRcvQueue_ db connId RcvQueue {..} = do
|
||||
qId <- newQueueId_ <$> DB.query_ db "SELECT rcv_queue_id FROM rcv_queues ORDER BY rcv_queue_id DESC LIMIT 1"
|
||||
DB.execute
|
||||
db
|
||||
[sql|
|
||||
INSERT INTO rcv_queues
|
||||
(rcv_queue_id, host, port, rcv_id, conn_id, rcv_private_key, rcv_dh_secret, e2e_priv_key, e2e_dh_secret, snd_id, status, smp_client_version, created_at, updated_at) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?);
|
||||
(rcv_queue_id, host, port, rcv_id, conn_id, rcv_private_key, rcv_dh_secret, e2e_priv_key, e2e_dh_secret, snd_id, status, curr_rcv_queue, smp_client_version, created_at, updated_at) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?);
|
||||
|]
|
||||
((qId, host server, port server, rcvId, connId) :. (rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status) :. (smpClientVersion, createdAt, updatedAt))
|
||||
((qId, host server, port server, rcvId, connId) :. (rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status) :. (currRcvQueue, smpClientVersion, createdAt, updatedAt))
|
||||
pure qId
|
||||
|
||||
-- * createSndConn helpers
|
||||
|
||||
insertSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> IO ()
|
||||
insertSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> IO Int64
|
||||
insertSndQueue_ db connId SndQueue {..} = do
|
||||
qId <- newQueueId_ <$> DB.query_ db "SELECT snd_queue_id FROM snd_queues ORDER BY snd_queue_id DESC LIMIT 1"
|
||||
DB.execute
|
||||
db
|
||||
[sql|
|
||||
INSERT INTO snd_queues
|
||||
(snd_queue_id, host, port, snd_id, conn_id, snd_public_key, snd_private_key, e2e_pub_key, e2e_dh_secret, status, smp_client_version, created_at, updated_at) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?);
|
||||
(snd_queue_id, host, port, snd_id, conn_id, snd_public_key, snd_private_key, e2e_pub_key, e2e_dh_secret, status, curr_snd_queue, smp_client_version, created_at, updated_at) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?);
|
||||
|]
|
||||
((qId, host server, port server, sndId, connId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status) :. (smpClientVersion, createdAt, updatedAt))
|
||||
((qId, host server, port server, sndId, connId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status) :. (currSndQueue, smpClientVersion, createdAt, updatedAt))
|
||||
pure qId
|
||||
|
||||
newQueueId_ :: [Only (Maybe Int64)] -> Int64
|
||||
newQueueId_ [] = 1
|
||||
@@ -1240,8 +1279,8 @@ toNtfCreds _ = Nothing
|
||||
|
||||
getRcvQueueByConnId_ :: DB.Connection -> ConnId -> IO (Maybe RcvQueue)
|
||||
getRcvQueueByConnId_ dbConn connId =
|
||||
listToMaybe . map toRcvQueue
|
||||
<$> DB.query
|
||||
maybeFirstRow toRcvQueue $
|
||||
DB.query
|
||||
dbConn
|
||||
[sql|
|
||||
SELECT q.host, q.port, s.key_hash,
|
||||
@@ -1256,10 +1295,10 @@ getRcvQueueByConnId_ dbConn connId =
|
||||
(connId, True)
|
||||
|
||||
getSndQueueByConnId_ :: DB.Connection -> ConnId -> IO (Maybe SndQueue)
|
||||
getSndQueueByConnId_ dbConn connId =
|
||||
listToMaybe . map sndQueue
|
||||
<$> DB.query
|
||||
dbConn
|
||||
getSndQueueByConnId_ db connId =
|
||||
maybeFirstRow toSndQueue $
|
||||
DB.query
|
||||
db
|
||||
[sql|
|
||||
SELECT q.host, q.port, s.key_hash,
|
||||
q.snd_id, q.snd_public_key, q.snd_private_key, q.e2e_pub_key, q.e2e_dh_secret, q.status,
|
||||
@@ -1270,11 +1309,17 @@ getSndQueueByConnId_ dbConn connId =
|
||||
WHERE q.conn_id = ? AND q.curr_snd_queue = ?
|
||||
|]
|
||||
(connId, True)
|
||||
where
|
||||
sndQueue (srvRow :. (sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, sqAction_, sqActionTs_, currSndQueue, dbNextSndQueueId) :. (smpClientVersion, createdAt, updatedAt)) =
|
||||
let server = toSMPServer srvRow
|
||||
sndQueueAction = (,) <$> sqAction_ <*> sqActionTs_
|
||||
in SndQueue {server, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, sndQueueAction, currSndQueue, dbNextSndQueueId, smpClientVersion, createdAt, updatedAt}
|
||||
|
||||
type SndQueueRow =
|
||||
ServerRow
|
||||
:. (SMP.SenderId, Maybe C.APublicVerifyKey, SMP.SndPrivateSignKey, Maybe C.PublicKeyX25519, C.DhSecretX25519, QueueStatus, Maybe SndQueueAction, Maybe UTCTime, Bool, Maybe Int64)
|
||||
:. (Version, UTCTime, UTCTime)
|
||||
|
||||
toSndQueue :: SndQueueRow -> SndQueue
|
||||
toSndQueue (srvRow :. (sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, sqAction_, sqActionTs_, currSndQueue, dbNextSndQueueId) :. (smpClientVersion, createdAt, updatedAt)) =
|
||||
let server = toSMPServer srvRow
|
||||
sndQueueAction = (,) <$> sqAction_ <*> sqActionTs_
|
||||
in SndQueue {server, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, sndQueueAction, currSndQueue, dbNextSndQueueId, smpClientVersion, createdAt, updatedAt}
|
||||
|
||||
-- * updateRcvIds helpers
|
||||
|
||||
|
||||
Reference in New Issue
Block a user