mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-05-13 20:53:13 +00:00
agent: make agent workers usable from other contexts (#1614)
This commit is contained in:
@@ -75,7 +75,7 @@ import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String (strDecode, strEncode)
|
||||
import Simplex.Messaging.Protocol (ProtocolServer, ProtocolType (..), XFTPServer)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Util (catchAll_, liftError, tshow, unlessM, whenM)
|
||||
import Simplex.Messaging.Util (allFinally, catchAll_, catchAllErrors, liftError, tshow, unlessM, whenM)
|
||||
import System.FilePath (takeFileName, (</>))
|
||||
import UnliftIO
|
||||
import UnliftIO.Directory
|
||||
@@ -198,10 +198,10 @@ runXFTPRcvWorker c srv Worker {doWork} = do
|
||||
liftIO $ waitForUserNetwork c
|
||||
atomically $ incXFTPServerStat c userId srv downloadAttempts
|
||||
downloadFileChunk fc replica approvedRelays
|
||||
`catchAgentError` \e -> retryOnError "XFTP rcv worker" (retryLoop loop e delay') (retryDone e) e
|
||||
`catchAllErrors` \e -> retryOnError "XFTP rcv worker" (retryLoop loop e delay') (retryDone e) e
|
||||
where
|
||||
retryLoop loop e replicaDelay = do
|
||||
flip catchAgentError (\_ -> pure ()) $ do
|
||||
flip catchAllErrors (\_ -> pure ()) $ do
|
||||
when (serverHostError e) $ notify c (fromMaybe rcvFileEntityId redirectEntityId_) (RFWARN e)
|
||||
liftIO $ closeXFTPServerClient c userId server digest
|
||||
withStore' c $ \db -> updateRcvChunkReplicaDelay db rcvChunkReplicaId replicaDelay
|
||||
@@ -280,7 +280,7 @@ runXFTPRcvLocalWorker c Worker {doWork} = do
|
||||
runXFTPOperation AgentConfig {rcvFilesTTL} =
|
||||
withWork c doWork (`getNextRcvFileToDecrypt` rcvFilesTTL) $
|
||||
\f@RcvFile {rcvFileId, rcvFileEntityId, tmpPath, redirect} ->
|
||||
decryptFile f `catchAgentError` rcvWorkerInternalError c rcvFileId rcvFileEntityId (redirectEntityId <$> redirect) tmpPath
|
||||
decryptFile f `catchAllErrors` rcvWorkerInternalError c rcvFileId rcvFileEntityId (redirectEntityId <$> redirect) tmpPath
|
||||
decryptFile :: RcvFile -> AM ()
|
||||
decryptFile RcvFile {rcvFileId, rcvFileEntityId, size, digest, key, nonce, tmpPath, saveFile, status, chunks, redirect} = do
|
||||
let CryptoFile savePath cfArgs = saveFile
|
||||
@@ -307,7 +307,7 @@ runXFTPRcvLocalWorker c Worker {doWork} = do
|
||||
liftIO $ waitUntilForeground c
|
||||
withStore' c (`updateRcvFileComplete` rcvFileId)
|
||||
-- proceed with redirect
|
||||
yaml <- liftError (FILE . FILE_IO . show) (CF.readFile $ CryptoFile fsSavePath cfArgs) `agentFinally` (lift $ toFSFilePath fsSavePath >>= removePath)
|
||||
yaml <- liftError (FILE . FILE_IO . show) (CF.readFile $ CryptoFile fsSavePath cfArgs) `allFinally` (lift $ toFSFilePath fsSavePath >>= removePath)
|
||||
next@FileDescription {chunks = nextChunks} <- case strDecode (LB.toStrict yaml) of
|
||||
-- TODO switch to another error constructor
|
||||
Left _ -> throwE . FILE $ REDIRECT "decode error"
|
||||
@@ -399,7 +399,7 @@ runXFTPSndPrepareWorker c Worker {doWork} = do
|
||||
runXFTPOperation cfg@AgentConfig {sndFilesTTL} =
|
||||
withWork c doWork (`getNextSndFileToPrepare` sndFilesTTL) $
|
||||
\f@SndFile {sndFileId, sndFileEntityId, prefixPath} ->
|
||||
prepareFile cfg f `catchAgentError` sndWorkerInternalError c sndFileId sndFileEntityId prefixPath
|
||||
prepareFile cfg f `catchAllErrors` sndWorkerInternalError c sndFileId sndFileEntityId prefixPath
|
||||
prepareFile :: AgentConfig -> SndFile -> AM ()
|
||||
prepareFile _ SndFile {prefixPath = Nothing} =
|
||||
throwE $ INTERNAL "no prefix path"
|
||||
@@ -468,11 +468,11 @@ runXFTPSndPrepareWorker c Worker {doWork} = do
|
||||
liftIO $ waitForUserNetwork c
|
||||
let triedAllSrvs = n > userSrvCount
|
||||
createWithNextSrv triedHosts
|
||||
`catchAgentError` \e -> retryOnError "XFTP prepare worker" (retryLoop loop triedAllSrvs e) (throwE e) e
|
||||
`catchAllErrors` \e -> retryOnError "XFTP prepare worker" (retryLoop loop triedAllSrvs e) (throwE e) e
|
||||
where
|
||||
-- we don't do closeXFTPServerClient here to not risk closing connection for concurrent chunk upload
|
||||
retryLoop loop triedAllSrvs e = do
|
||||
flip catchAgentError (\_ -> pure ()) $ do
|
||||
flip catchAllErrors (\_ -> pure ()) $ do
|
||||
when (triedAllSrvs && serverHostError e) $ notify c sndFileEntityId $ SFWARN e
|
||||
liftIO $ assertAgentForeground c
|
||||
loop
|
||||
@@ -508,10 +508,10 @@ runXFTPSndWorker c srv Worker {doWork} = do
|
||||
liftIO $ waitForUserNetwork c
|
||||
atomically $ incXFTPServerStat c userId srv uploadAttempts
|
||||
uploadFileChunk cfg fc replica
|
||||
`catchAgentError` \e -> retryOnError "XFTP snd worker" (retryLoop loop e delay') (retryDone e) e
|
||||
`catchAllErrors` \e -> retryOnError "XFTP snd worker" (retryLoop loop e delay') (retryDone e) e
|
||||
where
|
||||
retryLoop loop e replicaDelay = do
|
||||
flip catchAgentError (\_ -> pure ()) $ do
|
||||
flip catchAllErrors (\_ -> pure ()) $ do
|
||||
when (serverHostError e) $ notify c sndFileEntityId $ SFWARN e
|
||||
liftIO $ closeXFTPServerClient c userId server digest
|
||||
withStore' c $ \db -> updateSndChunkReplicaDelay db sndChunkReplicaId replicaDelay
|
||||
@@ -681,10 +681,10 @@ runXFTPDelWorker c srv Worker {doWork} = do
|
||||
liftIO $ waitForUserNetwork c
|
||||
atomically $ incXFTPServerStat c userId srv deleteAttempts
|
||||
deleteChunkReplica
|
||||
`catchAgentError` \e -> retryOnError "XFTP del worker" (retryLoop loop e delay') (retryDone e) e
|
||||
`catchAllErrors` \e -> retryOnError "XFTP del worker" (retryLoop loop e delay') (retryDone e) e
|
||||
where
|
||||
retryLoop loop e replicaDelay = do
|
||||
flip catchAgentError (\_ -> pure ()) $ do
|
||||
flip catchAllErrors (\_ -> pure ()) $ do
|
||||
when (serverHostError e) $ notify c "" $ SFWARN e
|
||||
liftIO $ closeXFTPServerClient c userId server chunkDigest
|
||||
withStore' c $ \db -> updateDeletedSndChunkReplicaDelay db deletedSndChunkReplicaId replicaDelay
|
||||
|
||||
@@ -284,13 +284,13 @@ saveServersStats c@AgentClient {subQ, smpServersStats, xftpServersStats, ntfServ
|
||||
xss <- mapM (liftIO . getAgentXFTPServerStats) =<< readTVarIO xftpServersStats
|
||||
nss <- mapM (liftIO . getAgentNtfServerStats) =<< readTVarIO ntfServersStats
|
||||
let stats = AgentPersistedServerStats {smpServersStats = sss, xftpServersStats = xss, ntfServersStats = OptionalMap nss}
|
||||
tryAgentError' (withStore' c (`updateServersStats` stats)) >>= \case
|
||||
tryAllErrors' (withStore' c (`updateServersStats` stats)) >>= \case
|
||||
Left e -> atomically $ writeTBQueue subQ ("", "", AEvt SAEConn $ ERR $ INTERNAL $ show e)
|
||||
Right () -> pure ()
|
||||
|
||||
restoreServersStats :: AgentClient -> AM' ()
|
||||
restoreServersStats c@AgentClient {smpServersStats, xftpServersStats, ntfServersStats, srvStatsStartedAt} = do
|
||||
tryAgentError' (withStore c getServersStats) >>= \case
|
||||
tryAllErrors' (withStore c getServersStats) >>= \case
|
||||
Left e -> atomically $ writeTBQueue (subQ c) ("", "", AEvt SAEConn $ ERR $ INTERNAL $ show e)
|
||||
Right (startedAt, Nothing) -> atomically $ writeTVar srvStatsStartedAt startedAt
|
||||
Right (startedAt, Just AgentPersistedServerStats {smpServersStats = sss, xftpServersStats = xss, ntfServersStats = OptionalMap nss}) -> do
|
||||
@@ -774,7 +774,7 @@ acceptContactAsync' :: AgentClient -> UserId -> ACorrId -> Bool -> InvitationId
|
||||
acceptContactAsync' c userId corrId enableNtfs invId ownConnInfo pqSupport subMode = do
|
||||
Invitation {connReq} <- withStore c $ \db -> getInvitation db "acceptContactAsync'" invId
|
||||
withStore' c $ \db -> acceptInvitation db invId ownConnInfo
|
||||
joinConnAsync c userId corrId enableNtfs connReq ownConnInfo pqSupport subMode `catchAgentError` \err -> do
|
||||
joinConnAsync c userId corrId enableNtfs connReq ownConnInfo pqSupport subMode `catchAllErrors` \err -> do
|
||||
withStore' c (`unacceptInvitation` invId)
|
||||
throwE err
|
||||
|
||||
@@ -961,7 +961,7 @@ newRcvConnSrv c nm userId connId enableNtfs cMode userData_ clientData pqInitKey
|
||||
createRcvQueue nonce_ qd e2eKeys = do
|
||||
AgentConfig {smpClientVRange = vr} <- asks config
|
||||
ntfServer_ <- if enableNtfs then newQueueNtfServer else pure Nothing
|
||||
(rq, qUri, tSess, sessId) <- newRcvQueue_ c nm userId connId srvWithAuth vr qd (isJust ntfServer_) subMode nonce_ e2eKeys `catchAgentError` \e -> liftIO (print e) >> throwE e
|
||||
(rq, qUri, tSess, sessId) <- newRcvQueue_ c nm userId connId srvWithAuth vr qd (isJust ntfServer_) subMode nonce_ e2eKeys `catchAllErrors` \e -> liftIO (print e) >> throwE e
|
||||
atomically $ incSMPServerStat c userId srv connCreated
|
||||
rq' <- withStore c $ \db -> updateNewConnRcv db connId rq
|
||||
lift . when (subMode == SMSubscribe) $ addNewQueueSubscription c rq' tSess sessId
|
||||
@@ -1351,7 +1351,7 @@ subscribeClientService' = undefined
|
||||
|
||||
-- requesting messages sequentially, to reduce memory usage
|
||||
getConnectionMessages' :: AgentClient -> NonEmpty ConnMsgReq -> AM' (NonEmpty (Either AgentErrorType (Maybe SMPMsgMeta)))
|
||||
getConnectionMessages' c = mapM $ tryAgentError' . getConnectionMessage
|
||||
getConnectionMessages' c = mapM $ tryAllErrors' . getConnectionMessage
|
||||
where
|
||||
getConnectionMessage :: ConnMsgReq -> AM (Maybe SMPMsgMeta)
|
||||
getConnectionMessage (ConnMsgReq connId dbQueueId msgTs_) = do
|
||||
@@ -1363,7 +1363,7 @@ getConnectionMessages' c = mapM $ tryAgentError' . getConnectionMessage
|
||||
ContactConnection _ rq -> pure rq
|
||||
SndConnection _ _ -> throwE $ CONN SIMPLEX "getConnectionMessage"
|
||||
NewConnection _ -> throwE $ CMD PROHIBITED "getConnectionMessage: NewConnection"
|
||||
msg_ <- getQueueMessage c rq `catchAgentError` \e -> atomically (releaseGetLock c rq) >> throwError e
|
||||
msg_ <- getQueueMessage c rq `catchAllErrors` \e -> atomically (releaseGetLock c rq) >> throwError e
|
||||
when (isNothing msg_) $ do
|
||||
atomically $ releaseGetLock c rq
|
||||
forM_ msgTs_ $ \msgTs -> withStore' c $ \db -> setLastBrokerTs db connId (DBEntityId dbQueueId) msgTs
|
||||
@@ -1534,7 +1534,7 @@ runCommandProcessing c@AgentClient {subQ} connId server_ Worker {doWork} = do
|
||||
pure CCCompleted
|
||||
-- duplex connection is matched to handle SKEY retries
|
||||
DuplexConnection cData _ (sq :| _) -> do
|
||||
tryAgentError (mapM_ (connectReplyQueues c cData ownConnInfo (Just sq)) (L.nonEmpty $ smpReplyQueues senderConf)) >>= \case
|
||||
tryAllErrors (mapM_ (connectReplyQueues c cData ownConnInfo (Just sq)) (L.nonEmpty $ smpReplyQueues senderConf)) >>= \case
|
||||
Right () -> pure CCCompleted
|
||||
Left e
|
||||
| temporaryOrHostError e && Just server /= server_ -> do
|
||||
@@ -1621,7 +1621,7 @@ runCommandProcessing c@AgentClient {subQ} connId server_ Worker {doWork} = do
|
||||
tryMoveableCommand action = withRetryInterval ri $ \_ loop -> do
|
||||
liftIO $ waitWhileSuspended c
|
||||
liftIO $ waitForUserNetwork c
|
||||
tryAgentError action >>= \case
|
||||
tryAllErrors action >>= \case
|
||||
Left e
|
||||
| temporaryOrHostError e -> retrySndOp c loop
|
||||
| otherwise -> cmdError e
|
||||
@@ -2065,7 +2065,7 @@ synchronizeRatchet' c connId pqSupport' force = withConnLock c connId "synchroni
|
||||
ackQueueMessage :: AgentClient -> RcvQueue -> SMP.MsgId -> AM ()
|
||||
ackQueueMessage c rq@RcvQueue {userId, connId, server} srvMsgId = do
|
||||
atomically $ incSMPServerStat c userId server ackAttempts
|
||||
tryAgentError (sendAck c rq srvMsgId) >>= \case
|
||||
tryAllErrors (sendAck c rq srvMsgId) >>= \case
|
||||
Right _ -> sendMsgNtf ackMsgs
|
||||
Left (SMP _ SMP.NO_MSG) -> sendMsgNtf ackNoMsgErrs
|
||||
Left e -> do
|
||||
@@ -2076,7 +2076,7 @@ ackQueueMessage c rq@RcvQueue {userId, connId, server} srvMsgId = do
|
||||
atomically $ incSMPServerStat c userId server stat
|
||||
whenM (liftIO $ hasGetLock c rq) $ do
|
||||
atomically $ releaseGetLock c rq
|
||||
brokerTs_ <- eitherToMaybe <$> tryAgentError (withStore c $ \db -> getRcvMsgBrokerTs db connId srvMsgId)
|
||||
brokerTs_ <- eitherToMaybe <$> tryAllErrors (withStore c $ \db -> getRcvMsgBrokerTs db connId srvMsgId)
|
||||
atomically $ writeTBQueue (subQ c) ("", connId, AEvt SAEConn $ MSGNTF srvMsgId brokerTs_)
|
||||
|
||||
-- | Suspend SMP agent connection (OFF command) in Reader monad
|
||||
@@ -2307,7 +2307,7 @@ registerNtfToken' c nm suppliedDeviceToken suppliedNtfMode =
|
||||
replaceToken :: NtfTokenId -> AM NtfTknStatus
|
||||
replaceToken tknId = do
|
||||
ns <- asks ntfSupervisor
|
||||
tryReplace ns `catchAgentError` \e ->
|
||||
tryReplace ns `catchAllErrors` \e ->
|
||||
if temporaryOrHostError e
|
||||
then throwE e
|
||||
else do
|
||||
@@ -2564,7 +2564,7 @@ cleanupManager c@AgentClient {subQ} = do
|
||||
where
|
||||
run :: forall e. AEntityI e => (AgentErrorType -> AEvent e) -> AM () -> AM' ()
|
||||
run err a = do
|
||||
waitActive . runExceptT $ a `catchAgentError` (notify "" . err)
|
||||
waitActive . runExceptT $ a `catchAllErrors` (notify "" . err)
|
||||
step <- asks $ cleanupStepInterval . config
|
||||
liftIO $ threadDelay step
|
||||
-- we are catching it to avoid CRITICAL errors in tests when this is the only remaining handle to active
|
||||
@@ -2578,33 +2578,33 @@ cleanupManager c@AgentClient {subQ} = do
|
||||
deleteRcvFilesExpired = do
|
||||
rcvFilesTTL <- asks $ rcvFilesTTL . config
|
||||
rcvExpired <- withStore' c (`getRcvFilesExpired` rcvFilesTTL)
|
||||
forM_ rcvExpired $ \(dbId, entId, p) -> flip catchAgentError (notify entId . RFERR) $ do
|
||||
forM_ rcvExpired $ \(dbId, entId, p) -> flip catchAllErrors (notify entId . RFERR) $ do
|
||||
lift $ removePath =<< toFSFilePath p
|
||||
withStore' c (`deleteRcvFile'` dbId)
|
||||
deleteRcvFilesDeleted = do
|
||||
rcvDeleted <- withStore' c getCleanupRcvFilesDeleted
|
||||
forM_ rcvDeleted $ \(dbId, entId, p) -> flip catchAgentError (notify entId . RFERR) $ do
|
||||
forM_ rcvDeleted $ \(dbId, entId, p) -> flip catchAllErrors (notify entId . RFERR) $ do
|
||||
lift $ removePath =<< toFSFilePath p
|
||||
withStore' c (`deleteRcvFile'` dbId)
|
||||
deleteRcvFilesTmpPaths = do
|
||||
rcvTmpPaths <- withStore' c getCleanupRcvFilesTmpPaths
|
||||
forM_ rcvTmpPaths $ \(dbId, entId, p) -> flip catchAgentError (notify entId . RFERR) $ do
|
||||
forM_ rcvTmpPaths $ \(dbId, entId, p) -> flip catchAllErrors (notify entId . RFERR) $ do
|
||||
lift $ removePath =<< toFSFilePath p
|
||||
withStore' c (`updateRcvFileNoTmpPath` dbId)
|
||||
deleteSndFilesExpired = do
|
||||
sndFilesTTL <- asks $ sndFilesTTL . config
|
||||
sndExpired <- withStore' c (`getSndFilesExpired` sndFilesTTL)
|
||||
forM_ sndExpired $ \(dbId, entId, p) -> flip catchAgentError (notify entId . SFERR) $ do
|
||||
forM_ sndExpired $ \(dbId, entId, p) -> flip catchAllErrors (notify entId . SFERR) $ do
|
||||
lift . forM_ p $ removePath <=< toFSFilePath
|
||||
withStore' c (`deleteSndFile'` dbId)
|
||||
deleteSndFilesDeleted = do
|
||||
sndDeleted <- withStore' c getCleanupSndFilesDeleted
|
||||
forM_ sndDeleted $ \(dbId, entId, p) -> flip catchAgentError (notify entId . SFERR) $ do
|
||||
forM_ sndDeleted $ \(dbId, entId, p) -> flip catchAllErrors (notify entId . SFERR) $ do
|
||||
lift . forM_ p $ removePath <=< toFSFilePath
|
||||
withStore' c (`deleteSndFile'` dbId)
|
||||
deleteSndFilesPrefixPaths = do
|
||||
sndPrefixPaths <- withStore' c getCleanupSndFilesPrefixPaths
|
||||
forM_ sndPrefixPaths $ \(dbId, entId, p) -> flip catchAgentError (notify entId . SFERR) $ do
|
||||
forM_ sndPrefixPaths $ \(dbId, entId, p) -> flip catchAllErrors (notify entId . SFERR) $ do
|
||||
lift $ removePath =<< toFSFilePath p
|
||||
withStore' c (`updateSndFileNoPrefixPath` dbId)
|
||||
deleteExpiredReplicasForDeletion = do
|
||||
@@ -2652,10 +2652,10 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId
|
||||
where
|
||||
withRcvConn :: SMP.RecipientId -> (forall c. RcvQueue -> Connection c -> AM ()) -> AM' ()
|
||||
withRcvConn rId a = do
|
||||
tryAgentError' (withStore c $ \db -> getRcvConn db srv rId) >>= \case
|
||||
tryAllErrors' (withStore c $ \db -> getRcvConn db srv rId) >>= \case
|
||||
Left e -> notify' "" (ERR e)
|
||||
Right (rq@RcvQueue {connId}, SomeConn _ conn) ->
|
||||
tryAgentError' (a rq conn) >>= \case
|
||||
tryAllErrors' (a rq conn) >>= \case
|
||||
Left e -> notify' connId (ERR e)
|
||||
Right () -> pure ()
|
||||
processSubOk :: RcvQueue -> TVar [ConnId] -> AM ()
|
||||
@@ -2739,7 +2739,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId
|
||||
_ -> pure ()
|
||||
let encryptedMsgHash = C.sha256Hash encAgentMessage
|
||||
g <- asks random
|
||||
tryAgentError (agentClientMsg g encryptedMsgHash) >>= \case
|
||||
tryAllErrors (agentClientMsg g encryptedMsgHash) >>= \case
|
||||
Right (Just (msgId, msgMeta, aMessage, rcPrev)) -> do
|
||||
conn'' <- resetRatchetSync
|
||||
case aMessage of
|
||||
@@ -2848,7 +2848,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId
|
||||
ackDel :: InternalId -> AM ACKd
|
||||
ackDel aId = enqueueCmd (ICAckDel rId srvMsgId aId) $> ACKd
|
||||
handleNotifyAck :: AM ACKd -> AM ACKd
|
||||
handleNotifyAck m = m `catchAgentError` \e -> notify (ERR e) >> ack
|
||||
handleNotifyAck m = m `catchAllErrors` \e -> notify (ERR e) >> ack
|
||||
SMP.END ->
|
||||
atomically (ifM (activeClientSession c tSess sessId) (removeSubscription c connId $> True) (pure False))
|
||||
>>= notifyEnd
|
||||
@@ -3006,7 +3006,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId
|
||||
messagesRcvd :: NonEmpty AMessageReceipt -> MsgMeta -> Connection 'CDuplex -> AM ACKd
|
||||
messagesRcvd rcpts msgMeta@MsgMeta {broker = (srvMsgId, _)} _ = do
|
||||
logServer "<--" c srv rId $ "MSG <RCPT>:" <> logSecret' srvMsgId
|
||||
rs <- forM rcpts $ \rcpt -> clientReceipt rcpt `catchAgentError` \e -> notify (ERR e) $> Nothing
|
||||
rs <- forM rcpts $ \rcpt -> clientReceipt rcpt `catchAllErrors` \e -> notify (ERR e) $> Nothing
|
||||
case L.nonEmpty . catMaybes $ L.toList rs of
|
||||
Just rs' -> notify (RCVD msgMeta rs') $> ACKPending
|
||||
Nothing -> ack
|
||||
|
||||
@@ -130,6 +130,7 @@ module Simplex.Messaging.Agent.Client
|
||||
hasWorkToDo,
|
||||
hasWorkToDo',
|
||||
withWork,
|
||||
withWork_,
|
||||
withWorkItems,
|
||||
agentOperations,
|
||||
agentOperationBracket,
|
||||
@@ -371,12 +372,12 @@ data SMPConnectedClient = SMPConnectedClient
|
||||
|
||||
type ProxiedRelayVar = SessionVar (Either AgentErrorType ProxiedRelay)
|
||||
|
||||
getAgentWorker :: (Ord k, Show k) => String -> Bool -> AgentClient -> k -> TMap k Worker -> (Worker -> AM ()) -> AM' Worker
|
||||
getAgentWorker :: (Ord k, Show k, AnyError e, MonadUnliftIO m) => String -> Bool -> AgentClient -> k -> TMap k Worker -> (Worker -> ExceptT e m ()) -> m Worker
|
||||
getAgentWorker = getAgentWorker' id pure
|
||||
{-# INLINE getAgentWorker #-}
|
||||
|
||||
getAgentWorker' :: forall a k. (Ord k, Show k) => (a -> Worker) -> (Worker -> STM a) -> String -> Bool -> AgentClient -> k -> TMap k a -> (a -> AM ()) -> AM' a
|
||||
getAgentWorker' toW fromW name hasWork c key ws work = do
|
||||
getAgentWorker' :: forall a k e m. (Ord k, Show k, AnyError e, MonadUnliftIO m) => (a -> Worker) -> (Worker -> STM a) -> String -> Bool -> AgentClient -> k -> TMap k a -> (a -> ExceptT e m ()) -> m a
|
||||
getAgentWorker' toW fromW name hasWork c@AgentClient {agentEnv} key ws work = do
|
||||
atomically (getWorker >>= maybe createWorker whenExists) >>= \w -> runWorker w $> w
|
||||
where
|
||||
getWorker = TM.lookup key ws
|
||||
@@ -389,12 +390,12 @@ getAgentWorker' toW fromW name hasWork c key ws work = do
|
||||
| otherwise = pure w
|
||||
runWorker w = runWorkerAsync (toW w) runWork
|
||||
where
|
||||
runWork :: AM' ()
|
||||
runWork = tryAgentError' (work w) >>= restartOrDelete
|
||||
restartOrDelete :: Either AgentErrorType () -> AM' ()
|
||||
runWork :: m ()
|
||||
runWork = tryAllErrors' (work w) >>= restartOrDelete
|
||||
restartOrDelete :: Either e () -> m ()
|
||||
restartOrDelete e_ = do
|
||||
t <- liftIO getSystemTime
|
||||
maxRestarts <- asks $ maxWorkerRestartsPerMin . config
|
||||
let maxRestarts = maxWorkerRestartsPerMin $ config agentEnv
|
||||
-- worker may terminate because it was deleted from the map (getWorker returns Nothing), then it won't restart
|
||||
restart <- atomically $ getWorker >>= maybe (pure False) (shouldRestart e_ (toW w) t maxRestarts)
|
||||
when restart runWork
|
||||
@@ -431,7 +432,7 @@ newWorker c = do
|
||||
restarts <- newTVar $ RestartCount 0 0
|
||||
pure Worker {workerId, doWork, action, restarts}
|
||||
|
||||
runWorkerAsync :: Worker -> AM' () -> AM' ()
|
||||
runWorkerAsync :: MonadUnliftIO m => Worker -> m () -> m ()
|
||||
runWorkerAsync Worker {action} work =
|
||||
E.bracket
|
||||
(atomically $ takeTMVar action) -- get current action, locking to avoid race conditions
|
||||
@@ -665,7 +666,7 @@ getSMPProxyClient c@AgentClient {active, smpClients, smpProxiedRelays, workerSeq
|
||||
pure (clnt, sess)
|
||||
newProxiedRelay :: SMPConnectedClient -> Maybe SMP.BasicAuth -> ProxiedRelayVar -> AM (Either AgentErrorType ProxiedRelay)
|
||||
newProxiedRelay (SMPConnectedClient smp prs) proxyAuth rv =
|
||||
tryAgentError (liftClient SMP (clientServer smp) $ connectSMPProxiedRelay smp nm destSrv proxyAuth) >>= \case
|
||||
tryAllErrors (liftClient SMP (clientServer smp) $ connectSMPProxiedRelay smp nm destSrv proxyAuth) >>= \case
|
||||
Right sess -> do
|
||||
atomically $ putTMVar (sessionVar rv) (Right sess)
|
||||
pure $ Right sess
|
||||
@@ -688,7 +689,7 @@ getSMPProxyClient c@AgentClient {active, smpClients, smpProxiedRelays, workerSeq
|
||||
smpConnectClient :: AgentClient -> NetworkRequestMode -> SMPTransportSession -> TMap SMPServer ProxiedRelayVar -> SMPClientVar -> AM SMPConnectedClient
|
||||
smpConnectClient c@AgentClient {smpClients, msgQ, proxySessTs, presetDomains} nm tSess@(_, srv, _) prs v =
|
||||
newProtocolClient c tSess smpClients connectClient v
|
||||
`catchAgentError` \e -> lift (resubscribeSMPSession c tSess) >> throwE e
|
||||
`catchAllErrors` \e -> lift (resubscribeSMPSession c tSess) >> throwE e
|
||||
where
|
||||
connectClient :: SMPClientVar -> AM SMPConnectedClient
|
||||
connectClient v' = do
|
||||
@@ -866,7 +867,7 @@ newProtocolClient ::
|
||||
ClientVar msg ->
|
||||
AM (Client msg)
|
||||
newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient v =
|
||||
tryAgentError (connectClient v) >>= \case
|
||||
tryAllErrors (connectClient v) >>= \case
|
||||
Right client -> do
|
||||
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv <> " (user " <> bshow userId <> maybe "" (" for entity " <>) entityId_ <> ")"
|
||||
atomically $ putTMVar (sessionVar v) (Right client)
|
||||
@@ -1027,7 +1028,7 @@ getMapLock locks key = TM.lookup key locks >>= maybe newLock pure
|
||||
withClient_ :: forall a v err msg. ProtocolServerClient v err msg => AgentClient -> NetworkRequestMode -> TransportSession msg -> (Client msg -> AM a) -> AM a
|
||||
withClient_ c nm tSess@(_, srv, _) action = do
|
||||
cl <- getProtocolServerClient c nm tSess
|
||||
action cl `catchAgentError` logServerError
|
||||
action cl `catchAllErrors` logServerError
|
||||
where
|
||||
logServerError :: AgentErrorType -> AM a
|
||||
logServerError e = do
|
||||
@@ -1040,7 +1041,7 @@ withProxySession c nm proxySrv_ destSess@(_, destSrv, _) entId cmdStr action = d
|
||||
logServer ("--> " <> proxySrv cl <> " >") c destSrv entId cmdStr
|
||||
case sess_ of
|
||||
Right sess -> do
|
||||
r <- action (cl, sess) `catchAgentError` logServerError cl
|
||||
r <- action (cl, sess) `catchAllErrors` logServerError cl
|
||||
logServer ("<-- " <> proxySrv cl <> " <") c destSrv entId "OK"
|
||||
pure r
|
||||
Left e -> logServerError cl e
|
||||
@@ -1117,7 +1118,7 @@ sendOrProxySMPCommand c nm userId destSrv@ProtocolServer {host = destHosts} conn
|
||||
unknownServer = liftIO $ maybe True (\srvs -> all (`S.notMember` knownHosts srvs) destHosts) <$> TM.lookupIO userId (smpServers c)
|
||||
sendViaProxy :: Maybe SMPServerWithAuth -> SMPTransportSession -> AM (Maybe SMPServer, a)
|
||||
sendViaProxy proxySrv_ destSess@(_, _, connId_) = do
|
||||
r <- tryAgentError . withProxySession c nm proxySrv_ destSess entId ("PFWD " <> cmdStr) $ \(SMPConnectedClient smp _, proxySess@ProxiedRelay {prBasicAuth}) -> do
|
||||
r <- tryAllErrors . withProxySession c nm proxySrv_ destSess entId ("PFWD " <> cmdStr) $ \(SMPConnectedClient smp _, proxySess@ProxiedRelay {prBasicAuth}) -> do
|
||||
r' <- liftClient SMP (clientServer smp) $ sendCmdViaProxy smp proxySess
|
||||
let proxySrv = protocolClientServer' smp
|
||||
case r' of
|
||||
@@ -1164,7 +1165,7 @@ sendOrProxySMPCommand c nm userId destSrv@ProtocolServer {host = destHosts} conn
|
||||
| otherwise -> throwE e
|
||||
sendDirectly tSess =
|
||||
withLogClient_ c nm tSess (unEntityId entId) ("SEND " <> cmdStr) $ \(SMPConnectedClient smp _) -> do
|
||||
tryAgentError (liftClient SMP (clientServer smp) $ sendCmdDirectly smp) >>= \case
|
||||
tryAllErrors (liftClient SMP (clientServer smp) $ sendCmdDirectly smp) >>= \case
|
||||
Right r -> r <$ atomically (incSMPServerStat c userId destSrv sentDirect)
|
||||
Left e -> throwE e
|
||||
|
||||
@@ -1562,7 +1563,7 @@ sendTSessionBatches statCmd toRQ action c nm qs =
|
||||
in M.alter (Just . maybe [q] (q <|)) tSess m
|
||||
sendClientBatch :: (SMPTransportSession, NonEmpty q) -> AM' (BatchResponses q AgentErrorType r)
|
||||
sendClientBatch (tSess@(_, srv, _), qs') =
|
||||
tryAgentError' (getSMPServerClient c nm tSess) >>= \case
|
||||
tryAllErrors' (getSMPServerClient c nm tSess) >>= \case
|
||||
Left e -> pure $ L.map (,Left e) qs'
|
||||
Right (SMPConnectedClient smp _) -> liftIO $ do
|
||||
logServer' "-->" c srv (bshow (length qs') <> " queues") statCmd
|
||||
@@ -1867,7 +1868,7 @@ withNtfBatch ::
|
||||
AM' (NonEmpty (Either AgentErrorType r))
|
||||
withNtfBatch cmdStr action c NtfToken {ntfServer, ntfPrivKey} subs = do
|
||||
let tSess = (0, ntfServer, Nothing)
|
||||
tryAgentError' (getNtfServerClient c NRMBackground tSess) >>= \case
|
||||
tryAllErrors' (getNtfServerClient c NRMBackground tSess) >>= \case
|
||||
Left e -> pure $ L.map (\_ -> Left e) subs
|
||||
Right ntf -> liftIO $ do
|
||||
logServer' "-->" c ntfServer (bshow (length subs) <> " subscriptions") cmdStr
|
||||
@@ -1968,8 +1969,12 @@ waitForWork = void . atomically . readTMVar
|
||||
{-# INLINE waitForWork #-}
|
||||
|
||||
withWork :: AgentClient -> TMVar () -> (DB.Connection -> IO (Either StoreError (Maybe a))) -> (a -> AM ()) -> AM ()
|
||||
withWork c doWork getWork action =
|
||||
withStore' c getWork >>= \case
|
||||
withWork c doWork = withWork_ c doWork . withStore' c
|
||||
{-# INLINE withWork #-}
|
||||
|
||||
withWork_ :: MonadIO m => AgentClient -> TMVar () -> ExceptT e m (Either StoreError (Maybe a)) -> (a -> ExceptT e m ()) -> ExceptT e m ()
|
||||
withWork_ c doWork getWork action =
|
||||
getWork >>= \case
|
||||
Right (Just r) -> action r
|
||||
Right Nothing -> noWork
|
||||
-- worker is stopped here (noWork) because the next iteration is likely to produce the same result
|
||||
@@ -1979,9 +1984,9 @@ withWork c doWork getWork action =
|
||||
noWork = liftIO $ noWorkToDo doWork
|
||||
notifyErr err e = atomically $ writeTBQueue (subQ c) ("", "", AEvt SAEConn $ ERR $ err $ show e)
|
||||
|
||||
withWorkItems :: AgentClient -> TMVar () -> (DB.Connection -> IO (Either StoreError [Either StoreError a])) -> (NonEmpty a -> AM ()) -> AM ()
|
||||
withWorkItems :: MonadIO m => AgentClient -> TMVar () -> ExceptT e m (Either StoreError [Either StoreError a]) -> (NonEmpty a -> ExceptT e m ()) -> ExceptT e m ()
|
||||
withWorkItems c doWork getWork action = do
|
||||
withStore' c getWork >>= \case
|
||||
getWork >>= \case
|
||||
Right [] -> noWork
|
||||
Right rs -> do
|
||||
let (errs, items) = partitionEithers rs
|
||||
|
||||
@@ -27,11 +27,6 @@ module Simplex.Messaging.Agent.Env.SQLite
|
||||
serverHosts,
|
||||
defaultAgentConfig,
|
||||
defaultReconnectInterval,
|
||||
tryAgentError,
|
||||
tryAgentError',
|
||||
catchAgentError,
|
||||
catchAgentError',
|
||||
agentFinally,
|
||||
Env (..),
|
||||
newSMPAgentEnv,
|
||||
createAgentStore,
|
||||
@@ -45,7 +40,6 @@ module Simplex.Messaging.Agent.Env.SQLite
|
||||
where
|
||||
|
||||
import Control.Concurrent (ThreadId)
|
||||
import Control.Exception (BlockedIndefinitelyOnSTM (..), SomeException, fromException)
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Unlift
|
||||
import Control.Monad.Reader
|
||||
@@ -83,7 +77,6 @@ import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Transport (SMPVersion)
|
||||
import Simplex.Messaging.Transport.Client (TransportHost)
|
||||
import Simplex.Messaging.Util (allFinally, catchAllErrors, catchAllErrors', tryAllErrors, tryAllErrors')
|
||||
import System.Mem.Weak (Weak)
|
||||
import System.Random (StdGen, newStdGen)
|
||||
import UnliftIO.STM
|
||||
@@ -312,33 +305,6 @@ newXFTPAgent = do
|
||||
xftpDelWorkers <- TM.emptyIO
|
||||
pure XFTPAgent {xftpWorkDir, xftpRcvWorkers, xftpSndWorkers, xftpDelWorkers}
|
||||
|
||||
tryAgentError :: AM a -> AM (Either AgentErrorType a)
|
||||
tryAgentError = tryAllErrors mkInternal
|
||||
{-# INLINE tryAgentError #-}
|
||||
|
||||
-- unlike runExceptT, this ensures we catch IO exceptions as well
|
||||
tryAgentError' :: AM a -> AM' (Either AgentErrorType a)
|
||||
tryAgentError' = tryAllErrors' mkInternal
|
||||
{-# INLINE tryAgentError' #-}
|
||||
|
||||
catchAgentError :: AM a -> (AgentErrorType -> AM a) -> AM a
|
||||
catchAgentError = catchAllErrors mkInternal
|
||||
{-# INLINE catchAgentError #-}
|
||||
|
||||
catchAgentError' :: AM a -> (AgentErrorType -> AM' a) -> AM' a
|
||||
catchAgentError' = catchAllErrors' mkInternal
|
||||
{-# INLINE catchAgentError' #-}
|
||||
|
||||
agentFinally :: AM a -> AM b -> AM a
|
||||
agentFinally = allFinally mkInternal
|
||||
{-# INLINE agentFinally #-}
|
||||
|
||||
mkInternal :: SomeException -> AgentErrorType
|
||||
mkInternal e = case fromException e of
|
||||
Just BlockedIndefinitelyOnSTM -> CRITICAL True "Thread blocked indefinitely in STM transaction"
|
||||
_ -> INTERNAL $ show e
|
||||
{-# INLINE mkInternal #-}
|
||||
|
||||
data Worker = Worker
|
||||
{ workerId :: Int,
|
||||
doWork :: TMVar (),
|
||||
|
||||
@@ -52,7 +52,7 @@ import Simplex.Messaging.Notifications.Protocol
|
||||
import Simplex.Messaging.Notifications.Types
|
||||
import Simplex.Messaging.Protocol (NtfServer, sameSrvAddr)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Util (diffToMicroseconds, threadDelay', tshow, whenM)
|
||||
import Simplex.Messaging.Util (catchAllErrors, diffToMicroseconds, threadDelay', tryAllErrors, tshow, whenM)
|
||||
import System.Random (randomR)
|
||||
import UnliftIO
|
||||
import UnliftIO.Concurrent (forkIO)
|
||||
@@ -217,7 +217,7 @@ runNtfWorker c srv Worker {doWork} =
|
||||
runNtfOperation :: AM ()
|
||||
runNtfOperation = do
|
||||
ntfBatchSize <- asks $ ntfBatchSize . config
|
||||
withWorkItems c doWork (\db -> getNextNtfSubNTFActions db srv ntfBatchSize) $ \nextSubs -> do
|
||||
withWorkItems c doWork (withStore' c $ \db -> getNextNtfSubNTFActions db srv ntfBatchSize) $ \nextSubs -> do
|
||||
logInfo $ "runNtfWorker - length nextSubs = " <> tshow (length nextSubs)
|
||||
currTs <- liftIO getCurrentTime
|
||||
let (creates, checks, deletes, rotates) = splitActions currTs nextSubs
|
||||
@@ -357,7 +357,7 @@ runNtfWorker c srv Worker {doWork} =
|
||||
runCatching :: (NtfSubscription -> AM (Maybe NtfSubscription)) -> NtfSubscription -> AM' (Maybe NtfSubscription)
|
||||
runCatching action sub@NtfSubscription {connId} =
|
||||
fromRight Nothing
|
||||
<$> runExceptT (action sub `catchAgentError` \e -> workerInternalError c connId (show e) $> Nothing)
|
||||
<$> runExceptT (action sub `catchAllErrors` \e -> workerInternalError c connId (show e) $> Nothing)
|
||||
-- deleteNtfSub is only used in NSADelete and NSARotate, so also deprecated
|
||||
deleteNtfSub :: NtfSubscription -> AM () -> AM (Maybe NtfSubscription)
|
||||
deleteNtfSub sub@NtfSubscription {userId, ntfSubId} continue = case ntfSubId of
|
||||
@@ -365,7 +365,7 @@ runNtfWorker c srv Worker {doWork} =
|
||||
lift getNtfToken >>= \case
|
||||
Just tkn@NtfToken {ntfServer} -> do
|
||||
atomically $ incNtfServerStat c userId ntfServer ntfDelAttempts
|
||||
tryAgentError (agentNtfDeleteSubscription c nSubId tkn) >>= \case
|
||||
tryAllErrors (agentNtfDeleteSubscription c nSubId tkn) >>= \case
|
||||
Right _ -> do
|
||||
atomically $ incNtfServerStat c userId ntfServer ntfDeleted
|
||||
continue'
|
||||
@@ -385,7 +385,7 @@ runNtfSMPWorker c srv Worker {doWork} = forever $ do
|
||||
runNtfSMPOperation :: AM ()
|
||||
runNtfSMPOperation = do
|
||||
ntfBatchSize <- asks $ ntfBatchSize . config
|
||||
withWorkItems c doWork (\db -> getNextNtfSubSMPActions db srv ntfBatchSize) $ \nextSubs -> do
|
||||
withWorkItems c doWork (withStore' c $ \db -> getNextNtfSubSMPActions db srv ntfBatchSize) $ \nextSubs -> do
|
||||
logInfo $ "runNtfSMPWorker - length nextSubs = " <> tshow (length nextSubs)
|
||||
let (creates, deletes) = splitActions nextSubs
|
||||
retrySubActions c creates createNotifierKeys
|
||||
@@ -567,7 +567,7 @@ runNtfTknDelWorker c srv Worker {doWork} =
|
||||
withRetryInterval ri $ \_ loop -> do
|
||||
liftIO $ waitWhileSuspended c
|
||||
liftIO $ waitForUserNetwork c
|
||||
processTknToDelete nextTknToDelete `catchAgentError` retryTmpError loop nextTknToDelete
|
||||
processTknToDelete nextTknToDelete `catchAllErrors` retryTmpError loop nextTknToDelete
|
||||
retryTmpError :: AM () -> NtfTokenToDelete -> AgentErrorType -> AM ()
|
||||
retryTmpError loop (tknDbId, _, _) e = do
|
||||
logError $ "ntf tkn del error: " <> tshow e
|
||||
|
||||
@@ -173,6 +173,7 @@ module Simplex.Messaging.Agent.Protocol
|
||||
where
|
||||
|
||||
import Control.Applicative (optional, (<|>))
|
||||
import Control.Exception (BlockedIndefinitelyOnSTM (..), fromException)
|
||||
import Data.Aeson (FromJSON (..), ToJSON (..), Value (..), (.:), (.:?))
|
||||
import qualified Data.Aeson as J'
|
||||
import qualified Data.Aeson.Encoding as JE
|
||||
@@ -1866,6 +1867,12 @@ data AgentErrorType
|
||||
INACTIVE
|
||||
deriving (Eq, Show, Exception)
|
||||
|
||||
instance AnyError AgentErrorType where
|
||||
fromSomeException e = case fromException e of
|
||||
Just BlockedIndefinitelyOnSTM -> CRITICAL True "Thread blocked indefinitely in STM transaction"
|
||||
_ -> INTERNAL $ show e
|
||||
{-# INLINE fromSomeException #-}
|
||||
|
||||
-- | SMP agent protocol command or response error.
|
||||
data CommandErrorType
|
||||
= -- | command is prohibited in this context
|
||||
|
||||
@@ -29,6 +29,7 @@ import Data.Time (UTCTime)
|
||||
import Data.Type.Equality
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Agent.RetryInterval (RI2State)
|
||||
import Simplex.Messaging.Agent.Store.Entity
|
||||
import Simplex.Messaging.Agent.Store.Common
|
||||
import Simplex.Messaging.Agent.Store.Interface (createDBStore)
|
||||
import Simplex.Messaging.Agent.Store.Migrations.App (appMigrations)
|
||||
@@ -52,7 +53,7 @@ import Simplex.Messaging.Protocol
|
||||
VersionSMPC,
|
||||
)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Agent.Store.Entity
|
||||
import Simplex.Messaging.Util (AnyError (..), bshow)
|
||||
|
||||
createStore :: DBOpts -> MigrationConfirmation -> IO (Either MigrationError DBStore)
|
||||
createStore dbOpts = createDBStore dbOpts appMigrations
|
||||
@@ -696,3 +697,6 @@ data StoreError
|
||||
| -- | Servers stats not found.
|
||||
SEServersStatsNotFound
|
||||
deriving (Eq, Show, Exception)
|
||||
|
||||
instance AnyError StoreError where
|
||||
fromSomeException = SEInternal . bshow
|
||||
|
||||
@@ -975,13 +975,10 @@ getWorkItems itemName getIds getItem markFailed =
|
||||
runExceptT $ handleWrkErr itemName "getIds" getIds >>= mapM (tryE . tryGetItem itemName getItem markFailed)
|
||||
|
||||
tryGetItem :: Show i => ByteString -> (i -> IO (Either StoreError a)) -> (i -> IO ()) -> i -> ExceptT StoreError IO a
|
||||
tryGetItem itemName getItem markFailed itemId = ExceptT (getItem itemId) `catchStoreError` \e -> mark >> throwE e
|
||||
tryGetItem itemName getItem markFailed itemId = ExceptT (getItem itemId) `catchAllErrors` \e -> mark >> throwE e
|
||||
where
|
||||
mark = handleWrkErr itemName ("markFailed ID " <> bshow itemId) $ markFailed itemId
|
||||
|
||||
catchStoreError :: ExceptT StoreError IO a -> (StoreError -> ExceptT StoreError IO a) -> ExceptT StoreError IO a
|
||||
catchStoreError = catchAllErrors (SEInternal . bshow)
|
||||
|
||||
-- Errors caught by this function will suspend worker as if there is no more work,
|
||||
handleWrkErr :: ByteString -> ByteString -> IO a -> ExceptT StoreError IO a
|
||||
handleWrkErr itemName opName action = ExceptT $ first mkError <$> E.try action
|
||||
|
||||
@@ -171,28 +171,30 @@ catchAll_ :: IO a -> IO a -> IO a
|
||||
catchAll_ a = catchAll a . const
|
||||
{-# INLINE catchAll_ #-}
|
||||
|
||||
tryAllErrors :: MonadUnliftIO m => (E.SomeException -> e) -> ExceptT e m a -> ExceptT e m (Either e a)
|
||||
tryAllErrors err action = ExceptT $ Right <$> runExceptT action `UE.catch` (pure . Left . err)
|
||||
class Show e => AnyError e where fromSomeException :: E.SomeException -> e
|
||||
|
||||
tryAllErrors :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> ExceptT e m (Either e a)
|
||||
tryAllErrors action = ExceptT $ Right <$> runExceptT action `UE.catch` (pure . Left . fromSomeException)
|
||||
{-# INLINE tryAllErrors #-}
|
||||
|
||||
tryAllErrors' :: MonadUnliftIO m => (E.SomeException -> e) -> ExceptT e m a -> m (Either e a)
|
||||
tryAllErrors' err action = runExceptT action `UE.catch` (pure . Left . err)
|
||||
tryAllErrors' :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> m (Either e a)
|
||||
tryAllErrors' action = runExceptT action `UE.catch` (pure . Left . fromSomeException)
|
||||
{-# INLINE tryAllErrors' #-}
|
||||
|
||||
catchAllErrors :: MonadUnliftIO m => (E.SomeException -> e) -> ExceptT e m a -> (e -> ExceptT e m a) -> ExceptT e m a
|
||||
catchAllErrors err action handler = tryAllErrors err action >>= either handler pure
|
||||
catchAllErrors :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> (e -> ExceptT e m a) -> ExceptT e m a
|
||||
catchAllErrors action handler = tryAllErrors action >>= either handler pure
|
||||
{-# INLINE catchAllErrors #-}
|
||||
|
||||
catchAllErrors' :: MonadUnliftIO m => (E.SomeException -> e) -> ExceptT e m a -> (e -> m a) -> m a
|
||||
catchAllErrors' err action handler = tryAllErrors' err action >>= either handler pure
|
||||
catchAllErrors' :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> (e -> m a) -> m a
|
||||
catchAllErrors' action handler = tryAllErrors' action >>= either handler pure
|
||||
{-# INLINE catchAllErrors' #-}
|
||||
|
||||
catchThrow :: MonadUnliftIO m => ExceptT e m a -> (E.SomeException -> e) -> ExceptT e m a
|
||||
catchThrow action err = catchAllErrors err action throwE
|
||||
catchThrow :: MonadUnliftIO m => ExceptT e m a -> (SomeException -> e) -> ExceptT e m a
|
||||
action `catchThrow` err = ExceptT $ runExceptT action `UE.catch` (pure . Left . err)
|
||||
{-# INLINE catchThrow #-}
|
||||
|
||||
allFinally :: MonadUnliftIO m => (E.SomeException -> e) -> ExceptT e m a -> ExceptT e m b -> ExceptT e m a
|
||||
allFinally err action final = tryAllErrors err action >>= \r -> final >> except r
|
||||
allFinally :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> ExceptT e m b -> ExceptT e m a
|
||||
allFinally action final = tryAllErrors action >>= \r -> final >> except r
|
||||
{-# INLINE allFinally #-}
|
||||
|
||||
eitherToMaybe :: Either a b -> Maybe b
|
||||
|
||||
@@ -306,14 +306,8 @@ connectRCCtrl_ drg pairing'@RCCtrlPairing {caKey, caCert} inv@RCInvitation {ca,
|
||||
atomically $ takeTMVar endSession
|
||||
logDebug "Session ended"
|
||||
|
||||
catchRCError :: ExceptT RCErrorType IO a -> (RCErrorType -> ExceptT RCErrorType IO a) -> ExceptT RCErrorType IO a
|
||||
catchRCError = catchAllErrors $ \e -> case fromException e of
|
||||
Just (TLS.Terminated _ _ (TLS.Error_Protocol _ TLS.UnknownCa)) -> RCEIdentity
|
||||
_ -> RCEException $ show e
|
||||
{-# INLINE catchRCError #-}
|
||||
|
||||
putRCError :: ExceptT RCErrorType IO a -> TMVar (Either RCErrorType b) -> ExceptT RCErrorType IO a
|
||||
a `putRCError` r = a `catchRCError` \e -> atomically (tryPutTMVar r $ Left e) >> throwE e
|
||||
a `putRCError` r = a `catchAllErrors` \e -> atomically (tryPutTMVar r $ Left e) >> throwE e
|
||||
|
||||
sendRCPacket :: Encoding a => TLS p -> a -> ExceptT RCErrorType IO ()
|
||||
sendRCPacket tls pkt = do
|
||||
@@ -395,7 +389,7 @@ discoverRCCtrl subscribers pairings =
|
||||
pure r
|
||||
where
|
||||
loop :: ExceptT RCErrorType IO a -> ExceptT RCErrorType IO a
|
||||
loop action = action `catchRCError` \e -> logError (tshow e) >> loop action
|
||||
loop action = action `catchAllErrors` \e -> logError (tshow e) >> loop action
|
||||
|
||||
findRCCtrlPairing :: NonEmpty RCCtrlPairing -> RCEncInvitation -> ExceptT RCErrorType IO (RCCtrlPairing, RCVerifiedInvitation)
|
||||
findRCCtrlPairing pairings RCEncInvitation {dhPubKey, nonce, encInvitation} = do
|
||||
|
||||
@@ -19,6 +19,7 @@ import qualified Data.Text as T
|
||||
import Data.Text.Encoding (encodeUtf8)
|
||||
import Data.Word (Word16)
|
||||
import qualified Data.X509 as X
|
||||
import qualified Network.TLS as TLS
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Crypto.SNTRUP761.Bindings
|
||||
import Simplex.Messaging.Encoding
|
||||
@@ -26,7 +27,7 @@ import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Parsers (defaultJSON, dropPrefix, sumTypeJSON)
|
||||
import Simplex.Messaging.Transport (TLS, TSbChainKeys, TransportPeer (..))
|
||||
import Simplex.Messaging.Transport.Client (TransportHost)
|
||||
import Simplex.Messaging.Util (safeDecodeUtf8)
|
||||
import Simplex.Messaging.Util (AnyError (..), safeDecodeUtf8)
|
||||
import Simplex.Messaging.Version (VersionRange, VersionScope, mkVersionRange)
|
||||
import Simplex.Messaging.Version.Internal
|
||||
import UnliftIO
|
||||
@@ -50,6 +51,12 @@ data RCErrorType
|
||||
| RCESyntax {syntaxErr :: String}
|
||||
deriving (Eq, Show, Exception)
|
||||
|
||||
instance AnyError RCErrorType where
|
||||
fromSomeException e = case fromException e of
|
||||
Just (TLS.Terminated _ _ (TLS.Error_Protocol _ TLS.UnknownCa)) -> RCEIdentity
|
||||
_ -> RCEException $ show e
|
||||
{-# INLINE fromSomeException #-}
|
||||
|
||||
instance StrEncoding RCErrorType where
|
||||
strEncode = \case
|
||||
RCEInternal err -> "INTERNAL" <> text err
|
||||
|
||||
@@ -45,56 +45,32 @@ utilTests = do
|
||||
runExceptT (throwTestException `catchError` handleCatch) `shouldThrow` (\(e :: IOError) -> show e == "user error (error)")
|
||||
describe "tryAllErrors" $ do
|
||||
it "should return ExceptT error as Left" $
|
||||
runExceptT (tryAllErrors testErr throwTestError) `shouldReturn` Right (Left (TestError "error"))
|
||||
runExceptT (tryAllErrors throwTestError) `shouldReturn` Right (Left (TestError "error"))
|
||||
it "should return SomeException as Left" $
|
||||
runExceptT (tryAllErrors testErr throwTestException) `shouldReturn` Right (Left (TestException "user error (error)"))
|
||||
runExceptT (tryAllErrors throwTestException) `shouldReturn` Right (Left (TestException "user error (error)"))
|
||||
it "should return no errors as Right" $
|
||||
runExceptT (tryAllErrors testErr noErrors) `shouldReturn` Right (Right "no errors")
|
||||
describe "tryAllErrors specialized as tryTestError" $ do
|
||||
let tryTestError = tryAllErrors testErr
|
||||
it "should return ExceptT error as Left" $
|
||||
runExceptT (tryTestError throwTestError) `shouldReturn` Right (Left (TestError "error"))
|
||||
it "should return SomeException as Left" $
|
||||
runExceptT (tryTestError throwTestException) `shouldReturn` Right (Left (TestException "user error (error)"))
|
||||
it "should return no errors as Right" $
|
||||
runExceptT (tryTestError noErrors) `shouldReturn` Right (Right "no errors")
|
||||
runExceptT (tryAllErrors noErrors) `shouldReturn` Right (Right "no errors")
|
||||
describe "catchAllErrors" $ do
|
||||
it "should catch ExceptT error" $
|
||||
runExceptT (catchAllErrors testErr throwTestError handleCatch) `shouldReturn` Right "caught TestError \"error\""
|
||||
runExceptT (throwTestError `catchAllErrors` handleCatch) `shouldReturn` Right "caught TestError \"error\""
|
||||
it "should catch SomeException" $
|
||||
runExceptT (catchAllErrors testErr throwTestException handleCatch) `shouldReturn` Right "caught TestException \"user error (error)\""
|
||||
runExceptT (throwTestException `catchAllErrors` handleCatch) `shouldReturn` Right "caught TestException \"user error (error)\""
|
||||
it "should not throw if there are no errors" $
|
||||
runExceptT (catchAllErrors testErr noErrors throwError) `shouldReturn` Right "no errors"
|
||||
describe "catchAllErrors specialized as catchTestError" $ do
|
||||
let catchTestError = catchAllErrors testErr
|
||||
it "should catch ExceptT error" $
|
||||
runExceptT (throwTestError `catchTestError` handleCatch) `shouldReturn` Right "caught TestError \"error\""
|
||||
it "should catch SomeException" $
|
||||
runExceptT (throwTestException `catchTestError` handleCatch) `shouldReturn` Right "caught TestException \"user error (error)\""
|
||||
it "should not throw if there are no errors" $
|
||||
runExceptT (noErrors `catchTestError` throwError) `shouldReturn` Right "no errors"
|
||||
runExceptT (noErrors `catchAllErrors` throwError) `shouldReturn` Right "no errors"
|
||||
describe "catchThrow" $ do
|
||||
it "should re-throw ExceptT error" $
|
||||
runExceptT (throwTestError `catchThrow` testErr) `shouldReturn` Left (TestError "error")
|
||||
runExceptT (throwTestError `catchThrow` fromSomeException) `shouldReturn` Left (TestError "error")
|
||||
it "should catch SomeException and throw as ExceptT error" $
|
||||
runExceptT (throwTestException `catchThrow` testErr) `shouldReturn` Left (TestException "user error (error)")
|
||||
runExceptT (throwTestException `catchThrow` fromSomeException) `shouldReturn` Left (TestException "user error (error)")
|
||||
it "should not throw if there are no exceptions" $
|
||||
runExceptT (noErrors `catchThrow` testErr) `shouldReturn` Right "no errors"
|
||||
runExceptT (noErrors `catchThrow` fromSomeException) `shouldReturn` Right "no errors"
|
||||
describe "allFinally should run final action" $ do
|
||||
it "then throw ExceptT error" $ withFinal $ \final ->
|
||||
runExceptT (allFinally testErr throwTestError final) `shouldReturn` Left (TestError "error")
|
||||
runExceptT (throwTestError `allFinally` final) `shouldReturn` Left (TestError "error")
|
||||
it "then throw SomeException as ExceptT error" $ withFinal $ \final ->
|
||||
runExceptT (allFinally testErr throwTestException final) `shouldReturn` Left (TestException "user error (error)")
|
||||
runExceptT (throwTestException `allFinally` final) `shouldReturn` Left (TestException "user error (error)")
|
||||
it "and should not throw if there are no exceptions" $ withFinal $ \final ->
|
||||
runExceptT (allFinally testErr noErrors final) `shouldReturn` Right "no errors"
|
||||
describe "allFinally specialized as testFinally should run final action" $ do
|
||||
let testFinally = allFinally testErr
|
||||
it "then throw ExceptT error" $ withFinal $ \final ->
|
||||
runExceptT (throwTestError `testFinally` final) `shouldReturn` Left (TestError "error")
|
||||
it "then throw SomeException as ExceptT error" $ withFinal $ \final ->
|
||||
runExceptT (throwTestException `testFinally` final) `shouldReturn` Left (TestException "user error (error)")
|
||||
it "and should not throw if there are no exceptions" $ withFinal $ \final ->
|
||||
runExceptT (noErrors `testFinally` final) `shouldReturn` Right "no errors"
|
||||
runExceptT (noErrors `allFinally` final) `shouldReturn` Right "no errors"
|
||||
where
|
||||
throwTestError :: ExceptT TestError IO String
|
||||
throwTestError = throwError $ TestError "error"
|
||||
@@ -102,8 +78,6 @@ utilTests = do
|
||||
throwTestException = liftIO $ throwIO $ userError "error"
|
||||
noErrors :: ExceptT TestError IO String
|
||||
noErrors = pure "no errors"
|
||||
testErr :: SomeException -> TestError
|
||||
testErr = TestException . show
|
||||
handleCatch :: TestError -> ExceptT TestError IO String
|
||||
handleCatch e = pure $ "caught " <> show e
|
||||
handleException :: SomeException -> ExceptT TestError IO String
|
||||
@@ -119,3 +93,6 @@ data TestError = TestError String | TestException String
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance Exception TestError
|
||||
|
||||
instance AnyError TestError where
|
||||
fromSomeException = TestException . show
|
||||
|
||||
Reference in New Issue
Block a user