add msgId to ACK to avoid the risks of losing messages with concurrent delivery (in app/NSE) (#387)

* add msgId to ACK to avoid the risks of losing messages with concurrent delivery (in app/NSE)

* update ACK to only remove message and update stats if msgId matches

* add tests, fix

* rename sameMsgId/msgDeleted
This commit is contained in:
Evgeny Poberezkin
2022-06-07 10:18:40 +01:00
committed by GitHub
parent 4b3d04bd27
commit 60294521f4
11 changed files with 238 additions and 131 deletions
+7 -8
View File
@@ -508,8 +508,8 @@ ackMessage' c connId msgId = do
ack :: RcvQueue -> m ()
ack rq = do
let mId = InternalId msgId
withStore $ \st -> checkRcvMsg st connId mId
sendAck c rq
srvMsgId <- withStore $ \st -> checkRcvMsg st connId mId
sendAck c rq srvMsgId
withStore $ \st -> deleteMsg st connId mId
-- | Suspend SMP agent connection (OFF command) in Reader monad
@@ -711,6 +711,11 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, sessId, rId, cmd)
_ -> prohibited >> ack
_ -> prohibited >> ack
_ -> prohibited >> ack
where
ack :: m ()
ack = sendAck c rq srvMsgId
handleNotifyAck :: m () -> m ()
handleNotifyAck m = m `catchError` \e -> notify (ERR e) >> ack
SMP.END ->
atomically (TM.lookup srv smpClients $>>= tryReadTMVar >>= processEND)
>>= logServer "<--" c srv rId
@@ -731,15 +736,9 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, sessId, rId, cmd)
notify :: ACommand 'Agent -> m ()
notify msg = atomically $ writeTBQueue subQ ("", connId, msg)
handleNotifyAck :: m () -> m ()
handleNotifyAck m = m `catchError` \e -> notify (ERR e) >> ack
prohibited :: m ()
prohibited = notify . ERR $ AGENT A_PROHIBITED
ack :: m ()
ack = sendAck c rq
decryptClientMessage :: C.DhSecretX25519 -> SMP.ClientMsgEnvelope -> m (SMP.PrivHeader, AgentMsgEnvelope)
decryptClientMessage e2eDh SMP.ClientMsgEnvelope {cmNonce, cmEncBody} = do
clientMsg <- agentCbDecrypt e2eDh cmNonce cmEncBody
+4 -4
View File
@@ -71,7 +71,7 @@ import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding
import Simplex.Messaging.Notifications.Client
import Simplex.Messaging.Notifications.Protocol
import Simplex.Messaging.Protocol (BrokerMsg, ErrorType, MsgFlags (..), ProtocolServer (..), QueueId, QueueIdsKeys (..), SndPublicVerifyKey)
import Simplex.Messaging.Protocol (BrokerMsg, ErrorType, MsgFlags (..), MsgId, ProtocolServer (..), QueueId, QueueIdsKeys (..), SndPublicVerifyKey)
import qualified Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
@@ -486,10 +486,10 @@ secureQueue c RcvQueue {server, rcvId, rcvPrivateKey} senderKey =
withLogClient c server rcvId "KEY <key>" $ \smp ->
secureSMPQueue smp rcvPrivateKey rcvId senderKey
sendAck :: AgentMonad m => AgentClient -> RcvQueue -> m ()
sendAck c RcvQueue {server, rcvId, rcvPrivateKey} =
sendAck :: AgentMonad m => AgentClient -> RcvQueue -> MsgId -> m ()
sendAck c RcvQueue {server, rcvId, rcvPrivateKey} msgId =
withLogClient c server rcvId "ACK" $ \smp ->
ackSMPMessage smp rcvPrivateKey rcvId
ackSMPMessage smp rcvPrivateKey rcvId msgId
suspendQueue :: AgentMonad m => AgentClient -> RcvQueue -> m ()
suspendQueue c RcvQueue {server, rcvId, rcvPrivateKey} =
+1 -1
View File
@@ -67,7 +67,7 @@ class Monad m => MonadAgentStore s m where
createSndMsg :: s -> ConnId -> SndMsgData -> m ()
getPendingMsgData :: s -> ConnId -> InternalId -> m (Maybe RcvQueue, PendingMsgData)
getPendingMsgs :: s -> ConnId -> m [InternalId]
checkRcvMsg :: s -> ConnId -> InternalId -> m ()
checkRcvMsg :: s -> ConnId -> InternalId -> m MsgId
deleteMsg :: s -> ConnId -> InternalId -> m ()
-- Double ratchet persistence
+4 -7
View File
@@ -485,21 +485,18 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto
map fromOnly
<$> DB.query db "SELECT internal_id FROM snd_messages WHERE conn_id = ?" (Only connId)
checkRcvMsg :: SQLiteStore -> ConnId -> InternalId -> m ()
checkRcvMsg :: SQLiteStore -> ConnId -> InternalId -> m SMP.MsgId
checkRcvMsg st connId msgId =
liftIOEither . withTransaction st $ \db ->
hasMsg
<$> DB.query
firstRow fromOnly SEMsgNotFound $
DB.query
db
[sql|
SELECT conn_id, internal_id
SELECT broker_id
FROM rcv_messages
WHERE conn_id = ? AND internal_id = ?
|]
(connId, msgId)
where
hasMsg :: [(ConnId, InternalId)] -> Either StoreError ()
hasMsg r = if null r then Left SEMsgNotFound else Right ()
deleteMsg :: SQLiteStore -> ConnId -> InternalId -> m ()
deleteMsg st connId msgId =
+3 -3
View File
@@ -317,9 +317,9 @@ sendSMPMessage c spKey sId flags msg =
-- | Acknowledge message delivery (server deletes the message).
--
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#acknowledge-message-delivery
ackSMPMessage :: SMPClient -> RcvPrivateSignKey -> QueueId -> ExceptT ProtocolClientError IO ()
ackSMPMessage c@ProtocolClient {protocolServer, sessionId, msgQ} rpKey rId =
sendSMPCommand c (Just rpKey) rId ACK >>= \case
ackSMPMessage :: SMPClient -> RcvPrivateSignKey -> QueueId -> MsgId -> ExceptT ProtocolClientError IO ()
ackSMPMessage c@ProtocolClient {protocolServer, sessionId, msgQ} rpKey rId msgId =
sendSMPCommand c (Just rpKey) rId (ACK msgId) >>= \case
OK -> return ()
cmd@MSG {} ->
lift . atomically $ mapM_ (`writeTBQueue` (protocolServer, sessionId, rId, cmd)) msgQ
+9 -3
View File
@@ -217,7 +217,9 @@ data Command (p :: Party) where
KEY :: SndPublicVerifyKey -> Command Recipient
NKEY :: NtfPublicVerifyKey -> Command Recipient
GET :: Command Recipient
ACK :: Command Recipient
-- ACK v1 has to be supported for encoding/decoding
-- ACK :: Command Recipient
ACK :: MsgId -> Command Recipient
OFF :: Command Recipient
DEL :: Command Recipient
-- SMP sender commands
@@ -608,7 +610,9 @@ instance PartyI p => ProtocolEncoding (Command p) where
KEY k -> e (KEY_, ' ', k)
NKEY k -> e (NKEY_, ' ', k)
GET -> e GET_
ACK -> e ACK_
ACK msgId
| v == 1 -> e ACK_
| otherwise -> e (ACK_, ' ', msgId)
OFF -> e OFF_
DEL -> e DEL_
SEND flags msg
@@ -653,7 +657,9 @@ instance ProtocolEncoding Cmd where
KEY_ -> KEY <$> _smpP
NKEY_ -> NKEY <$> _smpP
GET_ -> pure GET
ACK_ -> pure ACK
ACK_
| v == 1 -> pure $ ACK ""
| otherwise -> ACK <$> _smpP
OFF_ -> pure OFF
DEL_ -> pure DEL
CT SSender tag ->
+98 -72
View File
@@ -207,10 +207,11 @@ clientDisconnected c@Client {subscriptions, connected} = do
sameClientSession :: Client -> Client -> Bool
sameClientSession Client {sessionId} Client {sessionId = s'} = sessionId == s'
cancelSub :: MonadUnliftIO m => Sub -> m ()
cancelSub = \case
Sub {subThread = SubThread t} -> killThread t
_ -> return ()
cancelSub :: MonadUnliftIO m => TVar Sub -> m ()
cancelSub sub =
readTVarIO sub >>= \case
Sub {subThread = SubThread t} -> killThread t
_ -> return ()
receive :: (Transport c, MonadUnliftIO m, MonadReader Env m) => THandle c -> Client -> m ()
receive th Client {rcvQ, sndQ, activeAt} = forever $ do
@@ -310,7 +311,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscri
(pure (corrId, queueId, ERR AUTH))
SUB -> subscribeQueue queueId
GET -> getMessage
ACK -> acknowledgeMsg
ACK msgId -> acknowledgeMsg msgId
KEY sKey -> secureQueue_ st sKey
NKEY nKey -> addQueueNotifier_ st nKey
OFF -> suspendQueue_ st
@@ -387,42 +388,55 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscri
subscribeQueue :: RecipientId -> m (Transmission BrokerMsg)
subscribeQueue rId =
atomically (getSubscription rId) >>= \case
Just s -> deliverMessage tryPeekMsg rId s
-- cannot use SUB in the same connection where GET was used
_ -> pure (corrId, rId, ERR $ CMD PROHIBITED)
getSubscription :: RecipientId -> STM (Maybe Sub)
getSubscription rId = do
TM.lookup rId subscriptions >>= \case
Just Sub {subThread = ProhibitSub} -> pure Nothing
Just s -> tryTakeTMVar (delivered s) $> Just s
Nothing -> do
atomically (TM.lookup rId subscriptions) >>= \case
Nothing ->
atomically newSub >>= deliver
Just sub ->
readTVarIO sub >>= \case
Sub {subThread = ProhibitSub} ->
-- cannot use SUB in the same connection where GET was used
pure (corrId, rId, ERR $ CMD PROHIBITED)
s ->
atomically (tryTakeTMVar $ delivered s) >> deliver sub
where
newSub :: STM (TVar Sub)
newSub = do
writeTBQueue subscribedQ (rId, clnt)
s <- newSubscription
TM.insert rId s subscriptions
pure $ Just s
sub <- newTVar =<< newSubscription NoSub
TM.insert rId sub subscriptions
pure sub
deliver :: TVar Sub -> m (Transmission BrokerMsg)
deliver sub = do
q <- getStoreMsgQueue rId
msg_ <- atomically $ tryPeekMsg q
deliverMessage rId sub q msg_
getMessage :: m (Transmission BrokerMsg)
getMessage =
atomically getProhibitedSub >>= \case
Just s -> do
atomically (TM.lookup queueId subscriptions) >>= \case
Nothing ->
atomically newSub >>= getMessage_
Just sub ->
readTVarIO sub >>= \case
s@Sub {subThread = ProhibitSub} ->
atomically (tryTakeTMVar $ delivered s)
>> getMessage_ s
-- cannot use GET in the same connection where there is an active subscription
_ -> pure (corrId, queueId, ERR $ CMD PROHIBITED)
where
newSub :: STM Sub
newSub = do
s <- newSubscription ProhibitSub
sub <- newTVar s
TM.insert queueId sub subscriptions
pure s
getMessage_ :: Sub -> m (Transmission BrokerMsg)
getMessage_ s = do
q <- getStoreMsgQueue queueId
atomically $
tryPeekMsg q >>= \case
Just msg -> tryPutTMVar (delivered s) () $> (corrId, queueId, msgCmd msg)
Just msg -> setDelivered s msg $> (corrId, queueId, msgCmd msg)
_ -> pure (corrId, queueId, ERR NO_MSG)
_ -> pure (corrId, queueId, ERR $ CMD PROHIBITED) -- cannot use GET in the same connection where there is an active subscription
where
getProhibitedSub :: STM (Maybe Sub)
getProhibitedSub =
TM.lookup queueId subscriptions >>= \case
Just s@Sub {subThread = ProhibitSub} -> tryTakeTMVar (delivered s) $> Just s
Just _ -> pure Nothing
Nothing -> do
s <- prohibitedSubscription
TM.insert queueId s subscriptions
pure $ Just s
subscribeNotifications :: m (Transmission BrokerMsg)
subscribeNotifications = atomically $ do
@@ -431,23 +445,37 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscri
TM.insert queueId () ntfSubscriptions
pure ok
acknowledgeMsg :: m (Transmission BrokerMsg)
acknowledgeMsg =
atomically (withSub queueId $ \s -> const s <$$> tryTakeTMVar (delivered s))
>>= \case
Just (Just s) -> do
stats <- asks serverStats
atomically $ modifyTVar (msgRecv stats) (+ 1)
atomically $ modifyTVar (msgQueues stats) (S.insert queueId)
case s of
Sub {subThread = ProhibitSub} ->
(getStoreMsgQueue queueId >>= atomically . tryDelMsg) $> ok
_ ->
deliverMessage tryDelPeekMsg queueId s
_ -> return $ err NO_MSG
withSub :: RecipientId -> (Sub -> STM a) -> STM (Maybe a)
withSub rId f = mapM f =<< TM.lookup rId subscriptions
acknowledgeMsg :: MsgId -> m (Transmission BrokerMsg)
acknowledgeMsg msgId = do
atomically (TM.lookup queueId subscriptions) >>= \case
Nothing -> pure $ err NO_MSG
Just sub ->
atomically (getDelivered sub) >>= \case
Just s -> do
q <- getStoreMsgQueue queueId
case s of
Sub {subThread = ProhibitSub} -> do
msgDeleted <- atomically $ tryDelMsg q msgId
when msgDeleted updateStats
pure ok
_ -> do
(msgDeleted, msg_) <- atomically $ tryDelPeekMsg q msgId
when msgDeleted updateStats
deliverMessage queueId sub q msg_
_ -> pure $ err NO_MSG
where
getDelivered :: TVar Sub -> STM (Maybe Sub)
getDelivered sub = do
s@Sub {delivered} <- readTVar sub
tryTakeTMVar delivered $>>= \msgId' ->
if B.null msgId || msgId == msgId'
then pure $ Just s
else putTMVar delivered msgId' $> Nothing
updateStats :: m ()
updateStats = do
stats <- asks serverStats
atomically $ modifyTVar (msgRecv stats) (+ 1)
atomically $ modifyTVar (msgQueues stats) (S.insert queueId)
sendMessage :: QueueStore -> MsgFlags -> MsgBody -> m (Transmission BrokerMsg)
sendMessage st flags msgBody
@@ -496,35 +524,33 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscri
unlessM (isFullTBQueue sndQ) $
writeTBQueue q (CorrId "", nId, NMSG)
deliverMessage :: (MsgQueue -> STM (Maybe Message)) -> RecipientId -> Sub -> m (Transmission BrokerMsg)
deliverMessage tryPeek rId = \case
Sub {subThread = NoSub} -> do
q <- getStoreMsgQueue rId
atomically (tryPeek q) >>= \case
Nothing -> forkSub q $> ok
Just msg -> atomically setDelivered $> (corrId, rId, msgCmd msg)
_ -> pure ok
deliverMessage :: RecipientId -> TVar Sub -> MsgQueue -> Maybe Message -> m (Transmission BrokerMsg)
deliverMessage rId sub q msg_ =
readTVarIO sub >>= \case
s@Sub {subThread = NoSub} ->
case msg_ of
Just msg -> atomically (setDelivered s msg) $> (corrId, rId, msgCmd msg)
_ -> forkSub $> ok
_ -> pure ok
where
forkSub :: MsgQueue -> m ()
forkSub q = do
atomically . setSub $ \s -> s {subThread = SubPending}
t <- forkIO $ subscriber q
atomically . setSub $ \case
forkSub :: m ()
forkSub = do
atomically . modifyTVar sub $ \s -> s {subThread = SubPending}
t <- forkIO subscriber
atomically . modifyTVar sub $ \case
s@Sub {subThread = SubPending} -> s {subThread = SubThread t}
s -> s
subscriber :: MsgQueue -> m ()
subscriber q = atomically $ do
subscriber :: m ()
subscriber = atomically $ do
msg <- peekMsg q
writeTBQueue sndQ (CorrId "", rId, msgCmd msg)
setSub (\s -> s {subThread = NoSub})
void setDelivered
s <- readTVar sub
void $ setDelivered s msg
writeTVar sub s {subThread = NoSub}
setSub :: (Sub -> Sub) -> STM ()
setSub f = TM.adjust f rId subscriptions
setDelivered :: STM (Maybe Bool)
setDelivered = withSub rId $ \s -> tryPutTMVar (delivered s) ()
setDelivered :: Sub -> Message -> STM Bool
setDelivered s Message {msgId} = tryPutTMVar (delivered s) msgId
getStoreMsgQueue :: RecipientId -> m MsgQueue
getStoreMsgQueue rId = do
+4 -10
View File
@@ -94,7 +94,7 @@ data Server = Server
}
data Client = Client
{ subscriptions :: TMap RecipientId Sub,
{ subscriptions :: TMap RecipientId (TVar Sub),
ntfSubscriptions :: TMap NotifierId (),
rcvQ :: TBQueue (Transmission Cmd),
sndQ :: TBQueue (Transmission BrokerMsg),
@@ -117,7 +117,7 @@ data SubscriptionThread = NoSub | SubPending | SubThread ThreadId | ProhibitSub
data Sub = Sub
{ subThread :: SubscriptionThread,
delivered :: TMVar ()
delivered :: TMVar MsgId
}
newServer :: Natural -> STM Server
@@ -149,14 +149,8 @@ newServerStats ts = do
fromTime <- newTVar ts
pure ServerStats {qCreated, qSecured, qDeleted, msgSent, msgRecv, msgQueues, fromTime}
newSubscription :: STM Sub
newSubscription = newSubscription_ NoSub
prohibitedSubscription :: STM Sub
prohibitedSubscription = newSubscription_ ProhibitSub
newSubscription_ :: SubscriptionThread -> STM Sub
newSubscription_ subThread = do
newSubscription :: SubscriptionThread -> STM Sub
newSubscription subThread = do
delivered <- newEmptyTMVar
return Sub {subThread, delivered}
+2 -2
View File
@@ -23,6 +23,6 @@ class MonadMsgQueue q m where
writeMsg :: q -> Message -> m () -- non blocking
tryPeekMsg :: q -> m (Maybe Message) -- non blocking
peekMsg :: q -> m Message -- blocking
tryDelMsg :: q -> m () -- non blocking
tryDelPeekMsg :: q -> m (Maybe Message) -- atomic delete (== read) last and peek next message, if available
tryDelMsg :: q -> MsgId -> m Bool -- non blocking
tryDelPeekMsg :: q -> MsgId -> m (Bool, Maybe Message) -- atomic delete (== read) last and peek next message, if available
deleteExpiredMsgs :: q -> Int64 -> m ()
+19 -6
View File
@@ -2,16 +2,19 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE TupleSections #-}
module Simplex.Messaging.Server.MsgStore.STM where
import Control.Monad (void, when)
import Control.Monad (when)
import Data.Functor (($>))
import Data.Int (Int64)
import Data.Time.Clock.System (SystemTime (systemSeconds))
import Numeric.Natural
import Simplex.Messaging.Protocol (RecipientId)
import Simplex.Messaging.Protocol (MsgId, RecipientId)
import Simplex.Messaging.Server.MsgStore
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
@@ -49,12 +52,22 @@ instance MonadMsgQueue MsgQueue STM where
peekMsg :: MsgQueue -> STM Message
peekMsg = peekTBQueue . msgQueue
tryDelMsg :: MsgQueue -> STM ()
tryDelMsg = void . tryReadTBQueue . msgQueue
tryDelMsg :: MsgQueue -> MsgId -> STM Bool
tryDelMsg (MsgQueue q) msgId' =
tryPeekTBQueue q >>= \case
Just Message {msgId}
| msgId == msgId' -> tryReadTBQueue q $> True
| otherwise -> pure False
_ -> pure False
-- atomic delete (== read) last and peek next message if available
tryDelPeekMsg :: MsgQueue -> STM (Maybe Message)
tryDelPeekMsg (MsgQueue q) = tryReadTBQueue q >> tryPeekTBQueue q
tryDelPeekMsg :: MsgQueue -> MsgId -> STM (Bool, Maybe Message)
tryDelPeekMsg (MsgQueue q) msgId' =
tryPeekTBQueue q >>= \case
msg_@(Just Message {msgId})
| msgId == msgId' -> (True,) <$> (tryReadTBQueue q >> tryPeekTBQueue q)
| otherwise -> pure (False, msg_)
_ -> pure (False, Nothing)
deleteExpiredMsgs :: MsgQueue -> Int64 -> STM ()
deleteExpiredMsgs (MsgQueue q) old = loop