diff --git a/src/Simplex/FileTransfer/Agent.hs b/src/Simplex/FileTransfer/Agent.hs index 6e50e76c7..340cfe8bf 100644 --- a/src/Simplex/FileTransfer/Agent.hs +++ b/src/Simplex/FileTransfer/Agent.hs @@ -5,7 +5,6 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} @@ -158,6 +157,7 @@ runXFTPRcvWorker :: forall m. AgentMonad m => AgentClient -> XFTPServer -> TMVar runXFTPRcvWorker c srv doWork = do forever $ do void . atomically $ readTMVar doWork + -- TODO waitUntilNotSuspended agentOperationBracket c AORcvNetwork waitUntilActive runXftpOperation where noWorkToDo = void . atomically $ tryTakeTMVar doWork @@ -173,13 +173,19 @@ runXFTPRcvWorker c srv doWork = do let ri' = maybe ri (\d -> ri {initialInterval = d, increaseAfter = 0}) delay withRetryInterval ri' $ \delay' loop -> downloadFileChunk fc replica - `catchError` \e -> retryOnError c AORcvNetwork "XFTP rcv worker" loop (retryMaintenance e delay') (retryDone e) e + `catchError` \e -> retryOnError "XFTP rcv worker" (retryLoop loop e delay') (retryDone e) e where - retryMaintenance e replicaDelay = do - notifyOnRetry <- asks (xftpNotifyErrsOnRetry . config) - when notifyOnRetry $ notify c rcvFileEntityId $ RFERR e - closeXFTPServerClient c userId server replicaId - withStore' c $ \db -> updateRcvChunkReplicaDelay db rcvChunkReplicaId replicaDelay + retryLoop loop e replicaDelay = do + flip catchError (\_ -> pure ()) $ do + notifyOnRetry <- asks (xftpNotifyErrsOnRetry . config) + when notifyOnRetry $ notify c rcvFileEntityId $ RFERR e + closeXFTPServerClient c userId server replicaId + withStore' c $ \db -> updateRcvChunkReplicaDelay db rcvChunkReplicaId replicaDelay + -- TODO waitUntilNotSuspended + atomically $ endAgentOperation c AORcvNetwork + atomically $ throwWhenInactive c + atomically $ beginAgentOperation c AORcvNetwork + loop retryDone e = rcvWorkerInternalError c rcvFileId rcvFileEntityId (Just fileTmpPath) (show e) downloadFileChunk :: RcvFileChunk -> RcvFileChunkReplica -> m () downloadFileChunk RcvFileChunk {userId, rcvFileId, rcvFileEntityId, rcvChunkId, chunkNo, chunkSize, digest, fileTmpPath} replica = do @@ -205,19 +211,12 @@ runXFTPRcvWorker c srv doWork = do | otherwise = 0 chunkReceived RcvFileChunk {replicas} = any received replicas -retryOnError :: AgentMonad m => AgentClient -> AgentOperation -> Text -> m () -> m () -> m () -> AgentErrorType -> m () -retryOnError c agentOp name loop maintenance done e = do +retryOnError :: AgentMonad m => Text -> m a -> m a -> AgentErrorType -> m a +retryOnError name loop done e = do logError $ name <> " error: " <> tshow e if temporaryAgentError e - then retryLoop + then loop else done - where - retryLoop = do - maintenance `catchError` \_ -> pure () - atomically $ endAgentOperation c agentOp - atomically $ throwWhenInactive c - atomically $ beginAgentOperation c agentOp - loop rcvWorkerInternalError :: AgentMonad m => AgentClient -> DBRcvFileId -> RcvFileId -> Maybe FilePath -> String -> m () rcvWorkerInternalError c rcvFileId rcvFileEntityId tmpPath internalErrStr = do @@ -229,7 +228,7 @@ runXFTPRcvLocalWorker :: forall m. AgentMonad m => AgentClient -> TMVar () -> m runXFTPRcvLocalWorker c doWork = do forever $ do void . atomically $ readTMVar doWork - -- TODO agentOperationBracket? + -- TODO waitUntilNotSuspended runXftpOperation where runXftpOperation :: m () @@ -350,7 +349,7 @@ runXFTPSndPrepareWorker :: forall m. AgentMonad m => AgentClient -> TMVar () -> runXFTPSndPrepareWorker c doWork = do forever $ do void . atomically $ readTMVar doWork - -- TODO agentOperationBracket + -- TODO waitUntilNotSuspended runXftpOperation where runXftpOperation :: m () @@ -364,7 +363,7 @@ runXFTPSndPrepareWorker c doWork = do prepareFile :: SndFile -> m () prepareFile SndFile {prefixPath = Nothing} = throwError $ INTERNAL "no prefix path" - prepareFile sndFile@SndFile {sndFileId, prefixPath = Just ppath, status} = do + prepareFile sndFile@SndFile {sndFileId, userId, prefixPath = Just ppath, status} = do SndFile {numRecipients, chunks} <- if status /= SFSEncrypted -- status is SFSNew or SFSEncrypting then do @@ -380,7 +379,7 @@ runXFTPSndPrepareWorker c doWork = do maxRecipients <- asks (xftpMaxRecipientsPerRequest . config) let numRecipients' = min numRecipients maxRecipients -- concurrently? - forM_ chunks $ createChunk numRecipients' + forM_ (filter (not . chunkCreated) chunks) $ createChunk numRecipients' withStore' c $ \db -> updateSndFileStatus db sndFileId SFSUploading where encryptFileForUpload :: SndFile -> FilePath -> m (FileDigest, [(XFTPChunkSpec, FileDigest)]) @@ -398,17 +397,33 @@ runXFTPSndPrepareWorker c doWork = do let chunkSpecs = prepareChunkSpecs fsEncPath chunkSizes chunkDigests <- map FileDigest <$> mapM (liftIO . getChunkDigest) chunkSpecs pure (FileDigest digest, zip chunkSpecs chunkDigests) + chunkCreated :: SndFileChunk -> Bool + chunkCreated SndFileChunk {replicas} = + any (\SndFileChunkReplica {replicaStatus} -> replicaStatus == SFRSCreated) replicas createChunk :: Int -> SndFileChunk -> m () createChunk numRecipients' ch = do - srvAuth@(ProtoServerWithAuth srv _) <- getServer - replica <- agentXFTPNewChunk c ch numRecipients' srvAuth + -- TODO waitUntilNotSuspended + (replica, ProtoServerWithAuth srv _) <- agentOperationBracket c AOSndNetwork throwWhenInactive tryCreate withStore' c $ \db -> createSndFileReplica db ch replica addXFTPSndWorker c $ Just srv - getServer :: m XFTPServerWithAuth - getServer = do - -- TODO get user servers from config - -- TODO choose next server (per chunk? per file?) - undefined + where + tryCreate = do + ri <- asks $ messageRetryInterval . config + usedSrvs <- newTVarIO ([] :: [XFTPServer]) + withRetryInterval (riFast ri) $ \_ loop -> + createWithNextSrv usedSrvs + `catchError` \e -> retryOnError "XFTP prepare worker" (retryLoop loop) (throwError e) e + where + retryLoop loop = do + -- TODO waitUntilNotSuspended + atomically $ endAgentOperation c AOSndNetwork + atomically $ throwWhenInactive c + atomically $ beginAgentOperation c AOSndNetwork + loop + createWithNextSrv usedSrvs = do + withNextSrv c userId usedSrvs [] $ \srvAuth -> do + replica <- agentXFTPNewChunk c ch numRecipients' srvAuth + pure (replica, srvAuth) sndWorkerInternalError :: AgentMonad m => AgentClient -> DBSndFileId -> SndFileId -> Maybe FilePath -> String -> m () sndWorkerInternalError c sndFileId sndFileEntityId prefixPath internalErrStr = do @@ -420,6 +435,7 @@ runXFTPSndWorker :: forall m. AgentMonad m => AgentClient -> XFTPServer -> TMVar runXFTPSndWorker c srv doWork = do forever $ do void . atomically $ readTMVar doWork + -- TODO waitUntilNotSuspended agentOperationBracket c AOSndNetwork throwWhenInactive runXftpOperation where noWorkToDo = void . atomically $ tryTakeTMVar doWork @@ -434,19 +450,26 @@ runXFTPSndWorker c srv doWork = do let ri' = maybe ri (\d -> ri {initialInterval = d, increaseAfter = 0}) delay withRetryInterval ri' $ \delay' loop -> uploadFileChunk fc replica - `catchError` \e -> retryOnError c AOSndNetwork "XFTP snd worker" loop (retryMaintenance e delay') (retryDone e) e + `catchError` \e -> retryOnError "XFTP snd worker" (retryLoop loop e delay') (retryDone e) e where - retryMaintenance e replicaDelay = do - notifyOnRetry <- asks (xftpNotifyErrsOnRetry . config) - when notifyOnRetry $ notify c sndFileEntityId $ SFERR e - closeXFTPServerClient c userId server replicaId - withStore' c $ \db -> updateRcvChunkReplicaDelay db sndChunkReplicaId replicaDelay + retryLoop loop e replicaDelay = do + flip catchError (\_ -> pure ()) $ do + notifyOnRetry <- asks (xftpNotifyErrsOnRetry . config) + when notifyOnRetry $ notify c sndFileEntityId $ SFERR e + closeXFTPServerClient c userId server replicaId + withStore' c $ \db -> updateSndChunkReplicaDelay db sndChunkReplicaId replicaDelay + -- TODO waitUntilNotSuspended + atomically $ endAgentOperation c AOSndNetwork + atomically $ throwWhenInactive c + atomically $ beginAgentOperation c AOSndNetwork + loop retryDone e = sndWorkerInternalError c sndFileId sndFileEntityId (Just filePrefixPath) (show e) uploadFileChunk :: SndFileChunk -> SndFileChunkReplica -> m () uploadFileChunk sndFileChunk@SndFileChunk {sndFileId, sndChunkId, userId, chunkSpec = chunkSpec@XFTPChunkSpec {filePath}} replica = do replica'@SndFileChunkReplica {sndChunkReplicaId} <- addRecipients sndFileChunk replica fsFilePath <- toFSFilePath filePath let chunkSpec' = chunkSpec {filePath = fsFilePath} :: XFTPChunkSpec + -- TODO waitUntilNotSuspended agentXFTPUploadChunk c userId sndChunkId replica' chunkSpec' sf@SndFile {sndFileEntityId, prefixPath, chunks} <- withStore c $ \db -> do updateSndChunkReplicaStatus db sndChunkReplicaId SFRSUploaded @@ -524,5 +547,6 @@ runXFTPSndWorker c srv doWork = do Just ch@FileChunk {replicas} -> ch {replicas = replica' : replicas} _ -> FileChunk {chunkNo, digest, chunkSize, replicas = [replica']} replica' = FileChunkReplica {server, replicaId, replicaKey} + chunkUploaded :: SndFileChunk -> Bool chunkUploaded SndFileChunk {replicas} = any (\SndFileChunkReplica {replicaStatus} -> replicaStatus == SFRSUploaded) replicas diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 785e76d08..f4956a965 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -107,7 +107,7 @@ import qualified Data.ByteString.Char8 as B import Data.Composition ((.:), (.:.), (.::)) import Data.Foldable (foldl') import Data.Functor (($>)) -import Data.List (deleteFirstsBy, find, (\\)) +import Data.List (find) import Data.List.NonEmpty (NonEmpty (..), (<|)) import qualified Data.List.NonEmpty as L import Data.Map.Strict (Map) @@ -140,12 +140,11 @@ import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfRegCode (NtfReg import Simplex.Messaging.Notifications.Server.Push.APNS (PNMessageData (..)) import Simplex.Messaging.Notifications.Types import Simplex.Messaging.Parsers (parse) -import Simplex.Messaging.Protocol (BrokerMsg, EntityId, ErrorType (AUTH), MsgBody, MsgFlags, NtfServer, ProtoServerWithAuth, ProtocolTypeI (..), SMPMsgMeta, SProtocolType (..), SndPublicVerifyKey, UserProtocol, XFTPServerWithAuth, protoServer, sameSrvAddr') +import Simplex.Messaging.Protocol (BrokerMsg, EntityId, ErrorType (AUTH), MsgBody, MsgFlags, NtfServer, ProtoServerWithAuth, ProtocolTypeI (..), SMPMsgMeta, SProtocolType (..), SndPublicVerifyKey, UserProtocol, XFTPServerWithAuth) import qualified Simplex.Messaging.Protocol as SMP import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Util import Simplex.Messaging.Version -import System.Random (randomR) import UnliftIO.Async (async, race_) import UnliftIO.Concurrent (forkFinally, forkIO, threadDelay) import qualified UnliftIO.Exception as E @@ -351,6 +350,8 @@ xftpDeleteRcvFile c = withAgentEnv c .: deleteRcvFile c xftpSendFile :: AgentErrorMonad m => AgentClient -> UserId -> FilePath -> Int -> m SndFileId xftpSendFile c = withAgentEnv c .:. sendFileExperimental c +-- TODO rename setAgentForeground + -- | Activate operations activateAgent :: MonadUnliftIO m => AgentClient -> m () activateAgent c = withAgentEnv c $ activateAgent' c @@ -551,7 +552,7 @@ joinConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> Bool -> C joinConn c userId connId asyncMode enableNtfs cReq cInfo = do srv <- case cReq of CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _ -> - getNextSMPServer c userId [qServer q] + getNextServer c userId [qServer q] _ -> getSMPServer c userId joinConnSrv c userId connId asyncMode enableNtfs cReq cInfo srv @@ -847,13 +848,13 @@ runCommandProcessing c@AgentClient {subQ} server_ = do AClientCommand (APC _ cmd) -> case cmd of NEW enableNtfs (ACM cMode) -> noServer $ do usedSrvs <- newTVarIO ([] :: [SMPServer]) - tryCommand . withNextSrv usedSrvs [] $ \srv -> do + tryCommand . withNextSrv c userId usedSrvs [] $ \srv -> do (_, cReq) <- newRcvConnSrv c userId connId enableNtfs cMode Nothing srv notify $ INV (ACR cMode cReq) JOIN enableNtfs (ACR _ cReq@(CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _)) connInfo -> noServer $ do let initUsed = [qServer q] usedSrvs <- newTVarIO initUsed - tryCommand . withNextSrv usedSrvs initUsed $ \srv -> do + tryCommand . withNextSrv c userId usedSrvs initUsed $ \srv -> do void $ joinConnSrv c userId connId True enableNtfs cReq connInfo srv notify OK LET confId ownCInfo -> withServer' . tryCommand $ allowConnection' c connId confId ownCInfo >> notify OK @@ -933,16 +934,6 @@ runCommandProcessing c@AgentClient {subQ} server_ = do cmdError e = notify (ERR e) >> withStore' c (`deleteCommand` cmdId) notify :: forall e. AEntityI e => ACommand 'Agent e -> m () notify cmd = atomically $ writeTBQueue subQ (corrId, connId, APC (sAEntity @e) cmd) - withNextSrv :: TVar [SMPServer] -> [SMPServer] -> (SMPServerWithAuth -> m ()) -> m () - withNextSrv usedSrvs initUsed action = do - used <- readTVarIO usedSrvs - srvAuth@(ProtoServerWithAuth srv _) <- getNextSMPServer c userId used - atomically $ do - srvs_ <- TM.lookup userId $ smpServers c - let unused = maybe [] ((\\ used) . map protoServer . L.toList) srvs_ - used' = if null unused then initUsed else srv : used - writeTVar usedSrvs $! used' - action srvAuth -- ^ ^ ^ async command processing / enqueueMessages :: AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> MsgFlags -> AMessage -> m AgentMsgId @@ -1023,7 +1014,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {userId, connId, dupl atomically $ throwWhenNoDelivery c sq msgId <- atomically $ readTQueue mq atomically $ beginAgentOperation c AOSndNetwork - atomically $ endAgentOperation c AOMsgDelivery + atomically $ endAgentOperation c AOMsgDelivery -- this operation begins in queuePendingMsgs let mId = unId msgId E.try (withStore c $ \db -> getPendingMsgData db connId msgId) >>= \case Left (e :: E.SomeException) -> @@ -1185,8 +1176,8 @@ switchConnection' c connId = withConnLock c connId "switchConnection" $ do DuplexConnection cData@ConnData {userId} rqs@(rq@RcvQueue {server, dbQueueId, sndId} :| rqs_) sqs -> do clientVRange <- asks $ smpClientVRange . config -- try to get the server that is different from all queues, or at least from the primary rcv queue - srvAuth@(ProtoServerWithAuth srv _) <- getNextSMPServer c userId $ map qServer (L.toList rqs) <> map qServer (L.toList sqs) - srv' <- if srv == server then getNextSMPServer c userId [server] else pure srvAuth + srvAuth@(ProtoServerWithAuth srv _) <- getNextServer c userId $ map qServer (L.toList rqs) <> map qServer (L.toList sqs) + srv' <- if srv == server then getNextServer c userId [server] else pure srvAuth (q, qUri) <- newRcvQueue c userId connId srv' clientVRange let rq' = (q :: RcvQueue) {primary = True, dbReplaceQueueId = Just dbQueueId} void . withStore c $ \db -> addConnRcvQueue db connId rq' @@ -1340,11 +1331,7 @@ connectionStats = \case -- | Change servers to be used for creating new queues, in Reader monad setProtocolServers' :: forall p m. (ProtocolTypeI p, UserProtocol p, AgentMonad m) => AgentClient -> UserId -> NonEmpty (ProtoServerWithAuth p) -> m () -setProtocolServers' c userId srvs = servers >>= atomically . TM.insert userId srvs - where - servers = case protocolTypeI @p of - SPSMP -> pure $ smpServers c - SPXFTP -> pure $ xftpServers c +setProtocolServers' c userId srvs = atomically $ TM.insert userId srvs (userServers c) registerNtfToken' :: forall m. AgentMonad m => AgentClient -> DeviceToken -> NotificationsMode -> m NtfTknStatus registerNtfToken' c suppliedDeviceToken suppliedNtfMode = @@ -1590,25 +1577,6 @@ debugAgentLocks' AgentClient {connLocks = cs, reconnectLocks = rs, deleteLock = getSMPServer :: AgentMonad m => AgentClient -> UserId -> m SMPServerWithAuth getSMPServer c userId = withUserServers c userId pickServer -pickServer :: AgentMonad' m => NonEmpty SMPServerWithAuth -> m SMPServerWithAuth -pickServer = \case - srv :| [] -> pure srv - servers -> do - gen <- asks randomServer - atomically $ (servers L.!!) <$> stateTVar gen (randomR (0, L.length servers - 1)) - -getNextSMPServer :: AgentMonad m => AgentClient -> UserId -> [SMPServer] -> m SMPServerWithAuth -getNextSMPServer c userId usedSrvs = withUserServers c userId $ \srvs -> - case L.nonEmpty $ deleteFirstsBy sameSrvAddr' (L.toList srvs) (map noAuthSrv usedSrvs) of - Just srvs' -> pickServer srvs' - _ -> pickServer srvs - -withUserServers :: AgentMonad m => AgentClient -> UserId -> (NonEmpty SMPServerWithAuth -> m a) -> m a -withUserServers c userId action = - atomically (TM.lookup userId $ smpServers c) >>= \case - Just srvs -> action srvs - _ -> throwError $ INTERNAL "unknown userId - no SMP servers" - subscriber :: AgentMonad' m => AgentClient -> m () subscriber c@AgentClient {msgQ} = forever $ do t <- atomically $ readTBQueue msgQ diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 83e059138..3c178842c 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -87,6 +87,11 @@ module Simplex.Messaging.Agent.Client withStore, withStore', storeError, + userServers, + pickServer, + getNextServer, + withUserServers, + withNextSrv, ) where @@ -109,7 +114,7 @@ import qualified Data.ByteString.Char8 as B import Data.Either (lefts, partitionEithers) import Data.Functor (($>)) import Data.Int (Int64) -import Data.List (foldl', partition) +import Data.List (deleteFirstsBy, foldl', partition, (\\)) import Data.List.NonEmpty (NonEmpty (..), (<|)) import qualified Data.List.NonEmpty as L import Data.Map.Strict (Map) @@ -160,6 +165,7 @@ import Simplex.Messaging.Protocol NtfPublicVerifyKey, NtfServer, ProtoServer, + ProtoServerWithAuth (..), Protocol (..), ProtocolServer (..), ProtocolTypeI (..), @@ -167,11 +173,13 @@ import Simplex.Messaging.Protocol QueueIdsKeys (..), RcvMessage (..), RcvNtfPublicDhKey, - RecipientId, SMPMsgMeta (..), + SProtocolType (..), SndPublicVerifyKey, + UserProtocol, XFTPServer, XFTPServerWithAuth, + sameSrvAddr', ) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.TMap (TMap) @@ -179,6 +187,7 @@ import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport.Client (TransportHost) import Simplex.Messaging.Util import Simplex.Messaging.Version +import System.Random (randomR) import System.Timeout (timeout) import UnliftIO (mapConcurrently) import UnliftIO.Directory (getTemporaryDirectory) @@ -256,7 +265,7 @@ agentOperations = [ntfNetworkOp, rcvNetworkOp, msgDeliveryOp, sndNetworkOp, data data AgentOpState = AgentOpState {opSuspended :: Bool, opsInProgress :: Int} -data AgentState = ASActive | ASSuspending | ASSuspended +data AgentState = ASActive | ASSuspending | ASSuspended -- TODO rename ASActive -> ASForeground deriving (Eq, Show) data AgentLocks = AgentLocks {connLocks :: Map String String, srvLocks :: Map String String, delLock :: Maybe String} @@ -1245,3 +1254,38 @@ incClientStatN c userId pc n cmd res = do atomically $ incStat c n statsKey where statsKey = AgentStatsKey {userId, host = strEncode $ clientTransportHost pc, clientTs = strEncode $ clientSessionTs pc, cmd, res} + +userServers :: forall p. (ProtocolTypeI p, UserProtocol p) => AgentClient -> TMap UserId (NonEmpty (ProtoServerWithAuth p)) +userServers c = case protocolTypeI @p of + SPSMP -> smpServers c + SPXFTP -> xftpServers c + +pickServer :: forall p m. (AgentMonad' m) => NonEmpty (ProtoServerWithAuth p) -> m (ProtoServerWithAuth p) +pickServer = \case + srv :| [] -> pure srv + servers -> do + gen <- asks randomServer + atomically $ (servers L.!!) <$> stateTVar gen (randomR (0, L.length servers - 1)) + +getNextServer :: forall p m. (ProtocolTypeI p, UserProtocol p, AgentMonad m) => AgentClient -> UserId -> [ProtocolServer p] -> m (ProtoServerWithAuth p) +getNextServer c userId usedSrvs = withUserServers c userId $ \srvs -> + case L.nonEmpty $ deleteFirstsBy sameSrvAddr' (L.toList srvs) (map noAuthSrv usedSrvs) of + Just srvs' -> pickServer srvs' + _ -> pickServer srvs + +withUserServers :: forall p m a. (ProtocolTypeI p, UserProtocol p, AgentMonad m) => AgentClient -> UserId -> (NonEmpty (ProtoServerWithAuth p) -> m a) -> m a +withUserServers c userId action = + atomically (TM.lookup userId $ userServers c) >>= \case + Just srvs -> action srvs + _ -> throwError $ INTERNAL "unknown userId - no user servers" + +withNextSrv :: forall p m a. (ProtocolTypeI p, UserProtocol p, AgentMonad m) => AgentClient -> UserId -> TVar [ProtocolServer p] -> [ProtocolServer p] -> ((ProtoServerWithAuth p) -> m a) -> m a +withNextSrv c userId usedSrvs initUsed action = do + used <- readTVarIO usedSrvs + srvAuth@(ProtoServerWithAuth srv _) <- getNextServer c userId used + atomically $ do + srvs_ <- TM.lookup userId $ userServers c + let unused = maybe [] ((\\ used) . map protoServer . L.toList) srvs_ + used' = if null unused then initUsed else srv : used + writeTVar usedSrvs $! used' + action srvAuth diff --git a/src/Simplex/Messaging/Agent/RetryInterval.hs b/src/Simplex/Messaging/Agent/RetryInterval.hs index 7f88592fc..97d537a5a 100644 --- a/src/Simplex/Messaging/Agent/RetryInterval.hs +++ b/src/Simplex/Messaging/Agent/RetryInterval.hs @@ -48,10 +48,10 @@ updateRetryInterval2 RI2State {slowInterval, fastInterval} RetryInterval2 {riSlo data RetryIntervalMode = RISlow | RIFast deriving (Eq, Show) -withRetryInterval :: forall m. MonadIO m => RetryInterval -> (Int64 -> m () -> m ()) -> m () +withRetryInterval :: forall m a. MonadIO m => RetryInterval -> (Int64 -> m a -> m a) -> m a withRetryInterval ri action = callAction 0 $ initialInterval ri where - callAction :: Int64 -> Int64 -> m () + callAction :: Int64 -> Int64 -> m a callAction elapsed delay = action delay loop where loop = do