send/process "quota exceeded" message from SMP server when sender gets ERR QUOTA (#585)

* send "quota exceeded" message from SMP server when sender gets ERR QUOTA (ignored in the agent for now)

* send msg quota to the recipient to indicate that sender got ERR QUOTA, test

* switch between slow/fast retry intervals (tests do not pass yet)

* send QCONT message, refactor RetryInterval, test

* refactor

* remove comment

* remove space

* unit test for withRetryLock2

* refactor
This commit is contained in:
Evgeny Poberezkin
2023-01-04 14:10:13 +00:00
committed by GitHub
parent 470b512a88
commit 058e3ac55e
15 changed files with 490 additions and 205 deletions
+1
View File
@@ -357,6 +357,7 @@ test-suite smp-server-test
CoreTests.CryptoTests
CoreTests.EncodingTests
CoreTests.ProtocolErrorTests
CoreTests.RetryIntervalTests
CoreTests.VersionRangeTests
NtfClient
NtfServerTests
+113 -93
View File
@@ -791,7 +791,7 @@ runCommandProcessing c@AgentClient {subQ} server_ = do
atomically $ beginAgentOperation c AOSndNetwork
E.try (withStore c $ \db -> getPendingCommand db cmdId) >>= \case
Left (e :: E.SomeException) -> atomically $ writeTBQueue subQ ("", "", ERR . INTERNAL $ show e)
Right (corrId, connId, cmd) -> processCmd ri corrId connId cmdId cmd
Right (corrId, connId, cmd) -> processCmd (riFast ri) corrId connId cmdId cmd
where
processCmd :: RetryInterval -> ACorrId -> ConnId -> AsyncCmdId -> AgentCommand -> m ()
processCmd ri corrId connId cmdId command = case command of
@@ -964,22 +964,22 @@ queuePendingMsgs c sq msgIds = atomically $ do
modifyTVar' (msgDeliveryOp c) $ \s -> s {opsInProgress = opsInProgress s + length msgIds}
-- s <- readTVar (msgDeliveryOp c)
-- unsafeIOToSTM $ putStrLn $ "msgDeliveryOp: " <> show (opsInProgress s)
q <- getPendingMsgQ c sq
mapM_ (writeTQueue q) msgIds
(mq, _) <- getPendingMsgQ c sq
mapM_ (writeTQueue mq) msgIds
getPendingMsgQ :: AgentClient -> SndQueue -> STM (TQueue InternalId)
getPendingMsgQ :: AgentClient -> SndQueue -> STM (TQueue InternalId, TMVar ())
getPendingMsgQ c SndQueue {server, sndId} = do
let qKey = (server, sndId)
maybe (newMsgQueue qKey) pure =<< TM.lookup qKey (smpQueueMsgQueues c)
where
newMsgQueue qKey = do
mq <- newTQueue
TM.insert qKey mq $ smpQueueMsgQueues c
pure mq
q <- (,) <$> newTQueue <*> newEmptyTMVar
TM.insert qKey q $ smpQueueMsgQueues c
pure q
runSmpQueueMsgDelivery :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> m ()
runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandshake} sq = do
mq <- atomically $ getPendingMsgQ c sq
(mq, qLock) <- atomically $ getPendingMsgQ c sq
ri <- asks $ messageRetryInterval . config
forever $ do
atomically $ endAgentOperation c AOSndNetwork
@@ -993,7 +993,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
Left (e :: E.SomeException) ->
notify $ MERR mId (INTERNAL $ show e)
Right (rq_, PendingMsgData {msgType, msgBody, msgFlags, internalTs}) ->
withRetryInterval ri $ \loop -> do
withRetryLock2 ri qLock $ \loop -> do
resp <- tryError $ case msgType of
AM_CONN_INFO -> sendConfirmation c sq msgBody
_ -> sendAgentMessage c sq msgFlags msgBody
@@ -1004,7 +1004,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
SMP SMP.QUOTA -> case msgType of
AM_CONN_INFO -> connError msgId NOT_AVAILABLE
AM_CONN_INFO_REPLY -> connError msgId NOT_AVAILABLE
_ -> retrySndOp c loop
_ -> retrySndOp c $ loop RISlow
SMP SMP.AUTH -> case msgType of
AM_CONN_INFO -> connError msgId NOT_AVAILABLE
AM_CONN_INFO_REPLY -> connError msgId NOT_AVAILABLE
@@ -1013,7 +1013,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
-- because the queue must be secured by the time the confirmation or the first HELLO is received
| duplexHandshake == Just True -> connErr
| otherwise ->
ifM (msgExpired helloTimeout) connErr (retrySndOp c loop)
ifM (msgExpired helloTimeout) connErr (retrySndOp c $ loop RIFast)
where
connErr = case rq_ of
-- party initiating connection
@@ -1022,6 +1022,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
_ -> connError msgId NOT_ACCEPTED
AM_REPLY_ -> notifyDel msgId err
AM_A_MSG_ -> notifyDel msgId err
AM_QCONT_ -> notifyDel msgId err
AM_QADD_ -> qError msgId "QADD: AUTH"
AM_QKEY_ -> qError msgId "QKEY: AUTH"
AM_QUSE_ -> qError msgId "QUSE: AUTH"
@@ -1031,7 +1032,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
-- the message sending would be retried
| temporaryOrHostError e -> do
let timeoutSel = if msgType == AM_HELLO_ then helloTimeout else messageTimeout
ifM (msgExpired timeoutSel) (notifyDel msgId err) (retrySndOp c loop)
ifM (msgExpired timeoutSel) (notifyDel msgId err) (retrySndOp c $ loop RIFast)
| otherwise -> notifyDel msgId err
where
msgExpired timeoutSel = do
@@ -1071,6 +1072,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
qInfo <- createReplyQueue c cData sq srv
void . enqueueMessage c cData sq SMP.noMsgFlags $ REPLY [qInfo]
AM_A_MSG_ -> notify $ SENT mId
AM_QCONT_ -> pure ()
AM_QADD_ -> pure ()
AM_QKEY_ -> pure ()
AM_QUSE_ -> pure ()
@@ -1492,88 +1494,96 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
processSMP :: RcvQueue -> Connection c -> ConnData -> m ()
processSMP rq@RcvQueue {e2ePrivKey, e2eDhSecret, status} conn cData@ConnData {connId, duplexHandshake} = withConnLock c connId "processSMP" $
case cmd of
SMP.MSG msg@SMP.RcvMessage {msgId = srvMsgId} -> handleNotifyAck $ do
SMP.ClientRcvMsgBody {msgTs = srvTs, msgFlags, msgBody} <- decryptSMPMessage v rq msg
clientMsg@SMP.ClientMsgEnvelope {cmHeader = SMP.PubHeader phVer e2ePubKey_} <-
parseMessage msgBody
clientVRange <- asks $ smpClientVRange . config
unless (phVer `isCompatible` clientVRange) . throwError $ AGENT A_VERSION
case (e2eDhSecret, e2ePubKey_) of
(Nothing, Just e2ePubKey) -> do
let e2eDh = C.dh' e2ePubKey e2ePrivKey
decryptClientMessage e2eDh clientMsg >>= \case
(SMP.PHConfirmation senderKey, AgentConfirmation {e2eEncryption, encConnInfo, agentVersion}) ->
smpConfirmation senderKey e2ePubKey e2eEncryption encConnInfo phVer agentVersion >> ack
(SMP.PHEmpty, AgentInvitation {connReq, connInfo}) ->
smpInvitation connReq connInfo >> ack
_ -> prohibited >> ack
(Just e2eDh, Nothing) -> do
decryptClientMessage e2eDh clientMsg >>= \case
(SMP.PHEmpty, AgentMsgEnvelope _ encAgentMsg) -> do
-- primary queue is set as Active in helloMsg, below is to set additional queues Active
let RcvQueue {primary, dbReplaceQueueId} = rq
unless (status == Active) . withStore' c $ \db -> setRcvQueueStatus db rq Active
case (conn, dbReplaceQueueId) of
(DuplexConnection _ rqs _, Just replacedId) -> do
when primary . withStore' c $ \db -> setRcvQueuePrimary db connId rq
case find (\RcvQueue {dbQueueId} -> dbQueueId == replacedId) rqs of
Just RcvQueue {server, rcvId} -> do
enqueueCommand c "" connId (Just server) $ AInternalCommand $ ICQDelete rcvId
_ -> notify . ERR . AGENT $ A_QUEUE "replaced RcvQueue not found in connection"
_ -> pure ()
tryError agentClientMsg >>= \case
Right (Just (msgId, msgMeta, aMessage)) -> case aMessage of
HELLO -> helloMsg >> ackDel msgId
REPLY cReq -> replyMsg cReq >> ackDel msgId
-- note that there is no ACK sent for A_MSG, it is sent with agent's user ACK command
A_MSG body -> do
logServer "<--" c srv rId "MSG <MSG>"
notify $ MSG msgMeta msgFlags body
QADD qs -> qDuplex "QADD" $ qAddMsg qs
QKEY qs -> qDuplex "QKEY" $ qKeyMsg qs
QUSE qs -> qDuplex "QUSE" $ qUseMsg qs
-- no action needed for QTEST
-- any message in the new queue will mark it active and trigger deletion of the old queue
QTEST _ -> logServer "<--" c srv rId "MSG <QTEST>" >> ackDel msgId
where
qDuplex :: String -> (Connection 'CDuplex -> m ()) -> m ()
qDuplex name a = case conn of
DuplexConnection {} -> a conn >> ackDel msgId
_ -> qError $ name <> ": message must be sent to duplex connection"
Right _ -> prohibited >> ack
Left e@(AGENT A_DUPLICATE) -> do
withStore' c (\db -> getLastMsg db connId srvMsgId) >>= \case
Just RcvMsg {internalId, msgMeta, msgBody = agentMsgBody, userAck}
| userAck -> ackDel internalId
| otherwise -> do
liftEither (parse smpP (AGENT A_MESSAGE) agentMsgBody) >>= \case
AgentMessage _ (A_MSG body) -> do
logServer "<--" c srv rId "MSG <MSG>"
notify $ MSG msgMeta msgFlags body
_ -> pure ()
_ -> throwError e
Left e -> throwError e
where
agentClientMsg :: m (Maybe (InternalId, MsgMeta, AMessage))
agentClientMsg = withStore c $ \db -> runExceptT $ do
agentMsgBody <- agentRatchetDecrypt db connId encAgentMsg
liftEither (parse smpP (SEAgentError $ AGENT A_MESSAGE) agentMsgBody) >>= \case
agentMsg@(AgentMessage APrivHeader {sndMsgId, prevMsgHash} aMessage) -> do
let msgType = agentMessageType agentMsg
internalHash = C.sha256Hash agentMsgBody
internalTs <- liftIO getCurrentTime
(internalId, internalRcvId, prevExtSndId, prevRcvMsgHash) <- liftIO $ updateRcvIds db connId
let integrity = checkMsgIntegrity prevExtSndId sndMsgId prevRcvMsgHash prevMsgHash
recipient = (unId internalId, internalTs)
broker = (srvMsgId, systemToUTCTime srvTs)
msgMeta = MsgMeta {integrity, recipient, broker, sndMsgId}
rcvMsg = RcvMsgData {msgMeta, msgType, msgFlags, msgBody = agentMsgBody, internalRcvId, internalHash, externalPrevSndHash = prevMsgHash}
liftIO $ createRcvMsg db connId rq rcvMsg
pure $ Just (internalId, msgMeta, aMessage)
_ -> pure Nothing
_ -> prohibited >> ack
_ -> prohibited >> ack
SMP.MSG msg@SMP.RcvMessage {msgId = srvMsgId} ->
handleNotifyAck $
decryptSMPMessage v rq msg >>= \case
SMP.ClientRcvMsgBody {msgTs = srvTs, msgFlags, msgBody} -> processClientMsg srvTs msgFlags msgBody
SMP.ClientRcvMsgQuota {} -> queueDrained >> ack
where
queueDrained = case conn of
DuplexConnection _ _ sqs -> void $ enqueueMessages c cData sqs SMP.noMsgFlags $ QCONT (sndAddress rq)
_ -> pure ()
processClientMsg srvTs msgFlags msgBody = do
clientMsg@SMP.ClientMsgEnvelope {cmHeader = SMP.PubHeader phVer e2ePubKey_} <-
parseMessage msgBody
clientVRange <- asks $ smpClientVRange . config
unless (phVer `isCompatible` clientVRange) . throwError $ AGENT A_VERSION
case (e2eDhSecret, e2ePubKey_) of
(Nothing, Just e2ePubKey) -> do
let e2eDh = C.dh' e2ePubKey e2ePrivKey
decryptClientMessage e2eDh clientMsg >>= \case
(SMP.PHConfirmation senderKey, AgentConfirmation {e2eEncryption, encConnInfo, agentVersion}) ->
smpConfirmation senderKey e2ePubKey e2eEncryption encConnInfo phVer agentVersion >> ack
(SMP.PHEmpty, AgentInvitation {connReq, connInfo}) ->
smpInvitation connReq connInfo >> ack
_ -> prohibited >> ack
(Just e2eDh, Nothing) -> do
decryptClientMessage e2eDh clientMsg >>= \case
(SMP.PHEmpty, AgentMsgEnvelope _ encAgentMsg) -> do
-- primary queue is set as Active in helloMsg, below is to set additional queues Active
let RcvQueue {primary, dbReplaceQueueId} = rq
unless (status == Active) . withStore' c $ \db -> setRcvQueueStatus db rq Active
case (conn, dbReplaceQueueId) of
(DuplexConnection _ rqs _, Just replacedId) -> do
when primary . withStore' c $ \db -> setRcvQueuePrimary db connId rq
case find (\RcvQueue {dbQueueId} -> dbQueueId == replacedId) rqs of
Just RcvQueue {server, rcvId} -> do
enqueueCommand c "" connId (Just server) $ AInternalCommand $ ICQDelete rcvId
_ -> notify . ERR . AGENT $ A_QUEUE "replaced RcvQueue not found in connection"
_ -> pure ()
tryError agentClientMsg >>= \case
Right (Just (msgId, msgMeta, aMessage)) -> case aMessage of
HELLO -> helloMsg >> ackDel msgId
REPLY cReq -> replyMsg cReq >> ackDel msgId
-- note that there is no ACK sent for A_MSG, it is sent with agent's user ACK command
A_MSG body -> do
logServer "<--" c srv rId "MSG <MSG>"
notify $ MSG msgMeta msgFlags body
QCONT addr -> qDuplex "QCONT" $ continueSending addr
QADD qs -> qDuplex "QADD" $ qAddMsg qs
QKEY qs -> qDuplex "QKEY" $ qKeyMsg qs
QUSE qs -> qDuplex "QUSE" $ qUseMsg qs
-- no action needed for QTEST
-- any message in the new queue will mark it active and trigger deletion of the old queue
QTEST _ -> logServer "<--" c srv rId "MSG <QTEST>" >> ackDel msgId
where
qDuplex :: String -> (Connection 'CDuplex -> m ()) -> m ()
qDuplex name a = case conn of
DuplexConnection {} -> a conn >> ackDel msgId
_ -> qError $ name <> ": message must be sent to duplex connection"
Right _ -> prohibited >> ack
Left e@(AGENT A_DUPLICATE) -> do
withStore' c (\db -> getLastMsg db connId srvMsgId) >>= \case
Just RcvMsg {internalId, msgMeta, msgBody = agentMsgBody, userAck}
| userAck -> ackDel internalId
| otherwise -> do
liftEither (parse smpP (AGENT A_MESSAGE) agentMsgBody) >>= \case
AgentMessage _ (A_MSG body) -> do
logServer "<--" c srv rId "MSG <MSG>"
notify $ MSG msgMeta msgFlags body
_ -> pure ()
_ -> throwError e
Left e -> throwError e
where
agentClientMsg :: m (Maybe (InternalId, MsgMeta, AMessage))
agentClientMsg = withStore c $ \db -> runExceptT $ do
agentMsgBody <- agentRatchetDecrypt db connId encAgentMsg
liftEither (parse smpP (SEAgentError $ AGENT A_MESSAGE) agentMsgBody) >>= \case
agentMsg@(AgentMessage APrivHeader {sndMsgId, prevMsgHash} aMessage) -> do
let msgType = agentMessageType agentMsg
internalHash = C.sha256Hash agentMsgBody
internalTs <- liftIO getCurrentTime
(internalId, internalRcvId, prevExtSndId, prevRcvMsgHash) <- liftIO $ updateRcvIds db connId
let integrity = checkMsgIntegrity prevExtSndId sndMsgId prevRcvMsgHash prevMsgHash
recipient = (unId internalId, internalTs)
broker = (srvMsgId, systemToUTCTime srvTs)
msgMeta = MsgMeta {integrity, recipient, broker, sndMsgId}
rcvMsg = RcvMsgData {msgMeta, msgType, msgFlags, msgBody = agentMsgBody, internalRcvId, internalHash, externalPrevSndHash = prevMsgHash}
liftIO $ createRcvMsg db connId rq rcvMsg
pure $ Just (internalId, msgMeta, aMessage)
_ -> pure Nothing
_ -> prohibited >> ack
_ -> prohibited >> ack
ack :: m ()
ack = enqueueCmd $ ICAck rId srvMsgId
ackDel :: InternalId -> m ()
@@ -1698,6 +1708,16 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
connectReplyQueues c cData ownConnInfo smpQueues `catchError` (notify . ERR)
_ -> prohibited
continueSending :: (SMPServer, SMP.SenderId) -> Connection 'CDuplex -> m ()
continueSending addr (DuplexConnection _ _ sqs) =
case findQ addr sqs of
Just sq -> do
logServer "<--" c srv rId "MSG <QCONT>"
atomically $ do
(_, qLock) <- getPendingMsgQ c sq
void $ tryPutTMVar qLock ()
Nothing -> qError "QCONT: queue address not found"
-- processed by queue sender
qAddMsg :: NonEmpty (SMPQueueUri, Maybe SndQAddr) -> Connection 'CDuplex -> m ()
qAddMsg ((_, Nothing) :| _) _ = qError "adding queue without switching is not supported"
+1 -1
View File
@@ -180,7 +180,7 @@ data AgentClient = AgentClient
activeSubs :: TRcvQueues,
pendingSubs :: TRcvQueues,
pendingMsgsQueued :: TMap SndQAddr Bool,
smpQueueMsgQueues :: TMap SndQAddr (TQueue InternalId),
smpQueueMsgQueues :: TMap SndQAddr (TQueue InternalId, TMVar ()),
smpQueueMsgDeliveries :: TMap SndQAddr (Async ()),
connCmdsQueued :: TMap ConnId Bool,
asyncCmdQueues :: TMap (Maybe SMPServer) (TQueue AsyncCmdId),
+18 -6
View File
@@ -83,7 +83,7 @@ data AgentConfig = AgentConfig
smpCfg :: ProtocolClientConfig,
ntfCfg :: ProtocolClientConfig,
reconnectInterval :: RetryInterval,
messageRetryInterval :: RetryInterval,
messageRetryInterval :: RetryInterval2,
messageTimeout :: NominalDiffTime,
helloTimeout :: NominalDiffTime,
ntfCron :: Word16,
@@ -108,12 +108,24 @@ defaultReconnectInterval =
maxInterval = 180_000000
}
defaultMessageRetryInterval :: RetryInterval
defaultMessageRetryInterval :: RetryInterval2
defaultMessageRetryInterval =
RetryInterval
{ initialInterval = 1_000000,
increaseAfter = 10_000000,
maxInterval = 60_000000
RetryInterval2
{ riFast =
RetryInterval
{ initialInterval = 1_000000,
increaseAfter = 10_000000,
maxInterval = 60_000000
},
riSlow =
-- TODO: these timeouts can be increased once most clients are updates
-- to resume sending on QCONT messages.
-- After that local message expiration period should be also increased.
RetryInterval
{ initialInterval = 10_000000,
increaseAfter = 30_000000,
maxInterval = 300_000000
}
}
defaultAgentConfig :: AgentConfig
+11
View File
@@ -577,6 +577,7 @@ data AgentMessageType
| AM_HELLO_
| AM_REPLY_
| AM_A_MSG_
| AM_QCONT_
| AM_QADD_
| AM_QKEY_
| AM_QUSE_
@@ -590,6 +591,7 @@ instance Encoding AgentMessageType where
AM_HELLO_ -> "H"
AM_REPLY_ -> "R"
AM_A_MSG_ -> "M"
AM_QCONT_ -> "QC"
AM_QADD_ -> "QA"
AM_QKEY_ -> "QK"
AM_QUSE_ -> "QU"
@@ -603,6 +605,7 @@ instance Encoding AgentMessageType where
'M' -> pure AM_A_MSG_
'Q' ->
A.anyChar >>= \case
'C' -> pure AM_QCONT_
'A' -> pure AM_QADD_
'K' -> pure AM_QKEY_
'U' -> pure AM_QUSE_
@@ -623,6 +626,7 @@ agentMessageType = \case
-- REPLY is only used in v1
REPLY _ -> AM_REPLY_
A_MSG _ -> AM_A_MSG_
QCONT _ -> AM_QCONT_
QADD _ -> AM_QADD_
QKEY _ -> AM_QKEY_
QUSE _ -> AM_QUSE_
@@ -645,6 +649,7 @@ data AMsgType
= HELLO_
| REPLY_
| A_MSG_
| QCONT_
| QADD_
| QKEY_
| QUSE_
@@ -656,6 +661,7 @@ instance Encoding AMsgType where
HELLO_ -> "H"
REPLY_ -> "R"
A_MSG_ -> "M"
QCONT_ -> "QC"
QADD_ -> "QA"
QKEY_ -> "QK"
QUSE_ -> "QU"
@@ -667,6 +673,7 @@ instance Encoding AMsgType where
'M' -> pure A_MSG_
'Q' ->
A.anyChar >>= \case
'C' -> pure QCONT_
'A' -> pure QADD_
'K' -> pure QKEY_
'U' -> pure QUSE_
@@ -684,6 +691,8 @@ data AMessage
REPLY (L.NonEmpty SMPQueueInfo)
| -- | agent envelope for the client message
A_MSG MsgBody
| -- | the message instructing the client to continue sending messages (after ERR QUOTA)
QCONT SndQAddr
| -- add queue to connection (sent by recipient), with optional address of the replaced queue
QADD (L.NonEmpty (SMPQueueUri, Maybe SndQAddr))
| -- key to secure the added queues and agree e2e encryption key (sent by sender)
@@ -701,6 +710,7 @@ instance Encoding AMessage where
HELLO -> smpEncode HELLO_
REPLY smpQueues -> smpEncode (REPLY_, smpQueues)
A_MSG body -> smpEncode (A_MSG_, Tail body)
QCONT addr -> smpEncode (QCONT_, addr)
QADD qs -> smpEncode (QADD_, qs)
QKEY qs -> smpEncode (QKEY_, qs)
QUSE qs -> smpEncode (QUSE_, qs)
@@ -711,6 +721,7 @@ instance Encoding AMessage where
HELLO_ -> pure HELLO
REPLY_ -> REPLY <$> smpP
A_MSG_ -> A_MSG . unTail <$> smpP
QCONT_ -> QCONT <$> smpP
QADD_ -> QADD <$> smpP
QKEY_ -> QKEY <$> smpP
QUSE_ -> QUSE <$> smpP
+55 -10
View File
@@ -1,10 +1,21 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Simplex.Messaging.Agent.RetryInterval where
module Simplex.Messaging.Agent.RetryInterval
( RetryInterval (..),
RetryInterval2 (..),
RetryIntervalMode (..),
withRetryInterval,
withRetryLock2,
)
where
import Control.Concurrent (threadDelay)
import Control.Concurrent (forkIO, threadDelay)
import Control.Monad (void)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Simplex.Messaging.Util (whenM)
import UnliftIO.STM
data RetryInterval = RetryInterval
{ initialInterval :: Int,
@@ -12,17 +23,51 @@ data RetryInterval = RetryInterval
maxInterval :: Int
}
data RetryInterval2 = RetryInterval2
{ riSlow :: RetryInterval,
riFast :: RetryInterval
}
data RetryIntervalMode = RISlow | RIFast
deriving (Eq)
withRetryInterval :: forall m. MonadIO m => RetryInterval -> (m () -> m ()) -> m ()
withRetryInterval RetryInterval {initialInterval, increaseAfter, maxInterval} action =
callAction 0 initialInterval
withRetryInterval ri action = callAction 0 $ initialInterval ri
where
callAction :: Int -> Int -> m ()
callAction elapsedTime delay = action loop
callAction elapsed delay = action loop
where
loop = do
let newDelay =
if elapsedTime < increaseAfter || delay == maxInterval
then delay
else min (delay * 3 `div` 2) maxInterval
liftIO $ threadDelay delay
callAction (elapsedTime + delay) newDelay
let elapsed' = elapsed + delay
callAction elapsed' $ nextDelay elapsed' delay ri
-- This function allows action to toggle between slow and fast retry intervals.
withRetryLock2 :: forall m. MonadIO m => RetryInterval2 -> TMVar () -> ((RetryIntervalMode -> m ()) -> m ()) -> m ()
withRetryLock2 RetryInterval2 {riSlow, riFast} lock action =
callAction (0, initialInterval riSlow) (0, initialInterval riFast)
where
callAction :: (Int, Int) -> (Int, Int) -> m ()
callAction slow fast = action loop
where
loop = \case
RISlow -> run slow riSlow (`callAction` fast)
RIFast -> run fast riFast (callAction slow)
run (elapsed, delay) ri call = do
wait delay
let elapsed' = elapsed + delay
call (elapsed', nextDelay elapsed' delay ri)
wait delay = do
waiting <- newTVarIO True
_ <- liftIO . forkIO $ do
threadDelay delay
atomically $ whenM (readTVar waiting) $ void $ tryPutTMVar lock ()
atomically $ do
takeTMVar lock
writeTVar waiting False
nextDelay :: Int -> Int -> RetryInterval -> Int
nextDelay elapsed delay RetryInterval {increaseAfter, maxInterval} =
if elapsed < increaseAfter || delay == maxInterval
then delay
else min (delay * 3 `div` 2) maxInterval
+71 -35
View File
@@ -141,6 +141,7 @@ import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Char (isPrint, isSpace)
import Data.Functor (($>))
import Data.Kind
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as L
@@ -303,12 +304,17 @@ data RcvMessage = RcvMessage
deriving (Eq, Show)
-- | received message without server/recipient encryption
data Message = Message
{ msgId :: MsgId,
msgTs :: SystemTime,
msgFlags :: MsgFlags,
msgBody :: C.MaxLenBS MaxMessageLen
}
data Message
= Message
{ msgId :: MsgId,
msgTs :: SystemTime,
msgFlags :: MsgFlags,
msgBody :: C.MaxLenBS MaxMessageLen
}
| MessageQuota
{ msgId :: MsgId,
msgTs :: SystemTime
}
instance StrEncoding RcvMessage where
strEncode RcvMessage {msgId, msgTs, msgFlags, msgBody = EncRcvMsgBody body} =
@@ -328,44 +334,72 @@ instance StrEncoding RcvMessage where
newtype EncRcvMsgBody = EncRcvMsgBody ByteString
deriving (Eq, Show)
data RcvMsgBody = RcvMsgBody
{ msgTs :: SystemTime,
msgFlags :: MsgFlags,
msgBody :: C.MaxLenBS MaxMessageLen
}
data RcvMsgBody
= RcvMsgBody
{ msgTs :: SystemTime,
msgFlags :: MsgFlags,
msgBody :: C.MaxLenBS MaxMessageLen
}
| RcvMsgQuota
{ msgTs :: SystemTime
}
msgQuotaTag :: ByteString
msgQuotaTag = "QUOTA"
encodeRcvMsgBody :: RcvMsgBody -> C.MaxLenBS MaxRcvMessageLen
encodeRcvMsgBody RcvMsgBody {msgTs, msgFlags, msgBody} =
let rcvMeta :: C.MaxLenBS 16 = C.unsafeMaxLenBS $ smpEncode (msgTs, msgFlags, ' ')
in C.appendMaxLenBS rcvMeta msgBody
encodeRcvMsgBody = \case
RcvMsgBody {msgTs, msgFlags, msgBody} ->
let rcvMeta :: C.MaxLenBS 16 = C.unsafeMaxLenBS $ smpEncode (msgTs, msgFlags, ' ')
in C.appendMaxLenBS rcvMeta msgBody
RcvMsgQuota {msgTs} ->
C.unsafeMaxLenBS $ msgQuotaTag <> " " <> smpEncode msgTs
data ClientRcvMsgBody = ClientRcvMsgBody
{ msgTs :: SystemTime,
msgFlags :: MsgFlags,
msgBody :: ByteString
}
data ClientRcvMsgBody
= ClientRcvMsgBody
{ msgTs :: SystemTime,
msgFlags :: MsgFlags,
msgBody :: ByteString
}
| ClientRcvMsgQuota
{ msgTs :: SystemTime
}
clientRcvMsgBodyP :: Parser ClientRcvMsgBody
clientRcvMsgBodyP = do
msgTs <- smpP
msgFlags <- smpP
Tail msgBody <- _smpP
pure ClientRcvMsgBody {msgTs, msgFlags, msgBody}
clientRcvMsgBodyP = msgQuotaP <|> msgBodyP
where
msgQuotaP = A.string msgQuotaTag *> (ClientRcvMsgQuota <$> _smpP)
msgBodyP = do
msgTs <- smpP
msgFlags <- smpP
Tail msgBody <- _smpP
pure ClientRcvMsgBody {msgTs, msgFlags, msgBody}
instance StrEncoding Message where
strEncode Message {msgId, msgTs, msgFlags, msgBody} =
B.unwords
[ strEncode msgId,
strEncode msgTs,
"flags=" <> strEncode msgFlags,
strEncode msgBody
]
strEncode = \case
Message {msgId, msgTs, msgFlags, msgBody} ->
B.unwords
[ strEncode msgId,
strEncode msgTs,
"flags=" <> strEncode msgFlags,
strEncode msgBody
]
MessageQuota {msgId, msgTs} ->
B.unwords
[ strEncode msgId,
strEncode msgTs,
"quota"
]
strP = do
msgId <- strP_
msgTs <- strP_
msgFlags <- ("flags=" *> strP_) <|> pure noMsgFlags
msgBody <- strP
pure Message {msgId, msgTs, msgFlags, msgBody}
msgQuotaP msgId msgTs <|> msgP msgId msgTs
where
msgQuotaP msgId msgTs = "quota" $> MessageQuota {msgId, msgTs}
msgP msgId msgTs = do
msgFlags <- ("flags=" *> strP_) <|> pure noMsgFlags
msgBody <- strP
pure Message {msgId, msgTs, msgFlags, msgBody}
type EncNMsgMeta = ByteString
@@ -377,7 +411,9 @@ data SMPMsgMeta = SMPMsgMeta
deriving (Show)
rcvMessageMeta :: MsgId -> ClientRcvMsgBody -> SMPMsgMeta
rcvMessageMeta msgId ClientRcvMsgBody {msgTs, msgFlags} = SMPMsgMeta {msgId, msgTs, msgFlags}
rcvMessageMeta msgId = \case
ClientRcvMsgBody {msgTs, msgFlags} -> SMPMsgMeta {msgId, msgTs, msgFlags}
ClientRcvMsgQuota {msgTs} -> SMPMsgMeta {msgId, msgTs, msgFlags = noMsgFlags}
data NMsgMeta = NMsgMeta
{ msgId :: MsgId,
+34 -25
View File
@@ -538,20 +538,22 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
| otherwise = case status qr of
QueueOff -> return $ err AUTH
QueueActive ->
mapM mkMessage (C.maxLenBS msgBody) >>= \case
case C.maxLenBS msgBody of
Left _ -> pure $ err LARGE_MSG
Right msg -> do
resp@(_, _, sent) <- time "SEND" $ do
Right body -> do
msg_ <- time "SEND" $ do
q <- getStoreMsgQueue "SEND" $ recipientId qr
expireMessages q
atomically $ ifM (isFull q) (pure $ err QUOTA) (writeMsg q msg $> ok)
when (sent == OK) . time "SEND ok" $ do
when (notification msgFlags) $
atomically . trySendNotification msg =<< asks idsDrg
stats <- asks serverStats
atomically $ modifyTVar (msgSent stats) (+ 1)
atomically $ updatePeriodStats (activeQueues stats) (recipientId qr)
pure resp
atomically . writeMsg q =<< mkMessage body
case msg_ of
Nothing -> pure $ err QUOTA
Just msg -> time "SEND ok" $ do
when (notification msgFlags) $
atomically . trySendNotification msg =<< asks idsDrg
stats <- asks serverStats
atomically $ modifyTVar (msgSent stats) (+ 1)
atomically $ updatePeriodStats (activeQueues stats) (recipientId qr)
pure ok
where
mkMessage :: C.MaxLenBS MaxMessageLen -> m Message
mkMessage body = do
@@ -572,12 +574,14 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
writeNtf :: NotifierId -> Message -> RcvNtfDhSecret -> TVar ChaChaDRG -> Client -> STM ()
writeNtf nId msg rcvNtfDhSecret ntfNonceDrg Client {sndQ = q} =
unlessM (isFullTBQueue q) $ do
(nmsgNonce, encNMsgMeta) <- mkMessageNotification msg rcvNtfDhSecret ntfNonceDrg
writeTBQueue q [(CorrId "", nId, NMSG nmsgNonce encNMsgMeta)]
unlessM (isFullTBQueue q) $ case msg of
Message {msgId, msgTs} -> do
(nmsgNonce, encNMsgMeta) <- mkMessageNotification msgId msgTs rcvNtfDhSecret ntfNonceDrg
writeTBQueue q [(CorrId "", nId, NMSG nmsgNonce encNMsgMeta)]
_ -> pure ()
mkMessageNotification :: Message -> RcvNtfDhSecret -> TVar ChaChaDRG -> STM (C.CbNonce, EncNMsgMeta)
mkMessageNotification Message {msgId, msgTs} rcvNtfDhSecret ntfNonceDrg = do
mkMessageNotification :: ByteString -> SystemTime -> RcvNtfDhSecret -> TVar ChaChaDRG -> STM (C.CbNonce, EncNMsgMeta)
mkMessageNotification msgId msgTs rcvNtfDhSecret ntfNonceDrg = do
cbNonce <- C.pseudoRandomCbNonce ntfNonceDrg
let msgMeta = NMsgMeta {msgId, msgTs}
encNMsgMeta = C.cbEncrypt rcvNtfDhSecret cbNonce (smpEncode msgMeta) 128
@@ -615,17 +619,22 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
time name = timed name queueId
encryptMsg :: QueueRec -> Message -> RcvMessage
encryptMsg qr Message {msgId, msgTs, msgFlags, msgBody}
| thVersion == 1 || thVersion == 2 = encrypt msgBody
| otherwise = encrypt $ encodeRcvMsgBody RcvMsgBody {msgTs, msgFlags, msgBody}
encryptMsg qr msg = case msg of
Message {msgFlags, msgBody}
| thVersion == 1 || thVersion == 2 -> encrypt msgFlags msgBody
| otherwise -> encrypt msgFlags $ encodeRcvMsgBody RcvMsgBody {msgTs = msgTs', msgFlags, msgBody}
MessageQuota {} ->
encrypt noMsgFlags $ encodeRcvMsgBody (RcvMsgQuota msgTs')
where
encrypt :: KnownNat i => C.MaxLenBS i -> RcvMessage
encrypt body =
let encBody = EncRcvMsgBody $ C.cbEncryptMaxLenBS (rcvDhSecret qr) (C.cbNonce msgId) body
in RcvMessage msgId msgTs msgFlags encBody
encrypt :: KnownNat i => MsgFlags -> C.MaxLenBS i -> RcvMessage
encrypt msgFlags body =
let encBody = EncRcvMsgBody $ C.cbEncryptMaxLenBS (rcvDhSecret qr) (C.cbNonce msgId') body
in RcvMessage msgId' msgTs' msgFlags encBody
msgId' = msgId (msg :: Message)
msgTs' = msgTs (msg :: Message)
setDelivered :: Sub -> Message -> STM Bool
setDelivered s Message {msgId} = tryPutTMVar (delivered s) msgId
setDelivered s msg = tryPutTMVar (delivered s) $ msgId (msg :: Message)
getStoreMsgQueue :: T.Text -> RecipientId -> m MsgQueue
getStoreMsgQueue name rId = time (name <> " getMsgQueue") $ do
@@ -717,7 +726,7 @@ restoreServerMessages = asks (storeMsgsFile . config) >>= mapM_ restoreMessages
addToMsgQueue rId msg = do
full <- atomically $ do
q <- getMsgQueue ms rId quota
ifM (isFull q) (pure True) (writeMsg q msg $> False)
isNothing <$> writeMsg q msg
when full . logError . decodeLatin1 $ "message queue " <> strEncode rId <> " is full, message not restored: " <> strEncode (msgId (msg :: Message))
updateMsgV1toV3 QueueRec {rcvDhSecret} RcvMessage {msgId, msgTs, msgFlags, msgBody = EncRcvMsgBody body} = do
let nonce = C.cbNonce msgId
+1 -1
View File
@@ -39,7 +39,7 @@ data ServerConfig = ServerConfig
{ transports :: [(ServiceName, ATransport)],
tbqSize :: Natural,
serverTbqSize :: Natural,
msgQueueQuota :: Natural,
msgQueueQuota :: Int,
queueIdBytes :: Int,
msgIdBytes :: Int,
storeLogFile :: Maybe FilePath,
+2 -3
View File
@@ -19,13 +19,12 @@ instance StrEncoding MsgLogRecord where
strP = "v3 " *> (MLRv3 <$> strP_ <*> strP) <|> MLRv1 <$> strP_ <*> strP
class MonadMsgStore s q m | s -> q where
getMsgQueue :: s -> RecipientId -> Natural -> m q
getMsgQueue :: s -> RecipientId -> Int -> m q
delMsgQueue :: s -> RecipientId -> m ()
flushMsgQueue :: s -> RecipientId -> m [Message]
class MonadMsgQueue q m where
isFull :: q -> m Bool
writeMsg :: q -> Message -> m () -- non blocking
writeMsg :: q -> Message -> m (Maybe Message) -- non blocking
tryPeekMsg :: q -> m (Maybe Message) -- non blocking
peekMsg :: q -> m Message -- blocking
tryDelMsg :: q -> MsgId -> m Bool -- non blocking
+60 -28
View File
@@ -7,22 +7,31 @@
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE TupleSections #-}
module Simplex.Messaging.Server.MsgStore.STM where
module Simplex.Messaging.Server.MsgStore.STM
( STMMsgStore,
MsgQueue,
newMsgStore,
)
where
import Control.Concurrent.STM.TBQueue (flushTBQueue)
import Control.Concurrent.STM.TQueue (flushTQueue)
import Control.Monad (when)
import qualified Data.ByteString.Char8 as B
import Data.Functor (($>))
import Data.Int (Int64)
import Data.Time.Clock.System (SystemTime (systemSeconds))
import Numeric.Natural
import Simplex.Messaging.Protocol (Message (..), MsgId, RecipientId)
import Simplex.Messaging.Server.MsgStore
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import UnliftIO.STM
newtype MsgQueue = MsgQueue {msgQueue :: TBQueue Message}
data MsgQueue = MsgQueue
{ msgQueue :: TQueue Message,
quota :: Int,
canWrite :: TVar Bool,
size :: TVar Int
}
type STMMsgStore = TMap RecipientId MsgQueue
@@ -30,54 +39,77 @@ newMsgStore :: STM STMMsgStore
newMsgStore = TM.empty
instance MonadMsgStore STMMsgStore MsgQueue STM where
getMsgQueue :: STMMsgStore -> RecipientId -> Natural -> STM MsgQueue
getMsgQueue :: STMMsgStore -> RecipientId -> Int -> STM MsgQueue
getMsgQueue st rId quota = maybe newQ pure =<< TM.lookup rId st
where
newQ = do
q <- MsgQueue <$> newTBQueue quota
msgQueue <- newTQueue
canWrite <- newTVar True
size <- newTVar 0
let q = MsgQueue {msgQueue, quota, canWrite, size}
TM.insert rId q st
return q
pure q
delMsgQueue :: STMMsgStore -> RecipientId -> STM ()
delMsgQueue st rId = TM.delete rId st
flushMsgQueue :: STMMsgStore -> RecipientId -> STM [Message]
flushMsgQueue st rId = TM.lookup rId st >>= maybe (pure []) (flushTBQueue . msgQueue)
flushMsgQueue st rId = TM.lookupDelete rId st >>= maybe (pure []) (flushTQueue . msgQueue)
instance MonadMsgQueue MsgQueue STM where
isFull :: MsgQueue -> STM Bool
isFull = isFullTBQueue . msgQueue
writeMsg :: MsgQueue -> Message -> STM ()
writeMsg = writeTBQueue . msgQueue
writeMsg :: MsgQueue -> Message -> STM (Maybe Message)
writeMsg MsgQueue {msgQueue = q, quota, canWrite, size} msg = do
canWrt <- readTVar canWrite
empty <- isEmptyTQueue q
if canWrt || empty
then do
canWrt' <- (quota >) <$> readTVar size
writeTVar canWrite canWrt'
modifyTVar' size (+ 1)
if canWrt'
then writeTQueue q msg $> Just msg
else writeTQueue q msgQuota $> Nothing
else pure Nothing
where
msgQuota = MessageQuota {msgId = msgId msg, msgTs = msgTs msg}
tryPeekMsg :: MsgQueue -> STM (Maybe Message)
tryPeekMsg = tryPeekTBQueue . msgQueue
tryPeekMsg = tryPeekTQueue . msgQueue
{-# INLINE tryPeekMsg #-}
peekMsg :: MsgQueue -> STM Message
peekMsg = peekTBQueue . msgQueue
peekMsg = peekTQueue . msgQueue
{-# INLINE peekMsg #-}
tryDelMsg :: MsgQueue -> MsgId -> STM Bool
tryDelMsg (MsgQueue q) msgId' =
tryPeekTBQueue q >>= \case
Just Message {msgId}
| msgId == msgId' || B.null msgId' -> tryReadTBQueue q $> True
tryDelMsg mq msgId' =
tryPeekMsg mq >>= \case
Just msg
| msgId msg == msgId' || B.null msgId' -> tryDeleteMsg mq >> pure True
| otherwise -> pure False
_ -> pure False
-- atomic delete (== read) last and peek next message if available
tryDelPeekMsg :: MsgQueue -> MsgId -> STM (Bool, Maybe Message)
tryDelPeekMsg (MsgQueue q) msgId' =
tryPeekTBQueue q >>= \case
msg_@(Just Message {msgId})
| msgId == msgId' || B.null msgId' -> (True,) <$> (tryReadTBQueue q >> tryPeekTBQueue q)
tryDelPeekMsg mq msgId' =
tryPeekMsg mq >>= \case
msg_@(Just msg)
| msgId msg == msgId' || B.null msgId' -> (True,) <$> (tryDeleteMsg mq >> tryPeekMsg mq)
| otherwise -> pure (False, msg_)
_ -> pure (False, Nothing)
deleteExpiredMsgs :: MsgQueue -> Int64 -> STM ()
deleteExpiredMsgs (MsgQueue q) old = loop
deleteExpiredMsgs mq old = loop
where
loop = tryPeekTBQueue q >>= mapM_ delOldMsg
delOldMsg Message {msgTs} =
when (systemSeconds msgTs < old) $
tryReadTBQueue q >> loop
loop = tryPeekMsg mq >>= mapM_ delOldMsg
delOldMsg = \case
Message {msgTs} ->
when (systemSeconds msgTs < old) $
tryDeleteMsg mq >> loop
_ -> pure ()
tryDeleteMsg :: MsgQueue -> STM ()
tryDeleteMsg MsgQueue {msgQueue = q, size} =
tryReadTQueue q >>= \case
Just _ -> modifyTVar' size (subtract 1)
_ -> pure ()
+28
View File
@@ -78,6 +78,8 @@ agentTests (ATransport t) = do
smpAgentTest2_2_1 $ testConcurrentMsgDelivery t
it "should deliver messages if one of connections has quota exceeded" $
smpAgentTest2_2_1 $ testMsgDeliveryQuotaExceeded t
it "should resume delivering messages after exceeding quota once all messages are received" $
smpAgentTest2_2_1 $ testResumeDeliveryQuotaExceeded t
tGetAgent :: Transport c => c -> IO (ATransmissionOrError 'Agent)
tGetAgent h = do
@@ -430,6 +432,32 @@ testMsgDeliveryQuotaExceeded _ alice bob = do
-- if delivery is blocked it won't go further
alice <# ("", "bob2", SENT 4)
testResumeDeliveryQuotaExceeded :: Transport c => TProxy c -> c -> c -> IO ()
testResumeDeliveryQuotaExceeded _ alice bob = do
connect (alice, "alice") (bob, "bob")
forM_ [1 .. 4 :: Int] $ \i -> do
let corrId = bshow i
msg = "message " <> bshow i
(_, "bob", Right (MID mId)) <- alice #: (corrId, "bob", "SEND F :" <> msg)
alice <#= \case ("", "bob", SENT m) -> m == mId; _ -> False
("5", "bob", Right (MID 8)) <- alice #: ("5", "bob", "SEND F :over quota")
alice #:# "the last message not sent yet"
bob <#= \case ("", "alice", Msg "message 1") -> True; _ -> False
bob #: ("1", "alice", "ACK 4") #> ("1", "alice", OK)
alice #:# "the last message not sent"
bob <#= \case ("", "alice", Msg "message 2") -> True; _ -> False
bob #: ("2", "alice", "ACK 5") #> ("2", "alice", OK)
alice #:# "the last message not sent"
bob <#= \case ("", "alice", Msg "message 3") -> True; _ -> False
bob #: ("3", "alice", "ACK 6") #> ("3", "alice", OK)
alice #:# "the last message not sent"
bob <#= \case ("", "alice", Msg "message 4") -> True; _ -> False
bob #: ("4", "alice", "ACK 7") #> ("4", "alice", OK)
alice <# ("", "bob", SENT 8)
bob <#= \case ("", "alice", Msg "over quota") -> True; _ -> False
-- message 8 is skipped because of alice agent sending "QCONT" message
bob #: ("5", "alice", "ACK 9") #> ("5", "alice", OK)
connect :: forall c. Transport c => (c, ByteString) -> (c, ByteString) -> IO ()
connect (h1, name1) (h2, name2) = do
("c1", _, Right (INV cReq)) <- h1 #: ("c1", name2, "NEW T INV")
+61
View File
@@ -0,0 +1,61 @@
{-# LANGUAGE ScopedTypeVariables #-}
module CoreTests.RetryIntervalTests where
import Control.Concurrent.STM
import Control.Monad (when)
import Data.Time.Clock (UTCTime, diffUTCTime, getCurrentTime, nominalDiffTimeToSeconds)
import Simplex.Messaging.Agent.RetryInterval
import Test.Hspec
retryIntervalTests :: Spec
retryIntervalTests = do
describe "Retry interval with 2 modes and lock" $ do
testRetryIntervalSameMode
testRetryIntervalSwitchMode
testRI :: RetryInterval2
testRI =
RetryInterval2
{ riSlow =
RetryInterval
{ initialInterval = 20000,
increaseAfter = 40000,
maxInterval = 40000
},
riFast =
RetryInterval
{ initialInterval = 10000,
increaseAfter = 20000,
maxInterval = 40000
}
}
testRetryIntervalSameMode :: Spec
testRetryIntervalSameMode =
it "should increase elapased time and interval when the mode stays the same" $ do
lock <- newEmptyTMVarIO
intervals <- newTVarIO []
ts <- newTVarIO =<< getCurrentTime
withRetryLock2 testRI lock $ \loop -> do
ints <- addInterval intervals ts
when (length ints < 9) $ loop RIFast
(reverse <$> readTVarIO intervals) `shouldReturn` [0, 1, 1, 1, 2, 3, 4, 4, 4]
testRetryIntervalSwitchMode :: Spec
testRetryIntervalSwitchMode =
it "should increase elapased time and interval when the mode stays the same" $ do
lock <- newEmptyTMVarIO
intervals <- newTVarIO []
ts <- newTVarIO =<< getCurrentTime
withRetryLock2 testRI lock $ \loop -> do
ints <- addInterval intervals ts
when (length ints < 11) $ loop $ if length ints <= 5 then RIFast else RISlow
(reverse <$> readTVarIO intervals) `shouldReturn` [0, 1, 1, 1, 2, 3, 2, 2, 3, 4, 4]
addInterval :: TVar [Int] -> TVar UTCTime -> IO [Int]
addInterval intervals ts = do
ts' <- getCurrentTime
atomically $ do
int :: Int <- truncate . (* 100) . nominalDiffTimeToSeconds <$> stateTVar ts (\t -> (diffUTCTime ts' t, ts'))
stateTVar intervals $ \ints -> (int : ints, int : ints)
+32 -3
View File
@@ -49,6 +49,7 @@ serverTests t@(ATransport t') = do
describe "switch subscription to another TCP connection" $ testSwitchSub t
describe "GET command" $ testGetCommand t'
describe "GET & SUB commands" $ testGetSubCommands t'
describe "Exceeding queue quota" $ testExceedQueueQuota t'
describe "Store log" $ testWithStoreLog t
describe "Restore messages" $ testRestoreMessages t
describe "Restore messages (v2)" $ testRestoreMessagesV2 t
@@ -104,9 +105,11 @@ decryptMsgV2 :: C.DhSecret 'C.X25519 -> ByteString -> ByteString -> Either C.Cry
decryptMsgV2 dhShared = C.cbDecrypt dhShared . C.cbNonce
decryptMsgV3 :: C.DhSecret 'C.X25519 -> ByteString -> ByteString -> Either String MsgBody
decryptMsgV3 dhShared nonce body = do
ClientRcvMsgBody {msgBody} <- parseAll clientRcvMsgBodyP =<< first show (C.cbDecrypt dhShared (C.cbNonce nonce) body)
pure msgBody
decryptMsgV3 dhShared nonce body =
case parseAll clientRcvMsgBodyP =<< first show (C.cbDecrypt dhShared (C.cbNonce nonce) body) of
Right ClientRcvMsgBody {msgBody} -> Right msgBody
Right ClientRcvMsgQuota {} -> Left "ClientRcvMsgQuota"
Left e -> Left e
testCreateSecureV2 :: forall c. Transport c => TProxy c -> Spec
testCreateSecureV2 _ =
@@ -494,6 +497,32 @@ testGetSubCommands t =
Resp "12" _ OK <- signSendRecv rh2 rKey ("12", rId, GET)
pure ()
testExceedQueueQuota :: forall c. Transport c => TProxy c -> Spec
testExceedQueueQuota t =
it "should reply with ERR QUOTA to sender and send QUOTA message to the recipient" $ do
withSmpServerConfigOn (ATransport t) cfg {msgQueueQuota = 2} testPort $ \_ ->
testSMPClient @c $ \sh -> testSMPClient @c $ \rh -> do
(sPub, sKey) <- C.generateSignatureKeyPair C.SEd25519
(sId, rId, rKey, dhShared) <- createAndSecureQueue rh sPub
let dec = decryptMsgV3 dhShared
Resp "1" _ OK <- signSendRecv sh sKey ("1", sId, _SEND "hello 1")
Resp "2" _ OK <- signSendRecv sh sKey ("2", sId, _SEND "hello 2")
Resp "3" _ (ERR QUOTA) <- signSendRecv sh sKey ("3", sId, _SEND "hello 3")
Resp "" _ (Msg mId1 msg1) <- tGet1 rh
(dec mId1 msg1, Right "hello 1") #== "hello 1"
Resp "4" _ (Msg mId2 msg2) <- signSendRecv rh rKey ("4", rId, ACK mId1)
(dec mId2 msg2, Right "hello 2") #== "hello 2"
Resp "5" _ (ERR QUOTA) <- signSendRecv sh sKey ("5", sId, _SEND "hello 3")
Resp "6" _ (Msg mId3 msg3) <- signSendRecv rh rKey ("6", rId, ACK mId2)
(dec mId3 msg3, Left "ClientRcvMsgQuota") #== "ClientRcvMsgQuota"
Resp "7" _ (ERR QUOTA) <- signSendRecv sh sKey ("7", sId, _SEND "hello 3")
Resp "8" _ OK <- signSendRecv rh rKey ("8", rId, ACK mId3)
Resp "9" _ OK <- signSendRecv sh sKey ("9", sId, _SEND "hello 3")
Resp "" _ (Msg mId4 msg4) <- tGet1 rh
(dec mId4 msg4, Right "hello 3") #== "hello 3"
Resp "10" _ OK <- signSendRecv rh rKey ("10", rId, ACK mId4)
pure ()
testWithStoreLog :: ATransport -> Spec
testWithStoreLog at@(ATransport t) =
it "should store simplex queues to log and restore them after server restart" $ do
+2
View File
@@ -6,6 +6,7 @@ import CLITests
import CoreTests.CryptoTests
import CoreTests.EncodingTests
import CoreTests.ProtocolErrorTests
import CoreTests.RetryIntervalTests
import CoreTests.VersionRangeTests
import NtfServerTests (ntfServerTests)
import ServerTests
@@ -31,6 +32,7 @@ main = do
describe "Protocol error tests" protocolErrorTests
describe "Version range" versionRangeTests
describe "Encryption tests" cryptoTests
describe "Retry interval tests" retryIntervalTests
describe "SMP server via TLS" $ serverTests (transport @TLS)
describe "SMP server via WebSockets" $ serverTests (transport @WS)
describe "Notifications server" $ ntfServerTests (transport @TLS)