agent: return full argument with batched results (#1332)

This commit is contained in:
Evgeny
2024-09-24 16:16:52 +01:00
committed by GitHub
parent e6e9f1d5a5
commit 8e7f3f7b27

View File

@@ -1419,7 +1419,7 @@ subscribeQueues c qs = do
checkQueue rq = do
prohibited <- liftIO $ hasGetLock c rq
pure $ if prohibited then Left (rq, Left $ CMD PROHIBITED "subscribeQueues") else Right rq
subscribeQueues_ :: Env -> TVar (Maybe SessionId) -> SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses SMPClientError ())
subscribeQueues_ :: Env -> TVar (Maybe SessionId) -> SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses RcvQueue SMPClientError ())
subscribeQueues_ env session smp qs' = do
let (userId, srv, _) = transportSession' smp
atomically $ incSMPServerStat' c userId srv connSubAttempts $ length qs'
@@ -1450,31 +1450,33 @@ activeClientSession c tSess sessId = sameSess <$> tryReadSessVar tSess (smpClien
Just (Right (SMPConnectedClient smp _)) -> sessId == sessionId (thParams smp)
_ -> False
type BatchResponses e r = NonEmpty (RcvQueue, Either e r)
type BatchResponses q e r = NonEmpty (q, Either e r)
sendTSessionBatches :: forall q r. ByteString -> (q -> RcvQueue) -> (SMPClient -> NonEmpty q -> IO (BatchResponses SMPClientError r)) -> AgentClient -> [q] -> AM' [(RcvQueue, Either AgentErrorType r)]
-- Please note: this function does not preserve order of results to be the same as the order of arguments,
-- it includes arguments in the results instead.
sendTSessionBatches :: forall q r. ByteString -> (q -> RcvQueue) -> (SMPClient -> NonEmpty q -> IO (BatchResponses q SMPClientError r)) -> AgentClient -> [q] -> AM' [(q, Either AgentErrorType r)]
sendTSessionBatches statCmd toRQ action c qs =
concatMap L.toList <$> (mapConcurrently sendClientBatch =<< batchQueues)
where
batchQueues :: AM' [(SMPTransportSession, NonEmpty q)]
batchQueues = do
mode <- getSessionMode c
pure . M.assocs $ foldl' (batch mode) M.empty qs
pure . M.assocs $ foldr (batch mode) M.empty qs
where
batch mode m q =
batch mode q m =
let tSess = mkSMPTSession (toRQ q) mode
in M.alter (Just . maybe [q] (q <|)) tSess m
sendClientBatch :: (SMPTransportSession, NonEmpty q) -> AM' (BatchResponses AgentErrorType r)
sendClientBatch :: (SMPTransportSession, NonEmpty q) -> AM' (BatchResponses q AgentErrorType r)
sendClientBatch (tSess@(_, srv, _), qs') =
tryAgentError' (getSMPServerClient c tSess) >>= \case
Left e -> pure $ L.map ((,Left e) . toRQ) qs'
Left e -> pure $ L.map (,Left e) qs'
Right (SMPConnectedClient smp _) -> liftIO $ do
logServer' "-->" c srv (bshow (length qs') <> " queues") statCmd
L.map agentError <$> action smp qs'
where
agentError = second . first $ protocolClientError SMP $ clientServer smp
sendBatch :: (SMPClient -> NonEmpty (SMP.RcvPrivateAuthKey, SMP.RecipientId) -> IO (NonEmpty (Either SMPClientError ()))) -> SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses SMPClientError ())
sendBatch :: (SMPClient -> NonEmpty (SMP.RcvPrivateAuthKey, SMP.RecipientId) -> IO (NonEmpty (Either SMPClientError ()))) -> SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses RcvQueue SMPClientError ())
sendBatch smpCmdFunc smp qs = L.zip qs <$> smpCmdFunc smp (L.map queueCreds qs)
where
queueCreds RcvQueue {rcvPrivateKey, rcvId} = (rcvPrivateKey, rcvId)
@@ -1603,13 +1605,15 @@ enableQueueNotifications c rq@RcvQueue {rcvId, rcvPrivateKey} notifierKey rcvNtf
withSMPClient c rq "NKEY <nkey>" $ \smp ->
enableSMPQueueNotifications smp rcvPrivateKey rcvId notifierKey rcvNtfPublicDhKey
enableQueuesNtfs :: AgentClient -> [(RcvQueue, SMP.NtfPublicAuthKey, SMP.RcvNtfPublicDhKey)] -> AM' [(RcvQueue, Either AgentErrorType (SMP.NotifierId, SMP.RcvNtfPublicDhKey))]
type RcvQueueNtf = (RcvQueue, SMP.NtfPublicAuthKey, SMP.RcvNtfPublicDhKey)
enableQueuesNtfs :: AgentClient -> [RcvQueueNtf] -> AM' [(RcvQueueNtf, Either AgentErrorType (SMP.NotifierId, SMP.RcvNtfPublicDhKey))]
enableQueuesNtfs = sendTSessionBatches "NKEY" fst3 enableQueues_
where
fst3 (x, _, _) = x
enableQueues_ :: SMPClient -> NonEmpty (RcvQueue, SMP.NtfPublicAuthKey, SMP.RcvNtfPublicDhKey) -> IO (NonEmpty (RcvQueue, Either (ProtocolClientError ErrorType) (SMP.NotifierId, RcvNtfPublicDhKey)))
enableQueues_ smp qs' = L.zipWith ((,) . fst3) qs' <$> enableSMPQueuesNtfs smp (L.map queueCreds qs')
queueCreds :: (RcvQueue, SMP.NtfPublicAuthKey, SMP.RcvNtfPublicDhKey) -> (SMP.RcvPrivateAuthKey, SMP.RecipientId, SMP.NtfPublicAuthKey, SMP.RcvNtfPublicDhKey)
enableQueues_ :: SMPClient -> NonEmpty RcvQueueNtf -> IO (NonEmpty (RcvQueueNtf, Either (ProtocolClientError ErrorType) (SMP.NotifierId, RcvNtfPublicDhKey)))
enableQueues_ smp qs' = L.zipWith (,) qs' <$> enableSMPQueuesNtfs smp (L.map queueCreds qs')
queueCreds :: RcvQueueNtf -> (SMP.RcvPrivateAuthKey, SMP.RecipientId, SMP.NtfPublicAuthKey, SMP.RcvNtfPublicDhKey)
queueCreds (RcvQueue {rcvPrivateKey, rcvId}, notifierKey, rcvNtfPublicDhKey) = (rcvPrivateKey, rcvId, notifierKey, rcvNtfPublicDhKey)
disableQueueNotifications :: AgentClient -> RcvQueue -> AM ()