mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-05-25 20:44:49 +00:00
send/process "quota exceeded" message from SMP server when sender gets ERR QUOTA (#585)
* send "quota exceeded" message from SMP server when sender gets ERR QUOTA (ignored in the agent for now) * send msg quota to the recipient to indicate that sender got ERR QUOTA, test * switch between slow/fast retry intervals (tests do not pass yet) * send QCONT message, refactor RetryInterval, test * refactor * remove comment * remove space * unit test for withRetryLock2 * refactor
This commit is contained in:
committed by
GitHub
parent
470b512a88
commit
058e3ac55e
@@ -357,6 +357,7 @@ test-suite smp-server-test
|
||||
CoreTests.CryptoTests
|
||||
CoreTests.EncodingTests
|
||||
CoreTests.ProtocolErrorTests
|
||||
CoreTests.RetryIntervalTests
|
||||
CoreTests.VersionRangeTests
|
||||
NtfClient
|
||||
NtfServerTests
|
||||
|
||||
+113
-93
@@ -791,7 +791,7 @@ runCommandProcessing c@AgentClient {subQ} server_ = do
|
||||
atomically $ beginAgentOperation c AOSndNetwork
|
||||
E.try (withStore c $ \db -> getPendingCommand db cmdId) >>= \case
|
||||
Left (e :: E.SomeException) -> atomically $ writeTBQueue subQ ("", "", ERR . INTERNAL $ show e)
|
||||
Right (corrId, connId, cmd) -> processCmd ri corrId connId cmdId cmd
|
||||
Right (corrId, connId, cmd) -> processCmd (riFast ri) corrId connId cmdId cmd
|
||||
where
|
||||
processCmd :: RetryInterval -> ACorrId -> ConnId -> AsyncCmdId -> AgentCommand -> m ()
|
||||
processCmd ri corrId connId cmdId command = case command of
|
||||
@@ -964,22 +964,22 @@ queuePendingMsgs c sq msgIds = atomically $ do
|
||||
modifyTVar' (msgDeliveryOp c) $ \s -> s {opsInProgress = opsInProgress s + length msgIds}
|
||||
-- s <- readTVar (msgDeliveryOp c)
|
||||
-- unsafeIOToSTM $ putStrLn $ "msgDeliveryOp: " <> show (opsInProgress s)
|
||||
q <- getPendingMsgQ c sq
|
||||
mapM_ (writeTQueue q) msgIds
|
||||
(mq, _) <- getPendingMsgQ c sq
|
||||
mapM_ (writeTQueue mq) msgIds
|
||||
|
||||
getPendingMsgQ :: AgentClient -> SndQueue -> STM (TQueue InternalId)
|
||||
getPendingMsgQ :: AgentClient -> SndQueue -> STM (TQueue InternalId, TMVar ())
|
||||
getPendingMsgQ c SndQueue {server, sndId} = do
|
||||
let qKey = (server, sndId)
|
||||
maybe (newMsgQueue qKey) pure =<< TM.lookup qKey (smpQueueMsgQueues c)
|
||||
where
|
||||
newMsgQueue qKey = do
|
||||
mq <- newTQueue
|
||||
TM.insert qKey mq $ smpQueueMsgQueues c
|
||||
pure mq
|
||||
q <- (,) <$> newTQueue <*> newEmptyTMVar
|
||||
TM.insert qKey q $ smpQueueMsgQueues c
|
||||
pure q
|
||||
|
||||
runSmpQueueMsgDelivery :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> m ()
|
||||
runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandshake} sq = do
|
||||
mq <- atomically $ getPendingMsgQ c sq
|
||||
(mq, qLock) <- atomically $ getPendingMsgQ c sq
|
||||
ri <- asks $ messageRetryInterval . config
|
||||
forever $ do
|
||||
atomically $ endAgentOperation c AOSndNetwork
|
||||
@@ -993,7 +993,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
|
||||
Left (e :: E.SomeException) ->
|
||||
notify $ MERR mId (INTERNAL $ show e)
|
||||
Right (rq_, PendingMsgData {msgType, msgBody, msgFlags, internalTs}) ->
|
||||
withRetryInterval ri $ \loop -> do
|
||||
withRetryLock2 ri qLock $ \loop -> do
|
||||
resp <- tryError $ case msgType of
|
||||
AM_CONN_INFO -> sendConfirmation c sq msgBody
|
||||
_ -> sendAgentMessage c sq msgFlags msgBody
|
||||
@@ -1004,7 +1004,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
|
||||
SMP SMP.QUOTA -> case msgType of
|
||||
AM_CONN_INFO -> connError msgId NOT_AVAILABLE
|
||||
AM_CONN_INFO_REPLY -> connError msgId NOT_AVAILABLE
|
||||
_ -> retrySndOp c loop
|
||||
_ -> retrySndOp c $ loop RISlow
|
||||
SMP SMP.AUTH -> case msgType of
|
||||
AM_CONN_INFO -> connError msgId NOT_AVAILABLE
|
||||
AM_CONN_INFO_REPLY -> connError msgId NOT_AVAILABLE
|
||||
@@ -1013,7 +1013,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
|
||||
-- because the queue must be secured by the time the confirmation or the first HELLO is received
|
||||
| duplexHandshake == Just True -> connErr
|
||||
| otherwise ->
|
||||
ifM (msgExpired helloTimeout) connErr (retrySndOp c loop)
|
||||
ifM (msgExpired helloTimeout) connErr (retrySndOp c $ loop RIFast)
|
||||
where
|
||||
connErr = case rq_ of
|
||||
-- party initiating connection
|
||||
@@ -1022,6 +1022,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
|
||||
_ -> connError msgId NOT_ACCEPTED
|
||||
AM_REPLY_ -> notifyDel msgId err
|
||||
AM_A_MSG_ -> notifyDel msgId err
|
||||
AM_QCONT_ -> notifyDel msgId err
|
||||
AM_QADD_ -> qError msgId "QADD: AUTH"
|
||||
AM_QKEY_ -> qError msgId "QKEY: AUTH"
|
||||
AM_QUSE_ -> qError msgId "QUSE: AUTH"
|
||||
@@ -1031,7 +1032,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
|
||||
-- the message sending would be retried
|
||||
| temporaryOrHostError e -> do
|
||||
let timeoutSel = if msgType == AM_HELLO_ then helloTimeout else messageTimeout
|
||||
ifM (msgExpired timeoutSel) (notifyDel msgId err) (retrySndOp c loop)
|
||||
ifM (msgExpired timeoutSel) (notifyDel msgId err) (retrySndOp c $ loop RIFast)
|
||||
| otherwise -> notifyDel msgId err
|
||||
where
|
||||
msgExpired timeoutSel = do
|
||||
@@ -1071,6 +1072,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
|
||||
qInfo <- createReplyQueue c cData sq srv
|
||||
void . enqueueMessage c cData sq SMP.noMsgFlags $ REPLY [qInfo]
|
||||
AM_A_MSG_ -> notify $ SENT mId
|
||||
AM_QCONT_ -> pure ()
|
||||
AM_QADD_ -> pure ()
|
||||
AM_QKEY_ -> pure ()
|
||||
AM_QUSE_ -> pure ()
|
||||
@@ -1492,88 +1494,96 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
|
||||
processSMP :: RcvQueue -> Connection c -> ConnData -> m ()
|
||||
processSMP rq@RcvQueue {e2ePrivKey, e2eDhSecret, status} conn cData@ConnData {connId, duplexHandshake} = withConnLock c connId "processSMP" $
|
||||
case cmd of
|
||||
SMP.MSG msg@SMP.RcvMessage {msgId = srvMsgId} -> handleNotifyAck $ do
|
||||
SMP.ClientRcvMsgBody {msgTs = srvTs, msgFlags, msgBody} <- decryptSMPMessage v rq msg
|
||||
clientMsg@SMP.ClientMsgEnvelope {cmHeader = SMP.PubHeader phVer e2ePubKey_} <-
|
||||
parseMessage msgBody
|
||||
clientVRange <- asks $ smpClientVRange . config
|
||||
unless (phVer `isCompatible` clientVRange) . throwError $ AGENT A_VERSION
|
||||
case (e2eDhSecret, e2ePubKey_) of
|
||||
(Nothing, Just e2ePubKey) -> do
|
||||
let e2eDh = C.dh' e2ePubKey e2ePrivKey
|
||||
decryptClientMessage e2eDh clientMsg >>= \case
|
||||
(SMP.PHConfirmation senderKey, AgentConfirmation {e2eEncryption, encConnInfo, agentVersion}) ->
|
||||
smpConfirmation senderKey e2ePubKey e2eEncryption encConnInfo phVer agentVersion >> ack
|
||||
(SMP.PHEmpty, AgentInvitation {connReq, connInfo}) ->
|
||||
smpInvitation connReq connInfo >> ack
|
||||
_ -> prohibited >> ack
|
||||
(Just e2eDh, Nothing) -> do
|
||||
decryptClientMessage e2eDh clientMsg >>= \case
|
||||
(SMP.PHEmpty, AgentMsgEnvelope _ encAgentMsg) -> do
|
||||
-- primary queue is set as Active in helloMsg, below is to set additional queues Active
|
||||
let RcvQueue {primary, dbReplaceQueueId} = rq
|
||||
unless (status == Active) . withStore' c $ \db -> setRcvQueueStatus db rq Active
|
||||
case (conn, dbReplaceQueueId) of
|
||||
(DuplexConnection _ rqs _, Just replacedId) -> do
|
||||
when primary . withStore' c $ \db -> setRcvQueuePrimary db connId rq
|
||||
case find (\RcvQueue {dbQueueId} -> dbQueueId == replacedId) rqs of
|
||||
Just RcvQueue {server, rcvId} -> do
|
||||
enqueueCommand c "" connId (Just server) $ AInternalCommand $ ICQDelete rcvId
|
||||
_ -> notify . ERR . AGENT $ A_QUEUE "replaced RcvQueue not found in connection"
|
||||
_ -> pure ()
|
||||
tryError agentClientMsg >>= \case
|
||||
Right (Just (msgId, msgMeta, aMessage)) -> case aMessage of
|
||||
HELLO -> helloMsg >> ackDel msgId
|
||||
REPLY cReq -> replyMsg cReq >> ackDel msgId
|
||||
-- note that there is no ACK sent for A_MSG, it is sent with agent's user ACK command
|
||||
A_MSG body -> do
|
||||
logServer "<--" c srv rId "MSG <MSG>"
|
||||
notify $ MSG msgMeta msgFlags body
|
||||
QADD qs -> qDuplex "QADD" $ qAddMsg qs
|
||||
QKEY qs -> qDuplex "QKEY" $ qKeyMsg qs
|
||||
QUSE qs -> qDuplex "QUSE" $ qUseMsg qs
|
||||
-- no action needed for QTEST
|
||||
-- any message in the new queue will mark it active and trigger deletion of the old queue
|
||||
QTEST _ -> logServer "<--" c srv rId "MSG <QTEST>" >> ackDel msgId
|
||||
where
|
||||
qDuplex :: String -> (Connection 'CDuplex -> m ()) -> m ()
|
||||
qDuplex name a = case conn of
|
||||
DuplexConnection {} -> a conn >> ackDel msgId
|
||||
_ -> qError $ name <> ": message must be sent to duplex connection"
|
||||
Right _ -> prohibited >> ack
|
||||
Left e@(AGENT A_DUPLICATE) -> do
|
||||
withStore' c (\db -> getLastMsg db connId srvMsgId) >>= \case
|
||||
Just RcvMsg {internalId, msgMeta, msgBody = agentMsgBody, userAck}
|
||||
| userAck -> ackDel internalId
|
||||
| otherwise -> do
|
||||
liftEither (parse smpP (AGENT A_MESSAGE) agentMsgBody) >>= \case
|
||||
AgentMessage _ (A_MSG body) -> do
|
||||
logServer "<--" c srv rId "MSG <MSG>"
|
||||
notify $ MSG msgMeta msgFlags body
|
||||
_ -> pure ()
|
||||
_ -> throwError e
|
||||
Left e -> throwError e
|
||||
where
|
||||
agentClientMsg :: m (Maybe (InternalId, MsgMeta, AMessage))
|
||||
agentClientMsg = withStore c $ \db -> runExceptT $ do
|
||||
agentMsgBody <- agentRatchetDecrypt db connId encAgentMsg
|
||||
liftEither (parse smpP (SEAgentError $ AGENT A_MESSAGE) agentMsgBody) >>= \case
|
||||
agentMsg@(AgentMessage APrivHeader {sndMsgId, prevMsgHash} aMessage) -> do
|
||||
let msgType = agentMessageType agentMsg
|
||||
internalHash = C.sha256Hash agentMsgBody
|
||||
internalTs <- liftIO getCurrentTime
|
||||
(internalId, internalRcvId, prevExtSndId, prevRcvMsgHash) <- liftIO $ updateRcvIds db connId
|
||||
let integrity = checkMsgIntegrity prevExtSndId sndMsgId prevRcvMsgHash prevMsgHash
|
||||
recipient = (unId internalId, internalTs)
|
||||
broker = (srvMsgId, systemToUTCTime srvTs)
|
||||
msgMeta = MsgMeta {integrity, recipient, broker, sndMsgId}
|
||||
rcvMsg = RcvMsgData {msgMeta, msgType, msgFlags, msgBody = agentMsgBody, internalRcvId, internalHash, externalPrevSndHash = prevMsgHash}
|
||||
liftIO $ createRcvMsg db connId rq rcvMsg
|
||||
pure $ Just (internalId, msgMeta, aMessage)
|
||||
_ -> pure Nothing
|
||||
_ -> prohibited >> ack
|
||||
_ -> prohibited >> ack
|
||||
SMP.MSG msg@SMP.RcvMessage {msgId = srvMsgId} ->
|
||||
handleNotifyAck $
|
||||
decryptSMPMessage v rq msg >>= \case
|
||||
SMP.ClientRcvMsgBody {msgTs = srvTs, msgFlags, msgBody} -> processClientMsg srvTs msgFlags msgBody
|
||||
SMP.ClientRcvMsgQuota {} -> queueDrained >> ack
|
||||
where
|
||||
queueDrained = case conn of
|
||||
DuplexConnection _ _ sqs -> void $ enqueueMessages c cData sqs SMP.noMsgFlags $ QCONT (sndAddress rq)
|
||||
_ -> pure ()
|
||||
processClientMsg srvTs msgFlags msgBody = do
|
||||
clientMsg@SMP.ClientMsgEnvelope {cmHeader = SMP.PubHeader phVer e2ePubKey_} <-
|
||||
parseMessage msgBody
|
||||
clientVRange <- asks $ smpClientVRange . config
|
||||
unless (phVer `isCompatible` clientVRange) . throwError $ AGENT A_VERSION
|
||||
case (e2eDhSecret, e2ePubKey_) of
|
||||
(Nothing, Just e2ePubKey) -> do
|
||||
let e2eDh = C.dh' e2ePubKey e2ePrivKey
|
||||
decryptClientMessage e2eDh clientMsg >>= \case
|
||||
(SMP.PHConfirmation senderKey, AgentConfirmation {e2eEncryption, encConnInfo, agentVersion}) ->
|
||||
smpConfirmation senderKey e2ePubKey e2eEncryption encConnInfo phVer agentVersion >> ack
|
||||
(SMP.PHEmpty, AgentInvitation {connReq, connInfo}) ->
|
||||
smpInvitation connReq connInfo >> ack
|
||||
_ -> prohibited >> ack
|
||||
(Just e2eDh, Nothing) -> do
|
||||
decryptClientMessage e2eDh clientMsg >>= \case
|
||||
(SMP.PHEmpty, AgentMsgEnvelope _ encAgentMsg) -> do
|
||||
-- primary queue is set as Active in helloMsg, below is to set additional queues Active
|
||||
let RcvQueue {primary, dbReplaceQueueId} = rq
|
||||
unless (status == Active) . withStore' c $ \db -> setRcvQueueStatus db rq Active
|
||||
case (conn, dbReplaceQueueId) of
|
||||
(DuplexConnection _ rqs _, Just replacedId) -> do
|
||||
when primary . withStore' c $ \db -> setRcvQueuePrimary db connId rq
|
||||
case find (\RcvQueue {dbQueueId} -> dbQueueId == replacedId) rqs of
|
||||
Just RcvQueue {server, rcvId} -> do
|
||||
enqueueCommand c "" connId (Just server) $ AInternalCommand $ ICQDelete rcvId
|
||||
_ -> notify . ERR . AGENT $ A_QUEUE "replaced RcvQueue not found in connection"
|
||||
_ -> pure ()
|
||||
tryError agentClientMsg >>= \case
|
||||
Right (Just (msgId, msgMeta, aMessage)) -> case aMessage of
|
||||
HELLO -> helloMsg >> ackDel msgId
|
||||
REPLY cReq -> replyMsg cReq >> ackDel msgId
|
||||
-- note that there is no ACK sent for A_MSG, it is sent with agent's user ACK command
|
||||
A_MSG body -> do
|
||||
logServer "<--" c srv rId "MSG <MSG>"
|
||||
notify $ MSG msgMeta msgFlags body
|
||||
QCONT addr -> qDuplex "QCONT" $ continueSending addr
|
||||
QADD qs -> qDuplex "QADD" $ qAddMsg qs
|
||||
QKEY qs -> qDuplex "QKEY" $ qKeyMsg qs
|
||||
QUSE qs -> qDuplex "QUSE" $ qUseMsg qs
|
||||
-- no action needed for QTEST
|
||||
-- any message in the new queue will mark it active and trigger deletion of the old queue
|
||||
QTEST _ -> logServer "<--" c srv rId "MSG <QTEST>" >> ackDel msgId
|
||||
where
|
||||
qDuplex :: String -> (Connection 'CDuplex -> m ()) -> m ()
|
||||
qDuplex name a = case conn of
|
||||
DuplexConnection {} -> a conn >> ackDel msgId
|
||||
_ -> qError $ name <> ": message must be sent to duplex connection"
|
||||
Right _ -> prohibited >> ack
|
||||
Left e@(AGENT A_DUPLICATE) -> do
|
||||
withStore' c (\db -> getLastMsg db connId srvMsgId) >>= \case
|
||||
Just RcvMsg {internalId, msgMeta, msgBody = agentMsgBody, userAck}
|
||||
| userAck -> ackDel internalId
|
||||
| otherwise -> do
|
||||
liftEither (parse smpP (AGENT A_MESSAGE) agentMsgBody) >>= \case
|
||||
AgentMessage _ (A_MSG body) -> do
|
||||
logServer "<--" c srv rId "MSG <MSG>"
|
||||
notify $ MSG msgMeta msgFlags body
|
||||
_ -> pure ()
|
||||
_ -> throwError e
|
||||
Left e -> throwError e
|
||||
where
|
||||
agentClientMsg :: m (Maybe (InternalId, MsgMeta, AMessage))
|
||||
agentClientMsg = withStore c $ \db -> runExceptT $ do
|
||||
agentMsgBody <- agentRatchetDecrypt db connId encAgentMsg
|
||||
liftEither (parse smpP (SEAgentError $ AGENT A_MESSAGE) agentMsgBody) >>= \case
|
||||
agentMsg@(AgentMessage APrivHeader {sndMsgId, prevMsgHash} aMessage) -> do
|
||||
let msgType = agentMessageType agentMsg
|
||||
internalHash = C.sha256Hash agentMsgBody
|
||||
internalTs <- liftIO getCurrentTime
|
||||
(internalId, internalRcvId, prevExtSndId, prevRcvMsgHash) <- liftIO $ updateRcvIds db connId
|
||||
let integrity = checkMsgIntegrity prevExtSndId sndMsgId prevRcvMsgHash prevMsgHash
|
||||
recipient = (unId internalId, internalTs)
|
||||
broker = (srvMsgId, systemToUTCTime srvTs)
|
||||
msgMeta = MsgMeta {integrity, recipient, broker, sndMsgId}
|
||||
rcvMsg = RcvMsgData {msgMeta, msgType, msgFlags, msgBody = agentMsgBody, internalRcvId, internalHash, externalPrevSndHash = prevMsgHash}
|
||||
liftIO $ createRcvMsg db connId rq rcvMsg
|
||||
pure $ Just (internalId, msgMeta, aMessage)
|
||||
_ -> pure Nothing
|
||||
_ -> prohibited >> ack
|
||||
_ -> prohibited >> ack
|
||||
ack :: m ()
|
||||
ack = enqueueCmd $ ICAck rId srvMsgId
|
||||
ackDel :: InternalId -> m ()
|
||||
@@ -1698,6 +1708,16 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
|
||||
connectReplyQueues c cData ownConnInfo smpQueues `catchError` (notify . ERR)
|
||||
_ -> prohibited
|
||||
|
||||
continueSending :: (SMPServer, SMP.SenderId) -> Connection 'CDuplex -> m ()
|
||||
continueSending addr (DuplexConnection _ _ sqs) =
|
||||
case findQ addr sqs of
|
||||
Just sq -> do
|
||||
logServer "<--" c srv rId "MSG <QCONT>"
|
||||
atomically $ do
|
||||
(_, qLock) <- getPendingMsgQ c sq
|
||||
void $ tryPutTMVar qLock ()
|
||||
Nothing -> qError "QCONT: queue address not found"
|
||||
|
||||
-- processed by queue sender
|
||||
qAddMsg :: NonEmpty (SMPQueueUri, Maybe SndQAddr) -> Connection 'CDuplex -> m ()
|
||||
qAddMsg ((_, Nothing) :| _) _ = qError "adding queue without switching is not supported"
|
||||
|
||||
@@ -180,7 +180,7 @@ data AgentClient = AgentClient
|
||||
activeSubs :: TRcvQueues,
|
||||
pendingSubs :: TRcvQueues,
|
||||
pendingMsgsQueued :: TMap SndQAddr Bool,
|
||||
smpQueueMsgQueues :: TMap SndQAddr (TQueue InternalId),
|
||||
smpQueueMsgQueues :: TMap SndQAddr (TQueue InternalId, TMVar ()),
|
||||
smpQueueMsgDeliveries :: TMap SndQAddr (Async ()),
|
||||
connCmdsQueued :: TMap ConnId Bool,
|
||||
asyncCmdQueues :: TMap (Maybe SMPServer) (TQueue AsyncCmdId),
|
||||
|
||||
@@ -83,7 +83,7 @@ data AgentConfig = AgentConfig
|
||||
smpCfg :: ProtocolClientConfig,
|
||||
ntfCfg :: ProtocolClientConfig,
|
||||
reconnectInterval :: RetryInterval,
|
||||
messageRetryInterval :: RetryInterval,
|
||||
messageRetryInterval :: RetryInterval2,
|
||||
messageTimeout :: NominalDiffTime,
|
||||
helloTimeout :: NominalDiffTime,
|
||||
ntfCron :: Word16,
|
||||
@@ -108,12 +108,24 @@ defaultReconnectInterval =
|
||||
maxInterval = 180_000000
|
||||
}
|
||||
|
||||
defaultMessageRetryInterval :: RetryInterval
|
||||
defaultMessageRetryInterval :: RetryInterval2
|
||||
defaultMessageRetryInterval =
|
||||
RetryInterval
|
||||
{ initialInterval = 1_000000,
|
||||
increaseAfter = 10_000000,
|
||||
maxInterval = 60_000000
|
||||
RetryInterval2
|
||||
{ riFast =
|
||||
RetryInterval
|
||||
{ initialInterval = 1_000000,
|
||||
increaseAfter = 10_000000,
|
||||
maxInterval = 60_000000
|
||||
},
|
||||
riSlow =
|
||||
-- TODO: these timeouts can be increased once most clients are updates
|
||||
-- to resume sending on QCONT messages.
|
||||
-- After that local message expiration period should be also increased.
|
||||
RetryInterval
|
||||
{ initialInterval = 10_000000,
|
||||
increaseAfter = 30_000000,
|
||||
maxInterval = 300_000000
|
||||
}
|
||||
}
|
||||
|
||||
defaultAgentConfig :: AgentConfig
|
||||
|
||||
@@ -577,6 +577,7 @@ data AgentMessageType
|
||||
| AM_HELLO_
|
||||
| AM_REPLY_
|
||||
| AM_A_MSG_
|
||||
| AM_QCONT_
|
||||
| AM_QADD_
|
||||
| AM_QKEY_
|
||||
| AM_QUSE_
|
||||
@@ -590,6 +591,7 @@ instance Encoding AgentMessageType where
|
||||
AM_HELLO_ -> "H"
|
||||
AM_REPLY_ -> "R"
|
||||
AM_A_MSG_ -> "M"
|
||||
AM_QCONT_ -> "QC"
|
||||
AM_QADD_ -> "QA"
|
||||
AM_QKEY_ -> "QK"
|
||||
AM_QUSE_ -> "QU"
|
||||
@@ -603,6 +605,7 @@ instance Encoding AgentMessageType where
|
||||
'M' -> pure AM_A_MSG_
|
||||
'Q' ->
|
||||
A.anyChar >>= \case
|
||||
'C' -> pure AM_QCONT_
|
||||
'A' -> pure AM_QADD_
|
||||
'K' -> pure AM_QKEY_
|
||||
'U' -> pure AM_QUSE_
|
||||
@@ -623,6 +626,7 @@ agentMessageType = \case
|
||||
-- REPLY is only used in v1
|
||||
REPLY _ -> AM_REPLY_
|
||||
A_MSG _ -> AM_A_MSG_
|
||||
QCONT _ -> AM_QCONT_
|
||||
QADD _ -> AM_QADD_
|
||||
QKEY _ -> AM_QKEY_
|
||||
QUSE _ -> AM_QUSE_
|
||||
@@ -645,6 +649,7 @@ data AMsgType
|
||||
= HELLO_
|
||||
| REPLY_
|
||||
| A_MSG_
|
||||
| QCONT_
|
||||
| QADD_
|
||||
| QKEY_
|
||||
| QUSE_
|
||||
@@ -656,6 +661,7 @@ instance Encoding AMsgType where
|
||||
HELLO_ -> "H"
|
||||
REPLY_ -> "R"
|
||||
A_MSG_ -> "M"
|
||||
QCONT_ -> "QC"
|
||||
QADD_ -> "QA"
|
||||
QKEY_ -> "QK"
|
||||
QUSE_ -> "QU"
|
||||
@@ -667,6 +673,7 @@ instance Encoding AMsgType where
|
||||
'M' -> pure A_MSG_
|
||||
'Q' ->
|
||||
A.anyChar >>= \case
|
||||
'C' -> pure QCONT_
|
||||
'A' -> pure QADD_
|
||||
'K' -> pure QKEY_
|
||||
'U' -> pure QUSE_
|
||||
@@ -684,6 +691,8 @@ data AMessage
|
||||
REPLY (L.NonEmpty SMPQueueInfo)
|
||||
| -- | agent envelope for the client message
|
||||
A_MSG MsgBody
|
||||
| -- | the message instructing the client to continue sending messages (after ERR QUOTA)
|
||||
QCONT SndQAddr
|
||||
| -- add queue to connection (sent by recipient), with optional address of the replaced queue
|
||||
QADD (L.NonEmpty (SMPQueueUri, Maybe SndQAddr))
|
||||
| -- key to secure the added queues and agree e2e encryption key (sent by sender)
|
||||
@@ -701,6 +710,7 @@ instance Encoding AMessage where
|
||||
HELLO -> smpEncode HELLO_
|
||||
REPLY smpQueues -> smpEncode (REPLY_, smpQueues)
|
||||
A_MSG body -> smpEncode (A_MSG_, Tail body)
|
||||
QCONT addr -> smpEncode (QCONT_, addr)
|
||||
QADD qs -> smpEncode (QADD_, qs)
|
||||
QKEY qs -> smpEncode (QKEY_, qs)
|
||||
QUSE qs -> smpEncode (QUSE_, qs)
|
||||
@@ -711,6 +721,7 @@ instance Encoding AMessage where
|
||||
HELLO_ -> pure HELLO
|
||||
REPLY_ -> REPLY <$> smpP
|
||||
A_MSG_ -> A_MSG . unTail <$> smpP
|
||||
QCONT_ -> QCONT <$> smpP
|
||||
QADD_ -> QADD <$> smpP
|
||||
QKEY_ -> QKEY <$> smpP
|
||||
QUSE_ -> QUSE <$> smpP
|
||||
|
||||
@@ -1,10 +1,21 @@
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module Simplex.Messaging.Agent.RetryInterval where
|
||||
module Simplex.Messaging.Agent.RetryInterval
|
||||
( RetryInterval (..),
|
||||
RetryInterval2 (..),
|
||||
RetryIntervalMode (..),
|
||||
withRetryInterval,
|
||||
withRetryLock2,
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Concurrent (threadDelay)
|
||||
import Control.Concurrent (forkIO, threadDelay)
|
||||
import Control.Monad (void)
|
||||
import Control.Monad.IO.Class (MonadIO, liftIO)
|
||||
import Simplex.Messaging.Util (whenM)
|
||||
import UnliftIO.STM
|
||||
|
||||
data RetryInterval = RetryInterval
|
||||
{ initialInterval :: Int,
|
||||
@@ -12,17 +23,51 @@ data RetryInterval = RetryInterval
|
||||
maxInterval :: Int
|
||||
}
|
||||
|
||||
data RetryInterval2 = RetryInterval2
|
||||
{ riSlow :: RetryInterval,
|
||||
riFast :: RetryInterval
|
||||
}
|
||||
|
||||
data RetryIntervalMode = RISlow | RIFast
|
||||
deriving (Eq)
|
||||
|
||||
withRetryInterval :: forall m. MonadIO m => RetryInterval -> (m () -> m ()) -> m ()
|
||||
withRetryInterval RetryInterval {initialInterval, increaseAfter, maxInterval} action =
|
||||
callAction 0 initialInterval
|
||||
withRetryInterval ri action = callAction 0 $ initialInterval ri
|
||||
where
|
||||
callAction :: Int -> Int -> m ()
|
||||
callAction elapsedTime delay = action loop
|
||||
callAction elapsed delay = action loop
|
||||
where
|
||||
loop = do
|
||||
let newDelay =
|
||||
if elapsedTime < increaseAfter || delay == maxInterval
|
||||
then delay
|
||||
else min (delay * 3 `div` 2) maxInterval
|
||||
liftIO $ threadDelay delay
|
||||
callAction (elapsedTime + delay) newDelay
|
||||
let elapsed' = elapsed + delay
|
||||
callAction elapsed' $ nextDelay elapsed' delay ri
|
||||
|
||||
-- This function allows action to toggle between slow and fast retry intervals.
|
||||
withRetryLock2 :: forall m. MonadIO m => RetryInterval2 -> TMVar () -> ((RetryIntervalMode -> m ()) -> m ()) -> m ()
|
||||
withRetryLock2 RetryInterval2 {riSlow, riFast} lock action =
|
||||
callAction (0, initialInterval riSlow) (0, initialInterval riFast)
|
||||
where
|
||||
callAction :: (Int, Int) -> (Int, Int) -> m ()
|
||||
callAction slow fast = action loop
|
||||
where
|
||||
loop = \case
|
||||
RISlow -> run slow riSlow (`callAction` fast)
|
||||
RIFast -> run fast riFast (callAction slow)
|
||||
run (elapsed, delay) ri call = do
|
||||
wait delay
|
||||
let elapsed' = elapsed + delay
|
||||
call (elapsed', nextDelay elapsed' delay ri)
|
||||
wait delay = do
|
||||
waiting <- newTVarIO True
|
||||
_ <- liftIO . forkIO $ do
|
||||
threadDelay delay
|
||||
atomically $ whenM (readTVar waiting) $ void $ tryPutTMVar lock ()
|
||||
atomically $ do
|
||||
takeTMVar lock
|
||||
writeTVar waiting False
|
||||
|
||||
nextDelay :: Int -> Int -> RetryInterval -> Int
|
||||
nextDelay elapsed delay RetryInterval {increaseAfter, maxInterval} =
|
||||
if elapsed < increaseAfter || delay == maxInterval
|
||||
then delay
|
||||
else min (delay * 3 `div` 2) maxInterval
|
||||
|
||||
@@ -141,6 +141,7 @@ import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Char (isPrint, isSpace)
|
||||
import Data.Functor (($>))
|
||||
import Data.Kind
|
||||
import Data.List.NonEmpty (NonEmpty (..))
|
||||
import qualified Data.List.NonEmpty as L
|
||||
@@ -303,12 +304,17 @@ data RcvMessage = RcvMessage
|
||||
deriving (Eq, Show)
|
||||
|
||||
-- | received message without server/recipient encryption
|
||||
data Message = Message
|
||||
{ msgId :: MsgId,
|
||||
msgTs :: SystemTime,
|
||||
msgFlags :: MsgFlags,
|
||||
msgBody :: C.MaxLenBS MaxMessageLen
|
||||
}
|
||||
data Message
|
||||
= Message
|
||||
{ msgId :: MsgId,
|
||||
msgTs :: SystemTime,
|
||||
msgFlags :: MsgFlags,
|
||||
msgBody :: C.MaxLenBS MaxMessageLen
|
||||
}
|
||||
| MessageQuota
|
||||
{ msgId :: MsgId,
|
||||
msgTs :: SystemTime
|
||||
}
|
||||
|
||||
instance StrEncoding RcvMessage where
|
||||
strEncode RcvMessage {msgId, msgTs, msgFlags, msgBody = EncRcvMsgBody body} =
|
||||
@@ -328,44 +334,72 @@ instance StrEncoding RcvMessage where
|
||||
newtype EncRcvMsgBody = EncRcvMsgBody ByteString
|
||||
deriving (Eq, Show)
|
||||
|
||||
data RcvMsgBody = RcvMsgBody
|
||||
{ msgTs :: SystemTime,
|
||||
msgFlags :: MsgFlags,
|
||||
msgBody :: C.MaxLenBS MaxMessageLen
|
||||
}
|
||||
data RcvMsgBody
|
||||
= RcvMsgBody
|
||||
{ msgTs :: SystemTime,
|
||||
msgFlags :: MsgFlags,
|
||||
msgBody :: C.MaxLenBS MaxMessageLen
|
||||
}
|
||||
| RcvMsgQuota
|
||||
{ msgTs :: SystemTime
|
||||
}
|
||||
|
||||
msgQuotaTag :: ByteString
|
||||
msgQuotaTag = "QUOTA"
|
||||
|
||||
encodeRcvMsgBody :: RcvMsgBody -> C.MaxLenBS MaxRcvMessageLen
|
||||
encodeRcvMsgBody RcvMsgBody {msgTs, msgFlags, msgBody} =
|
||||
let rcvMeta :: C.MaxLenBS 16 = C.unsafeMaxLenBS $ smpEncode (msgTs, msgFlags, ' ')
|
||||
in C.appendMaxLenBS rcvMeta msgBody
|
||||
encodeRcvMsgBody = \case
|
||||
RcvMsgBody {msgTs, msgFlags, msgBody} ->
|
||||
let rcvMeta :: C.MaxLenBS 16 = C.unsafeMaxLenBS $ smpEncode (msgTs, msgFlags, ' ')
|
||||
in C.appendMaxLenBS rcvMeta msgBody
|
||||
RcvMsgQuota {msgTs} ->
|
||||
C.unsafeMaxLenBS $ msgQuotaTag <> " " <> smpEncode msgTs
|
||||
|
||||
data ClientRcvMsgBody = ClientRcvMsgBody
|
||||
{ msgTs :: SystemTime,
|
||||
msgFlags :: MsgFlags,
|
||||
msgBody :: ByteString
|
||||
}
|
||||
data ClientRcvMsgBody
|
||||
= ClientRcvMsgBody
|
||||
{ msgTs :: SystemTime,
|
||||
msgFlags :: MsgFlags,
|
||||
msgBody :: ByteString
|
||||
}
|
||||
| ClientRcvMsgQuota
|
||||
{ msgTs :: SystemTime
|
||||
}
|
||||
|
||||
clientRcvMsgBodyP :: Parser ClientRcvMsgBody
|
||||
clientRcvMsgBodyP = do
|
||||
msgTs <- smpP
|
||||
msgFlags <- smpP
|
||||
Tail msgBody <- _smpP
|
||||
pure ClientRcvMsgBody {msgTs, msgFlags, msgBody}
|
||||
clientRcvMsgBodyP = msgQuotaP <|> msgBodyP
|
||||
where
|
||||
msgQuotaP = A.string msgQuotaTag *> (ClientRcvMsgQuota <$> _smpP)
|
||||
msgBodyP = do
|
||||
msgTs <- smpP
|
||||
msgFlags <- smpP
|
||||
Tail msgBody <- _smpP
|
||||
pure ClientRcvMsgBody {msgTs, msgFlags, msgBody}
|
||||
|
||||
instance StrEncoding Message where
|
||||
strEncode Message {msgId, msgTs, msgFlags, msgBody} =
|
||||
B.unwords
|
||||
[ strEncode msgId,
|
||||
strEncode msgTs,
|
||||
"flags=" <> strEncode msgFlags,
|
||||
strEncode msgBody
|
||||
]
|
||||
strEncode = \case
|
||||
Message {msgId, msgTs, msgFlags, msgBody} ->
|
||||
B.unwords
|
||||
[ strEncode msgId,
|
||||
strEncode msgTs,
|
||||
"flags=" <> strEncode msgFlags,
|
||||
strEncode msgBody
|
||||
]
|
||||
MessageQuota {msgId, msgTs} ->
|
||||
B.unwords
|
||||
[ strEncode msgId,
|
||||
strEncode msgTs,
|
||||
"quota"
|
||||
]
|
||||
strP = do
|
||||
msgId <- strP_
|
||||
msgTs <- strP_
|
||||
msgFlags <- ("flags=" *> strP_) <|> pure noMsgFlags
|
||||
msgBody <- strP
|
||||
pure Message {msgId, msgTs, msgFlags, msgBody}
|
||||
msgQuotaP msgId msgTs <|> msgP msgId msgTs
|
||||
where
|
||||
msgQuotaP msgId msgTs = "quota" $> MessageQuota {msgId, msgTs}
|
||||
msgP msgId msgTs = do
|
||||
msgFlags <- ("flags=" *> strP_) <|> pure noMsgFlags
|
||||
msgBody <- strP
|
||||
pure Message {msgId, msgTs, msgFlags, msgBody}
|
||||
|
||||
type EncNMsgMeta = ByteString
|
||||
|
||||
@@ -377,7 +411,9 @@ data SMPMsgMeta = SMPMsgMeta
|
||||
deriving (Show)
|
||||
|
||||
rcvMessageMeta :: MsgId -> ClientRcvMsgBody -> SMPMsgMeta
|
||||
rcvMessageMeta msgId ClientRcvMsgBody {msgTs, msgFlags} = SMPMsgMeta {msgId, msgTs, msgFlags}
|
||||
rcvMessageMeta msgId = \case
|
||||
ClientRcvMsgBody {msgTs, msgFlags} -> SMPMsgMeta {msgId, msgTs, msgFlags}
|
||||
ClientRcvMsgQuota {msgTs} -> SMPMsgMeta {msgId, msgTs, msgFlags = noMsgFlags}
|
||||
|
||||
data NMsgMeta = NMsgMeta
|
||||
{ msgId :: MsgId,
|
||||
|
||||
@@ -538,20 +538,22 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
|
||||
| otherwise = case status qr of
|
||||
QueueOff -> return $ err AUTH
|
||||
QueueActive ->
|
||||
mapM mkMessage (C.maxLenBS msgBody) >>= \case
|
||||
case C.maxLenBS msgBody of
|
||||
Left _ -> pure $ err LARGE_MSG
|
||||
Right msg -> do
|
||||
resp@(_, _, sent) <- time "SEND" $ do
|
||||
Right body -> do
|
||||
msg_ <- time "SEND" $ do
|
||||
q <- getStoreMsgQueue "SEND" $ recipientId qr
|
||||
expireMessages q
|
||||
atomically $ ifM (isFull q) (pure $ err QUOTA) (writeMsg q msg $> ok)
|
||||
when (sent == OK) . time "SEND ok" $ do
|
||||
when (notification msgFlags) $
|
||||
atomically . trySendNotification msg =<< asks idsDrg
|
||||
stats <- asks serverStats
|
||||
atomically $ modifyTVar (msgSent stats) (+ 1)
|
||||
atomically $ updatePeriodStats (activeQueues stats) (recipientId qr)
|
||||
pure resp
|
||||
atomically . writeMsg q =<< mkMessage body
|
||||
case msg_ of
|
||||
Nothing -> pure $ err QUOTA
|
||||
Just msg -> time "SEND ok" $ do
|
||||
when (notification msgFlags) $
|
||||
atomically . trySendNotification msg =<< asks idsDrg
|
||||
stats <- asks serverStats
|
||||
atomically $ modifyTVar (msgSent stats) (+ 1)
|
||||
atomically $ updatePeriodStats (activeQueues stats) (recipientId qr)
|
||||
pure ok
|
||||
where
|
||||
mkMessage :: C.MaxLenBS MaxMessageLen -> m Message
|
||||
mkMessage body = do
|
||||
@@ -572,12 +574,14 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
|
||||
|
||||
writeNtf :: NotifierId -> Message -> RcvNtfDhSecret -> TVar ChaChaDRG -> Client -> STM ()
|
||||
writeNtf nId msg rcvNtfDhSecret ntfNonceDrg Client {sndQ = q} =
|
||||
unlessM (isFullTBQueue q) $ do
|
||||
(nmsgNonce, encNMsgMeta) <- mkMessageNotification msg rcvNtfDhSecret ntfNonceDrg
|
||||
writeTBQueue q [(CorrId "", nId, NMSG nmsgNonce encNMsgMeta)]
|
||||
unlessM (isFullTBQueue q) $ case msg of
|
||||
Message {msgId, msgTs} -> do
|
||||
(nmsgNonce, encNMsgMeta) <- mkMessageNotification msgId msgTs rcvNtfDhSecret ntfNonceDrg
|
||||
writeTBQueue q [(CorrId "", nId, NMSG nmsgNonce encNMsgMeta)]
|
||||
_ -> pure ()
|
||||
|
||||
mkMessageNotification :: Message -> RcvNtfDhSecret -> TVar ChaChaDRG -> STM (C.CbNonce, EncNMsgMeta)
|
||||
mkMessageNotification Message {msgId, msgTs} rcvNtfDhSecret ntfNonceDrg = do
|
||||
mkMessageNotification :: ByteString -> SystemTime -> RcvNtfDhSecret -> TVar ChaChaDRG -> STM (C.CbNonce, EncNMsgMeta)
|
||||
mkMessageNotification msgId msgTs rcvNtfDhSecret ntfNonceDrg = do
|
||||
cbNonce <- C.pseudoRandomCbNonce ntfNonceDrg
|
||||
let msgMeta = NMsgMeta {msgId, msgTs}
|
||||
encNMsgMeta = C.cbEncrypt rcvNtfDhSecret cbNonce (smpEncode msgMeta) 128
|
||||
@@ -615,17 +619,22 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
|
||||
time name = timed name queueId
|
||||
|
||||
encryptMsg :: QueueRec -> Message -> RcvMessage
|
||||
encryptMsg qr Message {msgId, msgTs, msgFlags, msgBody}
|
||||
| thVersion == 1 || thVersion == 2 = encrypt msgBody
|
||||
| otherwise = encrypt $ encodeRcvMsgBody RcvMsgBody {msgTs, msgFlags, msgBody}
|
||||
encryptMsg qr msg = case msg of
|
||||
Message {msgFlags, msgBody}
|
||||
| thVersion == 1 || thVersion == 2 -> encrypt msgFlags msgBody
|
||||
| otherwise -> encrypt msgFlags $ encodeRcvMsgBody RcvMsgBody {msgTs = msgTs', msgFlags, msgBody}
|
||||
MessageQuota {} ->
|
||||
encrypt noMsgFlags $ encodeRcvMsgBody (RcvMsgQuota msgTs')
|
||||
where
|
||||
encrypt :: KnownNat i => C.MaxLenBS i -> RcvMessage
|
||||
encrypt body =
|
||||
let encBody = EncRcvMsgBody $ C.cbEncryptMaxLenBS (rcvDhSecret qr) (C.cbNonce msgId) body
|
||||
in RcvMessage msgId msgTs msgFlags encBody
|
||||
encrypt :: KnownNat i => MsgFlags -> C.MaxLenBS i -> RcvMessage
|
||||
encrypt msgFlags body =
|
||||
let encBody = EncRcvMsgBody $ C.cbEncryptMaxLenBS (rcvDhSecret qr) (C.cbNonce msgId') body
|
||||
in RcvMessage msgId' msgTs' msgFlags encBody
|
||||
msgId' = msgId (msg :: Message)
|
||||
msgTs' = msgTs (msg :: Message)
|
||||
|
||||
setDelivered :: Sub -> Message -> STM Bool
|
||||
setDelivered s Message {msgId} = tryPutTMVar (delivered s) msgId
|
||||
setDelivered s msg = tryPutTMVar (delivered s) $ msgId (msg :: Message)
|
||||
|
||||
getStoreMsgQueue :: T.Text -> RecipientId -> m MsgQueue
|
||||
getStoreMsgQueue name rId = time (name <> " getMsgQueue") $ do
|
||||
@@ -717,7 +726,7 @@ restoreServerMessages = asks (storeMsgsFile . config) >>= mapM_ restoreMessages
|
||||
addToMsgQueue rId msg = do
|
||||
full <- atomically $ do
|
||||
q <- getMsgQueue ms rId quota
|
||||
ifM (isFull q) (pure True) (writeMsg q msg $> False)
|
||||
isNothing <$> writeMsg q msg
|
||||
when full . logError . decodeLatin1 $ "message queue " <> strEncode rId <> " is full, message not restored: " <> strEncode (msgId (msg :: Message))
|
||||
updateMsgV1toV3 QueueRec {rcvDhSecret} RcvMessage {msgId, msgTs, msgFlags, msgBody = EncRcvMsgBody body} = do
|
||||
let nonce = C.cbNonce msgId
|
||||
|
||||
@@ -39,7 +39,7 @@ data ServerConfig = ServerConfig
|
||||
{ transports :: [(ServiceName, ATransport)],
|
||||
tbqSize :: Natural,
|
||||
serverTbqSize :: Natural,
|
||||
msgQueueQuota :: Natural,
|
||||
msgQueueQuota :: Int,
|
||||
queueIdBytes :: Int,
|
||||
msgIdBytes :: Int,
|
||||
storeLogFile :: Maybe FilePath,
|
||||
|
||||
@@ -19,13 +19,12 @@ instance StrEncoding MsgLogRecord where
|
||||
strP = "v3 " *> (MLRv3 <$> strP_ <*> strP) <|> MLRv1 <$> strP_ <*> strP
|
||||
|
||||
class MonadMsgStore s q m | s -> q where
|
||||
getMsgQueue :: s -> RecipientId -> Natural -> m q
|
||||
getMsgQueue :: s -> RecipientId -> Int -> m q
|
||||
delMsgQueue :: s -> RecipientId -> m ()
|
||||
flushMsgQueue :: s -> RecipientId -> m [Message]
|
||||
|
||||
class MonadMsgQueue q m where
|
||||
isFull :: q -> m Bool
|
||||
writeMsg :: q -> Message -> m () -- non blocking
|
||||
writeMsg :: q -> Message -> m (Maybe Message) -- non blocking
|
||||
tryPeekMsg :: q -> m (Maybe Message) -- non blocking
|
||||
peekMsg :: q -> m Message -- blocking
|
||||
tryDelMsg :: q -> MsgId -> m Bool -- non blocking
|
||||
|
||||
@@ -7,22 +7,31 @@
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
|
||||
module Simplex.Messaging.Server.MsgStore.STM where
|
||||
module Simplex.Messaging.Server.MsgStore.STM
|
||||
( STMMsgStore,
|
||||
MsgQueue,
|
||||
newMsgStore,
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Concurrent.STM.TBQueue (flushTBQueue)
|
||||
import Control.Concurrent.STM.TQueue (flushTQueue)
|
||||
import Control.Monad (when)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Functor (($>))
|
||||
import Data.Int (Int64)
|
||||
import Data.Time.Clock.System (SystemTime (systemSeconds))
|
||||
import Numeric.Natural
|
||||
import Simplex.Messaging.Protocol (Message (..), MsgId, RecipientId)
|
||||
import Simplex.Messaging.Server.MsgStore
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import UnliftIO.STM
|
||||
|
||||
newtype MsgQueue = MsgQueue {msgQueue :: TBQueue Message}
|
||||
data MsgQueue = MsgQueue
|
||||
{ msgQueue :: TQueue Message,
|
||||
quota :: Int,
|
||||
canWrite :: TVar Bool,
|
||||
size :: TVar Int
|
||||
}
|
||||
|
||||
type STMMsgStore = TMap RecipientId MsgQueue
|
||||
|
||||
@@ -30,54 +39,77 @@ newMsgStore :: STM STMMsgStore
|
||||
newMsgStore = TM.empty
|
||||
|
||||
instance MonadMsgStore STMMsgStore MsgQueue STM where
|
||||
getMsgQueue :: STMMsgStore -> RecipientId -> Natural -> STM MsgQueue
|
||||
getMsgQueue :: STMMsgStore -> RecipientId -> Int -> STM MsgQueue
|
||||
getMsgQueue st rId quota = maybe newQ pure =<< TM.lookup rId st
|
||||
where
|
||||
newQ = do
|
||||
q <- MsgQueue <$> newTBQueue quota
|
||||
msgQueue <- newTQueue
|
||||
canWrite <- newTVar True
|
||||
size <- newTVar 0
|
||||
let q = MsgQueue {msgQueue, quota, canWrite, size}
|
||||
TM.insert rId q st
|
||||
return q
|
||||
pure q
|
||||
|
||||
delMsgQueue :: STMMsgStore -> RecipientId -> STM ()
|
||||
delMsgQueue st rId = TM.delete rId st
|
||||
|
||||
flushMsgQueue :: STMMsgStore -> RecipientId -> STM [Message]
|
||||
flushMsgQueue st rId = TM.lookup rId st >>= maybe (pure []) (flushTBQueue . msgQueue)
|
||||
flushMsgQueue st rId = TM.lookupDelete rId st >>= maybe (pure []) (flushTQueue . msgQueue)
|
||||
|
||||
instance MonadMsgQueue MsgQueue STM where
|
||||
isFull :: MsgQueue -> STM Bool
|
||||
isFull = isFullTBQueue . msgQueue
|
||||
|
||||
writeMsg :: MsgQueue -> Message -> STM ()
|
||||
writeMsg = writeTBQueue . msgQueue
|
||||
writeMsg :: MsgQueue -> Message -> STM (Maybe Message)
|
||||
writeMsg MsgQueue {msgQueue = q, quota, canWrite, size} msg = do
|
||||
canWrt <- readTVar canWrite
|
||||
empty <- isEmptyTQueue q
|
||||
if canWrt || empty
|
||||
then do
|
||||
canWrt' <- (quota >) <$> readTVar size
|
||||
writeTVar canWrite canWrt'
|
||||
modifyTVar' size (+ 1)
|
||||
if canWrt'
|
||||
then writeTQueue q msg $> Just msg
|
||||
else writeTQueue q msgQuota $> Nothing
|
||||
else pure Nothing
|
||||
where
|
||||
msgQuota = MessageQuota {msgId = msgId msg, msgTs = msgTs msg}
|
||||
|
||||
tryPeekMsg :: MsgQueue -> STM (Maybe Message)
|
||||
tryPeekMsg = tryPeekTBQueue . msgQueue
|
||||
tryPeekMsg = tryPeekTQueue . msgQueue
|
||||
{-# INLINE tryPeekMsg #-}
|
||||
|
||||
peekMsg :: MsgQueue -> STM Message
|
||||
peekMsg = peekTBQueue . msgQueue
|
||||
peekMsg = peekTQueue . msgQueue
|
||||
{-# INLINE peekMsg #-}
|
||||
|
||||
tryDelMsg :: MsgQueue -> MsgId -> STM Bool
|
||||
tryDelMsg (MsgQueue q) msgId' =
|
||||
tryPeekTBQueue q >>= \case
|
||||
Just Message {msgId}
|
||||
| msgId == msgId' || B.null msgId' -> tryReadTBQueue q $> True
|
||||
tryDelMsg mq msgId' =
|
||||
tryPeekMsg mq >>= \case
|
||||
Just msg
|
||||
| msgId msg == msgId' || B.null msgId' -> tryDeleteMsg mq >> pure True
|
||||
| otherwise -> pure False
|
||||
_ -> pure False
|
||||
|
||||
-- atomic delete (== read) last and peek next message if available
|
||||
tryDelPeekMsg :: MsgQueue -> MsgId -> STM (Bool, Maybe Message)
|
||||
tryDelPeekMsg (MsgQueue q) msgId' =
|
||||
tryPeekTBQueue q >>= \case
|
||||
msg_@(Just Message {msgId})
|
||||
| msgId == msgId' || B.null msgId' -> (True,) <$> (tryReadTBQueue q >> tryPeekTBQueue q)
|
||||
tryDelPeekMsg mq msgId' =
|
||||
tryPeekMsg mq >>= \case
|
||||
msg_@(Just msg)
|
||||
| msgId msg == msgId' || B.null msgId' -> (True,) <$> (tryDeleteMsg mq >> tryPeekMsg mq)
|
||||
| otherwise -> pure (False, msg_)
|
||||
_ -> pure (False, Nothing)
|
||||
|
||||
deleteExpiredMsgs :: MsgQueue -> Int64 -> STM ()
|
||||
deleteExpiredMsgs (MsgQueue q) old = loop
|
||||
deleteExpiredMsgs mq old = loop
|
||||
where
|
||||
loop = tryPeekTBQueue q >>= mapM_ delOldMsg
|
||||
delOldMsg Message {msgTs} =
|
||||
when (systemSeconds msgTs < old) $
|
||||
tryReadTBQueue q >> loop
|
||||
loop = tryPeekMsg mq >>= mapM_ delOldMsg
|
||||
delOldMsg = \case
|
||||
Message {msgTs} ->
|
||||
when (systemSeconds msgTs < old) $
|
||||
tryDeleteMsg mq >> loop
|
||||
_ -> pure ()
|
||||
|
||||
tryDeleteMsg :: MsgQueue -> STM ()
|
||||
tryDeleteMsg MsgQueue {msgQueue = q, size} =
|
||||
tryReadTQueue q >>= \case
|
||||
Just _ -> modifyTVar' size (subtract 1)
|
||||
_ -> pure ()
|
||||
|
||||
@@ -78,6 +78,8 @@ agentTests (ATransport t) = do
|
||||
smpAgentTest2_2_1 $ testConcurrentMsgDelivery t
|
||||
it "should deliver messages if one of connections has quota exceeded" $
|
||||
smpAgentTest2_2_1 $ testMsgDeliveryQuotaExceeded t
|
||||
it "should resume delivering messages after exceeding quota once all messages are received" $
|
||||
smpAgentTest2_2_1 $ testResumeDeliveryQuotaExceeded t
|
||||
|
||||
tGetAgent :: Transport c => c -> IO (ATransmissionOrError 'Agent)
|
||||
tGetAgent h = do
|
||||
@@ -430,6 +432,32 @@ testMsgDeliveryQuotaExceeded _ alice bob = do
|
||||
-- if delivery is blocked it won't go further
|
||||
alice <# ("", "bob2", SENT 4)
|
||||
|
||||
testResumeDeliveryQuotaExceeded :: Transport c => TProxy c -> c -> c -> IO ()
|
||||
testResumeDeliveryQuotaExceeded _ alice bob = do
|
||||
connect (alice, "alice") (bob, "bob")
|
||||
forM_ [1 .. 4 :: Int] $ \i -> do
|
||||
let corrId = bshow i
|
||||
msg = "message " <> bshow i
|
||||
(_, "bob", Right (MID mId)) <- alice #: (corrId, "bob", "SEND F :" <> msg)
|
||||
alice <#= \case ("", "bob", SENT m) -> m == mId; _ -> False
|
||||
("5", "bob", Right (MID 8)) <- alice #: ("5", "bob", "SEND F :over quota")
|
||||
alice #:# "the last message not sent yet"
|
||||
bob <#= \case ("", "alice", Msg "message 1") -> True; _ -> False
|
||||
bob #: ("1", "alice", "ACK 4") #> ("1", "alice", OK)
|
||||
alice #:# "the last message not sent"
|
||||
bob <#= \case ("", "alice", Msg "message 2") -> True; _ -> False
|
||||
bob #: ("2", "alice", "ACK 5") #> ("2", "alice", OK)
|
||||
alice #:# "the last message not sent"
|
||||
bob <#= \case ("", "alice", Msg "message 3") -> True; _ -> False
|
||||
bob #: ("3", "alice", "ACK 6") #> ("3", "alice", OK)
|
||||
alice #:# "the last message not sent"
|
||||
bob <#= \case ("", "alice", Msg "message 4") -> True; _ -> False
|
||||
bob #: ("4", "alice", "ACK 7") #> ("4", "alice", OK)
|
||||
alice <# ("", "bob", SENT 8)
|
||||
bob <#= \case ("", "alice", Msg "over quota") -> True; _ -> False
|
||||
-- message 8 is skipped because of alice agent sending "QCONT" message
|
||||
bob #: ("5", "alice", "ACK 9") #> ("5", "alice", OK)
|
||||
|
||||
connect :: forall c. Transport c => (c, ByteString) -> (c, ByteString) -> IO ()
|
||||
connect (h1, name1) (h2, name2) = do
|
||||
("c1", _, Right (INV cReq)) <- h1 #: ("c1", name2, "NEW T INV")
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module CoreTests.RetryIntervalTests where
|
||||
|
||||
import Control.Concurrent.STM
|
||||
import Control.Monad (when)
|
||||
import Data.Time.Clock (UTCTime, diffUTCTime, getCurrentTime, nominalDiffTimeToSeconds)
|
||||
import Simplex.Messaging.Agent.RetryInterval
|
||||
import Test.Hspec
|
||||
|
||||
retryIntervalTests :: Spec
|
||||
retryIntervalTests = do
|
||||
describe "Retry interval with 2 modes and lock" $ do
|
||||
testRetryIntervalSameMode
|
||||
testRetryIntervalSwitchMode
|
||||
|
||||
testRI :: RetryInterval2
|
||||
testRI =
|
||||
RetryInterval2
|
||||
{ riSlow =
|
||||
RetryInterval
|
||||
{ initialInterval = 20000,
|
||||
increaseAfter = 40000,
|
||||
maxInterval = 40000
|
||||
},
|
||||
riFast =
|
||||
RetryInterval
|
||||
{ initialInterval = 10000,
|
||||
increaseAfter = 20000,
|
||||
maxInterval = 40000
|
||||
}
|
||||
}
|
||||
|
||||
testRetryIntervalSameMode :: Spec
|
||||
testRetryIntervalSameMode =
|
||||
it "should increase elapased time and interval when the mode stays the same" $ do
|
||||
lock <- newEmptyTMVarIO
|
||||
intervals <- newTVarIO []
|
||||
ts <- newTVarIO =<< getCurrentTime
|
||||
withRetryLock2 testRI lock $ \loop -> do
|
||||
ints <- addInterval intervals ts
|
||||
when (length ints < 9) $ loop RIFast
|
||||
(reverse <$> readTVarIO intervals) `shouldReturn` [0, 1, 1, 1, 2, 3, 4, 4, 4]
|
||||
|
||||
testRetryIntervalSwitchMode :: Spec
|
||||
testRetryIntervalSwitchMode =
|
||||
it "should increase elapased time and interval when the mode stays the same" $ do
|
||||
lock <- newEmptyTMVarIO
|
||||
intervals <- newTVarIO []
|
||||
ts <- newTVarIO =<< getCurrentTime
|
||||
withRetryLock2 testRI lock $ \loop -> do
|
||||
ints <- addInterval intervals ts
|
||||
when (length ints < 11) $ loop $ if length ints <= 5 then RIFast else RISlow
|
||||
(reverse <$> readTVarIO intervals) `shouldReturn` [0, 1, 1, 1, 2, 3, 2, 2, 3, 4, 4]
|
||||
|
||||
addInterval :: TVar [Int] -> TVar UTCTime -> IO [Int]
|
||||
addInterval intervals ts = do
|
||||
ts' <- getCurrentTime
|
||||
atomically $ do
|
||||
int :: Int <- truncate . (* 100) . nominalDiffTimeToSeconds <$> stateTVar ts (\t -> (diffUTCTime ts' t, ts'))
|
||||
stateTVar intervals $ \ints -> (int : ints, int : ints)
|
||||
+32
-3
@@ -49,6 +49,7 @@ serverTests t@(ATransport t') = do
|
||||
describe "switch subscription to another TCP connection" $ testSwitchSub t
|
||||
describe "GET command" $ testGetCommand t'
|
||||
describe "GET & SUB commands" $ testGetSubCommands t'
|
||||
describe "Exceeding queue quota" $ testExceedQueueQuota t'
|
||||
describe "Store log" $ testWithStoreLog t
|
||||
describe "Restore messages" $ testRestoreMessages t
|
||||
describe "Restore messages (v2)" $ testRestoreMessagesV2 t
|
||||
@@ -104,9 +105,11 @@ decryptMsgV2 :: C.DhSecret 'C.X25519 -> ByteString -> ByteString -> Either C.Cry
|
||||
decryptMsgV2 dhShared = C.cbDecrypt dhShared . C.cbNonce
|
||||
|
||||
decryptMsgV3 :: C.DhSecret 'C.X25519 -> ByteString -> ByteString -> Either String MsgBody
|
||||
decryptMsgV3 dhShared nonce body = do
|
||||
ClientRcvMsgBody {msgBody} <- parseAll clientRcvMsgBodyP =<< first show (C.cbDecrypt dhShared (C.cbNonce nonce) body)
|
||||
pure msgBody
|
||||
decryptMsgV3 dhShared nonce body =
|
||||
case parseAll clientRcvMsgBodyP =<< first show (C.cbDecrypt dhShared (C.cbNonce nonce) body) of
|
||||
Right ClientRcvMsgBody {msgBody} -> Right msgBody
|
||||
Right ClientRcvMsgQuota {} -> Left "ClientRcvMsgQuota"
|
||||
Left e -> Left e
|
||||
|
||||
testCreateSecureV2 :: forall c. Transport c => TProxy c -> Spec
|
||||
testCreateSecureV2 _ =
|
||||
@@ -494,6 +497,32 @@ testGetSubCommands t =
|
||||
Resp "12" _ OK <- signSendRecv rh2 rKey ("12", rId, GET)
|
||||
pure ()
|
||||
|
||||
testExceedQueueQuota :: forall c. Transport c => TProxy c -> Spec
|
||||
testExceedQueueQuota t =
|
||||
it "should reply with ERR QUOTA to sender and send QUOTA message to the recipient" $ do
|
||||
withSmpServerConfigOn (ATransport t) cfg {msgQueueQuota = 2} testPort $ \_ ->
|
||||
testSMPClient @c $ \sh -> testSMPClient @c $ \rh -> do
|
||||
(sPub, sKey) <- C.generateSignatureKeyPair C.SEd25519
|
||||
(sId, rId, rKey, dhShared) <- createAndSecureQueue rh sPub
|
||||
let dec = decryptMsgV3 dhShared
|
||||
Resp "1" _ OK <- signSendRecv sh sKey ("1", sId, _SEND "hello 1")
|
||||
Resp "2" _ OK <- signSendRecv sh sKey ("2", sId, _SEND "hello 2")
|
||||
Resp "3" _ (ERR QUOTA) <- signSendRecv sh sKey ("3", sId, _SEND "hello 3")
|
||||
Resp "" _ (Msg mId1 msg1) <- tGet1 rh
|
||||
(dec mId1 msg1, Right "hello 1") #== "hello 1"
|
||||
Resp "4" _ (Msg mId2 msg2) <- signSendRecv rh rKey ("4", rId, ACK mId1)
|
||||
(dec mId2 msg2, Right "hello 2") #== "hello 2"
|
||||
Resp "5" _ (ERR QUOTA) <- signSendRecv sh sKey ("5", sId, _SEND "hello 3")
|
||||
Resp "6" _ (Msg mId3 msg3) <- signSendRecv rh rKey ("6", rId, ACK mId2)
|
||||
(dec mId3 msg3, Left "ClientRcvMsgQuota") #== "ClientRcvMsgQuota"
|
||||
Resp "7" _ (ERR QUOTA) <- signSendRecv sh sKey ("7", sId, _SEND "hello 3")
|
||||
Resp "8" _ OK <- signSendRecv rh rKey ("8", rId, ACK mId3)
|
||||
Resp "9" _ OK <- signSendRecv sh sKey ("9", sId, _SEND "hello 3")
|
||||
Resp "" _ (Msg mId4 msg4) <- tGet1 rh
|
||||
(dec mId4 msg4, Right "hello 3") #== "hello 3"
|
||||
Resp "10" _ OK <- signSendRecv rh rKey ("10", rId, ACK mId4)
|
||||
pure ()
|
||||
|
||||
testWithStoreLog :: ATransport -> Spec
|
||||
testWithStoreLog at@(ATransport t) =
|
||||
it "should store simplex queues to log and restore them after server restart" $ do
|
||||
|
||||
@@ -6,6 +6,7 @@ import CLITests
|
||||
import CoreTests.CryptoTests
|
||||
import CoreTests.EncodingTests
|
||||
import CoreTests.ProtocolErrorTests
|
||||
import CoreTests.RetryIntervalTests
|
||||
import CoreTests.VersionRangeTests
|
||||
import NtfServerTests (ntfServerTests)
|
||||
import ServerTests
|
||||
@@ -31,6 +32,7 @@ main = do
|
||||
describe "Protocol error tests" protocolErrorTests
|
||||
describe "Version range" versionRangeTests
|
||||
describe "Encryption tests" cryptoTests
|
||||
describe "Retry interval tests" retryIntervalTests
|
||||
describe "SMP server via TLS" $ serverTests (transport @TLS)
|
||||
describe "SMP server via WebSockets" $ serverTests (transport @WS)
|
||||
describe "Notifications server" $ ntfServerTests (transport @TLS)
|
||||
|
||||
Reference in New Issue
Block a user