refactor store

This commit is contained in:
Evgeny Poberezkin
2022-08-25 21:14:56 +01:00
parent 2f77f16276
commit cac30ca341
5 changed files with 48 additions and 39 deletions
+3 -3
View File
@@ -305,7 +305,7 @@ newConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> SConnectionMode c ->
newConn c connId enableNtfs cMode = do
srv <- getSMPServer c
clientVRange <- asks $ smpClientVRange . config
(rq, qUri) <- newRcvQueue c srv clientVRange
(rq, qUri) <- newRcvQueue c srv clientVRange False
g <- asks idsDrg
connAgentVersion <- asks $ maxVersion . smpAgentVRange . config
let cData = ConnData {connId, connAgentVersion, enableNtfs, duplexHandshake = Nothing} -- connection mode is determined by the accepting agent
@@ -368,7 +368,7 @@ joinConn c connId enableNtfs (CRContactUri (ConnReqUriData _ agentVRange (qUri :
createReplyQueue :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> m SMPQueueInfo
createReplyQueue c ConnData {connId, enableNtfs} SndQueue {smpClientVersion} = do
srv <- getSMPServer c
(rq, qUri) <- newRcvQueue c srv $ versionToRange smpClientVersion
(rq, qUri) <- newRcvQueue c srv (versionToRange smpClientVersion) False
let qInfo = toVersionT qUri smpClientVersion
addSubscription c rq connId
withStore c $ \db -> upgradeSndConnToDuplex db connId rq
@@ -459,7 +459,7 @@ createNextRcvQueue c cData rq@RcvQueue {server, sndId} sq = do
pure SMPQueueUri {clientVRange, queueAddress}
_ -> do
srv <- getSMPServer c
(rq', qUri) <- newRcvQueue c srv clientVRange
(rq', qUri) <- newRcvQueue c srv clientVRange True
withStore' c $ \db -> dbCreateNextRcvQueue db rq rq'
pure qUri
void $ enqueueMessage c cData sq SMP.noMsgFlags QNEW {currentAddress = (server, sndId), nextQueueUri}
+6 -4
View File
@@ -471,10 +471,10 @@ protocolClientError protocolError_ = \case
e@PCESignatureError {} -> INTERNAL $ show e
e@PCEIOError {} -> INTERNAL $ show e
newRcvQueue :: AgentMonad m => AgentClient -> SMPServer -> VersionRange -> m (RcvQueue, SMPQueueUri)
newRcvQueue c srv vRange =
newRcvQueue :: AgentMonad m => AgentClient -> SMPServer -> VersionRange -> Bool -> m (RcvQueue, SMPQueueUri)
newRcvQueue c srv vRange next =
asks (cmdSignAlg . config) >>= \case
C.SignAlg a -> newRcvQueue_ a c srv vRange
C.SignAlg a -> newRcvQueue_ a c srv vRange next
newRcvQueue_ ::
(C.SignatureAlgorithm a, C.AlgorithmI a, AgentMonad m) =>
@@ -482,8 +482,9 @@ newRcvQueue_ ::
AgentClient ->
SMPServer ->
VersionRange ->
Bool ->
m (RcvQueue, SMPQueueUri)
newRcvQueue_ a c srv vRange = do
newRcvQueue_ a c srv vRange next = do
(recipientKey, rcvPrivateKey) <- liftIO $ C.generateSignatureKeyPair a
(dhKey, privDhKey) <- liftIO C.generateKeyPair'
(e2eDhKey, e2ePrivKey) <- liftIO C.generateKeyPair'
@@ -504,6 +505,7 @@ newRcvQueue_ a c srv vRange = do
sndPublicKey = Nothing,
status = New,
rcvQueueAction = Nothing,
nextRcvQueue = next,
dbNextRcvQueueId = Nothing,
clientNtfCreds = Nothing,
smpClientVersion = maxVersion vRange,
+2
View File
@@ -61,6 +61,8 @@ data RcvQueue = RcvQueue
status :: QueueStatus,
-- | action to perform, to be done on connection subscription, if it fails and not reset
rcvQueueAction :: Maybe (RcvQueueAction, UTCTime),
-- | True if this is the queue the connection is switching to, rather than the current queue
nextRcvQueue :: Bool,
-- | database ID of the new queue created for this queue to switch to (queue rotation)
dbNextRcvQueueId :: Maybe Int64,
-- | credentials used in context of notifications
+35 -32
View File
@@ -15,6 +15,7 @@
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
@@ -387,13 +388,13 @@ setRcvQueueNtfCreds db connId clientNtfCreds =
getNextRcvQueue :: DB.Connection -> Maybe Int64 -> IO (Maybe RcvQueue)
getNextRcvQueue db = \case
Just rqId ->
listToMaybe . map rcvQueue
<$> DB.query
maybeFirstRow toRcvQueue $
DB.query
db
[sql|
SELECT q.host, q.port, s.key_hash,
q.rcv_id, q.rcv_private_key, q.rcv_dh_secret, q.e2e_priv_key, q.e2e_dh_secret, q.snd_id, q.snd_key, q.status,
q.rcv_queue_action, q.rcv_queue_action_ts, q.next_rcv_queue_id,
q.rcv_queue_action, q.rcv_queue_action_ts, q.next_rcv_queue, q.next_rcv_queue_id,
q.ntf_public_key, q.ntf_private_key, q.ntf_id, q.rcv_ntf_dh_secret,
q.smp_client_version, q.created_at, q.updated_at
FROM rcv_queues q
@@ -401,15 +402,6 @@ getNextRcvQueue db = \case
WHERE q.rcv_queue_id = ? AND q.next_rcv_queue = ?
|]
(rqId, True)
where
rcvQueue (srvRow :. (rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, sndPublicKey, status) :. (rqAction_, rqActionTs_, dbNextRcvQueueId) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) :. (smpClientVersion_, createdAt, updatedAt)) =
let server = toSMPServer srvRow
smpClientVersion = fromMaybe 1 smpClientVersion_
rcvQueueAction = (,) <$> rqAction_ <*> rqActionTs_
clientNtfCreds = case (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) of
(Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret) -> Just $ ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret}
_ -> Nothing
in RcvQueue {server, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, sndPublicKey, status, rcvQueueAction, dbNextRcvQueueId, smpClientVersion, clientNtfCreds, createdAt, updatedAt}
_ -> pure Nothing
getNextSndQueue :: DB.Connection -> Maybe Int64 -> IO (Maybe SndQueue)
@@ -1182,38 +1174,58 @@ getConn :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn)
getConn dbConn connId =
getConnData dbConn connId >>= \case
Nothing -> pure $ Left SEConnNotFound
Just (connData, cMode) -> do
Just (cData, cMode) -> do
rQ <- getRcvQueueByConnId_ dbConn connId
sQ <- getSndQueueByConnId_ dbConn connId
pure $ case (rQ, sQ, cMode) of
(Just rcvQ, Just sndQ, CMInvitation) -> Right $ SomeConn SCDuplex (DuplexConnection connData rcvQ sndQ)
(Just rcvQ, Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection connData rcvQ)
(Nothing, Just sndQ, CMInvitation) -> Right $ SomeConn SCSnd (SndConnection connData sndQ)
(Just rcvQ, Nothing, CMContact) -> Right $ SomeConn SCContact (ContactConnection connData rcvQ)
(Just rcvQ, Just sndQ, CMInvitation) -> Right $ SomeConn SCDuplex (DuplexConnection cData rcvQ sndQ)
(Just rcvQ, Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection cData rcvQ)
(Nothing, Just sndQ, CMInvitation) -> Right $ SomeConn SCSnd (SndConnection cData sndQ)
(Just rcvQ, Nothing, CMContact) -> Right $ SomeConn SCContact (ContactConnection cData rcvQ)
_ -> Left SEConnNotFound
getConnData :: DB.Connection -> ConnId -> IO (Maybe (ConnData, ConnectionMode))
getConnData dbConn connId' =
connData
<$> DB.query dbConn "SELECT conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake FROM connections WHERE conn_id = ?;" (Only connId')
maybeFirstRow toConnData $
DB.query dbConn "SELECT conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake FROM connections WHERE conn_id = ?;" (Only connId')
where
connData [(connId, cMode, connAgentVersion, enableNtfs_, duplexHandshake)] = Just (ConnData {connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, duplexHandshake}, cMode)
connData _ = Nothing
toConnData (connId, cMode, connAgentVersion, enableNtfs_, duplexHandshake) = (ConnData {connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, duplexHandshake}, cMode)
type RcvQueueRow =
ServerRow
:. (SMP.RecipientId, SMP.RcvPrivateSignKey, SMP.RcvDhSecret, C.PrivateKeyX25519, Maybe C.DhSecretX25519, SMP.SenderId, Maybe C.APublicVerifyKey, QueueStatus)
:. (Maybe RcvQueueAction, Maybe UTCTime, Bool, Maybe Int64)
:. NtfCredsRow
:. (Maybe Version, UTCTime, UTCTime)
type ServerRow = (NonEmpty TransportHost, String, C.KeyHash)
type NtfCredsRow = (Maybe SMP.NtfPublicVerifyKey, Maybe SMP.NtfPrivateSignKey, Maybe SMP.NotifierId, Maybe RcvNtfDhSecret)
toRcvQueue :: RcvQueueRow -> RcvQueue
toRcvQueue (srvRow :. (rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, sndPublicKey, status) :. (rqAction_, rqActionTs_, nextRcvQueue, dbNextRcvQueueId) :. ntfCredsRow :. (smpClientVersion_, createdAt, updatedAt)) =
let server = toSMPServer srvRow
smpClientVersion = fromMaybe 1 smpClientVersion_
rcvQueueAction = (,) <$> rqAction_ <*> rqActionTs_
clientNtfCreds = toNtfCreds ntfCredsRow
in RcvQueue {server, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, sndPublicKey, status, rcvQueueAction, nextRcvQueue, dbNextRcvQueueId, smpClientVersion, clientNtfCreds, createdAt, updatedAt}
toSMPServer :: ServerRow -> SMPServer
toSMPServer (host, port, keyHash) = SMPServer host port keyHash
toNtfCreds :: NtfCredsRow -> Maybe ClientNtfCreds
toNtfCreds (Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret) = Just $ ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret}
toNtfCreds _ = Nothing
getRcvQueueByConnId_ :: DB.Connection -> ConnId -> IO (Maybe RcvQueue)
getRcvQueueByConnId_ dbConn connId =
listToMaybe . map rcvQueue
listToMaybe . map toRcvQueue
<$> DB.query
dbConn
[sql|
SELECT q.host, q.port, s.key_hash,
q.rcv_id, q.rcv_private_key, q.rcv_dh_secret, q.e2e_priv_key, q.e2e_dh_secret, q.snd_id, q.snd_key, q.status,
q.rcv_queue_action, q.rcv_queue_action_ts, q.next_rcv_queue_id,
q.rcv_queue_action, q.rcv_queue_action_ts, q.next_rcv_queue, q.next_rcv_queue_id,
q.ntf_public_key, q.ntf_private_key, q.ntf_id, q.rcv_ntf_dh_secret,
q.smp_client_version, q.created_at, q.updated_at
FROM rcv_queues q
@@ -1221,15 +1233,6 @@ getRcvQueueByConnId_ dbConn connId =
WHERE q.conn_id = ? AND q.next_rcv_queue = ?
|]
(connId, False)
where
rcvQueue (srvRow :. (rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, sndPublicKey, status) :. (rqAction_, rqActionTs_, dbNextRcvQueueId) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) :. (smpClientVersion_, createdAt, updatedAt)) =
let server = toSMPServer srvRow
smpClientVersion = fromMaybe 1 smpClientVersion_
rcvQueueAction = (,) <$> rqAction_ <*> rqActionTs_
clientNtfCreds = case (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) of
(Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret) -> Just $ ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret}
_ -> Nothing
in RcvQueue {server, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, sndPublicKey, status, rcvQueueAction, dbNextRcvQueueId, smpClientVersion, clientNtfCreds, createdAt, updatedAt}
getSndQueueByConnId_ :: DB.Connection -> ConnId -> IO (Maybe SndQueue)
getSndQueueByConnId_ dbConn connId =
+2
View File
@@ -166,6 +166,7 @@ rcvQueue1 =
sndId = "2345",
sndPublicKey = Nothing,
status = New,
nextRcvQueue = False,
dbNextRcvQueueId = Nothing,
rcvQueueAction = Nothing,
clientNtfCreds = Nothing,
@@ -349,6 +350,7 @@ testUpgradeSndConnToDuplex =
sndId = "4567",
sndPublicKey = Nothing,
status = New,
nextRcvQueue = False,
dbNextRcvQueueId = Nothing,
rcvQueueAction = Nothing,
clientNtfCreds = Nothing,