server_key_hash fields (#643)

* server_key_hash fields

* test

* refactor

* fix

* order

* use sync command in test

* refactor

---------

Co-authored-by: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com>
This commit is contained in:
spaced4ndy
2023-02-18 01:24:32 +04:00
committed by GitHub
parent 2ddfb044fc
commit c0dcf283eb
9 changed files with 143 additions and 66 deletions
+1 -1
View File
@@ -757,7 +757,7 @@ sendMessage' c connId msgFlags msg = withConnLock c connId "sendMessage" $ do
enqueueCommand :: AgentMonad m => AgentClient -> ACorrId -> ConnId -> Maybe SMPServer -> AgentCommand -> m ()
enqueueCommand c corrId connId server aCommand = do
resumeSrvCmds c server
commandId <- withStore' c $ \db -> createCommand db corrId connId server aCommand
commandId <- withStore c $ \db -> createCommand db corrId connId server aCommand
queuePendingCommands c server [commandId]
resumeSrvCmds :: forall m. AgentMonad m => AgentClient -> Maybe SMPServer -> m ()
@@ -82,11 +82,11 @@ processNtfSub c (connId, cmd) = do
case clientNtfCreds of
Just ClientNtfCreds {notifierId} -> do
let newSub = newNtfSubscription connId smpServer (Just notifierId) ntfServer NASKey
withStore' c $ \db -> createNtfSubscription db newSub $ NtfSubNTFAction NSACreate
withStore c $ \db -> createNtfSubscription db newSub $ NtfSubNTFAction NSACreate
addNtfNTFWorker ntfServer
Nothing -> do
let newSub = newNtfSubscription connId smpServer Nothing ntfServer NASNew
withStore' c $ \db -> createNtfSubscription db newSub $ NtfSubSMPAction NSASmpKey
withStore c $ \db -> createNtfSubscription db newSub $ NtfSubSMPAction NSASmpKey
addNtfSMPWorker smpServer
(Just (sub@NtfSubscription {ntfSubStatus, ntfServer = subNtfServer, smpServer = smpServer', ntfQueueId}, action_)) -> do
case (clientNtfCreds, ntfQueueId) of
+2
View File
@@ -510,6 +510,8 @@ data StoreError
SEUserNotFound
| -- | Connection not found (or both queues absent).
SEConnNotFound
| -- | Server not found.
SEServerNotFound
| -- | Connection already used.
SEConnDuplicate
| -- | Wrong connection type, e.g. "send" connection when "receive" or "duplex" is expected, or vice versa.
+69 -59
View File
@@ -400,16 +400,16 @@ updateNewConnSnd db connId sq =
createRcvConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> RcvQueue -> SConnectionMode c -> IO (Either StoreError ConnId)
createRcvConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs, duplexHandshake} q@RcvQueue {server} cMode =
createConn_ gVar cData $ \connId -> do
upsertServer_ db server
serverKeyHash_ <- createServer_ db server
DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, cMode, connAgentVersion, enableNtfs, duplexHandshake)
void $ insertRcvQueue_ db connId q
void $ insertRcvQueue_ db connId q serverKeyHash_
createSndConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SndQueue -> IO (Either StoreError ConnId)
createSndConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs, duplexHandshake} q@SndQueue {server} =
createConn_ gVar cData $ \connId -> do
upsertServer_ db server
serverKeyHash_ <- createServer_ db server
DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, SCMInvitation, connAgentVersion, enableNtfs, duplexHandshake)
void $ insertSndQueue_ db connId q
void $ insertSndQueue_ db connId q serverKeyHash_
getRcvConn :: DB.Connection -> SMPServer -> SMP.RecipientId -> IO (Either StoreError (RcvQueue, SomeConn))
getRcvConn db ProtocolServer {host, port} rcvId = runExceptT $ do
@@ -447,8 +447,8 @@ addConnRcvQueue db connId rq =
addConnRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO Int64
addConnRcvQueue_ db connId rq@RcvQueue {server} = do
upsertServer_ db server
insertRcvQueue_ db connId rq
serverKeyHash_ <- createServer_ db server
insertRcvQueue_ db connId rq serverKeyHash_
addConnSndQueue :: DB.Connection -> ConnId -> SndQueue -> IO (Either StoreError Int64)
addConnSndQueue db connId sq =
@@ -459,8 +459,8 @@ addConnSndQueue db connId sq =
addConnSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> IO Int64
addConnSndQueue_ db connId sq@SndQueue {server} = do
upsertServer_ db server
insertSndQueue_ db connId sq
serverKeyHash_ <- createServer_ db server
insertSndQueue_ db connId sq serverKeyHash_
setRcvQueueStatus :: DB.Connection -> RcvQueue -> QueueStatus -> IO ()
setRcvQueueStatus db RcvQueue {rcvId, server = ProtocolServer {host, port}} status =
@@ -858,18 +858,21 @@ updateRatchet db connId rc skipped = do
forM_ (M.assocs mks) $ \(msgN, mk) ->
DB.execute db "INSERT INTO skipped_messages (conn_id, header_key, msg_n, msg_key) VALUES (?, ?, ?, ?)" (connId, hk, msgN, mk)
createCommand :: DB.Connection -> ACorrId -> ConnId -> Maybe SMPServer -> AgentCommand -> IO AsyncCmdId
createCommand db corrId connId srv cmd = do
DB.execute
db
"INSERT INTO commands (host, port, corr_id, conn_id, command_tag, command) VALUES (?,?,?,?,?,?)"
(host_, port_, corrId, connId, agentCommandTag cmd, cmd)
insertedRowId db
createCommand :: DB.Connection -> ACorrId -> ConnId -> Maybe SMPServer -> AgentCommand -> IO (Either StoreError AsyncCmdId)
createCommand db corrId connId srv_ cmd = runExceptT $ do
(host_, port_, serverKeyHash_) <- serverFields
liftIO $ do
DB.execute
db
"INSERT INTO commands (host, port, corr_id, conn_id, command_tag, command, server_key_hash) VALUES (?,?,?,?,?,?,?)"
(host_, port_, corrId, connId, agentCommandTag cmd, cmd, serverKeyHash_)
insertedRowId db
where
(host_, port_) =
case srv of
Just (SMPServer host port _) -> (Just host, Just port)
_ -> (Nothing, Nothing)
serverFields :: ExceptT StoreError IO (Maybe (NonEmpty TransportHost), Maybe ServiceName, Maybe C.KeyHash)
serverFields = case srv_ of
Just srv@(SMPServer host port _) ->
(Just host,Just port,) <$> ExceptT (getServerKeyHash_ db srv)
Nothing -> pure (Nothing, Nothing, Nothing)
insertedRowId :: DB.Connection -> IO Int64
insertedRowId db = fromOnly . head <$> DB.query_ db "SELECT last_insert_rowid()"
@@ -880,7 +883,7 @@ getPendingCommands db connId = do
<$> DB.query
db
[sql|
SELECT c.host, c.port, s.key_hash, c.command_id
SELECT c.host, c.port, COALESCE(c.server_key_hash, s.key_hash), c.command_id
FROM commands c
LEFT JOIN servers s ON s.host = c.host AND s.port = c.port
WHERE conn_id = ?
@@ -1003,7 +1006,7 @@ getNtfSubscription db connId =
DB.query
db
[sql|
SELECT s.host, s.port, s.key_hash, ns.ntf_host, ns.ntf_port, ns.ntf_key_hash,
SELECT s.host, s.port, COALESCE(nsb.smp_server_key_hash, s.key_hash), ns.ntf_host, ns.ntf_port, ns.ntf_key_hash,
nsb.smp_ntf_id, nsb.ntf_sub_id, nsb.ntf_sub_status, nsb.ntf_sub_action, nsb.ntf_sub_smp_action, nsb.ntf_sub_action_ts
FROM ntf_subscriptions nsb
JOIN servers s ON s.host = nsb.smp_host AND s.port = nsb.smp_port
@@ -1021,21 +1024,23 @@ getNtfSubscription db connId =
_ -> Nothing
in (NtfSubscription {connId, smpServer, ntfQueueId, ntfServer, ntfSubId, ntfSubStatus}, action)
createNtfSubscription :: DB.Connection -> NtfSubscription -> NtfSubAction -> IO ()
createNtfSubscription db ntfSubscription action = do
let NtfSubscription {connId, smpServer = (SMPServer host port _), ntfQueueId, ntfServer = (NtfServer ntfHost ntfPort _), ntfSubId, ntfSubStatus} = ntfSubscription
createNtfSubscription :: DB.Connection -> NtfSubscription -> NtfSubAction -> IO (Either StoreError ())
createNtfSubscription db ntfSubscription action = runExceptT $ do
let NtfSubscription {connId, smpServer = smpServer@(SMPServer host port _), ntfQueueId, ntfServer = (NtfServer ntfHost ntfPort _), ntfSubId, ntfSubStatus} = ntfSubscription
smpServerKeyHash_ <- ExceptT $ getServerKeyHash_ db smpServer
actionTs <- liftIO getCurrentTime
DB.execute
db
[sql|
INSERT INTO ntf_subscriptions
(conn_id, smp_host, smp_port, smp_ntf_id, ntf_host, ntf_port, ntf_sub_id,
ntf_sub_status, ntf_sub_action, ntf_sub_smp_action, ntf_sub_action_ts)
VALUES (?,?,?,?,?,?,?,?,?,?,?)
|]
( (connId, host, port, ntfQueueId, ntfHost, ntfPort, ntfSubId)
:. (ntfSubStatus, ntfSubAction, ntfSubSMPAction, actionTs)
)
liftIO $
DB.execute
db
[sql|
INSERT INTO ntf_subscriptions
(conn_id, smp_host, smp_port, smp_ntf_id, ntf_host, ntf_port, ntf_sub_id,
ntf_sub_status, ntf_sub_action, ntf_sub_smp_action, ntf_sub_action_ts, smp_server_key_hash)
VALUES (?,?,?,?,?,?,?,?,?,?,?,?)
|]
( (connId, host, port, ntfQueueId, ntfHost, ntfPort, ntfSubId)
:. (ntfSubStatus, ntfSubAction, ntfSubSMPAction, actionTs, smpServerKeyHash_)
)
where
(ntfSubAction, ntfSubSMPAction) = ntfSubAndSMPAction action
@@ -1136,7 +1141,7 @@ getNextNtfSubNTFAction db ntfServer@(NtfServer ntfHost ntfPort _) = do
DB.query
db
[sql|
SELECT ns.conn_id, s.host, s.port, s.key_hash,
SELECT ns.conn_id, s.host, s.port, COALESCE(ns.smp_server_key_hash, s.key_hash),
ns.smp_ntf_id, ns.ntf_sub_id, ns.ntf_sub_status, ns.ntf_sub_action_ts, ns.ntf_sub_action
FROM ntf_subscriptions ns
JOIN servers s ON s.host = ns.smp_host AND s.port = ns.smp_port
@@ -1321,20 +1326,25 @@ instance (ToField a, ToField b, ToField c, ToField d, ToField e, ToField f,
{- ORMOLU_ENABLE -}
-- * Server upsert helper
-- * Server helper
upsertServer_ :: DB.Connection -> SMPServer -> IO ()
upsertServer_ dbConn ProtocolServer {host, port, keyHash} = do
DB.executeNamed
dbConn
[sql|
INSERT INTO servers (host, port, key_hash) VALUES (:host,:port,:key_hash)
ON CONFLICT (host, port) DO UPDATE SET
host=excluded.host,
port=excluded.port,
key_hash=excluded.key_hash;
|]
[":host" := host, ":port" := port, ":key_hash" := keyHash]
-- | Creates a new server, if it doesn't exist, and returns the passed key hash if it is different from stored.
createServer_ :: DB.Connection -> SMPServer -> IO (Maybe C.KeyHash)
createServer_ db newSrv@ProtocolServer {host, port, keyHash} =
getServerKeyHash_ db newSrv >>= \case
Right keyHash_ -> pure keyHash_
Left _ -> insertNewServer_ $> Nothing
where
insertNewServer_ =
DB.execute db "INSERT INTO servers (host, port, key_hash) VALUES (?,?,?)" (host, port, keyHash)
-- | Returns the stored server key hash if it is different from the passed one, or the error if the server does not exist.
getServerKeyHash_ :: DB.Connection -> SMPServer -> IO (Either StoreError (Maybe C.KeyHash))
getServerKeyHash_ db ProtocolServer {host, port, keyHash} = do
firstRow useKeyHash SEServerNotFound $
DB.query db "SELECT key_hash FROM servers WHERE host = ? AND port = ?" (host, port)
where
useKeyHash (Only keyHash') = if keyHash /= keyHash' then Just keyHash else Nothing
upsertNtfServer_ :: DB.Connection -> NtfServer -> IO ()
upsertNtfServer_ db ProtocolServer {host, port, keyHash} = do
@@ -1351,30 +1361,30 @@ upsertNtfServer_ db ProtocolServer {host, port, keyHash} = do
-- * createRcvConn helpers
insertRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO Int64
insertRcvQueue_ db connId' RcvQueue {..} = do
insertRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> Maybe C.KeyHash -> IO Int64
insertRcvQueue_ db connId' RcvQueue {..} serverKeyHash_ = do
qId <- newQueueId_ <$> DB.query db "SELECT rcv_queue_id FROM rcv_queues WHERE conn_id = ? ORDER BY rcv_queue_id DESC LIMIT 1" (Only connId')
DB.execute
db
[sql|
INSERT INTO rcv_queues
(host, port, rcv_id, conn_id, rcv_private_key, rcv_dh_secret, e2e_priv_key, e2e_dh_secret, snd_id, status, rcv_queue_id, rcv_primary, replace_rcv_queue_id, smp_client_version) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?);
(host, port, rcv_id, conn_id, rcv_private_key, rcv_dh_secret, e2e_priv_key, e2e_dh_secret, snd_id, status, rcv_queue_id, rcv_primary, replace_rcv_queue_id, smp_client_version, server_key_hash) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?);
|]
((host server, port server, rcvId, connId', rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret) :. (sndId, status, qId, primary, dbReplaceQueueId, smpClientVersion))
((host server, port server, rcvId, connId', rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret) :. (sndId, status, qId, primary, dbReplaceQueueId, smpClientVersion, serverKeyHash_))
pure qId
-- * createSndConn helpers
insertSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> IO Int64
insertSndQueue_ db connId' SndQueue {..} = do
insertSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> Maybe C.KeyHash -> IO Int64
insertSndQueue_ db connId' SndQueue {..} serverKeyHash_ = do
qId <- newQueueId_ <$> DB.query db "SELECT snd_queue_id FROM snd_queues WHERE conn_id = ? ORDER BY snd_queue_id DESC LIMIT 1" (Only connId')
DB.execute
db
[sql|
INSERT INTO snd_queues
(host, port, snd_id, conn_id, snd_public_key, snd_private_key, e2e_pub_key, e2e_dh_secret, status, snd_queue_id, snd_primary, replace_snd_queue_id, smp_client_version) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?);
(host, port, snd_id, conn_id, snd_public_key, snd_private_key, e2e_pub_key, e2e_dh_secret, status, snd_queue_id, snd_primary, replace_snd_queue_id, smp_client_version, server_key_hash) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?);
|]
((host server, port server, sndId, connId', sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret) :. (status, qId, primary, dbReplaceQueueId, smpClientVersion))
((host server, port server, sndId, connId', sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret) :. (status, qId, primary, dbReplaceQueueId, smpClientVersion, serverKeyHash_))
pure qId
newQueueId_ :: [Only Int64] -> Int64
@@ -1443,7 +1453,7 @@ getRcvQueuesByConnId_ db connId =
rcvQueueQuery :: Query
rcvQueueQuery =
[sql|
SELECT c.user_id, s.key_hash, q.conn_id, q.host, q.port, q.rcv_id, q.rcv_private_key, q.rcv_dh_secret,
SELECT c.user_id, COALESCE(q.server_key_hash, s.key_hash), q.conn_id, q.host, q.port, q.rcv_id, q.rcv_private_key, q.rcv_dh_secret,
q.e2e_priv_key, q.e2e_dh_secret, q.snd_id, q.status,
q.rcv_queue_id, q.rcv_primary, q.replace_rcv_queue_id, q.smp_client_version, q.delete_errors,
q.ntf_public_key, q.ntf_private_key, q.ntf_id, q.rcv_ntf_dh_secret
@@ -1477,7 +1487,7 @@ getSndQueuesByConnId_ dbConn connId =
<$> DB.query
dbConn
[sql|
SELECT c.user_id, s.key_hash, q.host, q.port, q.snd_id, q.snd_public_key, q.snd_private_key, q.e2e_pub_key, q.e2e_dh_secret, q.status, q.snd_queue_id, q.snd_primary, q.replace_snd_queue_id, q.smp_client_version
SELECT c.user_id, COALESCE(q.server_key_hash, s.key_hash), q.host, q.port, q.snd_id, q.snd_public_key, q.snd_private_key, q.e2e_pub_key, q.e2e_dh_secret, q.status, q.snd_queue_id, q.snd_primary, q.replace_snd_queue_id, q.smp_client_version
FROM snd_queues q
JOIN servers s ON q.host = s.host AND q.port = s.port
JOIN connections c ON q.conn_id = c.conn_id
@@ -40,6 +40,7 @@ import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220915_connection_queu
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230110_users
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230117_fkey_indexes
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230120_delete_errors
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230217_server_key_hash
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Transport.Client (TransportHost)
@@ -59,7 +60,8 @@ schemaMigrations =
("m20220915_connection_queues", m20220915_connection_queues),
("m20230110_users", m20230110_users),
("m20230117_fkey_indexes", m20230117_fkey_indexes),
("m20230120_delete_errors", m20230120_delete_errors)
("m20230120_delete_errors", m20230120_delete_errors),
("m20230217_server_key_hash", m20230217_server_key_hash)
]
-- | The list of migrations in ascending order by date
@@ -0,0 +1,20 @@
{-# LANGUAGE QuasiQuotes #-}
module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230217_server_key_hash where
import Database.SQLite.Simple (Query)
import Database.SQLite.Simple.QQ (sql)
-- server_key_hash is not null for records whose entities refer to a server
-- that was previously saved with the same host and port but different key hash
m20230217_server_key_hash :: Query
m20230217_server_key_hash =
[sql|
ALTER TABLE rcv_queues ADD COLUMN server_key_hash BLOB;
ALTER TABLE snd_queues ADD COLUMN server_key_hash BLOB;
ALTER TABLE ntf_subscriptions ADD COLUMN smp_server_key_hash BLOB;
ALTER TABLE commands ADD COLUMN server_key_hash BLOB;
|]
@@ -48,6 +48,7 @@ CREATE TABLE rcv_queues(
rcv_primary INTEGER CHECK(rcv_primary NOT NULL),
replace_rcv_queue_id INTEGER NULL,
delete_errors INTEGER DEFAULT 0 CHECK(delete_errors NOT NULL),
server_key_hash BLOB,
PRIMARY KEY(host, port, rcv_id),
FOREIGN KEY(host, port) REFERENCES servers
ON DELETE RESTRICT ON UPDATE CASCADE,
@@ -68,6 +69,7 @@ CREATE TABLE snd_queues(
snd_queue_id INTEGER CHECK(snd_queue_id NOT NULL),
snd_primary INTEGER CHECK(snd_primary NOT NULL),
replace_snd_queue_id INTEGER NULL,
server_key_hash BLOB,
PRIMARY KEY(host, port, snd_id),
FOREIGN KEY(host, port) REFERENCES servers
ON DELETE RESTRICT ON UPDATE CASCADE
@@ -199,6 +201,7 @@ CREATE TABLE ntf_subscriptions(
updated_by_supervisor INTEGER NOT NULL DEFAULT 0, -- to be checked on updates by workers to not overwrite supervisor command(state still should be updated)
created_at TEXT NOT NULL DEFAULT(datetime('now')),
updated_at TEXT NOT NULL DEFAULT(datetime('now')),
smp_server_key_hash BLOB,
PRIMARY KEY(conn_id),
FOREIGN KEY(smp_host, smp_port) REFERENCES servers(host, port)
ON DELETE SET NULL ON UPDATE CASCADE,
@@ -214,6 +217,7 @@ CREATE TABLE commands(
command_tag BLOB NOT NULL,
command BLOB NOT NULL,
agent_version INTEGER NOT NULL DEFAULT 1,
server_key_hash BLOB,
FOREIGN KEY(host, port) REFERENCES servers
ON DELETE RESTRICT ON UPDATE CASCADE
);