diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 74228e427..5b376008c 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -445,9 +445,17 @@ doRcvQueueAction :: AgentMonad m => AgentClient -> ConnData -> RcvQueue -> SndQu doRcvQueueAction c cData rq@RcvQueue {rcvQueueAction} sq = forM_ rcvQueueAction $ \(a, _ts) -> case a of RQACreateNextQueue -> createNextRcvQueue c cData rq sq - RQASecureNextQueue -> withNextRcvQueue c rq $ secureNextRcvQueue cData sq - RQASuspendCurrQueue -> withNextRcvQueue c rq suspendCurrRcvQueue - RQADeleteCurrQueue -> withNextRcvQueue c rq deleteCurrRcvQueue + RQASecureNextQueue -> withNextRcvQueue secureNextRcvQueue + RQASuspendCurrQueue -> withNextRcvQueue suspendCurrRcvQueue + RQADeleteCurrQueue -> withNextRcvQueue deleteCurrRcvQueue + where + withNextRcvQueue :: AgentMonad m => (AgentClient -> ConnData -> RcvQueue -> SndQueue -> RcvQueue -> m ()) -> m () + withNextRcvQueue action = do + withStore' c (`getNextRcvQueue` dbNextRcvQueueId rq) >>= \case + Just rq' -> action c cData rq sq rq' + _ -> do + -- notify agent internal error + pure () createNextRcvQueue :: AgentMonad m => AgentClient -> ConnData -> RcvQueue -> SndQueue -> m () createNextRcvQueue c cData rq@RcvQueue {server, sndId} sq = do @@ -465,46 +473,36 @@ createNextRcvQueue c cData rq@RcvQueue {server, sndId} sq = do void $ enqueueMessage c cData sq SMP.noMsgFlags QNEW {currentAddress = (server, sndId), nextQueueUri} withStore' c $ \db -> setRcvQueueAction db rq Nothing -secureNextRcvQueue :: AgentMonad m => ConnData -> SndQueue -> AgentClient -> RcvQueue -> RcvQueue -> m () -secureNextRcvQueue cData sq c rq nextRq@RcvQueue {server, sndId, status, sndPublicKey} = do +secureNextRcvQueue :: AgentMonad m => AgentClient -> ConnData -> RcvQueue -> SndQueue -> RcvQueue -> m () +secureNextRcvQueue c cData rq sq rq'@RcvQueue {server, sndId, status, sndPublicKey} = do when (status == Confirmed) $ case sndPublicKey of Just sKey -> do secureQueue c rq sKey - withStore' c $ \db -> setRcvQueueStatus db nextRq Secured + withStore' c $ \db -> setRcvQueueStatus db rq' Secured _ -> do -- notify user: no sender key pure () void . enqueueMessage c cData sq SMP.noMsgFlags $ QREADY (server, sndId) withStore' c $ \db -> setRcvQueueAction db rq Nothing -suspendCurrRcvQueue :: AgentMonad m => AgentClient -> RcvQueue -> RcvQueue -> m () -suspendCurrRcvQueue c currRq nextRq = do - -- Suspend curr queue - -- if 0 messages left: - -- - currRcvQueueDrained c currRq nextRq +suspendCurrRcvQueue :: AgentMonad m => AgentClient -> ConnData -> RcvQueue -> SndQueue -> RcvQueue -> m () +suspendCurrRcvQueue c cData rq sq rq' = do + msgCount <- suspendQueue c rq + withStore' c $ \db -> setRcvQueueStatus db rq Disabled + when (msgCount == 0) $ currRcvQueueDrained c cData rq sq rq' -currRcvQueueDrained :: AgentMonad m => AgentClient -> RcvQueue -> RcvQueue -> m () -currRcvQueueDrained c currRq nextRq = do - -- old queue status = Disabled - -- rcv_queue_action = RQADeleteQueue - -- - deleteCurrRcvQueue c currRq nextRq +currRcvQueueDrained :: AgentMonad m => AgentClient -> ConnData -> RcvQueue -> SndQueue -> RcvQueue -> m () +currRcvQueueDrained c cData rq sq rq' = do + withStore' c $ \db -> setRcvQueueAction db rq $ Just RQADeleteCurrQueue + deleteCurrRcvQueue c cData rq sq rq' -deleteCurrRcvQueue :: AgentMonad m => AgentClient -> RcvQueue -> RcvQueue -> m () -deleteCurrRcvQueue _c _currRq _nextRq = do - -- delete old queue - -- make a new queue a main one - -- get message from a new queue storage and process it (possibly send to the processing queue) - pure () - -withNextRcvQueue :: AgentMonad m => AgentClient -> RcvQueue -> (AgentClient -> RcvQueue -> RcvQueue -> m ()) -> m () -withNextRcvQueue c rq action = do - withStore' c (`getNextRcvQueue` dbNextRcvQueueId rq) >>= \case - Just nextRq -> action c rq nextRq - _ -> do - -- notify agent internal error - pure () +deleteCurrRcvQueue :: AgentMonad m => AgentClient -> ConnData -> RcvQueue -> SndQueue -> RcvQueue -> m () +deleteCurrRcvQueue c ConnData {connId} rq _sq rq'@RcvQueue {server, rcvId} = do + deleteQueue c rq + withStore' c $ \db -> switchCurrRcvQueue db rq rq' + atomically $ + TM.lookupDelete (connId, server, rcvId) (nextRcvQueueMsgs c) + >>= mapM_ (mapM_ . writeTBQueue $ msgQ c) subscribeConnections' :: forall m. AgentMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ())) subscribeConnections' _ [] = pure M.empty @@ -1096,16 +1094,12 @@ subscriber c@AgentClient {msgQ} = forever $ do Right _ -> return () processSMPTransmission :: forall m. AgentMonad m => AgentClient -> ServerTransmission BrokerMsg -> m () -processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cmd) = - withStore c (\db -> getRcvConn db srv rId) >>= \case - -- TODO somehow it should get next queue if the message is to it - SomeConn _ conn@(DuplexConnection cData rq _) -> processSMP conn cData rq - SomeConn _ conn@(RcvConnection cData rq) -> processSMP conn cData rq - SomeConn _ conn@(ContactConnection cData rq) -> processSMP conn cData rq - _ -> atomically $ writeTBQueue subQ ("", "", ERR $ CONN NOT_FOUND) +processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cmd) = do + (rq, SomeConn _ conn) <- withStore c $ \db -> getRcvConn db srv rId + processSMP conn (connData conn) rq where processSMP :: Connection c -> ConnData -> RcvQueue -> m () - processSMP conn cData@ConnData {connId, duplexHandshake} rq@RcvQueue {e2ePrivKey, e2eDhSecret, status} = + processSMP conn cData@ConnData {connId, duplexHandshake} rq@RcvQueue {e2ePrivKey, e2eDhSecret, status, currRcvQueue} = case cmd of SMP.MSG msg@SMP.RcvMessage {msgId = srvMsgId} -> handleNotifyAck $ do SMP.ClientRcvMsgBody {msgTs = srvTs, msgFlags, msgBody} <- decryptSMPMessage v rq msg @@ -1304,29 +1298,32 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm -- processed by queue sender rqNewMsg :: (SMPServer, SMP.SenderId) -> SMPQueueUri -> m () - rqNewMsg _currAddr nextQUri = case conn of - DuplexConnection _ _ sq -> do - -- TODO check that current address matches - clientVRange <- asks $ smpClientVRange . config - case (nextQUri `compatibleVersion` clientVRange) of - Just qInfo@(Compatible nextQInfo) -> do - sq'@SndQueue {sndPublicKey, e2ePubKey} <- newSndQueue qInfo False - withStore' c $ \db -> dbCreateNextSndQueue db sq sq' - case (sndPublicKey, e2ePubKey) of - (Just nextSenderKey, Just dhPublicKey) -> do - let qAddr = (queueAddress (nextQInfo :: SMPQueueInfo)) {dhPublicKey} - nextQueueInfo = (nextQInfo :: SMPQueueInfo) {queueAddress = qAddr} - void . enqueueMessage c cData sq SMP.noMsgFlags $ QKEYS {nextSenderKey, nextQueueInfo} - rq' <- withStore' c (`getNextRcvQueue` dbNextRcvQueueId rq) - notify . SWITCH SPStarted $ connectionStats conn rq' (Just sq') - _ -> throwError $ INTERNAL "absent sender keys" - _ -> throwError $ AGENT A_VERSION - _ -> throwError $ INTERNAL "message can only be sent to duplex connection" + rqNewMsg (smpServer, senderId) nextQUri + | currRcvQueue = case conn of + DuplexConnection _ _ sq@SndQueue {server, sndId} + | smpServer == server && senderId == sndId -> do + clientVRange <- asks $ smpClientVRange . config + case (nextQUri `compatibleVersion` clientVRange) of + Just qInfo@(Compatible nextQInfo) -> do + sq'@SndQueue {sndPublicKey, e2ePubKey} <- newSndQueue qInfo False + withStore' c $ \db -> dbCreateNextSndQueue db sq sq' + case (sndPublicKey, e2ePubKey) of + (Just nextSenderKey, Just dhPublicKey) -> do + let qAddr = (queueAddress (nextQInfo :: SMPQueueInfo)) {dhPublicKey} + nextQueueInfo = (nextQInfo :: SMPQueueInfo) {queueAddress = qAddr} + void . enqueueMessage c cData sq SMP.noMsgFlags $ QKEYS {nextSenderKey, nextQueueInfo} + rq' <- withStore' c (`getNextRcvQueue` dbNextRcvQueueId rq) + notify . SWITCH SPStarted $ connectionStats conn rq' (Just sq') + _ -> throwError $ INTERNAL "absent sender keys" + _ -> throwError $ AGENT A_VERSION + | otherwise -> throwError $ INTERNAL "incorrect queue address" + _ -> throwError $ INTERNAL "message can only be sent to duplex connection" + | otherwise = throwError $ INTERNAL "message can only be sent to current queue" -- processed by queue recipient rqKeys :: SndPublicVerifyKey -> SMPQueueInfo -> m () -> m () - rqKeys senderKey qInfo ackDelete = - case conn of + rqKeys senderKey qInfo ackDelete + | currRcvQueue = case conn of DuplexConnection _ _ sq -> do clientVRange <- asks $ smpClientVRange . config case qInfo `proveCompatible` clientVRange of @@ -1339,38 +1336,41 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm setRcvQueueConfirmedE2E db rq' senderKey dhSecret $ min clntVer clntVer' setRcvQueueAction db rq $ Just RQASecureNextQueue ackDelete - secureNextRcvQueue cData sq c rq rq' + secureNextRcvQueue c cData rq sq rq' | otherwise -> throwError $ INTERNAL "incorrect queue address" _ -> throwError $ INTERNAL "message can only be sent during rotation" _ -> throwError $ AGENT A_VERSION _ -> throwError $ INTERNAL "message can only be sent to duplex connection" + | otherwise = throwError $ INTERNAL "message can only be sent to current queue" -- processed by queue sender rqReady :: (SMPServer, SMP.SenderId) -> m () - rqReady (smpServer, senderId) = - case conn of + rqReady (smpServer, senderId) + | currRcvQueue = case conn of DuplexConnection _ _ sq -> withStore' c (`getNextSndQueue` dbNextSndQueueId sq) >>= \case Just sq'@SndQueue {server, sndId} | server == smpServer && sndId == senderId -> - void . enqueueMessage c cData sq' SMP.noMsgFlags $ QHELLO + void $ enqueueMessage c cData sq' SMP.noMsgFlags QHELLO | otherwise -> throwError $ INTERNAL "incorrect queue address" _ -> throwError $ INTERNAL "message can only be sent during rotation" _ -> throwError $ INTERNAL "message can only be sent to duplex connection" + | otherwise = throwError $ INTERNAL "message can only be sent to current queue" -- processed by queue recipient, received from the new queue rqHello :: m () -> m () - rqHello ackDelete = do - -- validate it's the next queue, or send error to the client - -- Enqueue QSWITCH message to the sender - -- snd_switch_action = RQASuspendCurrQueue - -- new queue status = Active - -- currRq <- load current queue - -- - ackDelete - -- - -- suspendCurrRcvQueue currRq rq - pure () + rqHello ackDelete + | currRcvQueue = throwError $ INTERNAL "message can only be sent to the next queue" + | otherwise = case conn of + DuplexConnection _ currRq sq -> do + let RcvQueue {server, sndId} = rq + void . enqueueMessage c cData sq SMP.noMsgFlags $ QSWITCH (server, sndId) + withStore' c $ \db -> do + setRcvQueueStatus db rq Active + setRcvQueueAction db currRq $ Just RQASuspendCurrQueue + ackDelete + suspendCurrRcvQueue c cData currRq sq rq + _ -> throwError $ INTERNAL "message can only be sent to duplex connection" -- processed by queue sender rqSwitch :: (SMPServer, SMP.SenderId) -> m () diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 024ee9218..3bb1862fd 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -161,6 +161,7 @@ data AgentClient = AgentClient connMsgsQueued :: TMap ConnId Bool, smpQueueMsgQueues :: TMap (ConnId, SMPServer, SMP.SenderId) (TQueue InternalId), smpQueueMsgDeliveries :: TMap (ConnId, SMPServer, SMP.SenderId) (Async ()), + nextRcvQueueMsgs :: TMap (ConnId, SMPServer, SMP.RecipientId) [ServerTransmission BrokerMsg], ntfNetworkOp :: TVar AgentOpState, rcvNetworkOp :: TVar AgentOpState, msgDeliveryOp :: TVar AgentOpState, @@ -212,6 +213,7 @@ newAgentClient InitialAgentServers {smp, ntf, netCfg} agentEnv = do connMsgsQueued <- TM.empty smpQueueMsgQueues <- TM.empty smpQueueMsgDeliveries <- TM.empty + nextRcvQueueMsgs <- TM.empty ntfNetworkOp <- newTVar $ AgentOpState False 0 rcvNetworkOp <- newTVar $ AgentOpState False 0 msgDeliveryOp <- newTVar $ AgentOpState False 0 @@ -223,7 +225,7 @@ newAgentClient InitialAgentServers {smp, ntf, netCfg} agentEnv = do asyncClients <- newTVar [] clientId <- stateTVar (clientCounter agentEnv) $ \i -> let i' = i + 1 in (i', i') lock <- newTMVar () - return AgentClient {active, rcvQ, subQ, msgQ, smpServers, smpClients, ntfServers, ntfClients, useNetworkConfig, subscrSrvrs, pendingSubscrSrvrs, subscrConns, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, ntfNetworkOp, rcvNetworkOp, msgDeliveryOp, sndNetworkOp, databaseOp, agentState, getMsgLocks, reconnections, asyncClients, clientId, agentEnv, lock} + return AgentClient {active, rcvQ, subQ, msgQ, smpServers, smpClients, ntfServers, ntfClients, useNetworkConfig, subscrSrvrs, pendingSubscrSrvrs, subscrConns, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, nextRcvQueueMsgs, ntfNetworkOp, rcvNetworkOp, msgDeliveryOp, sndNetworkOp, databaseOp, agentState, getMsgLocks, reconnections, asyncClients, clientId, agentEnv, lock} agentDbPath :: AgentClient -> FilePath agentDbPath AgentClient {agentEnv = Env {store = SQLiteStore {dbFilePath}}} = dbFilePath diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 302268eaf..c49b289d2 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -43,6 +43,7 @@ module Simplex.Messaging.Agent.Store.SQLite dbCreateNextRcvQueue, dbCreateNextSndQueue, setRcvQueueAction, + switchCurrRcvQueue, -- Confirmations createConfirmation, acceptConfirmation, @@ -280,19 +281,26 @@ createSndConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandsh -- TODO add queue ID in insertSndQueue_ insertSndQueue_ db connId q -getRcvConn :: DB.Connection -> SMPServer -> SMP.RecipientId -> IO (Either StoreError SomeConn) -getRcvConn db ProtocolServer {host, port} rcvId = - DB.queryNamed - db - [sql| - SELECT q.conn_id - FROM rcv_queues q - WHERE q.host = :host AND q.port = :port AND q.rcv_id = :rcv_id; - |] - [":host" := host, ":port" := port, ":rcv_id" := rcvId] - >>= \case - [Only connId] -> getConn db connId - _ -> pure $ Left SEConnNotFound +getRcvConn :: DB.Connection -> SMPServer -> SMP.RecipientId -> IO (Either StoreError (RcvQueue, SomeConn)) +getRcvConn db ProtocolServer {host, port} rcvId = runExceptT $ do + (rq, connId) <- + ExceptT . firstRow (\(qRow :. Only connId) -> (toRcvQueue qRow, connId)) SEConnNotFound $ + DB.query + db + [sql| + SELECT q.host, q.port, s.key_hash, + q.rcv_id, q.rcv_private_key, q.rcv_dh_secret, q.e2e_priv_key, q.e2e_dh_secret, q.snd_id, q.snd_key, q.status, + q.rcv_queue_action, q.rcv_queue_action_ts, q.curr_rcv_queue, q.next_rcv_queue_id, + q.ntf_public_key, q.ntf_private_key, q.ntf_id, q.rcv_ntf_dh_secret, + q.smp_client_version, q.created_at, q.updated_at, + q.conn_id + FROM rcv_queues q + INNER JOIN servers s ON q.host = s.host AND q.port = s.port + WHERE q.host = ? AND q.port = ? AND q.rcv_id = ? + |] + (host, port, rcvId) + conn <- ExceptT $ getConn db connId + pure (rq, conn) deleteConn :: DB.Connection -> ConnId -> IO () deleteConn db connId = @@ -313,13 +321,12 @@ upgradeRcvConnToDuplex db connId sq@SndQueue {server} = upgradeSndConnToDuplex :: DB.Connection -> ConnId -> RcvQueue -> IO (Either StoreError ()) upgradeSndConnToDuplex db connId rq@RcvQueue {server} = - getConn db connId >>= \case - Right (SomeConn _ SndConnection {}) -> do + getConn db connId $>>= \case + SomeConn _ SndConnection {} -> do upsertServer_ db server insertRcvQueue_ db connId rq pure $ Right () - Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c - _ -> pure $ Left SEConnNotFound + SomeConn c _ -> pure . Left . SEBadConnType $ connType c setRcvQueueStatus :: DB.Connection -> RcvQueue -> QueueStatus -> IO () setRcvQueueStatus db RcvQueue {rcvId, server = ProtocolServer {host, port}} status = @@ -419,6 +426,11 @@ dbCreateNextSndQueue _db _sq _nextSq = do setRcvQueueAction :: DB.Connection -> RcvQueue -> Maybe RcvQueueAction -> IO () setRcvQueueAction _db _rq _rqAction_ = pure () +switchCurrRcvQueue :: DB.Connection -> RcvQueue -> RcvQueue -> IO () +switchCurrRcvQueue _db _rq _nextRq = do + -- make a new queue a main one + pure () + type SMPConfirmationRow = (SndPublicVerifyKey, C.PublicKeyX25519, ConnInfo, Maybe [SMPQueueInfo], Maybe Version) smpConfirmation :: SMPConfirmationRow -> SMPConfirmation diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index 4a6566482..eb4691522 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -267,7 +267,7 @@ testGetRcvConn = g <- newTVarIO =<< drgNew _ <- createRcvConn db g cData1 rcvQueue1 SCMInvitation getRcvConn db smpServer recipientId - `shouldReturn` Right (SomeConn SCRcv (RcvConnection cData1 rcvQueue1)) + `shouldReturn` Right (rcvQueue1, SomeConn SCRcv (RcvConnection cData1 rcvQueue1)) testDeleteRcvConn :: SpecWith SQLiteStore testDeleteRcvConn =