agent: return error and message absence differently when getting notification messages (#1535)

* agent: return error and message absence differently when getting notification messages

* fix test

* mapM

* inline nse functions, release lock on error or no message
This commit is contained in:
Evgeny
2025-05-06 16:20:01 +01:00
committed by GitHub
parent a632eea75b
commit cb59a449dd
5 changed files with 20 additions and 17 deletions
+13 -13
View File
@@ -224,7 +224,6 @@ import Simplex.RemoteControl.Client
import Simplex.RemoteControl.Invitation
import Simplex.RemoteControl.Types
import System.Mem.Weak (deRefWeak)
import UnliftIO.Async (mapConcurrently)
import UnliftIO.Concurrent (forkFinally, forkIO, killThread, mkWeakThreadId, threadDelay)
import qualified UnliftIO.Exception as E
import UnliftIO.STM
@@ -440,7 +439,7 @@ subscribeConnections c = withAgentEnv c . subscribeConnections' c
{-# INLINE subscribeConnections #-}
-- | Get messages for connections (GET commands)
getConnectionMessages :: AgentClient -> NonEmpty ConnMsgReq -> IO (NonEmpty (Maybe SMPMsgMeta))
getConnectionMessages :: AgentClient -> NonEmpty ConnMsgReq -> IO (NonEmpty (Either AgentErrorType (Maybe SMPMsgMeta)))
getConnectionMessages c = withAgentEnv' c . getConnectionMessages' c
{-# INLINE getConnectionMessages #-}
@@ -1277,26 +1276,26 @@ resubscribeConnections' c connIds = do
-- union is left-biased, so results returned by subscribeConnections' take precedence
(`M.union` r) <$> subscribeConnections' c connIds'
getConnectionMessages' :: AgentClient -> NonEmpty ConnMsgReq -> AM' (NonEmpty (Maybe SMPMsgMeta))
getConnectionMessages' c =
mapConcurrently $ \cmr ->
getConnectionMessage cmr `catchAgentError'` \e -> do
logError $ "Error loading message: " <> tshow e
pure Nothing
-- requesting messages sequentially, to reduce memory usage
getConnectionMessages' :: AgentClient -> NonEmpty ConnMsgReq -> AM' (NonEmpty (Either AgentErrorType (Maybe SMPMsgMeta)))
getConnectionMessages' c = mapM $ tryAgentError' . getConnectionMessage
where
getConnectionMessage :: ConnMsgReq -> AM (Maybe SMPMsgMeta)
getConnectionMessage (ConnMsgReq connId dbQueueId msgTs_) = do
whenM (atomically $ hasActiveSubscription c connId) . throwE $ CMD PROHIBITED "getConnectionMessage: subscribed"
SomeConn _ conn <- withStore c (`getConn` connId)
msg_ <- case conn of
DuplexConnection _ (rq :| _) _ -> getQueueMessage c rq
RcvConnection _ rq -> getQueueMessage c rq
ContactConnection _ rq -> getQueueMessage c rq
rq <- case conn of
DuplexConnection _ (rq :| _) _ -> pure rq
RcvConnection _ rq -> pure rq
ContactConnection _ rq -> pure rq
SndConnection _ _ -> throwE $ CONN SIMPLEX
NewConnection _ -> throwE $ CMD PROHIBITED "getConnectionMessage: NewConnection"
when (isNothing msg_) $
msg_ <- getQueueMessage c rq `catchAgentError` \e -> atomically (releaseGetLock c rq) >> throwError e
when (isNothing msg_) $ do
atomically $ releaseGetLock c rq
forM_ msgTs_ $ \msgTs -> withStore' c $ \db -> setLastBrokerTs db connId (DBQueueId dbQueueId) msgTs
pure msg_
{-# INLINE getConnectionMessages' #-}
getNotificationConns' :: AgentClient -> C.CbNonce -> ByteString -> AM (NonEmpty NotificationInfo)
getNotificationConns' c nonce encNtfInfo =
@@ -1330,6 +1329,7 @@ getNotificationConns' c nonce encNtfInfo =
Just SMP.NMsgMeta {msgTs}
| maybe True (systemToUTCTime msgTs >) lastBrokerTs_ -> Just ntfInfo
_ -> Nothing
{-# INLINE getNotificationConns' #-}
-- | Send message to the connection (SEND command) in Reader monad
sendMessage' :: AgentClient -> ConnId -> PQEncryption -> MsgFlags -> MsgBody -> AM (AgentMsgId, PQEncryption)
+3
View File
@@ -1654,6 +1654,7 @@ getQueueMessage c rq@RcvQueue {server, rcvId, rcvPrivateKey} = do
l <- maybe (newTMVar ()) pure l_
takeTMVar l
pure $ Just l
{-# INLINE getQueueMessage #-}
decryptSMPMessage :: RcvQueue -> SMP.RcvMessage -> AM SMP.ClientRcvMsgBody
decryptSMPMessage rq SMP.RcvMessage {msgId, msgBody = SMP.EncRcvMsgBody body} =
@@ -1743,10 +1744,12 @@ sendAck c rq@RcvQueue {rcvId, rcvPrivateKey} msgId =
hasGetLock :: AgentClient -> RcvQueue -> IO Bool
hasGetLock c RcvQueue {server, rcvId} =
TM.memberIO (server, rcvId) $ getMsgLocks c
{-# INLINE hasGetLock #-}
releaseGetLock :: AgentClient -> RcvQueue -> STM ()
releaseGetLock c RcvQueue {server, rcvId} =
TM.lookup (server, rcvId) (getMsgLocks c) >>= mapM_ (`tryPutTMVar` ())
{-# INLINE releaseGetLock #-}
suspendQueue :: AgentClient -> RcvQueue -> AM ()
suspendQueue c rq@RcvQueue {rcvId, rcvPrivateKey} =
+1
View File
@@ -801,6 +801,7 @@ getSMPMessage c rpKey rId =
OK -> pure Nothing
cmd@(MSG msg) -> liftIO (writeSMPMessage c rId cmd) $> Just msg
r -> throwE $ unexpectedResponse r
{-# INLINE getSMPMessage #-}
-- | Subscribe to the SMP queue notifications.
--
+1 -1
View File
@@ -1946,7 +1946,7 @@ testOnlyCreatePullSlowHandshake = withAgentClientsCfg2 agentProxyCfgV8 agentProx
getMsg :: AgentClient -> ConnId -> ExceptT AgentErrorType IO a -> ExceptT AgentErrorType IO a
getMsg c cId action = do
liftIO $ noMessages c "nothing should be delivered before GET"
[Just _] <- lift $ getConnectionMessages c [ConnMsgReq cId 1 Nothing]
[Right (Just _)] <- lift $ getConnectionMessages c [ConnMsgReq cId 1 Nothing]
action
getMSGNTF :: AgentClient -> ConnId -> ExceptT AgentErrorType IO ()
+2 -3
View File
@@ -562,8 +562,7 @@ testNotificationSubscriptionExistingConnection apns baseId alice@AgentClient {ag
Right [NotificationInfo {ntfConnId = cId, ntfMsgMeta = Just NMsgMeta {msgTs}}] <- runExceptT $ getNotificationConns alice nonce message
cId `shouldBe` bobId
-- alice client already has subscription for the connection,
-- so get fails with CMD PROHIBITED (transformed into Nothing in catch)
[Nothing] <- getConnectionMessages alice [ConnMsgReq cId 1 $ Just $ systemToUTCTime msgTs]
[Left (CMD PROHIBITED _)] <- getConnectionMessages alice [ConnMsgReq cId 1 $ Just $ systemToUTCTime msgTs]
threadDelay 500000
suspendAgent alice 0
@@ -573,7 +572,7 @@ testNotificationSubscriptionExistingConnection apns baseId alice@AgentClient {ag
-- aliceNtf client doesn't have subscription and is allowed to get notification message
withAgent 3 aliceCfg initAgentServers testDB $ \aliceNtf -> do
(Just SMPMsgMeta {msgFlags = MsgFlags True}) :| _ <- getConnectionMessages aliceNtf [ConnMsgReq cId 1 $ Just $ systemToUTCTime msgTs]
(Right (Just SMPMsgMeta {msgFlags = MsgFlags True})) :| _ <- getConnectionMessages aliceNtf [ConnMsgReq cId 1 $ Just $ systemToUTCTime msgTs]
pure ()
threadDelay 1000000