|
|
|
@@ -400,16 +400,16 @@ updateNewConnSnd db connId sq =
|
|
|
|
|
createRcvConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> RcvQueue -> SConnectionMode c -> IO (Either StoreError ConnId)
|
|
|
|
|
createRcvConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs, duplexHandshake} q@RcvQueue {server} cMode =
|
|
|
|
|
createConn_ gVar cData $ \connId -> do
|
|
|
|
|
upsertServer_ db server
|
|
|
|
|
serverKeyHash_ <- createServer_ db server
|
|
|
|
|
DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, cMode, connAgentVersion, enableNtfs, duplexHandshake)
|
|
|
|
|
void $ insertRcvQueue_ db connId q
|
|
|
|
|
void $ insertRcvQueue_ db connId q serverKeyHash_
|
|
|
|
|
|
|
|
|
|
createSndConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SndQueue -> IO (Either StoreError ConnId)
|
|
|
|
|
createSndConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs, duplexHandshake} q@SndQueue {server} =
|
|
|
|
|
createConn_ gVar cData $ \connId -> do
|
|
|
|
|
upsertServer_ db server
|
|
|
|
|
serverKeyHash_ <- createServer_ db server
|
|
|
|
|
DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, SCMInvitation, connAgentVersion, enableNtfs, duplexHandshake)
|
|
|
|
|
void $ insertSndQueue_ db connId q
|
|
|
|
|
void $ insertSndQueue_ db connId q serverKeyHash_
|
|
|
|
|
|
|
|
|
|
getRcvConn :: DB.Connection -> SMPServer -> SMP.RecipientId -> IO (Either StoreError (RcvQueue, SomeConn))
|
|
|
|
|
getRcvConn db ProtocolServer {host, port} rcvId = runExceptT $ do
|
|
|
|
@@ -447,8 +447,8 @@ addConnRcvQueue db connId rq =
|
|
|
|
|
|
|
|
|
|
addConnRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO Int64
|
|
|
|
|
addConnRcvQueue_ db connId rq@RcvQueue {server} = do
|
|
|
|
|
upsertServer_ db server
|
|
|
|
|
insertRcvQueue_ db connId rq
|
|
|
|
|
serverKeyHash_ <- createServer_ db server
|
|
|
|
|
insertRcvQueue_ db connId rq serverKeyHash_
|
|
|
|
|
|
|
|
|
|
addConnSndQueue :: DB.Connection -> ConnId -> SndQueue -> IO (Either StoreError Int64)
|
|
|
|
|
addConnSndQueue db connId sq =
|
|
|
|
@@ -459,8 +459,8 @@ addConnSndQueue db connId sq =
|
|
|
|
|
|
|
|
|
|
addConnSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> IO Int64
|
|
|
|
|
addConnSndQueue_ db connId sq@SndQueue {server} = do
|
|
|
|
|
upsertServer_ db server
|
|
|
|
|
insertSndQueue_ db connId sq
|
|
|
|
|
serverKeyHash_ <- createServer_ db server
|
|
|
|
|
insertSndQueue_ db connId sq serverKeyHash_
|
|
|
|
|
|
|
|
|
|
setRcvQueueStatus :: DB.Connection -> RcvQueue -> QueueStatus -> IO ()
|
|
|
|
|
setRcvQueueStatus db RcvQueue {rcvId, server = ProtocolServer {host, port}} status =
|
|
|
|
@@ -858,18 +858,21 @@ updateRatchet db connId rc skipped = do
|
|
|
|
|
forM_ (M.assocs mks) $ \(msgN, mk) ->
|
|
|
|
|
DB.execute db "INSERT INTO skipped_messages (conn_id, header_key, msg_n, msg_key) VALUES (?, ?, ?, ?)" (connId, hk, msgN, mk)
|
|
|
|
|
|
|
|
|
|
createCommand :: DB.Connection -> ACorrId -> ConnId -> Maybe SMPServer -> AgentCommand -> IO AsyncCmdId
|
|
|
|
|
createCommand db corrId connId srv cmd = do
|
|
|
|
|
DB.execute
|
|
|
|
|
db
|
|
|
|
|
"INSERT INTO commands (host, port, corr_id, conn_id, command_tag, command) VALUES (?,?,?,?,?,?)"
|
|
|
|
|
(host_, port_, corrId, connId, agentCommandTag cmd, cmd)
|
|
|
|
|
insertedRowId db
|
|
|
|
|
createCommand :: DB.Connection -> ACorrId -> ConnId -> Maybe SMPServer -> AgentCommand -> IO (Either StoreError AsyncCmdId)
|
|
|
|
|
createCommand db corrId connId srv_ cmd = runExceptT $ do
|
|
|
|
|
(host_, port_, serverKeyHash_) <- serverFields
|
|
|
|
|
liftIO $ do
|
|
|
|
|
DB.execute
|
|
|
|
|
db
|
|
|
|
|
"INSERT INTO commands (host, port, corr_id, conn_id, command_tag, command, server_key_hash) VALUES (?,?,?,?,?,?,?)"
|
|
|
|
|
(host_, port_, corrId, connId, agentCommandTag cmd, cmd, serverKeyHash_)
|
|
|
|
|
insertedRowId db
|
|
|
|
|
where
|
|
|
|
|
(host_, port_) =
|
|
|
|
|
case srv of
|
|
|
|
|
Just (SMPServer host port _) -> (Just host, Just port)
|
|
|
|
|
_ -> (Nothing, Nothing)
|
|
|
|
|
serverFields :: ExceptT StoreError IO (Maybe (NonEmpty TransportHost), Maybe ServiceName, Maybe C.KeyHash)
|
|
|
|
|
serverFields = case srv_ of
|
|
|
|
|
Just srv@(SMPServer host port _) ->
|
|
|
|
|
(Just host,Just port,) <$> ExceptT (getServerKeyHash_ db srv)
|
|
|
|
|
Nothing -> pure (Nothing, Nothing, Nothing)
|
|
|
|
|
|
|
|
|
|
insertedRowId :: DB.Connection -> IO Int64
|
|
|
|
|
insertedRowId db = fromOnly . head <$> DB.query_ db "SELECT last_insert_rowid()"
|
|
|
|
@@ -880,7 +883,7 @@ getPendingCommands db connId = do
|
|
|
|
|
<$> DB.query
|
|
|
|
|
db
|
|
|
|
|
[sql|
|
|
|
|
|
SELECT c.host, c.port, s.key_hash, c.command_id
|
|
|
|
|
SELECT c.host, c.port, COALESCE(c.server_key_hash, s.key_hash), c.command_id
|
|
|
|
|
FROM commands c
|
|
|
|
|
LEFT JOIN servers s ON s.host = c.host AND s.port = c.port
|
|
|
|
|
WHERE conn_id = ?
|
|
|
|
@@ -1003,7 +1006,7 @@ getNtfSubscription db connId =
|
|
|
|
|
DB.query
|
|
|
|
|
db
|
|
|
|
|
[sql|
|
|
|
|
|
SELECT s.host, s.port, s.key_hash, ns.ntf_host, ns.ntf_port, ns.ntf_key_hash,
|
|
|
|
|
SELECT s.host, s.port, COALESCE(nsb.smp_server_key_hash, s.key_hash), ns.ntf_host, ns.ntf_port, ns.ntf_key_hash,
|
|
|
|
|
nsb.smp_ntf_id, nsb.ntf_sub_id, nsb.ntf_sub_status, nsb.ntf_sub_action, nsb.ntf_sub_smp_action, nsb.ntf_sub_action_ts
|
|
|
|
|
FROM ntf_subscriptions nsb
|
|
|
|
|
JOIN servers s ON s.host = nsb.smp_host AND s.port = nsb.smp_port
|
|
|
|
@@ -1021,21 +1024,23 @@ getNtfSubscription db connId =
|
|
|
|
|
_ -> Nothing
|
|
|
|
|
in (NtfSubscription {connId, smpServer, ntfQueueId, ntfServer, ntfSubId, ntfSubStatus}, action)
|
|
|
|
|
|
|
|
|
|
createNtfSubscription :: DB.Connection -> NtfSubscription -> NtfSubAction -> IO ()
|
|
|
|
|
createNtfSubscription db ntfSubscription action = do
|
|
|
|
|
let NtfSubscription {connId, smpServer = (SMPServer host port _), ntfQueueId, ntfServer = (NtfServer ntfHost ntfPort _), ntfSubId, ntfSubStatus} = ntfSubscription
|
|
|
|
|
createNtfSubscription :: DB.Connection -> NtfSubscription -> NtfSubAction -> IO (Either StoreError ())
|
|
|
|
|
createNtfSubscription db ntfSubscription action = runExceptT $ do
|
|
|
|
|
let NtfSubscription {connId, smpServer = smpServer@(SMPServer host port _), ntfQueueId, ntfServer = (NtfServer ntfHost ntfPort _), ntfSubId, ntfSubStatus} = ntfSubscription
|
|
|
|
|
smpServerKeyHash_ <- ExceptT $ getServerKeyHash_ db smpServer
|
|
|
|
|
actionTs <- liftIO getCurrentTime
|
|
|
|
|
DB.execute
|
|
|
|
|
db
|
|
|
|
|
[sql|
|
|
|
|
|
INSERT INTO ntf_subscriptions
|
|
|
|
|
(conn_id, smp_host, smp_port, smp_ntf_id, ntf_host, ntf_port, ntf_sub_id,
|
|
|
|
|
ntf_sub_status, ntf_sub_action, ntf_sub_smp_action, ntf_sub_action_ts)
|
|
|
|
|
VALUES (?,?,?,?,?,?,?,?,?,?,?)
|
|
|
|
|
|]
|
|
|
|
|
( (connId, host, port, ntfQueueId, ntfHost, ntfPort, ntfSubId)
|
|
|
|
|
:. (ntfSubStatus, ntfSubAction, ntfSubSMPAction, actionTs)
|
|
|
|
|
)
|
|
|
|
|
liftIO $
|
|
|
|
|
DB.execute
|
|
|
|
|
db
|
|
|
|
|
[sql|
|
|
|
|
|
INSERT INTO ntf_subscriptions
|
|
|
|
|
(conn_id, smp_host, smp_port, smp_ntf_id, ntf_host, ntf_port, ntf_sub_id,
|
|
|
|
|
ntf_sub_status, ntf_sub_action, ntf_sub_smp_action, ntf_sub_action_ts, smp_server_key_hash)
|
|
|
|
|
VALUES (?,?,?,?,?,?,?,?,?,?,?,?)
|
|
|
|
|
|]
|
|
|
|
|
( (connId, host, port, ntfQueueId, ntfHost, ntfPort, ntfSubId)
|
|
|
|
|
:. (ntfSubStatus, ntfSubAction, ntfSubSMPAction, actionTs, smpServerKeyHash_)
|
|
|
|
|
)
|
|
|
|
|
where
|
|
|
|
|
(ntfSubAction, ntfSubSMPAction) = ntfSubAndSMPAction action
|
|
|
|
|
|
|
|
|
@@ -1136,7 +1141,7 @@ getNextNtfSubNTFAction db ntfServer@(NtfServer ntfHost ntfPort _) = do
|
|
|
|
|
DB.query
|
|
|
|
|
db
|
|
|
|
|
[sql|
|
|
|
|
|
SELECT ns.conn_id, s.host, s.port, s.key_hash,
|
|
|
|
|
SELECT ns.conn_id, s.host, s.port, COALESCE(ns.smp_server_key_hash, s.key_hash),
|
|
|
|
|
ns.smp_ntf_id, ns.ntf_sub_id, ns.ntf_sub_status, ns.ntf_sub_action_ts, ns.ntf_sub_action
|
|
|
|
|
FROM ntf_subscriptions ns
|
|
|
|
|
JOIN servers s ON s.host = ns.smp_host AND s.port = ns.smp_port
|
|
|
|
@@ -1321,20 +1326,25 @@ instance (ToField a, ToField b, ToField c, ToField d, ToField e, ToField f,
|
|
|
|
|
|
|
|
|
|
{- ORMOLU_ENABLE -}
|
|
|
|
|
|
|
|
|
|
-- * Server upsert helper
|
|
|
|
|
-- * Server helper
|
|
|
|
|
|
|
|
|
|
upsertServer_ :: DB.Connection -> SMPServer -> IO ()
|
|
|
|
|
upsertServer_ dbConn ProtocolServer {host, port, keyHash} = do
|
|
|
|
|
DB.executeNamed
|
|
|
|
|
dbConn
|
|
|
|
|
[sql|
|
|
|
|
|
INSERT INTO servers (host, port, key_hash) VALUES (:host,:port,:key_hash)
|
|
|
|
|
ON CONFLICT (host, port) DO UPDATE SET
|
|
|
|
|
host=excluded.host,
|
|
|
|
|
port=excluded.port,
|
|
|
|
|
key_hash=excluded.key_hash;
|
|
|
|
|
|]
|
|
|
|
|
[":host" := host, ":port" := port, ":key_hash" := keyHash]
|
|
|
|
|
-- | Creates a new server, if it doesn't exist, and returns the passed key hash if it is different from stored.
|
|
|
|
|
createServer_ :: DB.Connection -> SMPServer -> IO (Maybe C.KeyHash)
|
|
|
|
|
createServer_ db newSrv@ProtocolServer {host, port, keyHash} =
|
|
|
|
|
getServerKeyHash_ db newSrv >>= \case
|
|
|
|
|
Right keyHash_ -> pure keyHash_
|
|
|
|
|
Left _ -> insertNewServer_ $> Nothing
|
|
|
|
|
where
|
|
|
|
|
insertNewServer_ =
|
|
|
|
|
DB.execute db "INSERT INTO servers (host, port, key_hash) VALUES (?,?,?)" (host, port, keyHash)
|
|
|
|
|
|
|
|
|
|
-- | Returns the stored server key hash if it is different from the passed one, or the error if the server does not exist.
|
|
|
|
|
getServerKeyHash_ :: DB.Connection -> SMPServer -> IO (Either StoreError (Maybe C.KeyHash))
|
|
|
|
|
getServerKeyHash_ db ProtocolServer {host, port, keyHash} = do
|
|
|
|
|
firstRow useKeyHash SEServerNotFound $
|
|
|
|
|
DB.query db "SELECT key_hash FROM servers WHERE host = ? AND port = ?" (host, port)
|
|
|
|
|
where
|
|
|
|
|
useKeyHash (Only keyHash') = if keyHash /= keyHash' then Just keyHash else Nothing
|
|
|
|
|
|
|
|
|
|
upsertNtfServer_ :: DB.Connection -> NtfServer -> IO ()
|
|
|
|
|
upsertNtfServer_ db ProtocolServer {host, port, keyHash} = do
|
|
|
|
@@ -1351,30 +1361,30 @@ upsertNtfServer_ db ProtocolServer {host, port, keyHash} = do
|
|
|
|
|
|
|
|
|
|
-- * createRcvConn helpers
|
|
|
|
|
|
|
|
|
|
insertRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO Int64
|
|
|
|
|
insertRcvQueue_ db connId' RcvQueue {..} = do
|
|
|
|
|
insertRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> Maybe C.KeyHash -> IO Int64
|
|
|
|
|
insertRcvQueue_ db connId' RcvQueue {..} serverKeyHash_ = do
|
|
|
|
|
qId <- newQueueId_ <$> DB.query db "SELECT rcv_queue_id FROM rcv_queues WHERE conn_id = ? ORDER BY rcv_queue_id DESC LIMIT 1" (Only connId')
|
|
|
|
|
DB.execute
|
|
|
|
|
db
|
|
|
|
|
[sql|
|
|
|
|
|
INSERT INTO rcv_queues
|
|
|
|
|
(host, port, rcv_id, conn_id, rcv_private_key, rcv_dh_secret, e2e_priv_key, e2e_dh_secret, snd_id, status, rcv_queue_id, rcv_primary, replace_rcv_queue_id, smp_client_version) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?);
|
|
|
|
|
(host, port, rcv_id, conn_id, rcv_private_key, rcv_dh_secret, e2e_priv_key, e2e_dh_secret, snd_id, status, rcv_queue_id, rcv_primary, replace_rcv_queue_id, smp_client_version, server_key_hash) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?);
|
|
|
|
|
|]
|
|
|
|
|
((host server, port server, rcvId, connId', rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret) :. (sndId, status, qId, primary, dbReplaceQueueId, smpClientVersion))
|
|
|
|
|
((host server, port server, rcvId, connId', rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret) :. (sndId, status, qId, primary, dbReplaceQueueId, smpClientVersion, serverKeyHash_))
|
|
|
|
|
pure qId
|
|
|
|
|
|
|
|
|
|
-- * createSndConn helpers
|
|
|
|
|
|
|
|
|
|
insertSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> IO Int64
|
|
|
|
|
insertSndQueue_ db connId' SndQueue {..} = do
|
|
|
|
|
insertSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> Maybe C.KeyHash -> IO Int64
|
|
|
|
|
insertSndQueue_ db connId' SndQueue {..} serverKeyHash_ = do
|
|
|
|
|
qId <- newQueueId_ <$> DB.query db "SELECT snd_queue_id FROM snd_queues WHERE conn_id = ? ORDER BY snd_queue_id DESC LIMIT 1" (Only connId')
|
|
|
|
|
DB.execute
|
|
|
|
|
db
|
|
|
|
|
[sql|
|
|
|
|
|
INSERT INTO snd_queues
|
|
|
|
|
(host, port, snd_id, conn_id, snd_public_key, snd_private_key, e2e_pub_key, e2e_dh_secret, status, snd_queue_id, snd_primary, replace_snd_queue_id, smp_client_version) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?);
|
|
|
|
|
(host, port, snd_id, conn_id, snd_public_key, snd_private_key, e2e_pub_key, e2e_dh_secret, status, snd_queue_id, snd_primary, replace_snd_queue_id, smp_client_version, server_key_hash) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?);
|
|
|
|
|
|]
|
|
|
|
|
((host server, port server, sndId, connId', sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret) :. (status, qId, primary, dbReplaceQueueId, smpClientVersion))
|
|
|
|
|
((host server, port server, sndId, connId', sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret) :. (status, qId, primary, dbReplaceQueueId, smpClientVersion, serverKeyHash_))
|
|
|
|
|
pure qId
|
|
|
|
|
|
|
|
|
|
newQueueId_ :: [Only Int64] -> Int64
|
|
|
|
@@ -1443,7 +1453,7 @@ getRcvQueuesByConnId_ db connId =
|
|
|
|
|
rcvQueueQuery :: Query
|
|
|
|
|
rcvQueueQuery =
|
|
|
|
|
[sql|
|
|
|
|
|
SELECT c.user_id, s.key_hash, q.conn_id, q.host, q.port, q.rcv_id, q.rcv_private_key, q.rcv_dh_secret,
|
|
|
|
|
SELECT c.user_id, COALESCE(q.server_key_hash, s.key_hash), q.conn_id, q.host, q.port, q.rcv_id, q.rcv_private_key, q.rcv_dh_secret,
|
|
|
|
|
q.e2e_priv_key, q.e2e_dh_secret, q.snd_id, q.status,
|
|
|
|
|
q.rcv_queue_id, q.rcv_primary, q.replace_rcv_queue_id, q.smp_client_version, q.delete_errors,
|
|
|
|
|
q.ntf_public_key, q.ntf_private_key, q.ntf_id, q.rcv_ntf_dh_secret
|
|
|
|
@@ -1477,7 +1487,7 @@ getSndQueuesByConnId_ dbConn connId =
|
|
|
|
|
<$> DB.query
|
|
|
|
|
dbConn
|
|
|
|
|
[sql|
|
|
|
|
|
SELECT c.user_id, s.key_hash, q.host, q.port, q.snd_id, q.snd_public_key, q.snd_private_key, q.e2e_pub_key, q.e2e_dh_secret, q.status, q.snd_queue_id, q.snd_primary, q.replace_snd_queue_id, q.smp_client_version
|
|
|
|
|
SELECT c.user_id, COALESCE(q.server_key_hash, s.key_hash), q.host, q.port, q.snd_id, q.snd_public_key, q.snd_private_key, q.e2e_pub_key, q.e2e_dh_secret, q.status, q.snd_queue_id, q.snd_primary, q.replace_snd_queue_id, q.smp_client_version
|
|
|
|
|
FROM snd_queues q
|
|
|
|
|
JOIN servers s ON q.host = s.host AND q.port = s.port
|
|
|
|
|
JOIN connections c ON q.conn_id = c.conn_id
|
|
|
|
|