diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 82d5d2785..6737a3155 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -1419,7 +1419,7 @@ subscribeQueues c qs = do checkQueue rq = do prohibited <- liftIO $ hasGetLock c rq pure $ if prohibited then Left (rq, Left $ CMD PROHIBITED "subscribeQueues") else Right rq - subscribeQueues_ :: Env -> TVar (Maybe SessionId) -> SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses SMPClientError ()) + subscribeQueues_ :: Env -> TVar (Maybe SessionId) -> SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses RcvQueue SMPClientError ()) subscribeQueues_ env session smp qs' = do let (userId, srv, _) = transportSession' smp atomically $ incSMPServerStat' c userId srv connSubAttempts $ length qs' @@ -1450,31 +1450,33 @@ activeClientSession c tSess sessId = sameSess <$> tryReadSessVar tSess (smpClien Just (Right (SMPConnectedClient smp _)) -> sessId == sessionId (thParams smp) _ -> False -type BatchResponses e r = NonEmpty (RcvQueue, Either e r) +type BatchResponses q e r = NonEmpty (q, Either e r) -sendTSessionBatches :: forall q r. ByteString -> (q -> RcvQueue) -> (SMPClient -> NonEmpty q -> IO (BatchResponses SMPClientError r)) -> AgentClient -> [q] -> AM' [(RcvQueue, Either AgentErrorType r)] +-- Please note: this function does not preserve order of results to be the same as the order of arguments, +-- it includes arguments in the results instead. +sendTSessionBatches :: forall q r. ByteString -> (q -> RcvQueue) -> (SMPClient -> NonEmpty q -> IO (BatchResponses q SMPClientError r)) -> AgentClient -> [q] -> AM' [(q, Either AgentErrorType r)] sendTSessionBatches statCmd toRQ action c qs = concatMap L.toList <$> (mapConcurrently sendClientBatch =<< batchQueues) where batchQueues :: AM' [(SMPTransportSession, NonEmpty q)] batchQueues = do mode <- getSessionMode c - pure . M.assocs $ foldl' (batch mode) M.empty qs + pure . M.assocs $ foldr (batch mode) M.empty qs where - batch mode m q = + batch mode q m = let tSess = mkSMPTSession (toRQ q) mode in M.alter (Just . maybe [q] (q <|)) tSess m - sendClientBatch :: (SMPTransportSession, NonEmpty q) -> AM' (BatchResponses AgentErrorType r) + sendClientBatch :: (SMPTransportSession, NonEmpty q) -> AM' (BatchResponses q AgentErrorType r) sendClientBatch (tSess@(_, srv, _), qs') = tryAgentError' (getSMPServerClient c tSess) >>= \case - Left e -> pure $ L.map ((,Left e) . toRQ) qs' + Left e -> pure $ L.map (,Left e) qs' Right (SMPConnectedClient smp _) -> liftIO $ do logServer' "-->" c srv (bshow (length qs') <> " queues") statCmd L.map agentError <$> action smp qs' where agentError = second . first $ protocolClientError SMP $ clientServer smp -sendBatch :: (SMPClient -> NonEmpty (SMP.RcvPrivateAuthKey, SMP.RecipientId) -> IO (NonEmpty (Either SMPClientError ()))) -> SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses SMPClientError ()) +sendBatch :: (SMPClient -> NonEmpty (SMP.RcvPrivateAuthKey, SMP.RecipientId) -> IO (NonEmpty (Either SMPClientError ()))) -> SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses RcvQueue SMPClientError ()) sendBatch smpCmdFunc smp qs = L.zip qs <$> smpCmdFunc smp (L.map queueCreds qs) where queueCreds RcvQueue {rcvPrivateKey, rcvId} = (rcvPrivateKey, rcvId) @@ -1603,13 +1605,15 @@ enableQueueNotifications c rq@RcvQueue {rcvId, rcvPrivateKey} notifierKey rcvNtf withSMPClient c rq "NKEY " $ \smp -> enableSMPQueueNotifications smp rcvPrivateKey rcvId notifierKey rcvNtfPublicDhKey -enableQueuesNtfs :: AgentClient -> [(RcvQueue, SMP.NtfPublicAuthKey, SMP.RcvNtfPublicDhKey)] -> AM' [(RcvQueue, Either AgentErrorType (SMP.NotifierId, SMP.RcvNtfPublicDhKey))] +type RcvQueueNtf = (RcvQueue, SMP.NtfPublicAuthKey, SMP.RcvNtfPublicDhKey) + +enableQueuesNtfs :: AgentClient -> [RcvQueueNtf] -> AM' [(RcvQueueNtf, Either AgentErrorType (SMP.NotifierId, SMP.RcvNtfPublicDhKey))] enableQueuesNtfs = sendTSessionBatches "NKEY" fst3 enableQueues_ where fst3 (x, _, _) = x - enableQueues_ :: SMPClient -> NonEmpty (RcvQueue, SMP.NtfPublicAuthKey, SMP.RcvNtfPublicDhKey) -> IO (NonEmpty (RcvQueue, Either (ProtocolClientError ErrorType) (SMP.NotifierId, RcvNtfPublicDhKey))) - enableQueues_ smp qs' = L.zipWith ((,) . fst3) qs' <$> enableSMPQueuesNtfs smp (L.map queueCreds qs') - queueCreds :: (RcvQueue, SMP.NtfPublicAuthKey, SMP.RcvNtfPublicDhKey) -> (SMP.RcvPrivateAuthKey, SMP.RecipientId, SMP.NtfPublicAuthKey, SMP.RcvNtfPublicDhKey) + enableQueues_ :: SMPClient -> NonEmpty RcvQueueNtf -> IO (NonEmpty (RcvQueueNtf, Either (ProtocolClientError ErrorType) (SMP.NotifierId, RcvNtfPublicDhKey))) + enableQueues_ smp qs' = L.zipWith (,) qs' <$> enableSMPQueuesNtfs smp (L.map queueCreds qs') + queueCreds :: RcvQueueNtf -> (SMP.RcvPrivateAuthKey, SMP.RecipientId, SMP.NtfPublicAuthKey, SMP.RcvNtfPublicDhKey) queueCreds (RcvQueue {rcvPrivateKey, rcvId}, notifierKey, rcvNtfPublicDhKey) = (rcvPrivateKey, rcvId, notifierKey, rcvNtfPublicDhKey) disableQueueNotifications :: AgentClient -> RcvQueue -> AM ()