From 058e3ac55e8577280267f9341ccd7d3e971bc51a Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Wed, 4 Jan 2023 14:10:13 +0000 Subject: [PATCH] send/process "quota exceeded" message from SMP server when sender gets ERR QUOTA (#585) * send "quota exceeded" message from SMP server when sender gets ERR QUOTA (ignored in the agent for now) * send msg quota to the recipient to indicate that sender got ERR QUOTA, test * switch between slow/fast retry intervals (tests do not pass yet) * send QCONT message, refactor RetryInterval, test * refactor * remove comment * remove space * unit test for withRetryLock2 * refactor --- simplexmq.cabal | 1 + src/Simplex/Messaging/Agent.hs | 206 ++++++++++--------- src/Simplex/Messaging/Agent/Client.hs | 2 +- src/Simplex/Messaging/Agent/Env/SQLite.hs | 24 ++- src/Simplex/Messaging/Agent/Protocol.hs | 11 + src/Simplex/Messaging/Agent/RetryInterval.hs | 65 +++++- src/Simplex/Messaging/Protocol.hs | 106 ++++++---- src/Simplex/Messaging/Server.hs | 59 +++--- src/Simplex/Messaging/Server/Env/STM.hs | 2 +- src/Simplex/Messaging/Server/MsgStore.hs | 5 +- src/Simplex/Messaging/Server/MsgStore/STM.hs | 88 +++++--- tests/AgentTests.hs | 28 +++ tests/CoreTests/RetryIntervalTests.hs | 61 ++++++ tests/ServerTests.hs | 35 +++- tests/Test.hs | 2 + 15 files changed, 490 insertions(+), 205 deletions(-) create mode 100644 tests/CoreTests/RetryIntervalTests.hs diff --git a/simplexmq.cabal b/simplexmq.cabal index d8f11dfab..e01dd0a15 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -357,6 +357,7 @@ test-suite smp-server-test CoreTests.CryptoTests CoreTests.EncodingTests CoreTests.ProtocolErrorTests + CoreTests.RetryIntervalTests CoreTests.VersionRangeTests NtfClient NtfServerTests diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 763c51522..e5c964b0d 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -791,7 +791,7 @@ runCommandProcessing c@AgentClient {subQ} server_ = do atomically $ beginAgentOperation c AOSndNetwork E.try (withStore c $ \db -> getPendingCommand db cmdId) >>= \case Left (e :: E.SomeException) -> atomically $ writeTBQueue subQ ("", "", ERR . INTERNAL $ show e) - Right (corrId, connId, cmd) -> processCmd ri corrId connId cmdId cmd + Right (corrId, connId, cmd) -> processCmd (riFast ri) corrId connId cmdId cmd where processCmd :: RetryInterval -> ACorrId -> ConnId -> AsyncCmdId -> AgentCommand -> m () processCmd ri corrId connId cmdId command = case command of @@ -964,22 +964,22 @@ queuePendingMsgs c sq msgIds = atomically $ do modifyTVar' (msgDeliveryOp c) $ \s -> s {opsInProgress = opsInProgress s + length msgIds} -- s <- readTVar (msgDeliveryOp c) -- unsafeIOToSTM $ putStrLn $ "msgDeliveryOp: " <> show (opsInProgress s) - q <- getPendingMsgQ c sq - mapM_ (writeTQueue q) msgIds + (mq, _) <- getPendingMsgQ c sq + mapM_ (writeTQueue mq) msgIds -getPendingMsgQ :: AgentClient -> SndQueue -> STM (TQueue InternalId) +getPendingMsgQ :: AgentClient -> SndQueue -> STM (TQueue InternalId, TMVar ()) getPendingMsgQ c SndQueue {server, sndId} = do let qKey = (server, sndId) maybe (newMsgQueue qKey) pure =<< TM.lookup qKey (smpQueueMsgQueues c) where newMsgQueue qKey = do - mq <- newTQueue - TM.insert qKey mq $ smpQueueMsgQueues c - pure mq + q <- (,) <$> newTQueue <*> newEmptyTMVar + TM.insert qKey q $ smpQueueMsgQueues c + pure q runSmpQueueMsgDelivery :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> m () runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandshake} sq = do - mq <- atomically $ getPendingMsgQ c sq + (mq, qLock) <- atomically $ getPendingMsgQ c sq ri <- asks $ messageRetryInterval . config forever $ do atomically $ endAgentOperation c AOSndNetwork @@ -993,7 +993,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh Left (e :: E.SomeException) -> notify $ MERR mId (INTERNAL $ show e) Right (rq_, PendingMsgData {msgType, msgBody, msgFlags, internalTs}) -> - withRetryInterval ri $ \loop -> do + withRetryLock2 ri qLock $ \loop -> do resp <- tryError $ case msgType of AM_CONN_INFO -> sendConfirmation c sq msgBody _ -> sendAgentMessage c sq msgFlags msgBody @@ -1004,7 +1004,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh SMP SMP.QUOTA -> case msgType of AM_CONN_INFO -> connError msgId NOT_AVAILABLE AM_CONN_INFO_REPLY -> connError msgId NOT_AVAILABLE - _ -> retrySndOp c loop + _ -> retrySndOp c $ loop RISlow SMP SMP.AUTH -> case msgType of AM_CONN_INFO -> connError msgId NOT_AVAILABLE AM_CONN_INFO_REPLY -> connError msgId NOT_AVAILABLE @@ -1013,7 +1013,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh -- because the queue must be secured by the time the confirmation or the first HELLO is received | duplexHandshake == Just True -> connErr | otherwise -> - ifM (msgExpired helloTimeout) connErr (retrySndOp c loop) + ifM (msgExpired helloTimeout) connErr (retrySndOp c $ loop RIFast) where connErr = case rq_ of -- party initiating connection @@ -1022,6 +1022,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh _ -> connError msgId NOT_ACCEPTED AM_REPLY_ -> notifyDel msgId err AM_A_MSG_ -> notifyDel msgId err + AM_QCONT_ -> notifyDel msgId err AM_QADD_ -> qError msgId "QADD: AUTH" AM_QKEY_ -> qError msgId "QKEY: AUTH" AM_QUSE_ -> qError msgId "QUSE: AUTH" @@ -1031,7 +1032,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh -- the message sending would be retried | temporaryOrHostError e -> do let timeoutSel = if msgType == AM_HELLO_ then helloTimeout else messageTimeout - ifM (msgExpired timeoutSel) (notifyDel msgId err) (retrySndOp c loop) + ifM (msgExpired timeoutSel) (notifyDel msgId err) (retrySndOp c $ loop RIFast) | otherwise -> notifyDel msgId err where msgExpired timeoutSel = do @@ -1071,6 +1072,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh qInfo <- createReplyQueue c cData sq srv void . enqueueMessage c cData sq SMP.noMsgFlags $ REPLY [qInfo] AM_A_MSG_ -> notify $ SENT mId + AM_QCONT_ -> pure () AM_QADD_ -> pure () AM_QKEY_ -> pure () AM_QUSE_ -> pure () @@ -1492,88 +1494,96 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm processSMP :: RcvQueue -> Connection c -> ConnData -> m () processSMP rq@RcvQueue {e2ePrivKey, e2eDhSecret, status} conn cData@ConnData {connId, duplexHandshake} = withConnLock c connId "processSMP" $ case cmd of - SMP.MSG msg@SMP.RcvMessage {msgId = srvMsgId} -> handleNotifyAck $ do - SMP.ClientRcvMsgBody {msgTs = srvTs, msgFlags, msgBody} <- decryptSMPMessage v rq msg - clientMsg@SMP.ClientMsgEnvelope {cmHeader = SMP.PubHeader phVer e2ePubKey_} <- - parseMessage msgBody - clientVRange <- asks $ smpClientVRange . config - unless (phVer `isCompatible` clientVRange) . throwError $ AGENT A_VERSION - case (e2eDhSecret, e2ePubKey_) of - (Nothing, Just e2ePubKey) -> do - let e2eDh = C.dh' e2ePubKey e2ePrivKey - decryptClientMessage e2eDh clientMsg >>= \case - (SMP.PHConfirmation senderKey, AgentConfirmation {e2eEncryption, encConnInfo, agentVersion}) -> - smpConfirmation senderKey e2ePubKey e2eEncryption encConnInfo phVer agentVersion >> ack - (SMP.PHEmpty, AgentInvitation {connReq, connInfo}) -> - smpInvitation connReq connInfo >> ack - _ -> prohibited >> ack - (Just e2eDh, Nothing) -> do - decryptClientMessage e2eDh clientMsg >>= \case - (SMP.PHEmpty, AgentMsgEnvelope _ encAgentMsg) -> do - -- primary queue is set as Active in helloMsg, below is to set additional queues Active - let RcvQueue {primary, dbReplaceQueueId} = rq - unless (status == Active) . withStore' c $ \db -> setRcvQueueStatus db rq Active - case (conn, dbReplaceQueueId) of - (DuplexConnection _ rqs _, Just replacedId) -> do - when primary . withStore' c $ \db -> setRcvQueuePrimary db connId rq - case find (\RcvQueue {dbQueueId} -> dbQueueId == replacedId) rqs of - Just RcvQueue {server, rcvId} -> do - enqueueCommand c "" connId (Just server) $ AInternalCommand $ ICQDelete rcvId - _ -> notify . ERR . AGENT $ A_QUEUE "replaced RcvQueue not found in connection" - _ -> pure () - tryError agentClientMsg >>= \case - Right (Just (msgId, msgMeta, aMessage)) -> case aMessage of - HELLO -> helloMsg >> ackDel msgId - REPLY cReq -> replyMsg cReq >> ackDel msgId - -- note that there is no ACK sent for A_MSG, it is sent with agent's user ACK command - A_MSG body -> do - logServer "<--" c srv rId "MSG " - notify $ MSG msgMeta msgFlags body - QADD qs -> qDuplex "QADD" $ qAddMsg qs - QKEY qs -> qDuplex "QKEY" $ qKeyMsg qs - QUSE qs -> qDuplex "QUSE" $ qUseMsg qs - -- no action needed for QTEST - -- any message in the new queue will mark it active and trigger deletion of the old queue - QTEST _ -> logServer "<--" c srv rId "MSG " >> ackDel msgId - where - qDuplex :: String -> (Connection 'CDuplex -> m ()) -> m () - qDuplex name a = case conn of - DuplexConnection {} -> a conn >> ackDel msgId - _ -> qError $ name <> ": message must be sent to duplex connection" - Right _ -> prohibited >> ack - Left e@(AGENT A_DUPLICATE) -> do - withStore' c (\db -> getLastMsg db connId srvMsgId) >>= \case - Just RcvMsg {internalId, msgMeta, msgBody = agentMsgBody, userAck} - | userAck -> ackDel internalId - | otherwise -> do - liftEither (parse smpP (AGENT A_MESSAGE) agentMsgBody) >>= \case - AgentMessage _ (A_MSG body) -> do - logServer "<--" c srv rId "MSG " - notify $ MSG msgMeta msgFlags body - _ -> pure () - _ -> throwError e - Left e -> throwError e - where - agentClientMsg :: m (Maybe (InternalId, MsgMeta, AMessage)) - agentClientMsg = withStore c $ \db -> runExceptT $ do - agentMsgBody <- agentRatchetDecrypt db connId encAgentMsg - liftEither (parse smpP (SEAgentError $ AGENT A_MESSAGE) agentMsgBody) >>= \case - agentMsg@(AgentMessage APrivHeader {sndMsgId, prevMsgHash} aMessage) -> do - let msgType = agentMessageType agentMsg - internalHash = C.sha256Hash agentMsgBody - internalTs <- liftIO getCurrentTime - (internalId, internalRcvId, prevExtSndId, prevRcvMsgHash) <- liftIO $ updateRcvIds db connId - let integrity = checkMsgIntegrity prevExtSndId sndMsgId prevRcvMsgHash prevMsgHash - recipient = (unId internalId, internalTs) - broker = (srvMsgId, systemToUTCTime srvTs) - msgMeta = MsgMeta {integrity, recipient, broker, sndMsgId} - rcvMsg = RcvMsgData {msgMeta, msgType, msgFlags, msgBody = agentMsgBody, internalRcvId, internalHash, externalPrevSndHash = prevMsgHash} - liftIO $ createRcvMsg db connId rq rcvMsg - pure $ Just (internalId, msgMeta, aMessage) - _ -> pure Nothing - _ -> prohibited >> ack - _ -> prohibited >> ack + SMP.MSG msg@SMP.RcvMessage {msgId = srvMsgId} -> + handleNotifyAck $ + decryptSMPMessage v rq msg >>= \case + SMP.ClientRcvMsgBody {msgTs = srvTs, msgFlags, msgBody} -> processClientMsg srvTs msgFlags msgBody + SMP.ClientRcvMsgQuota {} -> queueDrained >> ack where + queueDrained = case conn of + DuplexConnection _ _ sqs -> void $ enqueueMessages c cData sqs SMP.noMsgFlags $ QCONT (sndAddress rq) + _ -> pure () + processClientMsg srvTs msgFlags msgBody = do + clientMsg@SMP.ClientMsgEnvelope {cmHeader = SMP.PubHeader phVer e2ePubKey_} <- + parseMessage msgBody + clientVRange <- asks $ smpClientVRange . config + unless (phVer `isCompatible` clientVRange) . throwError $ AGENT A_VERSION + case (e2eDhSecret, e2ePubKey_) of + (Nothing, Just e2ePubKey) -> do + let e2eDh = C.dh' e2ePubKey e2ePrivKey + decryptClientMessage e2eDh clientMsg >>= \case + (SMP.PHConfirmation senderKey, AgentConfirmation {e2eEncryption, encConnInfo, agentVersion}) -> + smpConfirmation senderKey e2ePubKey e2eEncryption encConnInfo phVer agentVersion >> ack + (SMP.PHEmpty, AgentInvitation {connReq, connInfo}) -> + smpInvitation connReq connInfo >> ack + _ -> prohibited >> ack + (Just e2eDh, Nothing) -> do + decryptClientMessage e2eDh clientMsg >>= \case + (SMP.PHEmpty, AgentMsgEnvelope _ encAgentMsg) -> do + -- primary queue is set as Active in helloMsg, below is to set additional queues Active + let RcvQueue {primary, dbReplaceQueueId} = rq + unless (status == Active) . withStore' c $ \db -> setRcvQueueStatus db rq Active + case (conn, dbReplaceQueueId) of + (DuplexConnection _ rqs _, Just replacedId) -> do + when primary . withStore' c $ \db -> setRcvQueuePrimary db connId rq + case find (\RcvQueue {dbQueueId} -> dbQueueId == replacedId) rqs of + Just RcvQueue {server, rcvId} -> do + enqueueCommand c "" connId (Just server) $ AInternalCommand $ ICQDelete rcvId + _ -> notify . ERR . AGENT $ A_QUEUE "replaced RcvQueue not found in connection" + _ -> pure () + tryError agentClientMsg >>= \case + Right (Just (msgId, msgMeta, aMessage)) -> case aMessage of + HELLO -> helloMsg >> ackDel msgId + REPLY cReq -> replyMsg cReq >> ackDel msgId + -- note that there is no ACK sent for A_MSG, it is sent with agent's user ACK command + A_MSG body -> do + logServer "<--" c srv rId "MSG " + notify $ MSG msgMeta msgFlags body + QCONT addr -> qDuplex "QCONT" $ continueSending addr + QADD qs -> qDuplex "QADD" $ qAddMsg qs + QKEY qs -> qDuplex "QKEY" $ qKeyMsg qs + QUSE qs -> qDuplex "QUSE" $ qUseMsg qs + -- no action needed for QTEST + -- any message in the new queue will mark it active and trigger deletion of the old queue + QTEST _ -> logServer "<--" c srv rId "MSG " >> ackDel msgId + where + qDuplex :: String -> (Connection 'CDuplex -> m ()) -> m () + qDuplex name a = case conn of + DuplexConnection {} -> a conn >> ackDel msgId + _ -> qError $ name <> ": message must be sent to duplex connection" + Right _ -> prohibited >> ack + Left e@(AGENT A_DUPLICATE) -> do + withStore' c (\db -> getLastMsg db connId srvMsgId) >>= \case + Just RcvMsg {internalId, msgMeta, msgBody = agentMsgBody, userAck} + | userAck -> ackDel internalId + | otherwise -> do + liftEither (parse smpP (AGENT A_MESSAGE) agentMsgBody) >>= \case + AgentMessage _ (A_MSG body) -> do + logServer "<--" c srv rId "MSG " + notify $ MSG msgMeta msgFlags body + _ -> pure () + _ -> throwError e + Left e -> throwError e + where + agentClientMsg :: m (Maybe (InternalId, MsgMeta, AMessage)) + agentClientMsg = withStore c $ \db -> runExceptT $ do + agentMsgBody <- agentRatchetDecrypt db connId encAgentMsg + liftEither (parse smpP (SEAgentError $ AGENT A_MESSAGE) agentMsgBody) >>= \case + agentMsg@(AgentMessage APrivHeader {sndMsgId, prevMsgHash} aMessage) -> do + let msgType = agentMessageType agentMsg + internalHash = C.sha256Hash agentMsgBody + internalTs <- liftIO getCurrentTime + (internalId, internalRcvId, prevExtSndId, prevRcvMsgHash) <- liftIO $ updateRcvIds db connId + let integrity = checkMsgIntegrity prevExtSndId sndMsgId prevRcvMsgHash prevMsgHash + recipient = (unId internalId, internalTs) + broker = (srvMsgId, systemToUTCTime srvTs) + msgMeta = MsgMeta {integrity, recipient, broker, sndMsgId} + rcvMsg = RcvMsgData {msgMeta, msgType, msgFlags, msgBody = agentMsgBody, internalRcvId, internalHash, externalPrevSndHash = prevMsgHash} + liftIO $ createRcvMsg db connId rq rcvMsg + pure $ Just (internalId, msgMeta, aMessage) + _ -> pure Nothing + _ -> prohibited >> ack + _ -> prohibited >> ack ack :: m () ack = enqueueCmd $ ICAck rId srvMsgId ackDel :: InternalId -> m () @@ -1698,6 +1708,16 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm connectReplyQueues c cData ownConnInfo smpQueues `catchError` (notify . ERR) _ -> prohibited + continueSending :: (SMPServer, SMP.SenderId) -> Connection 'CDuplex -> m () + continueSending addr (DuplexConnection _ _ sqs) = + case findQ addr sqs of + Just sq -> do + logServer "<--" c srv rId "MSG " + atomically $ do + (_, qLock) <- getPendingMsgQ c sq + void $ tryPutTMVar qLock () + Nothing -> qError "QCONT: queue address not found" + -- processed by queue sender qAddMsg :: NonEmpty (SMPQueueUri, Maybe SndQAddr) -> Connection 'CDuplex -> m () qAddMsg ((_, Nothing) :| _) _ = qError "adding queue without switching is not supported" diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 700459cc2..e47509be3 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -180,7 +180,7 @@ data AgentClient = AgentClient activeSubs :: TRcvQueues, pendingSubs :: TRcvQueues, pendingMsgsQueued :: TMap SndQAddr Bool, - smpQueueMsgQueues :: TMap SndQAddr (TQueue InternalId), + smpQueueMsgQueues :: TMap SndQAddr (TQueue InternalId, TMVar ()), smpQueueMsgDeliveries :: TMap SndQAddr (Async ()), connCmdsQueued :: TMap ConnId Bool, asyncCmdQueues :: TMap (Maybe SMPServer) (TQueue AsyncCmdId), diff --git a/src/Simplex/Messaging/Agent/Env/SQLite.hs b/src/Simplex/Messaging/Agent/Env/SQLite.hs index ce754e5e9..920a9b4dd 100644 --- a/src/Simplex/Messaging/Agent/Env/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Env/SQLite.hs @@ -83,7 +83,7 @@ data AgentConfig = AgentConfig smpCfg :: ProtocolClientConfig, ntfCfg :: ProtocolClientConfig, reconnectInterval :: RetryInterval, - messageRetryInterval :: RetryInterval, + messageRetryInterval :: RetryInterval2, messageTimeout :: NominalDiffTime, helloTimeout :: NominalDiffTime, ntfCron :: Word16, @@ -108,12 +108,24 @@ defaultReconnectInterval = maxInterval = 180_000000 } -defaultMessageRetryInterval :: RetryInterval +defaultMessageRetryInterval :: RetryInterval2 defaultMessageRetryInterval = - RetryInterval - { initialInterval = 1_000000, - increaseAfter = 10_000000, - maxInterval = 60_000000 + RetryInterval2 + { riFast = + RetryInterval + { initialInterval = 1_000000, + increaseAfter = 10_000000, + maxInterval = 60_000000 + }, + riSlow = + -- TODO: these timeouts can be increased once most clients are updates + -- to resume sending on QCONT messages. + -- After that local message expiration period should be also increased. + RetryInterval + { initialInterval = 10_000000, + increaseAfter = 30_000000, + maxInterval = 300_000000 + } } defaultAgentConfig :: AgentConfig diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index 9d2b33a70..783cde141 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -577,6 +577,7 @@ data AgentMessageType | AM_HELLO_ | AM_REPLY_ | AM_A_MSG_ + | AM_QCONT_ | AM_QADD_ | AM_QKEY_ | AM_QUSE_ @@ -590,6 +591,7 @@ instance Encoding AgentMessageType where AM_HELLO_ -> "H" AM_REPLY_ -> "R" AM_A_MSG_ -> "M" + AM_QCONT_ -> "QC" AM_QADD_ -> "QA" AM_QKEY_ -> "QK" AM_QUSE_ -> "QU" @@ -603,6 +605,7 @@ instance Encoding AgentMessageType where 'M' -> pure AM_A_MSG_ 'Q' -> A.anyChar >>= \case + 'C' -> pure AM_QCONT_ 'A' -> pure AM_QADD_ 'K' -> pure AM_QKEY_ 'U' -> pure AM_QUSE_ @@ -623,6 +626,7 @@ agentMessageType = \case -- REPLY is only used in v1 REPLY _ -> AM_REPLY_ A_MSG _ -> AM_A_MSG_ + QCONT _ -> AM_QCONT_ QADD _ -> AM_QADD_ QKEY _ -> AM_QKEY_ QUSE _ -> AM_QUSE_ @@ -645,6 +649,7 @@ data AMsgType = HELLO_ | REPLY_ | A_MSG_ + | QCONT_ | QADD_ | QKEY_ | QUSE_ @@ -656,6 +661,7 @@ instance Encoding AMsgType where HELLO_ -> "H" REPLY_ -> "R" A_MSG_ -> "M" + QCONT_ -> "QC" QADD_ -> "QA" QKEY_ -> "QK" QUSE_ -> "QU" @@ -667,6 +673,7 @@ instance Encoding AMsgType where 'M' -> pure A_MSG_ 'Q' -> A.anyChar >>= \case + 'C' -> pure QCONT_ 'A' -> pure QADD_ 'K' -> pure QKEY_ 'U' -> pure QUSE_ @@ -684,6 +691,8 @@ data AMessage REPLY (L.NonEmpty SMPQueueInfo) | -- | agent envelope for the client message A_MSG MsgBody + | -- | the message instructing the client to continue sending messages (after ERR QUOTA) + QCONT SndQAddr | -- add queue to connection (sent by recipient), with optional address of the replaced queue QADD (L.NonEmpty (SMPQueueUri, Maybe SndQAddr)) | -- key to secure the added queues and agree e2e encryption key (sent by sender) @@ -701,6 +710,7 @@ instance Encoding AMessage where HELLO -> smpEncode HELLO_ REPLY smpQueues -> smpEncode (REPLY_, smpQueues) A_MSG body -> smpEncode (A_MSG_, Tail body) + QCONT addr -> smpEncode (QCONT_, addr) QADD qs -> smpEncode (QADD_, qs) QKEY qs -> smpEncode (QKEY_, qs) QUSE qs -> smpEncode (QUSE_, qs) @@ -711,6 +721,7 @@ instance Encoding AMessage where HELLO_ -> pure HELLO REPLY_ -> REPLY <$> smpP A_MSG_ -> A_MSG . unTail <$> smpP + QCONT_ -> QCONT <$> smpP QADD_ -> QADD <$> smpP QKEY_ -> QKEY <$> smpP QUSE_ -> QUSE <$> smpP diff --git a/src/Simplex/Messaging/Agent/RetryInterval.hs b/src/Simplex/Messaging/Agent/RetryInterval.hs index 048b9e09c..3d5cfcbae 100644 --- a/src/Simplex/Messaging/Agent/RetryInterval.hs +++ b/src/Simplex/Messaging/Agent/RetryInterval.hs @@ -1,10 +1,21 @@ +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE ScopedTypeVariables #-} -module Simplex.Messaging.Agent.RetryInterval where +module Simplex.Messaging.Agent.RetryInterval + ( RetryInterval (..), + RetryInterval2 (..), + RetryIntervalMode (..), + withRetryInterval, + withRetryLock2, + ) +where -import Control.Concurrent (threadDelay) +import Control.Concurrent (forkIO, threadDelay) +import Control.Monad (void) import Control.Monad.IO.Class (MonadIO, liftIO) +import Simplex.Messaging.Util (whenM) +import UnliftIO.STM data RetryInterval = RetryInterval { initialInterval :: Int, @@ -12,17 +23,51 @@ data RetryInterval = RetryInterval maxInterval :: Int } +data RetryInterval2 = RetryInterval2 + { riSlow :: RetryInterval, + riFast :: RetryInterval + } + +data RetryIntervalMode = RISlow | RIFast + deriving (Eq) + withRetryInterval :: forall m. MonadIO m => RetryInterval -> (m () -> m ()) -> m () -withRetryInterval RetryInterval {initialInterval, increaseAfter, maxInterval} action = - callAction 0 initialInterval +withRetryInterval ri action = callAction 0 $ initialInterval ri where callAction :: Int -> Int -> m () - callAction elapsedTime delay = action loop + callAction elapsed delay = action loop where loop = do - let newDelay = - if elapsedTime < increaseAfter || delay == maxInterval - then delay - else min (delay * 3 `div` 2) maxInterval liftIO $ threadDelay delay - callAction (elapsedTime + delay) newDelay + let elapsed' = elapsed + delay + callAction elapsed' $ nextDelay elapsed' delay ri + +-- This function allows action to toggle between slow and fast retry intervals. +withRetryLock2 :: forall m. MonadIO m => RetryInterval2 -> TMVar () -> ((RetryIntervalMode -> m ()) -> m ()) -> m () +withRetryLock2 RetryInterval2 {riSlow, riFast} lock action = + callAction (0, initialInterval riSlow) (0, initialInterval riFast) + where + callAction :: (Int, Int) -> (Int, Int) -> m () + callAction slow fast = action loop + where + loop = \case + RISlow -> run slow riSlow (`callAction` fast) + RIFast -> run fast riFast (callAction slow) + run (elapsed, delay) ri call = do + wait delay + let elapsed' = elapsed + delay + call (elapsed', nextDelay elapsed' delay ri) + wait delay = do + waiting <- newTVarIO True + _ <- liftIO . forkIO $ do + threadDelay delay + atomically $ whenM (readTVar waiting) $ void $ tryPutTMVar lock () + atomically $ do + takeTMVar lock + writeTVar waiting False + +nextDelay :: Int -> Int -> RetryInterval -> Int +nextDelay elapsed delay RetryInterval {increaseAfter, maxInterval} = + if elapsed < increaseAfter || delay == maxInterval + then delay + else min (delay * 3 `div` 2) maxInterval diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 11b586f4b..e28ac1b6c 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -141,6 +141,7 @@ import qualified Data.Attoparsec.ByteString.Char8 as A import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Char (isPrint, isSpace) +import Data.Functor (($>)) import Data.Kind import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L @@ -303,12 +304,17 @@ data RcvMessage = RcvMessage deriving (Eq, Show) -- | received message without server/recipient encryption -data Message = Message - { msgId :: MsgId, - msgTs :: SystemTime, - msgFlags :: MsgFlags, - msgBody :: C.MaxLenBS MaxMessageLen - } +data Message + = Message + { msgId :: MsgId, + msgTs :: SystemTime, + msgFlags :: MsgFlags, + msgBody :: C.MaxLenBS MaxMessageLen + } + | MessageQuota + { msgId :: MsgId, + msgTs :: SystemTime + } instance StrEncoding RcvMessage where strEncode RcvMessage {msgId, msgTs, msgFlags, msgBody = EncRcvMsgBody body} = @@ -328,44 +334,72 @@ instance StrEncoding RcvMessage where newtype EncRcvMsgBody = EncRcvMsgBody ByteString deriving (Eq, Show) -data RcvMsgBody = RcvMsgBody - { msgTs :: SystemTime, - msgFlags :: MsgFlags, - msgBody :: C.MaxLenBS MaxMessageLen - } +data RcvMsgBody + = RcvMsgBody + { msgTs :: SystemTime, + msgFlags :: MsgFlags, + msgBody :: C.MaxLenBS MaxMessageLen + } + | RcvMsgQuota + { msgTs :: SystemTime + } + +msgQuotaTag :: ByteString +msgQuotaTag = "QUOTA" encodeRcvMsgBody :: RcvMsgBody -> C.MaxLenBS MaxRcvMessageLen -encodeRcvMsgBody RcvMsgBody {msgTs, msgFlags, msgBody} = - let rcvMeta :: C.MaxLenBS 16 = C.unsafeMaxLenBS $ smpEncode (msgTs, msgFlags, ' ') - in C.appendMaxLenBS rcvMeta msgBody +encodeRcvMsgBody = \case + RcvMsgBody {msgTs, msgFlags, msgBody} -> + let rcvMeta :: C.MaxLenBS 16 = C.unsafeMaxLenBS $ smpEncode (msgTs, msgFlags, ' ') + in C.appendMaxLenBS rcvMeta msgBody + RcvMsgQuota {msgTs} -> + C.unsafeMaxLenBS $ msgQuotaTag <> " " <> smpEncode msgTs -data ClientRcvMsgBody = ClientRcvMsgBody - { msgTs :: SystemTime, - msgFlags :: MsgFlags, - msgBody :: ByteString - } +data ClientRcvMsgBody + = ClientRcvMsgBody + { msgTs :: SystemTime, + msgFlags :: MsgFlags, + msgBody :: ByteString + } + | ClientRcvMsgQuota + { msgTs :: SystemTime + } clientRcvMsgBodyP :: Parser ClientRcvMsgBody -clientRcvMsgBodyP = do - msgTs <- smpP - msgFlags <- smpP - Tail msgBody <- _smpP - pure ClientRcvMsgBody {msgTs, msgFlags, msgBody} +clientRcvMsgBodyP = msgQuotaP <|> msgBodyP + where + msgQuotaP = A.string msgQuotaTag *> (ClientRcvMsgQuota <$> _smpP) + msgBodyP = do + msgTs <- smpP + msgFlags <- smpP + Tail msgBody <- _smpP + pure ClientRcvMsgBody {msgTs, msgFlags, msgBody} instance StrEncoding Message where - strEncode Message {msgId, msgTs, msgFlags, msgBody} = - B.unwords - [ strEncode msgId, - strEncode msgTs, - "flags=" <> strEncode msgFlags, - strEncode msgBody - ] + strEncode = \case + Message {msgId, msgTs, msgFlags, msgBody} -> + B.unwords + [ strEncode msgId, + strEncode msgTs, + "flags=" <> strEncode msgFlags, + strEncode msgBody + ] + MessageQuota {msgId, msgTs} -> + B.unwords + [ strEncode msgId, + strEncode msgTs, + "quota" + ] strP = do msgId <- strP_ msgTs <- strP_ - msgFlags <- ("flags=" *> strP_) <|> pure noMsgFlags - msgBody <- strP - pure Message {msgId, msgTs, msgFlags, msgBody} + msgQuotaP msgId msgTs <|> msgP msgId msgTs + where + msgQuotaP msgId msgTs = "quota" $> MessageQuota {msgId, msgTs} + msgP msgId msgTs = do + msgFlags <- ("flags=" *> strP_) <|> pure noMsgFlags + msgBody <- strP + pure Message {msgId, msgTs, msgFlags, msgBody} type EncNMsgMeta = ByteString @@ -377,7 +411,9 @@ data SMPMsgMeta = SMPMsgMeta deriving (Show) rcvMessageMeta :: MsgId -> ClientRcvMsgBody -> SMPMsgMeta -rcvMessageMeta msgId ClientRcvMsgBody {msgTs, msgFlags} = SMPMsgMeta {msgId, msgTs, msgFlags} +rcvMessageMeta msgId = \case + ClientRcvMsgBody {msgTs, msgFlags} -> SMPMsgMeta {msgId, msgTs, msgFlags} + ClientRcvMsgQuota {msgTs} -> SMPMsgMeta {msgId, msgTs, msgFlags = noMsgFlags} data NMsgMeta = NMsgMeta { msgId :: MsgId, diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 9fdb99599..1bf3d46df 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -538,20 +538,22 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv | otherwise = case status qr of QueueOff -> return $ err AUTH QueueActive -> - mapM mkMessage (C.maxLenBS msgBody) >>= \case + case C.maxLenBS msgBody of Left _ -> pure $ err LARGE_MSG - Right msg -> do - resp@(_, _, sent) <- time "SEND" $ do + Right body -> do + msg_ <- time "SEND" $ do q <- getStoreMsgQueue "SEND" $ recipientId qr expireMessages q - atomically $ ifM (isFull q) (pure $ err QUOTA) (writeMsg q msg $> ok) - when (sent == OK) . time "SEND ok" $ do - when (notification msgFlags) $ - atomically . trySendNotification msg =<< asks idsDrg - stats <- asks serverStats - atomically $ modifyTVar (msgSent stats) (+ 1) - atomically $ updatePeriodStats (activeQueues stats) (recipientId qr) - pure resp + atomically . writeMsg q =<< mkMessage body + case msg_ of + Nothing -> pure $ err QUOTA + Just msg -> time "SEND ok" $ do + when (notification msgFlags) $ + atomically . trySendNotification msg =<< asks idsDrg + stats <- asks serverStats + atomically $ modifyTVar (msgSent stats) (+ 1) + atomically $ updatePeriodStats (activeQueues stats) (recipientId qr) + pure ok where mkMessage :: C.MaxLenBS MaxMessageLen -> m Message mkMessage body = do @@ -572,12 +574,14 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv writeNtf :: NotifierId -> Message -> RcvNtfDhSecret -> TVar ChaChaDRG -> Client -> STM () writeNtf nId msg rcvNtfDhSecret ntfNonceDrg Client {sndQ = q} = - unlessM (isFullTBQueue q) $ do - (nmsgNonce, encNMsgMeta) <- mkMessageNotification msg rcvNtfDhSecret ntfNonceDrg - writeTBQueue q [(CorrId "", nId, NMSG nmsgNonce encNMsgMeta)] + unlessM (isFullTBQueue q) $ case msg of + Message {msgId, msgTs} -> do + (nmsgNonce, encNMsgMeta) <- mkMessageNotification msgId msgTs rcvNtfDhSecret ntfNonceDrg + writeTBQueue q [(CorrId "", nId, NMSG nmsgNonce encNMsgMeta)] + _ -> pure () - mkMessageNotification :: Message -> RcvNtfDhSecret -> TVar ChaChaDRG -> STM (C.CbNonce, EncNMsgMeta) - mkMessageNotification Message {msgId, msgTs} rcvNtfDhSecret ntfNonceDrg = do + mkMessageNotification :: ByteString -> SystemTime -> RcvNtfDhSecret -> TVar ChaChaDRG -> STM (C.CbNonce, EncNMsgMeta) + mkMessageNotification msgId msgTs rcvNtfDhSecret ntfNonceDrg = do cbNonce <- C.pseudoRandomCbNonce ntfNonceDrg let msgMeta = NMsgMeta {msgId, msgTs} encNMsgMeta = C.cbEncrypt rcvNtfDhSecret cbNonce (smpEncode msgMeta) 128 @@ -615,17 +619,22 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv time name = timed name queueId encryptMsg :: QueueRec -> Message -> RcvMessage - encryptMsg qr Message {msgId, msgTs, msgFlags, msgBody} - | thVersion == 1 || thVersion == 2 = encrypt msgBody - | otherwise = encrypt $ encodeRcvMsgBody RcvMsgBody {msgTs, msgFlags, msgBody} + encryptMsg qr msg = case msg of + Message {msgFlags, msgBody} + | thVersion == 1 || thVersion == 2 -> encrypt msgFlags msgBody + | otherwise -> encrypt msgFlags $ encodeRcvMsgBody RcvMsgBody {msgTs = msgTs', msgFlags, msgBody} + MessageQuota {} -> + encrypt noMsgFlags $ encodeRcvMsgBody (RcvMsgQuota msgTs') where - encrypt :: KnownNat i => C.MaxLenBS i -> RcvMessage - encrypt body = - let encBody = EncRcvMsgBody $ C.cbEncryptMaxLenBS (rcvDhSecret qr) (C.cbNonce msgId) body - in RcvMessage msgId msgTs msgFlags encBody + encrypt :: KnownNat i => MsgFlags -> C.MaxLenBS i -> RcvMessage + encrypt msgFlags body = + let encBody = EncRcvMsgBody $ C.cbEncryptMaxLenBS (rcvDhSecret qr) (C.cbNonce msgId') body + in RcvMessage msgId' msgTs' msgFlags encBody + msgId' = msgId (msg :: Message) + msgTs' = msgTs (msg :: Message) setDelivered :: Sub -> Message -> STM Bool - setDelivered s Message {msgId} = tryPutTMVar (delivered s) msgId + setDelivered s msg = tryPutTMVar (delivered s) $ msgId (msg :: Message) getStoreMsgQueue :: T.Text -> RecipientId -> m MsgQueue getStoreMsgQueue name rId = time (name <> " getMsgQueue") $ do @@ -717,7 +726,7 @@ restoreServerMessages = asks (storeMsgsFile . config) >>= mapM_ restoreMessages addToMsgQueue rId msg = do full <- atomically $ do q <- getMsgQueue ms rId quota - ifM (isFull q) (pure True) (writeMsg q msg $> False) + isNothing <$> writeMsg q msg when full . logError . decodeLatin1 $ "message queue " <> strEncode rId <> " is full, message not restored: " <> strEncode (msgId (msg :: Message)) updateMsgV1toV3 QueueRec {rcvDhSecret} RcvMessage {msgId, msgTs, msgFlags, msgBody = EncRcvMsgBody body} = do let nonce = C.cbNonce msgId diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index 61458fdc8..644136571 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -39,7 +39,7 @@ data ServerConfig = ServerConfig { transports :: [(ServiceName, ATransport)], tbqSize :: Natural, serverTbqSize :: Natural, - msgQueueQuota :: Natural, + msgQueueQuota :: Int, queueIdBytes :: Int, msgIdBytes :: Int, storeLogFile :: Maybe FilePath, diff --git a/src/Simplex/Messaging/Server/MsgStore.hs b/src/Simplex/Messaging/Server/MsgStore.hs index 476a03f3f..565c89e1e 100644 --- a/src/Simplex/Messaging/Server/MsgStore.hs +++ b/src/Simplex/Messaging/Server/MsgStore.hs @@ -19,13 +19,12 @@ instance StrEncoding MsgLogRecord where strP = "v3 " *> (MLRv3 <$> strP_ <*> strP) <|> MLRv1 <$> strP_ <*> strP class MonadMsgStore s q m | s -> q where - getMsgQueue :: s -> RecipientId -> Natural -> m q + getMsgQueue :: s -> RecipientId -> Int -> m q delMsgQueue :: s -> RecipientId -> m () flushMsgQueue :: s -> RecipientId -> m [Message] class MonadMsgQueue q m where - isFull :: q -> m Bool - writeMsg :: q -> Message -> m () -- non blocking + writeMsg :: q -> Message -> m (Maybe Message) -- non blocking tryPeekMsg :: q -> m (Maybe Message) -- non blocking peekMsg :: q -> m Message -- blocking tryDelMsg :: q -> MsgId -> m Bool -- non blocking diff --git a/src/Simplex/Messaging/Server/MsgStore/STM.hs b/src/Simplex/Messaging/Server/MsgStore/STM.hs index 5905a1789..4e27e599f 100644 --- a/src/Simplex/Messaging/Server/MsgStore/STM.hs +++ b/src/Simplex/Messaging/Server/MsgStore/STM.hs @@ -7,22 +7,31 @@ {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE TupleSections #-} -module Simplex.Messaging.Server.MsgStore.STM where +module Simplex.Messaging.Server.MsgStore.STM + ( STMMsgStore, + MsgQueue, + newMsgStore, + ) +where -import Control.Concurrent.STM.TBQueue (flushTBQueue) +import Control.Concurrent.STM.TQueue (flushTQueue) import Control.Monad (when) import qualified Data.ByteString.Char8 as B import Data.Functor (($>)) import Data.Int (Int64) import Data.Time.Clock.System (SystemTime (systemSeconds)) -import Numeric.Natural import Simplex.Messaging.Protocol (Message (..), MsgId, RecipientId) import Simplex.Messaging.Server.MsgStore import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import UnliftIO.STM -newtype MsgQueue = MsgQueue {msgQueue :: TBQueue Message} +data MsgQueue = MsgQueue + { msgQueue :: TQueue Message, + quota :: Int, + canWrite :: TVar Bool, + size :: TVar Int + } type STMMsgStore = TMap RecipientId MsgQueue @@ -30,54 +39,77 @@ newMsgStore :: STM STMMsgStore newMsgStore = TM.empty instance MonadMsgStore STMMsgStore MsgQueue STM where - getMsgQueue :: STMMsgStore -> RecipientId -> Natural -> STM MsgQueue + getMsgQueue :: STMMsgStore -> RecipientId -> Int -> STM MsgQueue getMsgQueue st rId quota = maybe newQ pure =<< TM.lookup rId st where newQ = do - q <- MsgQueue <$> newTBQueue quota + msgQueue <- newTQueue + canWrite <- newTVar True + size <- newTVar 0 + let q = MsgQueue {msgQueue, quota, canWrite, size} TM.insert rId q st - return q + pure q delMsgQueue :: STMMsgStore -> RecipientId -> STM () delMsgQueue st rId = TM.delete rId st flushMsgQueue :: STMMsgStore -> RecipientId -> STM [Message] - flushMsgQueue st rId = TM.lookup rId st >>= maybe (pure []) (flushTBQueue . msgQueue) + flushMsgQueue st rId = TM.lookupDelete rId st >>= maybe (pure []) (flushTQueue . msgQueue) instance MonadMsgQueue MsgQueue STM where - isFull :: MsgQueue -> STM Bool - isFull = isFullTBQueue . msgQueue - - writeMsg :: MsgQueue -> Message -> STM () - writeMsg = writeTBQueue . msgQueue + writeMsg :: MsgQueue -> Message -> STM (Maybe Message) + writeMsg MsgQueue {msgQueue = q, quota, canWrite, size} msg = do + canWrt <- readTVar canWrite + empty <- isEmptyTQueue q + if canWrt || empty + then do + canWrt' <- (quota >) <$> readTVar size + writeTVar canWrite canWrt' + modifyTVar' size (+ 1) + if canWrt' + then writeTQueue q msg $> Just msg + else writeTQueue q msgQuota $> Nothing + else pure Nothing + where + msgQuota = MessageQuota {msgId = msgId msg, msgTs = msgTs msg} tryPeekMsg :: MsgQueue -> STM (Maybe Message) - tryPeekMsg = tryPeekTBQueue . msgQueue + tryPeekMsg = tryPeekTQueue . msgQueue + {-# INLINE tryPeekMsg #-} peekMsg :: MsgQueue -> STM Message - peekMsg = peekTBQueue . msgQueue + peekMsg = peekTQueue . msgQueue + {-# INLINE peekMsg #-} tryDelMsg :: MsgQueue -> MsgId -> STM Bool - tryDelMsg (MsgQueue q) msgId' = - tryPeekTBQueue q >>= \case - Just Message {msgId} - | msgId == msgId' || B.null msgId' -> tryReadTBQueue q $> True + tryDelMsg mq msgId' = + tryPeekMsg mq >>= \case + Just msg + | msgId msg == msgId' || B.null msgId' -> tryDeleteMsg mq >> pure True | otherwise -> pure False _ -> pure False -- atomic delete (== read) last and peek next message if available tryDelPeekMsg :: MsgQueue -> MsgId -> STM (Bool, Maybe Message) - tryDelPeekMsg (MsgQueue q) msgId' = - tryPeekTBQueue q >>= \case - msg_@(Just Message {msgId}) - | msgId == msgId' || B.null msgId' -> (True,) <$> (tryReadTBQueue q >> tryPeekTBQueue q) + tryDelPeekMsg mq msgId' = + tryPeekMsg mq >>= \case + msg_@(Just msg) + | msgId msg == msgId' || B.null msgId' -> (True,) <$> (tryDeleteMsg mq >> tryPeekMsg mq) | otherwise -> pure (False, msg_) _ -> pure (False, Nothing) deleteExpiredMsgs :: MsgQueue -> Int64 -> STM () - deleteExpiredMsgs (MsgQueue q) old = loop + deleteExpiredMsgs mq old = loop where - loop = tryPeekTBQueue q >>= mapM_ delOldMsg - delOldMsg Message {msgTs} = - when (systemSeconds msgTs < old) $ - tryReadTBQueue q >> loop + loop = tryPeekMsg mq >>= mapM_ delOldMsg + delOldMsg = \case + Message {msgTs} -> + when (systemSeconds msgTs < old) $ + tryDeleteMsg mq >> loop + _ -> pure () + +tryDeleteMsg :: MsgQueue -> STM () +tryDeleteMsg MsgQueue {msgQueue = q, size} = + tryReadTQueue q >>= \case + Just _ -> modifyTVar' size (subtract 1) + _ -> pure () diff --git a/tests/AgentTests.hs b/tests/AgentTests.hs index 1568eda57..bdc3a28d0 100644 --- a/tests/AgentTests.hs +++ b/tests/AgentTests.hs @@ -78,6 +78,8 @@ agentTests (ATransport t) = do smpAgentTest2_2_1 $ testConcurrentMsgDelivery t it "should deliver messages if one of connections has quota exceeded" $ smpAgentTest2_2_1 $ testMsgDeliveryQuotaExceeded t + it "should resume delivering messages after exceeding quota once all messages are received" $ + smpAgentTest2_2_1 $ testResumeDeliveryQuotaExceeded t tGetAgent :: Transport c => c -> IO (ATransmissionOrError 'Agent) tGetAgent h = do @@ -430,6 +432,32 @@ testMsgDeliveryQuotaExceeded _ alice bob = do -- if delivery is blocked it won't go further alice <# ("", "bob2", SENT 4) +testResumeDeliveryQuotaExceeded :: Transport c => TProxy c -> c -> c -> IO () +testResumeDeliveryQuotaExceeded _ alice bob = do + connect (alice, "alice") (bob, "bob") + forM_ [1 .. 4 :: Int] $ \i -> do + let corrId = bshow i + msg = "message " <> bshow i + (_, "bob", Right (MID mId)) <- alice #: (corrId, "bob", "SEND F :" <> msg) + alice <#= \case ("", "bob", SENT m) -> m == mId; _ -> False + ("5", "bob", Right (MID 8)) <- alice #: ("5", "bob", "SEND F :over quota") + alice #:# "the last message not sent yet" + bob <#= \case ("", "alice", Msg "message 1") -> True; _ -> False + bob #: ("1", "alice", "ACK 4") #> ("1", "alice", OK) + alice #:# "the last message not sent" + bob <#= \case ("", "alice", Msg "message 2") -> True; _ -> False + bob #: ("2", "alice", "ACK 5") #> ("2", "alice", OK) + alice #:# "the last message not sent" + bob <#= \case ("", "alice", Msg "message 3") -> True; _ -> False + bob #: ("3", "alice", "ACK 6") #> ("3", "alice", OK) + alice #:# "the last message not sent" + bob <#= \case ("", "alice", Msg "message 4") -> True; _ -> False + bob #: ("4", "alice", "ACK 7") #> ("4", "alice", OK) + alice <# ("", "bob", SENT 8) + bob <#= \case ("", "alice", Msg "over quota") -> True; _ -> False + -- message 8 is skipped because of alice agent sending "QCONT" message + bob #: ("5", "alice", "ACK 9") #> ("5", "alice", OK) + connect :: forall c. Transport c => (c, ByteString) -> (c, ByteString) -> IO () connect (h1, name1) (h2, name2) = do ("c1", _, Right (INV cReq)) <- h1 #: ("c1", name2, "NEW T INV") diff --git a/tests/CoreTests/RetryIntervalTests.hs b/tests/CoreTests/RetryIntervalTests.hs new file mode 100644 index 000000000..5495e2a3a --- /dev/null +++ b/tests/CoreTests/RetryIntervalTests.hs @@ -0,0 +1,61 @@ +{-# LANGUAGE ScopedTypeVariables #-} + +module CoreTests.RetryIntervalTests where + +import Control.Concurrent.STM +import Control.Monad (when) +import Data.Time.Clock (UTCTime, diffUTCTime, getCurrentTime, nominalDiffTimeToSeconds) +import Simplex.Messaging.Agent.RetryInterval +import Test.Hspec + +retryIntervalTests :: Spec +retryIntervalTests = do + describe "Retry interval with 2 modes and lock" $ do + testRetryIntervalSameMode + testRetryIntervalSwitchMode + +testRI :: RetryInterval2 +testRI = + RetryInterval2 + { riSlow = + RetryInterval + { initialInterval = 20000, + increaseAfter = 40000, + maxInterval = 40000 + }, + riFast = + RetryInterval + { initialInterval = 10000, + increaseAfter = 20000, + maxInterval = 40000 + } + } + +testRetryIntervalSameMode :: Spec +testRetryIntervalSameMode = + it "should increase elapased time and interval when the mode stays the same" $ do + lock <- newEmptyTMVarIO + intervals <- newTVarIO [] + ts <- newTVarIO =<< getCurrentTime + withRetryLock2 testRI lock $ \loop -> do + ints <- addInterval intervals ts + when (length ints < 9) $ loop RIFast + (reverse <$> readTVarIO intervals) `shouldReturn` [0, 1, 1, 1, 2, 3, 4, 4, 4] + +testRetryIntervalSwitchMode :: Spec +testRetryIntervalSwitchMode = + it "should increase elapased time and interval when the mode stays the same" $ do + lock <- newEmptyTMVarIO + intervals <- newTVarIO [] + ts <- newTVarIO =<< getCurrentTime + withRetryLock2 testRI lock $ \loop -> do + ints <- addInterval intervals ts + when (length ints < 11) $ loop $ if length ints <= 5 then RIFast else RISlow + (reverse <$> readTVarIO intervals) `shouldReturn` [0, 1, 1, 1, 2, 3, 2, 2, 3, 4, 4] + +addInterval :: TVar [Int] -> TVar UTCTime -> IO [Int] +addInterval intervals ts = do + ts' <- getCurrentTime + atomically $ do + int :: Int <- truncate . (* 100) . nominalDiffTimeToSeconds <$> stateTVar ts (\t -> (diffUTCTime ts' t, ts')) + stateTVar intervals $ \ints -> (int : ints, int : ints) diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index 5dd98fe14..561da754e 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -49,6 +49,7 @@ serverTests t@(ATransport t') = do describe "switch subscription to another TCP connection" $ testSwitchSub t describe "GET command" $ testGetCommand t' describe "GET & SUB commands" $ testGetSubCommands t' + describe "Exceeding queue quota" $ testExceedQueueQuota t' describe "Store log" $ testWithStoreLog t describe "Restore messages" $ testRestoreMessages t describe "Restore messages (v2)" $ testRestoreMessagesV2 t @@ -104,9 +105,11 @@ decryptMsgV2 :: C.DhSecret 'C.X25519 -> ByteString -> ByteString -> Either C.Cry decryptMsgV2 dhShared = C.cbDecrypt dhShared . C.cbNonce decryptMsgV3 :: C.DhSecret 'C.X25519 -> ByteString -> ByteString -> Either String MsgBody -decryptMsgV3 dhShared nonce body = do - ClientRcvMsgBody {msgBody} <- parseAll clientRcvMsgBodyP =<< first show (C.cbDecrypt dhShared (C.cbNonce nonce) body) - pure msgBody +decryptMsgV3 dhShared nonce body = + case parseAll clientRcvMsgBodyP =<< first show (C.cbDecrypt dhShared (C.cbNonce nonce) body) of + Right ClientRcvMsgBody {msgBody} -> Right msgBody + Right ClientRcvMsgQuota {} -> Left "ClientRcvMsgQuota" + Left e -> Left e testCreateSecureV2 :: forall c. Transport c => TProxy c -> Spec testCreateSecureV2 _ = @@ -494,6 +497,32 @@ testGetSubCommands t = Resp "12" _ OK <- signSendRecv rh2 rKey ("12", rId, GET) pure () +testExceedQueueQuota :: forall c. Transport c => TProxy c -> Spec +testExceedQueueQuota t = + it "should reply with ERR QUOTA to sender and send QUOTA message to the recipient" $ do + withSmpServerConfigOn (ATransport t) cfg {msgQueueQuota = 2} testPort $ \_ -> + testSMPClient @c $ \sh -> testSMPClient @c $ \rh -> do + (sPub, sKey) <- C.generateSignatureKeyPair C.SEd25519 + (sId, rId, rKey, dhShared) <- createAndSecureQueue rh sPub + let dec = decryptMsgV3 dhShared + Resp "1" _ OK <- signSendRecv sh sKey ("1", sId, _SEND "hello 1") + Resp "2" _ OK <- signSendRecv sh sKey ("2", sId, _SEND "hello 2") + Resp "3" _ (ERR QUOTA) <- signSendRecv sh sKey ("3", sId, _SEND "hello 3") + Resp "" _ (Msg mId1 msg1) <- tGet1 rh + (dec mId1 msg1, Right "hello 1") #== "hello 1" + Resp "4" _ (Msg mId2 msg2) <- signSendRecv rh rKey ("4", rId, ACK mId1) + (dec mId2 msg2, Right "hello 2") #== "hello 2" + Resp "5" _ (ERR QUOTA) <- signSendRecv sh sKey ("5", sId, _SEND "hello 3") + Resp "6" _ (Msg mId3 msg3) <- signSendRecv rh rKey ("6", rId, ACK mId2) + (dec mId3 msg3, Left "ClientRcvMsgQuota") #== "ClientRcvMsgQuota" + Resp "7" _ (ERR QUOTA) <- signSendRecv sh sKey ("7", sId, _SEND "hello 3") + Resp "8" _ OK <- signSendRecv rh rKey ("8", rId, ACK mId3) + Resp "9" _ OK <- signSendRecv sh sKey ("9", sId, _SEND "hello 3") + Resp "" _ (Msg mId4 msg4) <- tGet1 rh + (dec mId4 msg4, Right "hello 3") #== "hello 3" + Resp "10" _ OK <- signSendRecv rh rKey ("10", rId, ACK mId4) + pure () + testWithStoreLog :: ATransport -> Spec testWithStoreLog at@(ATransport t) = it "should store simplex queues to log and restore them after server restart" $ do diff --git a/tests/Test.hs b/tests/Test.hs index fbf8a4af6..d074d9c86 100644 --- a/tests/Test.hs +++ b/tests/Test.hs @@ -6,6 +6,7 @@ import CLITests import CoreTests.CryptoTests import CoreTests.EncodingTests import CoreTests.ProtocolErrorTests +import CoreTests.RetryIntervalTests import CoreTests.VersionRangeTests import NtfServerTests (ntfServerTests) import ServerTests @@ -31,6 +32,7 @@ main = do describe "Protocol error tests" protocolErrorTests describe "Version range" versionRangeTests describe "Encryption tests" cryptoTests + describe "Retry interval tests" retryIntervalTests describe "SMP server via TLS" $ serverTests (transport @TLS) describe "SMP server via WebSockets" $ serverTests (transport @WS) describe "Notifications server" $ ntfServerTests (transport @TLS)