agent: make agent workers usable from other contexts (#1614)

This commit is contained in:
Evgeny
2025-08-29 08:33:55 +01:00
committed by GitHub
parent a2d777bda0
commit beafac1f73
12 changed files with 119 additions and 160 deletions
+12 -12
View File
@@ -75,7 +75,7 @@ import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String (strDecode, strEncode)
import Simplex.Messaging.Protocol (ProtocolServer, ProtocolType (..), XFTPServer)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Util (catchAll_, liftError, tshow, unlessM, whenM)
import Simplex.Messaging.Util (allFinally, catchAll_, catchAllErrors, liftError, tshow, unlessM, whenM)
import System.FilePath (takeFileName, (</>))
import UnliftIO
import UnliftIO.Directory
@@ -198,10 +198,10 @@ runXFTPRcvWorker c srv Worker {doWork} = do
liftIO $ waitForUserNetwork c
atomically $ incXFTPServerStat c userId srv downloadAttempts
downloadFileChunk fc replica approvedRelays
`catchAgentError` \e -> retryOnError "XFTP rcv worker" (retryLoop loop e delay') (retryDone e) e
`catchAllErrors` \e -> retryOnError "XFTP rcv worker" (retryLoop loop e delay') (retryDone e) e
where
retryLoop loop e replicaDelay = do
flip catchAgentError (\_ -> pure ()) $ do
flip catchAllErrors (\_ -> pure ()) $ do
when (serverHostError e) $ notify c (fromMaybe rcvFileEntityId redirectEntityId_) (RFWARN e)
liftIO $ closeXFTPServerClient c userId server digest
withStore' c $ \db -> updateRcvChunkReplicaDelay db rcvChunkReplicaId replicaDelay
@@ -280,7 +280,7 @@ runXFTPRcvLocalWorker c Worker {doWork} = do
runXFTPOperation AgentConfig {rcvFilesTTL} =
withWork c doWork (`getNextRcvFileToDecrypt` rcvFilesTTL) $
\f@RcvFile {rcvFileId, rcvFileEntityId, tmpPath, redirect} ->
decryptFile f `catchAgentError` rcvWorkerInternalError c rcvFileId rcvFileEntityId (redirectEntityId <$> redirect) tmpPath
decryptFile f `catchAllErrors` rcvWorkerInternalError c rcvFileId rcvFileEntityId (redirectEntityId <$> redirect) tmpPath
decryptFile :: RcvFile -> AM ()
decryptFile RcvFile {rcvFileId, rcvFileEntityId, size, digest, key, nonce, tmpPath, saveFile, status, chunks, redirect} = do
let CryptoFile savePath cfArgs = saveFile
@@ -307,7 +307,7 @@ runXFTPRcvLocalWorker c Worker {doWork} = do
liftIO $ waitUntilForeground c
withStore' c (`updateRcvFileComplete` rcvFileId)
-- proceed with redirect
yaml <- liftError (FILE . FILE_IO . show) (CF.readFile $ CryptoFile fsSavePath cfArgs) `agentFinally` (lift $ toFSFilePath fsSavePath >>= removePath)
yaml <- liftError (FILE . FILE_IO . show) (CF.readFile $ CryptoFile fsSavePath cfArgs) `allFinally` (lift $ toFSFilePath fsSavePath >>= removePath)
next@FileDescription {chunks = nextChunks} <- case strDecode (LB.toStrict yaml) of
-- TODO switch to another error constructor
Left _ -> throwE . FILE $ REDIRECT "decode error"
@@ -399,7 +399,7 @@ runXFTPSndPrepareWorker c Worker {doWork} = do
runXFTPOperation cfg@AgentConfig {sndFilesTTL} =
withWork c doWork (`getNextSndFileToPrepare` sndFilesTTL) $
\f@SndFile {sndFileId, sndFileEntityId, prefixPath} ->
prepareFile cfg f `catchAgentError` sndWorkerInternalError c sndFileId sndFileEntityId prefixPath
prepareFile cfg f `catchAllErrors` sndWorkerInternalError c sndFileId sndFileEntityId prefixPath
prepareFile :: AgentConfig -> SndFile -> AM ()
prepareFile _ SndFile {prefixPath = Nothing} =
throwE $ INTERNAL "no prefix path"
@@ -468,11 +468,11 @@ runXFTPSndPrepareWorker c Worker {doWork} = do
liftIO $ waitForUserNetwork c
let triedAllSrvs = n > userSrvCount
createWithNextSrv triedHosts
`catchAgentError` \e -> retryOnError "XFTP prepare worker" (retryLoop loop triedAllSrvs e) (throwE e) e
`catchAllErrors` \e -> retryOnError "XFTP prepare worker" (retryLoop loop triedAllSrvs e) (throwE e) e
where
-- we don't do closeXFTPServerClient here to not risk closing connection for concurrent chunk upload
retryLoop loop triedAllSrvs e = do
flip catchAgentError (\_ -> pure ()) $ do
flip catchAllErrors (\_ -> pure ()) $ do
when (triedAllSrvs && serverHostError e) $ notify c sndFileEntityId $ SFWARN e
liftIO $ assertAgentForeground c
loop
@@ -508,10 +508,10 @@ runXFTPSndWorker c srv Worker {doWork} = do
liftIO $ waitForUserNetwork c
atomically $ incXFTPServerStat c userId srv uploadAttempts
uploadFileChunk cfg fc replica
`catchAgentError` \e -> retryOnError "XFTP snd worker" (retryLoop loop e delay') (retryDone e) e
`catchAllErrors` \e -> retryOnError "XFTP snd worker" (retryLoop loop e delay') (retryDone e) e
where
retryLoop loop e replicaDelay = do
flip catchAgentError (\_ -> pure ()) $ do
flip catchAllErrors (\_ -> pure ()) $ do
when (serverHostError e) $ notify c sndFileEntityId $ SFWARN e
liftIO $ closeXFTPServerClient c userId server digest
withStore' c $ \db -> updateSndChunkReplicaDelay db sndChunkReplicaId replicaDelay
@@ -681,10 +681,10 @@ runXFTPDelWorker c srv Worker {doWork} = do
liftIO $ waitForUserNetwork c
atomically $ incXFTPServerStat c userId srv deleteAttempts
deleteChunkReplica
`catchAgentError` \e -> retryOnError "XFTP del worker" (retryLoop loop e delay') (retryDone e) e
`catchAllErrors` \e -> retryOnError "XFTP del worker" (retryLoop loop e delay') (retryDone e) e
where
retryLoop loop e replicaDelay = do
flip catchAgentError (\_ -> pure ()) $ do
flip catchAllErrors (\_ -> pure ()) $ do
when (serverHostError e) $ notify c "" $ SFWARN e
liftIO $ closeXFTPServerClient c userId server chunkDigest
withStore' c $ \db -> updateDeletedSndChunkReplicaDelay db deletedSndChunkReplicaId replicaDelay
+23 -23
View File
@@ -284,13 +284,13 @@ saveServersStats c@AgentClient {subQ, smpServersStats, xftpServersStats, ntfServ
xss <- mapM (liftIO . getAgentXFTPServerStats) =<< readTVarIO xftpServersStats
nss <- mapM (liftIO . getAgentNtfServerStats) =<< readTVarIO ntfServersStats
let stats = AgentPersistedServerStats {smpServersStats = sss, xftpServersStats = xss, ntfServersStats = OptionalMap nss}
tryAgentError' (withStore' c (`updateServersStats` stats)) >>= \case
tryAllErrors' (withStore' c (`updateServersStats` stats)) >>= \case
Left e -> atomically $ writeTBQueue subQ ("", "", AEvt SAEConn $ ERR $ INTERNAL $ show e)
Right () -> pure ()
restoreServersStats :: AgentClient -> AM' ()
restoreServersStats c@AgentClient {smpServersStats, xftpServersStats, ntfServersStats, srvStatsStartedAt} = do
tryAgentError' (withStore c getServersStats) >>= \case
tryAllErrors' (withStore c getServersStats) >>= \case
Left e -> atomically $ writeTBQueue (subQ c) ("", "", AEvt SAEConn $ ERR $ INTERNAL $ show e)
Right (startedAt, Nothing) -> atomically $ writeTVar srvStatsStartedAt startedAt
Right (startedAt, Just AgentPersistedServerStats {smpServersStats = sss, xftpServersStats = xss, ntfServersStats = OptionalMap nss}) -> do
@@ -774,7 +774,7 @@ acceptContactAsync' :: AgentClient -> UserId -> ACorrId -> Bool -> InvitationId
acceptContactAsync' c userId corrId enableNtfs invId ownConnInfo pqSupport subMode = do
Invitation {connReq} <- withStore c $ \db -> getInvitation db "acceptContactAsync'" invId
withStore' c $ \db -> acceptInvitation db invId ownConnInfo
joinConnAsync c userId corrId enableNtfs connReq ownConnInfo pqSupport subMode `catchAgentError` \err -> do
joinConnAsync c userId corrId enableNtfs connReq ownConnInfo pqSupport subMode `catchAllErrors` \err -> do
withStore' c (`unacceptInvitation` invId)
throwE err
@@ -961,7 +961,7 @@ newRcvConnSrv c nm userId connId enableNtfs cMode userData_ clientData pqInitKey
createRcvQueue nonce_ qd e2eKeys = do
AgentConfig {smpClientVRange = vr} <- asks config
ntfServer_ <- if enableNtfs then newQueueNtfServer else pure Nothing
(rq, qUri, tSess, sessId) <- newRcvQueue_ c nm userId connId srvWithAuth vr qd (isJust ntfServer_) subMode nonce_ e2eKeys `catchAgentError` \e -> liftIO (print e) >> throwE e
(rq, qUri, tSess, sessId) <- newRcvQueue_ c nm userId connId srvWithAuth vr qd (isJust ntfServer_) subMode nonce_ e2eKeys `catchAllErrors` \e -> liftIO (print e) >> throwE e
atomically $ incSMPServerStat c userId srv connCreated
rq' <- withStore c $ \db -> updateNewConnRcv db connId rq
lift . when (subMode == SMSubscribe) $ addNewQueueSubscription c rq' tSess sessId
@@ -1351,7 +1351,7 @@ subscribeClientService' = undefined
-- requesting messages sequentially, to reduce memory usage
getConnectionMessages' :: AgentClient -> NonEmpty ConnMsgReq -> AM' (NonEmpty (Either AgentErrorType (Maybe SMPMsgMeta)))
getConnectionMessages' c = mapM $ tryAgentError' . getConnectionMessage
getConnectionMessages' c = mapM $ tryAllErrors' . getConnectionMessage
where
getConnectionMessage :: ConnMsgReq -> AM (Maybe SMPMsgMeta)
getConnectionMessage (ConnMsgReq connId dbQueueId msgTs_) = do
@@ -1363,7 +1363,7 @@ getConnectionMessages' c = mapM $ tryAgentError' . getConnectionMessage
ContactConnection _ rq -> pure rq
SndConnection _ _ -> throwE $ CONN SIMPLEX "getConnectionMessage"
NewConnection _ -> throwE $ CMD PROHIBITED "getConnectionMessage: NewConnection"
msg_ <- getQueueMessage c rq `catchAgentError` \e -> atomically (releaseGetLock c rq) >> throwError e
msg_ <- getQueueMessage c rq `catchAllErrors` \e -> atomically (releaseGetLock c rq) >> throwError e
when (isNothing msg_) $ do
atomically $ releaseGetLock c rq
forM_ msgTs_ $ \msgTs -> withStore' c $ \db -> setLastBrokerTs db connId (DBEntityId dbQueueId) msgTs
@@ -1534,7 +1534,7 @@ runCommandProcessing c@AgentClient {subQ} connId server_ Worker {doWork} = do
pure CCCompleted
-- duplex connection is matched to handle SKEY retries
DuplexConnection cData _ (sq :| _) -> do
tryAgentError (mapM_ (connectReplyQueues c cData ownConnInfo (Just sq)) (L.nonEmpty $ smpReplyQueues senderConf)) >>= \case
tryAllErrors (mapM_ (connectReplyQueues c cData ownConnInfo (Just sq)) (L.nonEmpty $ smpReplyQueues senderConf)) >>= \case
Right () -> pure CCCompleted
Left e
| temporaryOrHostError e && Just server /= server_ -> do
@@ -1621,7 +1621,7 @@ runCommandProcessing c@AgentClient {subQ} connId server_ Worker {doWork} = do
tryMoveableCommand action = withRetryInterval ri $ \_ loop -> do
liftIO $ waitWhileSuspended c
liftIO $ waitForUserNetwork c
tryAgentError action >>= \case
tryAllErrors action >>= \case
Left e
| temporaryOrHostError e -> retrySndOp c loop
| otherwise -> cmdError e
@@ -2065,7 +2065,7 @@ synchronizeRatchet' c connId pqSupport' force = withConnLock c connId "synchroni
ackQueueMessage :: AgentClient -> RcvQueue -> SMP.MsgId -> AM ()
ackQueueMessage c rq@RcvQueue {userId, connId, server} srvMsgId = do
atomically $ incSMPServerStat c userId server ackAttempts
tryAgentError (sendAck c rq srvMsgId) >>= \case
tryAllErrors (sendAck c rq srvMsgId) >>= \case
Right _ -> sendMsgNtf ackMsgs
Left (SMP _ SMP.NO_MSG) -> sendMsgNtf ackNoMsgErrs
Left e -> do
@@ -2076,7 +2076,7 @@ ackQueueMessage c rq@RcvQueue {userId, connId, server} srvMsgId = do
atomically $ incSMPServerStat c userId server stat
whenM (liftIO $ hasGetLock c rq) $ do
atomically $ releaseGetLock c rq
brokerTs_ <- eitherToMaybe <$> tryAgentError (withStore c $ \db -> getRcvMsgBrokerTs db connId srvMsgId)
brokerTs_ <- eitherToMaybe <$> tryAllErrors (withStore c $ \db -> getRcvMsgBrokerTs db connId srvMsgId)
atomically $ writeTBQueue (subQ c) ("", connId, AEvt SAEConn $ MSGNTF srvMsgId brokerTs_)
-- | Suspend SMP agent connection (OFF command) in Reader monad
@@ -2307,7 +2307,7 @@ registerNtfToken' c nm suppliedDeviceToken suppliedNtfMode =
replaceToken :: NtfTokenId -> AM NtfTknStatus
replaceToken tknId = do
ns <- asks ntfSupervisor
tryReplace ns `catchAgentError` \e ->
tryReplace ns `catchAllErrors` \e ->
if temporaryOrHostError e
then throwE e
else do
@@ -2564,7 +2564,7 @@ cleanupManager c@AgentClient {subQ} = do
where
run :: forall e. AEntityI e => (AgentErrorType -> AEvent e) -> AM () -> AM' ()
run err a = do
waitActive . runExceptT $ a `catchAgentError` (notify "" . err)
waitActive . runExceptT $ a `catchAllErrors` (notify "" . err)
step <- asks $ cleanupStepInterval . config
liftIO $ threadDelay step
-- we are catching it to avoid CRITICAL errors in tests when this is the only remaining handle to active
@@ -2578,33 +2578,33 @@ cleanupManager c@AgentClient {subQ} = do
deleteRcvFilesExpired = do
rcvFilesTTL <- asks $ rcvFilesTTL . config
rcvExpired <- withStore' c (`getRcvFilesExpired` rcvFilesTTL)
forM_ rcvExpired $ \(dbId, entId, p) -> flip catchAgentError (notify entId . RFERR) $ do
forM_ rcvExpired $ \(dbId, entId, p) -> flip catchAllErrors (notify entId . RFERR) $ do
lift $ removePath =<< toFSFilePath p
withStore' c (`deleteRcvFile'` dbId)
deleteRcvFilesDeleted = do
rcvDeleted <- withStore' c getCleanupRcvFilesDeleted
forM_ rcvDeleted $ \(dbId, entId, p) -> flip catchAgentError (notify entId . RFERR) $ do
forM_ rcvDeleted $ \(dbId, entId, p) -> flip catchAllErrors (notify entId . RFERR) $ do
lift $ removePath =<< toFSFilePath p
withStore' c (`deleteRcvFile'` dbId)
deleteRcvFilesTmpPaths = do
rcvTmpPaths <- withStore' c getCleanupRcvFilesTmpPaths
forM_ rcvTmpPaths $ \(dbId, entId, p) -> flip catchAgentError (notify entId . RFERR) $ do
forM_ rcvTmpPaths $ \(dbId, entId, p) -> flip catchAllErrors (notify entId . RFERR) $ do
lift $ removePath =<< toFSFilePath p
withStore' c (`updateRcvFileNoTmpPath` dbId)
deleteSndFilesExpired = do
sndFilesTTL <- asks $ sndFilesTTL . config
sndExpired <- withStore' c (`getSndFilesExpired` sndFilesTTL)
forM_ sndExpired $ \(dbId, entId, p) -> flip catchAgentError (notify entId . SFERR) $ do
forM_ sndExpired $ \(dbId, entId, p) -> flip catchAllErrors (notify entId . SFERR) $ do
lift . forM_ p $ removePath <=< toFSFilePath
withStore' c (`deleteSndFile'` dbId)
deleteSndFilesDeleted = do
sndDeleted <- withStore' c getCleanupSndFilesDeleted
forM_ sndDeleted $ \(dbId, entId, p) -> flip catchAgentError (notify entId . SFERR) $ do
forM_ sndDeleted $ \(dbId, entId, p) -> flip catchAllErrors (notify entId . SFERR) $ do
lift . forM_ p $ removePath <=< toFSFilePath
withStore' c (`deleteSndFile'` dbId)
deleteSndFilesPrefixPaths = do
sndPrefixPaths <- withStore' c getCleanupSndFilesPrefixPaths
forM_ sndPrefixPaths $ \(dbId, entId, p) -> flip catchAgentError (notify entId . SFERR) $ do
forM_ sndPrefixPaths $ \(dbId, entId, p) -> flip catchAllErrors (notify entId . SFERR) $ do
lift $ removePath =<< toFSFilePath p
withStore' c (`updateSndFileNoPrefixPath` dbId)
deleteExpiredReplicasForDeletion = do
@@ -2652,10 +2652,10 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId
where
withRcvConn :: SMP.RecipientId -> (forall c. RcvQueue -> Connection c -> AM ()) -> AM' ()
withRcvConn rId a = do
tryAgentError' (withStore c $ \db -> getRcvConn db srv rId) >>= \case
tryAllErrors' (withStore c $ \db -> getRcvConn db srv rId) >>= \case
Left e -> notify' "" (ERR e)
Right (rq@RcvQueue {connId}, SomeConn _ conn) ->
tryAgentError' (a rq conn) >>= \case
tryAllErrors' (a rq conn) >>= \case
Left e -> notify' connId (ERR e)
Right () -> pure ()
processSubOk :: RcvQueue -> TVar [ConnId] -> AM ()
@@ -2739,7 +2739,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId
_ -> pure ()
let encryptedMsgHash = C.sha256Hash encAgentMessage
g <- asks random
tryAgentError (agentClientMsg g encryptedMsgHash) >>= \case
tryAllErrors (agentClientMsg g encryptedMsgHash) >>= \case
Right (Just (msgId, msgMeta, aMessage, rcPrev)) -> do
conn'' <- resetRatchetSync
case aMessage of
@@ -2848,7 +2848,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId
ackDel :: InternalId -> AM ACKd
ackDel aId = enqueueCmd (ICAckDel rId srvMsgId aId) $> ACKd
handleNotifyAck :: AM ACKd -> AM ACKd
handleNotifyAck m = m `catchAgentError` \e -> notify (ERR e) >> ack
handleNotifyAck m = m `catchAllErrors` \e -> notify (ERR e) >> ack
SMP.END ->
atomically (ifM (activeClientSession c tSess sessId) (removeSubscription c connId $> True) (pure False))
>>= notifyEnd
@@ -3006,7 +3006,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId
messagesRcvd :: NonEmpty AMessageReceipt -> MsgMeta -> Connection 'CDuplex -> AM ACKd
messagesRcvd rcpts msgMeta@MsgMeta {broker = (srvMsgId, _)} _ = do
logServer "<--" c srv rId $ "MSG <RCPT>:" <> logSecret' srvMsgId
rs <- forM rcpts $ \rcpt -> clientReceipt rcpt `catchAgentError` \e -> notify (ERR e) $> Nothing
rs <- forM rcpts $ \rcpt -> clientReceipt rcpt `catchAllErrors` \e -> notify (ERR e) $> Nothing
case L.nonEmpty . catMaybes $ L.toList rs of
Just rs' -> notify (RCVD msgMeta rs') $> ACKPending
Nothing -> ack
+26 -21
View File
@@ -130,6 +130,7 @@ module Simplex.Messaging.Agent.Client
hasWorkToDo,
hasWorkToDo',
withWork,
withWork_,
withWorkItems,
agentOperations,
agentOperationBracket,
@@ -371,12 +372,12 @@ data SMPConnectedClient = SMPConnectedClient
type ProxiedRelayVar = SessionVar (Either AgentErrorType ProxiedRelay)
getAgentWorker :: (Ord k, Show k) => String -> Bool -> AgentClient -> k -> TMap k Worker -> (Worker -> AM ()) -> AM' Worker
getAgentWorker :: (Ord k, Show k, AnyError e, MonadUnliftIO m) => String -> Bool -> AgentClient -> k -> TMap k Worker -> (Worker -> ExceptT e m ()) -> m Worker
getAgentWorker = getAgentWorker' id pure
{-# INLINE getAgentWorker #-}
getAgentWorker' :: forall a k. (Ord k, Show k) => (a -> Worker) -> (Worker -> STM a) -> String -> Bool -> AgentClient -> k -> TMap k a -> (a -> AM ()) -> AM' a
getAgentWorker' toW fromW name hasWork c key ws work = do
getAgentWorker' :: forall a k e m. (Ord k, Show k, AnyError e, MonadUnliftIO m) => (a -> Worker) -> (Worker -> STM a) -> String -> Bool -> AgentClient -> k -> TMap k a -> (a -> ExceptT e m ()) -> m a
getAgentWorker' toW fromW name hasWork c@AgentClient {agentEnv} key ws work = do
atomically (getWorker >>= maybe createWorker whenExists) >>= \w -> runWorker w $> w
where
getWorker = TM.lookup key ws
@@ -389,12 +390,12 @@ getAgentWorker' toW fromW name hasWork c key ws work = do
| otherwise = pure w
runWorker w = runWorkerAsync (toW w) runWork
where
runWork :: AM' ()
runWork = tryAgentError' (work w) >>= restartOrDelete
restartOrDelete :: Either AgentErrorType () -> AM' ()
runWork :: m ()
runWork = tryAllErrors' (work w) >>= restartOrDelete
restartOrDelete :: Either e () -> m ()
restartOrDelete e_ = do
t <- liftIO getSystemTime
maxRestarts <- asks $ maxWorkerRestartsPerMin . config
let maxRestarts = maxWorkerRestartsPerMin $ config agentEnv
-- worker may terminate because it was deleted from the map (getWorker returns Nothing), then it won't restart
restart <- atomically $ getWorker >>= maybe (pure False) (shouldRestart e_ (toW w) t maxRestarts)
when restart runWork
@@ -431,7 +432,7 @@ newWorker c = do
restarts <- newTVar $ RestartCount 0 0
pure Worker {workerId, doWork, action, restarts}
runWorkerAsync :: Worker -> AM' () -> AM' ()
runWorkerAsync :: MonadUnliftIO m => Worker -> m () -> m ()
runWorkerAsync Worker {action} work =
E.bracket
(atomically $ takeTMVar action) -- get current action, locking to avoid race conditions
@@ -665,7 +666,7 @@ getSMPProxyClient c@AgentClient {active, smpClients, smpProxiedRelays, workerSeq
pure (clnt, sess)
newProxiedRelay :: SMPConnectedClient -> Maybe SMP.BasicAuth -> ProxiedRelayVar -> AM (Either AgentErrorType ProxiedRelay)
newProxiedRelay (SMPConnectedClient smp prs) proxyAuth rv =
tryAgentError (liftClient SMP (clientServer smp) $ connectSMPProxiedRelay smp nm destSrv proxyAuth) >>= \case
tryAllErrors (liftClient SMP (clientServer smp) $ connectSMPProxiedRelay smp nm destSrv proxyAuth) >>= \case
Right sess -> do
atomically $ putTMVar (sessionVar rv) (Right sess)
pure $ Right sess
@@ -688,7 +689,7 @@ getSMPProxyClient c@AgentClient {active, smpClients, smpProxiedRelays, workerSeq
smpConnectClient :: AgentClient -> NetworkRequestMode -> SMPTransportSession -> TMap SMPServer ProxiedRelayVar -> SMPClientVar -> AM SMPConnectedClient
smpConnectClient c@AgentClient {smpClients, msgQ, proxySessTs, presetDomains} nm tSess@(_, srv, _) prs v =
newProtocolClient c tSess smpClients connectClient v
`catchAgentError` \e -> lift (resubscribeSMPSession c tSess) >> throwE e
`catchAllErrors` \e -> lift (resubscribeSMPSession c tSess) >> throwE e
where
connectClient :: SMPClientVar -> AM SMPConnectedClient
connectClient v' = do
@@ -866,7 +867,7 @@ newProtocolClient ::
ClientVar msg ->
AM (Client msg)
newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient v =
tryAgentError (connectClient v) >>= \case
tryAllErrors (connectClient v) >>= \case
Right client -> do
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv <> " (user " <> bshow userId <> maybe "" (" for entity " <>) entityId_ <> ")"
atomically $ putTMVar (sessionVar v) (Right client)
@@ -1027,7 +1028,7 @@ getMapLock locks key = TM.lookup key locks >>= maybe newLock pure
withClient_ :: forall a v err msg. ProtocolServerClient v err msg => AgentClient -> NetworkRequestMode -> TransportSession msg -> (Client msg -> AM a) -> AM a
withClient_ c nm tSess@(_, srv, _) action = do
cl <- getProtocolServerClient c nm tSess
action cl `catchAgentError` logServerError
action cl `catchAllErrors` logServerError
where
logServerError :: AgentErrorType -> AM a
logServerError e = do
@@ -1040,7 +1041,7 @@ withProxySession c nm proxySrv_ destSess@(_, destSrv, _) entId cmdStr action = d
logServer ("--> " <> proxySrv cl <> " >") c destSrv entId cmdStr
case sess_ of
Right sess -> do
r <- action (cl, sess) `catchAgentError` logServerError cl
r <- action (cl, sess) `catchAllErrors` logServerError cl
logServer ("<-- " <> proxySrv cl <> " <") c destSrv entId "OK"
pure r
Left e -> logServerError cl e
@@ -1117,7 +1118,7 @@ sendOrProxySMPCommand c nm userId destSrv@ProtocolServer {host = destHosts} conn
unknownServer = liftIO $ maybe True (\srvs -> all (`S.notMember` knownHosts srvs) destHosts) <$> TM.lookupIO userId (smpServers c)
sendViaProxy :: Maybe SMPServerWithAuth -> SMPTransportSession -> AM (Maybe SMPServer, a)
sendViaProxy proxySrv_ destSess@(_, _, connId_) = do
r <- tryAgentError . withProxySession c nm proxySrv_ destSess entId ("PFWD " <> cmdStr) $ \(SMPConnectedClient smp _, proxySess@ProxiedRelay {prBasicAuth}) -> do
r <- tryAllErrors . withProxySession c nm proxySrv_ destSess entId ("PFWD " <> cmdStr) $ \(SMPConnectedClient smp _, proxySess@ProxiedRelay {prBasicAuth}) -> do
r' <- liftClient SMP (clientServer smp) $ sendCmdViaProxy smp proxySess
let proxySrv = protocolClientServer' smp
case r' of
@@ -1164,7 +1165,7 @@ sendOrProxySMPCommand c nm userId destSrv@ProtocolServer {host = destHosts} conn
| otherwise -> throwE e
sendDirectly tSess =
withLogClient_ c nm tSess (unEntityId entId) ("SEND " <> cmdStr) $ \(SMPConnectedClient smp _) -> do
tryAgentError (liftClient SMP (clientServer smp) $ sendCmdDirectly smp) >>= \case
tryAllErrors (liftClient SMP (clientServer smp) $ sendCmdDirectly smp) >>= \case
Right r -> r <$ atomically (incSMPServerStat c userId destSrv sentDirect)
Left e -> throwE e
@@ -1562,7 +1563,7 @@ sendTSessionBatches statCmd toRQ action c nm qs =
in M.alter (Just . maybe [q] (q <|)) tSess m
sendClientBatch :: (SMPTransportSession, NonEmpty q) -> AM' (BatchResponses q AgentErrorType r)
sendClientBatch (tSess@(_, srv, _), qs') =
tryAgentError' (getSMPServerClient c nm tSess) >>= \case
tryAllErrors' (getSMPServerClient c nm tSess) >>= \case
Left e -> pure $ L.map (,Left e) qs'
Right (SMPConnectedClient smp _) -> liftIO $ do
logServer' "-->" c srv (bshow (length qs') <> " queues") statCmd
@@ -1867,7 +1868,7 @@ withNtfBatch ::
AM' (NonEmpty (Either AgentErrorType r))
withNtfBatch cmdStr action c NtfToken {ntfServer, ntfPrivKey} subs = do
let tSess = (0, ntfServer, Nothing)
tryAgentError' (getNtfServerClient c NRMBackground tSess) >>= \case
tryAllErrors' (getNtfServerClient c NRMBackground tSess) >>= \case
Left e -> pure $ L.map (\_ -> Left e) subs
Right ntf -> liftIO $ do
logServer' "-->" c ntfServer (bshow (length subs) <> " subscriptions") cmdStr
@@ -1968,8 +1969,12 @@ waitForWork = void . atomically . readTMVar
{-# INLINE waitForWork #-}
withWork :: AgentClient -> TMVar () -> (DB.Connection -> IO (Either StoreError (Maybe a))) -> (a -> AM ()) -> AM ()
withWork c doWork getWork action =
withStore' c getWork >>= \case
withWork c doWork = withWork_ c doWork . withStore' c
{-# INLINE withWork #-}
withWork_ :: MonadIO m => AgentClient -> TMVar () -> ExceptT e m (Either StoreError (Maybe a)) -> (a -> ExceptT e m ()) -> ExceptT e m ()
withWork_ c doWork getWork action =
getWork >>= \case
Right (Just r) -> action r
Right Nothing -> noWork
-- worker is stopped here (noWork) because the next iteration is likely to produce the same result
@@ -1979,9 +1984,9 @@ withWork c doWork getWork action =
noWork = liftIO $ noWorkToDo doWork
notifyErr err e = atomically $ writeTBQueue (subQ c) ("", "", AEvt SAEConn $ ERR $ err $ show e)
withWorkItems :: AgentClient -> TMVar () -> (DB.Connection -> IO (Either StoreError [Either StoreError a])) -> (NonEmpty a -> AM ()) -> AM ()
withWorkItems :: MonadIO m => AgentClient -> TMVar () -> ExceptT e m (Either StoreError [Either StoreError a]) -> (NonEmpty a -> ExceptT e m ()) -> ExceptT e m ()
withWorkItems c doWork getWork action = do
withStore' c getWork >>= \case
getWork >>= \case
Right [] -> noWork
Right rs -> do
let (errs, items) = partitionEithers rs
-34
View File
@@ -27,11 +27,6 @@ module Simplex.Messaging.Agent.Env.SQLite
serverHosts,
defaultAgentConfig,
defaultReconnectInterval,
tryAgentError,
tryAgentError',
catchAgentError,
catchAgentError',
agentFinally,
Env (..),
newSMPAgentEnv,
createAgentStore,
@@ -45,7 +40,6 @@ module Simplex.Messaging.Agent.Env.SQLite
where
import Control.Concurrent (ThreadId)
import Control.Exception (BlockedIndefinitelyOnSTM (..), SomeException, fromException)
import Control.Monad.Except
import Control.Monad.IO.Unlift
import Control.Monad.Reader
@@ -83,7 +77,6 @@ import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport (SMPVersion)
import Simplex.Messaging.Transport.Client (TransportHost)
import Simplex.Messaging.Util (allFinally, catchAllErrors, catchAllErrors', tryAllErrors, tryAllErrors')
import System.Mem.Weak (Weak)
import System.Random (StdGen, newStdGen)
import UnliftIO.STM
@@ -312,33 +305,6 @@ newXFTPAgent = do
xftpDelWorkers <- TM.emptyIO
pure XFTPAgent {xftpWorkDir, xftpRcvWorkers, xftpSndWorkers, xftpDelWorkers}
tryAgentError :: AM a -> AM (Either AgentErrorType a)
tryAgentError = tryAllErrors mkInternal
{-# INLINE tryAgentError #-}
-- unlike runExceptT, this ensures we catch IO exceptions as well
tryAgentError' :: AM a -> AM' (Either AgentErrorType a)
tryAgentError' = tryAllErrors' mkInternal
{-# INLINE tryAgentError' #-}
catchAgentError :: AM a -> (AgentErrorType -> AM a) -> AM a
catchAgentError = catchAllErrors mkInternal
{-# INLINE catchAgentError #-}
catchAgentError' :: AM a -> (AgentErrorType -> AM' a) -> AM' a
catchAgentError' = catchAllErrors' mkInternal
{-# INLINE catchAgentError' #-}
agentFinally :: AM a -> AM b -> AM a
agentFinally = allFinally mkInternal
{-# INLINE agentFinally #-}
mkInternal :: SomeException -> AgentErrorType
mkInternal e = case fromException e of
Just BlockedIndefinitelyOnSTM -> CRITICAL True "Thread blocked indefinitely in STM transaction"
_ -> INTERNAL $ show e
{-# INLINE mkInternal #-}
data Worker = Worker
{ workerId :: Int,
doWork :: TMVar (),
@@ -52,7 +52,7 @@ import Simplex.Messaging.Notifications.Protocol
import Simplex.Messaging.Notifications.Types
import Simplex.Messaging.Protocol (NtfServer, sameSrvAddr)
import qualified Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.Util (diffToMicroseconds, threadDelay', tshow, whenM)
import Simplex.Messaging.Util (catchAllErrors, diffToMicroseconds, threadDelay', tryAllErrors, tshow, whenM)
import System.Random (randomR)
import UnliftIO
import UnliftIO.Concurrent (forkIO)
@@ -217,7 +217,7 @@ runNtfWorker c srv Worker {doWork} =
runNtfOperation :: AM ()
runNtfOperation = do
ntfBatchSize <- asks $ ntfBatchSize . config
withWorkItems c doWork (\db -> getNextNtfSubNTFActions db srv ntfBatchSize) $ \nextSubs -> do
withWorkItems c doWork (withStore' c $ \db -> getNextNtfSubNTFActions db srv ntfBatchSize) $ \nextSubs -> do
logInfo $ "runNtfWorker - length nextSubs = " <> tshow (length nextSubs)
currTs <- liftIO getCurrentTime
let (creates, checks, deletes, rotates) = splitActions currTs nextSubs
@@ -357,7 +357,7 @@ runNtfWorker c srv Worker {doWork} =
runCatching :: (NtfSubscription -> AM (Maybe NtfSubscription)) -> NtfSubscription -> AM' (Maybe NtfSubscription)
runCatching action sub@NtfSubscription {connId} =
fromRight Nothing
<$> runExceptT (action sub `catchAgentError` \e -> workerInternalError c connId (show e) $> Nothing)
<$> runExceptT (action sub `catchAllErrors` \e -> workerInternalError c connId (show e) $> Nothing)
-- deleteNtfSub is only used in NSADelete and NSARotate, so also deprecated
deleteNtfSub :: NtfSubscription -> AM () -> AM (Maybe NtfSubscription)
deleteNtfSub sub@NtfSubscription {userId, ntfSubId} continue = case ntfSubId of
@@ -365,7 +365,7 @@ runNtfWorker c srv Worker {doWork} =
lift getNtfToken >>= \case
Just tkn@NtfToken {ntfServer} -> do
atomically $ incNtfServerStat c userId ntfServer ntfDelAttempts
tryAgentError (agentNtfDeleteSubscription c nSubId tkn) >>= \case
tryAllErrors (agentNtfDeleteSubscription c nSubId tkn) >>= \case
Right _ -> do
atomically $ incNtfServerStat c userId ntfServer ntfDeleted
continue'
@@ -385,7 +385,7 @@ runNtfSMPWorker c srv Worker {doWork} = forever $ do
runNtfSMPOperation :: AM ()
runNtfSMPOperation = do
ntfBatchSize <- asks $ ntfBatchSize . config
withWorkItems c doWork (\db -> getNextNtfSubSMPActions db srv ntfBatchSize) $ \nextSubs -> do
withWorkItems c doWork (withStore' c $ \db -> getNextNtfSubSMPActions db srv ntfBatchSize) $ \nextSubs -> do
logInfo $ "runNtfSMPWorker - length nextSubs = " <> tshow (length nextSubs)
let (creates, deletes) = splitActions nextSubs
retrySubActions c creates createNotifierKeys
@@ -567,7 +567,7 @@ runNtfTknDelWorker c srv Worker {doWork} =
withRetryInterval ri $ \_ loop -> do
liftIO $ waitWhileSuspended c
liftIO $ waitForUserNetwork c
processTknToDelete nextTknToDelete `catchAgentError` retryTmpError loop nextTknToDelete
processTknToDelete nextTknToDelete `catchAllErrors` retryTmpError loop nextTknToDelete
retryTmpError :: AM () -> NtfTokenToDelete -> AgentErrorType -> AM ()
retryTmpError loop (tknDbId, _, _) e = do
logError $ "ntf tkn del error: " <> tshow e
+7
View File
@@ -173,6 +173,7 @@ module Simplex.Messaging.Agent.Protocol
where
import Control.Applicative (optional, (<|>))
import Control.Exception (BlockedIndefinitelyOnSTM (..), fromException)
import Data.Aeson (FromJSON (..), ToJSON (..), Value (..), (.:), (.:?))
import qualified Data.Aeson as J'
import qualified Data.Aeson.Encoding as JE
@@ -1866,6 +1867,12 @@ data AgentErrorType
INACTIVE
deriving (Eq, Show, Exception)
instance AnyError AgentErrorType where
fromSomeException e = case fromException e of
Just BlockedIndefinitelyOnSTM -> CRITICAL True "Thread blocked indefinitely in STM transaction"
_ -> INTERNAL $ show e
{-# INLINE fromSomeException #-}
-- | SMP agent protocol command or response error.
data CommandErrorType
= -- | command is prohibited in this context
+5 -1
View File
@@ -29,6 +29,7 @@ import Data.Time (UTCTime)
import Data.Type.Equality
import Simplex.Messaging.Agent.Protocol
import Simplex.Messaging.Agent.RetryInterval (RI2State)
import Simplex.Messaging.Agent.Store.Entity
import Simplex.Messaging.Agent.Store.Common
import Simplex.Messaging.Agent.Store.Interface (createDBStore)
import Simplex.Messaging.Agent.Store.Migrations.App (appMigrations)
@@ -52,7 +53,7 @@ import Simplex.Messaging.Protocol
VersionSMPC,
)
import qualified Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.Agent.Store.Entity
import Simplex.Messaging.Util (AnyError (..), bshow)
createStore :: DBOpts -> MigrationConfirmation -> IO (Either MigrationError DBStore)
createStore dbOpts = createDBStore dbOpts appMigrations
@@ -696,3 +697,6 @@ data StoreError
| -- | Servers stats not found.
SEServersStatsNotFound
deriving (Eq, Show, Exception)
instance AnyError StoreError where
fromSomeException = SEInternal . bshow
@@ -975,13 +975,10 @@ getWorkItems itemName getIds getItem markFailed =
runExceptT $ handleWrkErr itemName "getIds" getIds >>= mapM (tryE . tryGetItem itemName getItem markFailed)
tryGetItem :: Show i => ByteString -> (i -> IO (Either StoreError a)) -> (i -> IO ()) -> i -> ExceptT StoreError IO a
tryGetItem itemName getItem markFailed itemId = ExceptT (getItem itemId) `catchStoreError` \e -> mark >> throwE e
tryGetItem itemName getItem markFailed itemId = ExceptT (getItem itemId) `catchAllErrors` \e -> mark >> throwE e
where
mark = handleWrkErr itemName ("markFailed ID " <> bshow itemId) $ markFailed itemId
catchStoreError :: ExceptT StoreError IO a -> (StoreError -> ExceptT StoreError IO a) -> ExceptT StoreError IO a
catchStoreError = catchAllErrors (SEInternal . bshow)
-- Errors caught by this function will suspend worker as if there is no more work,
handleWrkErr :: ByteString -> ByteString -> IO a -> ExceptT StoreError IO a
handleWrkErr itemName opName action = ExceptT $ first mkError <$> E.try action
+14 -12
View File
@@ -171,28 +171,30 @@ catchAll_ :: IO a -> IO a -> IO a
catchAll_ a = catchAll a . const
{-# INLINE catchAll_ #-}
tryAllErrors :: MonadUnliftIO m => (E.SomeException -> e) -> ExceptT e m a -> ExceptT e m (Either e a)
tryAllErrors err action = ExceptT $ Right <$> runExceptT action `UE.catch` (pure . Left . err)
class Show e => AnyError e where fromSomeException :: E.SomeException -> e
tryAllErrors :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> ExceptT e m (Either e a)
tryAllErrors action = ExceptT $ Right <$> runExceptT action `UE.catch` (pure . Left . fromSomeException)
{-# INLINE tryAllErrors #-}
tryAllErrors' :: MonadUnliftIO m => (E.SomeException -> e) -> ExceptT e m a -> m (Either e a)
tryAllErrors' err action = runExceptT action `UE.catch` (pure . Left . err)
tryAllErrors' :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> m (Either e a)
tryAllErrors' action = runExceptT action `UE.catch` (pure . Left . fromSomeException)
{-# INLINE tryAllErrors' #-}
catchAllErrors :: MonadUnliftIO m => (E.SomeException -> e) -> ExceptT e m a -> (e -> ExceptT e m a) -> ExceptT e m a
catchAllErrors err action handler = tryAllErrors err action >>= either handler pure
catchAllErrors :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> (e -> ExceptT e m a) -> ExceptT e m a
catchAllErrors action handler = tryAllErrors action >>= either handler pure
{-# INLINE catchAllErrors #-}
catchAllErrors' :: MonadUnliftIO m => (E.SomeException -> e) -> ExceptT e m a -> (e -> m a) -> m a
catchAllErrors' err action handler = tryAllErrors' err action >>= either handler pure
catchAllErrors' :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> (e -> m a) -> m a
catchAllErrors' action handler = tryAllErrors' action >>= either handler pure
{-# INLINE catchAllErrors' #-}
catchThrow :: MonadUnliftIO m => ExceptT e m a -> (E.SomeException -> e) -> ExceptT e m a
catchThrow action err = catchAllErrors err action throwE
catchThrow :: MonadUnliftIO m => ExceptT e m a -> (SomeException -> e) -> ExceptT e m a
action `catchThrow` err = ExceptT $ runExceptT action `UE.catch` (pure . Left . err)
{-# INLINE catchThrow #-}
allFinally :: MonadUnliftIO m => (E.SomeException -> e) -> ExceptT e m a -> ExceptT e m b -> ExceptT e m a
allFinally err action final = tryAllErrors err action >>= \r -> final >> except r
allFinally :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> ExceptT e m b -> ExceptT e m a
allFinally action final = tryAllErrors action >>= \r -> final >> except r
{-# INLINE allFinally #-}
eitherToMaybe :: Either a b -> Maybe b
+2 -8
View File
@@ -306,14 +306,8 @@ connectRCCtrl_ drg pairing'@RCCtrlPairing {caKey, caCert} inv@RCInvitation {ca,
atomically $ takeTMVar endSession
logDebug "Session ended"
catchRCError :: ExceptT RCErrorType IO a -> (RCErrorType -> ExceptT RCErrorType IO a) -> ExceptT RCErrorType IO a
catchRCError = catchAllErrors $ \e -> case fromException e of
Just (TLS.Terminated _ _ (TLS.Error_Protocol _ TLS.UnknownCa)) -> RCEIdentity
_ -> RCEException $ show e
{-# INLINE catchRCError #-}
putRCError :: ExceptT RCErrorType IO a -> TMVar (Either RCErrorType b) -> ExceptT RCErrorType IO a
a `putRCError` r = a `catchRCError` \e -> atomically (tryPutTMVar r $ Left e) >> throwE e
a `putRCError` r = a `catchAllErrors` \e -> atomically (tryPutTMVar r $ Left e) >> throwE e
sendRCPacket :: Encoding a => TLS p -> a -> ExceptT RCErrorType IO ()
sendRCPacket tls pkt = do
@@ -395,7 +389,7 @@ discoverRCCtrl subscribers pairings =
pure r
where
loop :: ExceptT RCErrorType IO a -> ExceptT RCErrorType IO a
loop action = action `catchRCError` \e -> logError (tshow e) >> loop action
loop action = action `catchAllErrors` \e -> logError (tshow e) >> loop action
findRCCtrlPairing :: NonEmpty RCCtrlPairing -> RCEncInvitation -> ExceptT RCErrorType IO (RCCtrlPairing, RCVerifiedInvitation)
findRCCtrlPairing pairings RCEncInvitation {dhPubKey, nonce, encInvitation} = do
+8 -1
View File
@@ -19,6 +19,7 @@ import qualified Data.Text as T
import Data.Text.Encoding (encodeUtf8)
import Data.Word (Word16)
import qualified Data.X509 as X
import qualified Network.TLS as TLS
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Crypto.SNTRUP761.Bindings
import Simplex.Messaging.Encoding
@@ -26,7 +27,7 @@ import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Parsers (defaultJSON, dropPrefix, sumTypeJSON)
import Simplex.Messaging.Transport (TLS, TSbChainKeys, TransportPeer (..))
import Simplex.Messaging.Transport.Client (TransportHost)
import Simplex.Messaging.Util (safeDecodeUtf8)
import Simplex.Messaging.Util (AnyError (..), safeDecodeUtf8)
import Simplex.Messaging.Version (VersionRange, VersionScope, mkVersionRange)
import Simplex.Messaging.Version.Internal
import UnliftIO
@@ -50,6 +51,12 @@ data RCErrorType
| RCESyntax {syntaxErr :: String}
deriving (Eq, Show, Exception)
instance AnyError RCErrorType where
fromSomeException e = case fromException e of
Just (TLS.Terminated _ _ (TLS.Error_Protocol _ TLS.UnknownCa)) -> RCEIdentity
_ -> RCEException $ show e
{-# INLINE fromSomeException #-}
instance StrEncoding RCErrorType where
strEncode = \case
RCEInternal err -> "INTERNAL" <> text err
+15 -38
View File
@@ -45,56 +45,32 @@ utilTests = do
runExceptT (throwTestException `catchError` handleCatch) `shouldThrow` (\(e :: IOError) -> show e == "user error (error)")
describe "tryAllErrors" $ do
it "should return ExceptT error as Left" $
runExceptT (tryAllErrors testErr throwTestError) `shouldReturn` Right (Left (TestError "error"))
runExceptT (tryAllErrors throwTestError) `shouldReturn` Right (Left (TestError "error"))
it "should return SomeException as Left" $
runExceptT (tryAllErrors testErr throwTestException) `shouldReturn` Right (Left (TestException "user error (error)"))
runExceptT (tryAllErrors throwTestException) `shouldReturn` Right (Left (TestException "user error (error)"))
it "should return no errors as Right" $
runExceptT (tryAllErrors testErr noErrors) `shouldReturn` Right (Right "no errors")
describe "tryAllErrors specialized as tryTestError" $ do
let tryTestError = tryAllErrors testErr
it "should return ExceptT error as Left" $
runExceptT (tryTestError throwTestError) `shouldReturn` Right (Left (TestError "error"))
it "should return SomeException as Left" $
runExceptT (tryTestError throwTestException) `shouldReturn` Right (Left (TestException "user error (error)"))
it "should return no errors as Right" $
runExceptT (tryTestError noErrors) `shouldReturn` Right (Right "no errors")
runExceptT (tryAllErrors noErrors) `shouldReturn` Right (Right "no errors")
describe "catchAllErrors" $ do
it "should catch ExceptT error" $
runExceptT (catchAllErrors testErr throwTestError handleCatch) `shouldReturn` Right "caught TestError \"error\""
runExceptT (throwTestError `catchAllErrors` handleCatch) `shouldReturn` Right "caught TestError \"error\""
it "should catch SomeException" $
runExceptT (catchAllErrors testErr throwTestException handleCatch) `shouldReturn` Right "caught TestException \"user error (error)\""
runExceptT (throwTestException `catchAllErrors` handleCatch) `shouldReturn` Right "caught TestException \"user error (error)\""
it "should not throw if there are no errors" $
runExceptT (catchAllErrors testErr noErrors throwError) `shouldReturn` Right "no errors"
describe "catchAllErrors specialized as catchTestError" $ do
let catchTestError = catchAllErrors testErr
it "should catch ExceptT error" $
runExceptT (throwTestError `catchTestError` handleCatch) `shouldReturn` Right "caught TestError \"error\""
it "should catch SomeException" $
runExceptT (throwTestException `catchTestError` handleCatch) `shouldReturn` Right "caught TestException \"user error (error)\""
it "should not throw if there are no errors" $
runExceptT (noErrors `catchTestError` throwError) `shouldReturn` Right "no errors"
runExceptT (noErrors `catchAllErrors` throwError) `shouldReturn` Right "no errors"
describe "catchThrow" $ do
it "should re-throw ExceptT error" $
runExceptT (throwTestError `catchThrow` testErr) `shouldReturn` Left (TestError "error")
runExceptT (throwTestError `catchThrow` fromSomeException) `shouldReturn` Left (TestError "error")
it "should catch SomeException and throw as ExceptT error" $
runExceptT (throwTestException `catchThrow` testErr) `shouldReturn` Left (TestException "user error (error)")
runExceptT (throwTestException `catchThrow` fromSomeException) `shouldReturn` Left (TestException "user error (error)")
it "should not throw if there are no exceptions" $
runExceptT (noErrors `catchThrow` testErr) `shouldReturn` Right "no errors"
runExceptT (noErrors `catchThrow` fromSomeException) `shouldReturn` Right "no errors"
describe "allFinally should run final action" $ do
it "then throw ExceptT error" $ withFinal $ \final ->
runExceptT (allFinally testErr throwTestError final) `shouldReturn` Left (TestError "error")
runExceptT (throwTestError `allFinally` final) `shouldReturn` Left (TestError "error")
it "then throw SomeException as ExceptT error" $ withFinal $ \final ->
runExceptT (allFinally testErr throwTestException final) `shouldReturn` Left (TestException "user error (error)")
runExceptT (throwTestException `allFinally` final) `shouldReturn` Left (TestException "user error (error)")
it "and should not throw if there are no exceptions" $ withFinal $ \final ->
runExceptT (allFinally testErr noErrors final) `shouldReturn` Right "no errors"
describe "allFinally specialized as testFinally should run final action" $ do
let testFinally = allFinally testErr
it "then throw ExceptT error" $ withFinal $ \final ->
runExceptT (throwTestError `testFinally` final) `shouldReturn` Left (TestError "error")
it "then throw SomeException as ExceptT error" $ withFinal $ \final ->
runExceptT (throwTestException `testFinally` final) `shouldReturn` Left (TestException "user error (error)")
it "and should not throw if there are no exceptions" $ withFinal $ \final ->
runExceptT (noErrors `testFinally` final) `shouldReturn` Right "no errors"
runExceptT (noErrors `allFinally` final) `shouldReturn` Right "no errors"
where
throwTestError :: ExceptT TestError IO String
throwTestError = throwError $ TestError "error"
@@ -102,8 +78,6 @@ utilTests = do
throwTestException = liftIO $ throwIO $ userError "error"
noErrors :: ExceptT TestError IO String
noErrors = pure "no errors"
testErr :: SomeException -> TestError
testErr = TestException . show
handleCatch :: TestError -> ExceptT TestError IO String
handleCatch e = pure $ "caught " <> show e
handleException :: SomeException -> ExceptT TestError IO String
@@ -119,3 +93,6 @@ data TestError = TestError String | TestException String
deriving (Eq, Show)
instance Exception TestError
instance AnyError TestError where
fromSomeException = TestException . show