unlock next GET with ACK (#418)

This commit is contained in:
Evgeny Poberezkin
2022-06-22 08:12:18 +01:00
committed by GitHub
parent a7c3133c35
commit 0d9d549cea
2 changed files with 22 additions and 19 deletions

View File

@@ -390,8 +390,8 @@ getConnectionMessage' :: AgentMonad m => AgentClient -> ConnId -> m (Maybe SMPMs
getConnectionMessage' c connId = do
whenM (atomically $ hasActiveSubscription c connId) . throwError $ CMD PROHIBITED
withStore c (`getConn` connId) >>= \case
SomeConn _ (DuplexConnection _ rq _) -> getQueueMessage c rq connId
SomeConn _ (RcvConnection _ rq) -> getQueueMessage c rq connId
SomeConn _ (DuplexConnection _ rq _) -> getQueueMessage c rq
SomeConn _ (RcvConnection _ rq) -> getQueueMessage c rq
SomeConn _ ContactConnection {} -> throwError $ CMD PROHIBITED
SomeConn _ SndConnection {} -> throwError $ CONN SIMPLEX

View File

@@ -117,7 +117,7 @@ data AgentClient = AgentClient
connMsgsQueued :: TMap ConnId Bool,
smpQueueMsgQueues :: TMap (ConnId, SMPServer, SMP.SenderId) (TQueue InternalId),
smpQueueMsgDeliveries :: TMap (ConnId, SMPServer, SMP.SenderId) (Async ()),
getMsgLocks :: TMap (ConnId, SMPServer, SMP.RecipientId) (TMVar ()),
getMsgLocks :: TMap (SMPServer, SMP.RecipientId) (TMVar ()),
reconnections :: TVar [Async ()],
asyncClients :: TVar [Async ()],
clientId :: Int,
@@ -205,6 +205,7 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
whenM (readTVarIO active) $ do
let conns = M.keys cs
unless (null conns) . notifySub "" $ DOWN srv conns
atomically $ mapM_ (releaseGetLock c) cs
unliftIO u reconnectServer
reconnectServer :: m ()
@@ -335,6 +336,7 @@ closeAgentClient c = liftIO $ do
clear subscrConns
clear connMsgsQueued
clear smpQueueMsgQueues
clear getMsgLocks
where
clientTimeout sel = tcpTimeout . sel . config $ agentEnv c
clear :: (AgentClient -> TMap k a) -> IO ()
@@ -427,7 +429,7 @@ newRcvQueue_ a c srv = do
subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> ConnId -> m ()
subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connId = do
whenM (atomically . TM.member (connId, server, rcvId) $ getMsgLocks c) . throwError $ CMD PROHIBITED
whenM (atomically . TM.member (server, rcvId) $ getMsgLocks c) . throwError $ CMD PROHIBITED
atomically $ addPendingSubscription c rq connId
withLogClient c server rcvId "SUB" $ \smp -> do
liftIO (runExceptT $ subscribeSMPQueue smp rcvPrivateKey rcvId) >>= \case
@@ -506,22 +508,18 @@ sendInvitation c (Compatible SMPQueueInfo {smpServer, senderId, dhPublicKey}) co
agentCbEncryptOnce dhPublicKey . smpEncode $
SMP.ClientMessage SMP.PHEmpty $ smpEncode agentEnvelope
getQueueMessage :: AgentMonad m => AgentClient -> RcvQueue -> ConnId -> m (Maybe SMPMsgMeta)
getQueueMessage c@AgentClient {getMsgLocks} RcvQueue {server, rcvId, rcvPrivateKey} connId =
E.bracket (atomically createTakeLock) (atomically . (`putTMVar` ())) $ \_ ->
withLogClient c server rcvId "GET" $ \smp ->
getSMPMessage smp rcvPrivateKey rcvId
getQueueMessage :: AgentMonad m => AgentClient -> RcvQueue -> m (Maybe SMPMsgMeta)
getQueueMessage c RcvQueue {server, rcvId, rcvPrivateKey} = do
atomically createTakeGetLock
withLogClient c server rcvId "GET" $ \smp ->
getSMPMessage smp rcvPrivateKey rcvId
where
k = (connId, server, rcvId)
createTakeLock = do
l <- TM.lookup k getMsgLocks >>= maybe newLock pure
takeTMVar l
pure l
createTakeGetLock = TM.alterF takeLock (server, rcvId) $ getMsgLocks c
where
newLock = do
l <- newTMVar ()
TM.insert k l getMsgLocks
pure l
takeLock l_ = do
l <- maybe (newTMVar ()) pure l_
takeTMVar l
pure $ Just l
secureQueue :: AgentMonad m => AgentClient -> RcvQueue -> SndPublicVerifyKey -> m ()
secureQueue c RcvQueue {server, rcvId, rcvPrivateKey} senderKey =
@@ -534,9 +532,14 @@ enableQueueNotifications c RcvQueue {server, rcvId, rcvPrivateKey} notifierKey r
enableSMPQueueNotifications smp rcvPrivateKey rcvId notifierKey rcvNtfPublicDhKey
sendAck :: AgentMonad m => AgentClient -> RcvQueue -> MsgId -> m ()
sendAck c RcvQueue {server, rcvId, rcvPrivateKey} msgId =
sendAck c rq@RcvQueue {server, rcvId, rcvPrivateKey} msgId = do
withLogClient c server rcvId "ACK" $ \smp ->
ackSMPMessage smp rcvPrivateKey rcvId msgId
atomically $ releaseGetLock c rq
releaseGetLock :: AgentClient -> RcvQueue -> STM ()
releaseGetLock c RcvQueue {server, rcvId} =
TM.lookup (server, rcvId) (getMsgLocks c) >>= mapM_ (void . (`tryPutTMVar` ()))
suspendQueue :: AgentMonad m => AgentClient -> RcvQueue -> m ()
suspendQueue c RcvQueue {server, rcvId, rcvPrivateKey} =