diff --git a/src/Simplex/FileTransfer/Agent.hs b/src/Simplex/FileTransfer/Agent.hs index cfee308bc..c3723e7a9 100644 --- a/src/Simplex/FileTransfer/Agent.hs +++ b/src/Simplex/FileTransfer/Agent.hs @@ -75,7 +75,7 @@ import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String (strDecode, strEncode) import Simplex.Messaging.Protocol (ProtocolServer, ProtocolType (..), XFTPServer) import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Util (catchAll_, liftError, tshow, unlessM, whenM) +import Simplex.Messaging.Util (allFinally, catchAll_, catchAllErrors, liftError, tshow, unlessM, whenM) import System.FilePath (takeFileName, ()) import UnliftIO import UnliftIO.Directory @@ -198,10 +198,10 @@ runXFTPRcvWorker c srv Worker {doWork} = do liftIO $ waitForUserNetwork c atomically $ incXFTPServerStat c userId srv downloadAttempts downloadFileChunk fc replica approvedRelays - `catchAgentError` \e -> retryOnError "XFTP rcv worker" (retryLoop loop e delay') (retryDone e) e + `catchAllErrors` \e -> retryOnError "XFTP rcv worker" (retryLoop loop e delay') (retryDone e) e where retryLoop loop e replicaDelay = do - flip catchAgentError (\_ -> pure ()) $ do + flip catchAllErrors (\_ -> pure ()) $ do when (serverHostError e) $ notify c (fromMaybe rcvFileEntityId redirectEntityId_) (RFWARN e) liftIO $ closeXFTPServerClient c userId server digest withStore' c $ \db -> updateRcvChunkReplicaDelay db rcvChunkReplicaId replicaDelay @@ -280,7 +280,7 @@ runXFTPRcvLocalWorker c Worker {doWork} = do runXFTPOperation AgentConfig {rcvFilesTTL} = withWork c doWork (`getNextRcvFileToDecrypt` rcvFilesTTL) $ \f@RcvFile {rcvFileId, rcvFileEntityId, tmpPath, redirect} -> - decryptFile f `catchAgentError` rcvWorkerInternalError c rcvFileId rcvFileEntityId (redirectEntityId <$> redirect) tmpPath + decryptFile f `catchAllErrors` rcvWorkerInternalError c rcvFileId rcvFileEntityId (redirectEntityId <$> redirect) tmpPath decryptFile :: RcvFile -> AM () decryptFile RcvFile {rcvFileId, rcvFileEntityId, size, digest, key, nonce, tmpPath, saveFile, status, chunks, redirect} = do let CryptoFile savePath cfArgs = saveFile @@ -307,7 +307,7 @@ runXFTPRcvLocalWorker c Worker {doWork} = do liftIO $ waitUntilForeground c withStore' c (`updateRcvFileComplete` rcvFileId) -- proceed with redirect - yaml <- liftError (FILE . FILE_IO . show) (CF.readFile $ CryptoFile fsSavePath cfArgs) `agentFinally` (lift $ toFSFilePath fsSavePath >>= removePath) + yaml <- liftError (FILE . FILE_IO . show) (CF.readFile $ CryptoFile fsSavePath cfArgs) `allFinally` (lift $ toFSFilePath fsSavePath >>= removePath) next@FileDescription {chunks = nextChunks} <- case strDecode (LB.toStrict yaml) of -- TODO switch to another error constructor Left _ -> throwE . FILE $ REDIRECT "decode error" @@ -399,7 +399,7 @@ runXFTPSndPrepareWorker c Worker {doWork} = do runXFTPOperation cfg@AgentConfig {sndFilesTTL} = withWork c doWork (`getNextSndFileToPrepare` sndFilesTTL) $ \f@SndFile {sndFileId, sndFileEntityId, prefixPath} -> - prepareFile cfg f `catchAgentError` sndWorkerInternalError c sndFileId sndFileEntityId prefixPath + prepareFile cfg f `catchAllErrors` sndWorkerInternalError c sndFileId sndFileEntityId prefixPath prepareFile :: AgentConfig -> SndFile -> AM () prepareFile _ SndFile {prefixPath = Nothing} = throwE $ INTERNAL "no prefix path" @@ -468,11 +468,11 @@ runXFTPSndPrepareWorker c Worker {doWork} = do liftIO $ waitForUserNetwork c let triedAllSrvs = n > userSrvCount createWithNextSrv triedHosts - `catchAgentError` \e -> retryOnError "XFTP prepare worker" (retryLoop loop triedAllSrvs e) (throwE e) e + `catchAllErrors` \e -> retryOnError "XFTP prepare worker" (retryLoop loop triedAllSrvs e) (throwE e) e where -- we don't do closeXFTPServerClient here to not risk closing connection for concurrent chunk upload retryLoop loop triedAllSrvs e = do - flip catchAgentError (\_ -> pure ()) $ do + flip catchAllErrors (\_ -> pure ()) $ do when (triedAllSrvs && serverHostError e) $ notify c sndFileEntityId $ SFWARN e liftIO $ assertAgentForeground c loop @@ -508,10 +508,10 @@ runXFTPSndWorker c srv Worker {doWork} = do liftIO $ waitForUserNetwork c atomically $ incXFTPServerStat c userId srv uploadAttempts uploadFileChunk cfg fc replica - `catchAgentError` \e -> retryOnError "XFTP snd worker" (retryLoop loop e delay') (retryDone e) e + `catchAllErrors` \e -> retryOnError "XFTP snd worker" (retryLoop loop e delay') (retryDone e) e where retryLoop loop e replicaDelay = do - flip catchAgentError (\_ -> pure ()) $ do + flip catchAllErrors (\_ -> pure ()) $ do when (serverHostError e) $ notify c sndFileEntityId $ SFWARN e liftIO $ closeXFTPServerClient c userId server digest withStore' c $ \db -> updateSndChunkReplicaDelay db sndChunkReplicaId replicaDelay @@ -681,10 +681,10 @@ runXFTPDelWorker c srv Worker {doWork} = do liftIO $ waitForUserNetwork c atomically $ incXFTPServerStat c userId srv deleteAttempts deleteChunkReplica - `catchAgentError` \e -> retryOnError "XFTP del worker" (retryLoop loop e delay') (retryDone e) e + `catchAllErrors` \e -> retryOnError "XFTP del worker" (retryLoop loop e delay') (retryDone e) e where retryLoop loop e replicaDelay = do - flip catchAgentError (\_ -> pure ()) $ do + flip catchAllErrors (\_ -> pure ()) $ do when (serverHostError e) $ notify c "" $ SFWARN e liftIO $ closeXFTPServerClient c userId server chunkDigest withStore' c $ \db -> updateDeletedSndChunkReplicaDelay db deletedSndChunkReplicaId replicaDelay diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 60ea5d69c..982c3099a 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -284,13 +284,13 @@ saveServersStats c@AgentClient {subQ, smpServersStats, xftpServersStats, ntfServ xss <- mapM (liftIO . getAgentXFTPServerStats) =<< readTVarIO xftpServersStats nss <- mapM (liftIO . getAgentNtfServerStats) =<< readTVarIO ntfServersStats let stats = AgentPersistedServerStats {smpServersStats = sss, xftpServersStats = xss, ntfServersStats = OptionalMap nss} - tryAgentError' (withStore' c (`updateServersStats` stats)) >>= \case + tryAllErrors' (withStore' c (`updateServersStats` stats)) >>= \case Left e -> atomically $ writeTBQueue subQ ("", "", AEvt SAEConn $ ERR $ INTERNAL $ show e) Right () -> pure () restoreServersStats :: AgentClient -> AM' () restoreServersStats c@AgentClient {smpServersStats, xftpServersStats, ntfServersStats, srvStatsStartedAt} = do - tryAgentError' (withStore c getServersStats) >>= \case + tryAllErrors' (withStore c getServersStats) >>= \case Left e -> atomically $ writeTBQueue (subQ c) ("", "", AEvt SAEConn $ ERR $ INTERNAL $ show e) Right (startedAt, Nothing) -> atomically $ writeTVar srvStatsStartedAt startedAt Right (startedAt, Just AgentPersistedServerStats {smpServersStats = sss, xftpServersStats = xss, ntfServersStats = OptionalMap nss}) -> do @@ -774,7 +774,7 @@ acceptContactAsync' :: AgentClient -> UserId -> ACorrId -> Bool -> InvitationId acceptContactAsync' c userId corrId enableNtfs invId ownConnInfo pqSupport subMode = do Invitation {connReq} <- withStore c $ \db -> getInvitation db "acceptContactAsync'" invId withStore' c $ \db -> acceptInvitation db invId ownConnInfo - joinConnAsync c userId corrId enableNtfs connReq ownConnInfo pqSupport subMode `catchAgentError` \err -> do + joinConnAsync c userId corrId enableNtfs connReq ownConnInfo pqSupport subMode `catchAllErrors` \err -> do withStore' c (`unacceptInvitation` invId) throwE err @@ -961,7 +961,7 @@ newRcvConnSrv c nm userId connId enableNtfs cMode userData_ clientData pqInitKey createRcvQueue nonce_ qd e2eKeys = do AgentConfig {smpClientVRange = vr} <- asks config ntfServer_ <- if enableNtfs then newQueueNtfServer else pure Nothing - (rq, qUri, tSess, sessId) <- newRcvQueue_ c nm userId connId srvWithAuth vr qd (isJust ntfServer_) subMode nonce_ e2eKeys `catchAgentError` \e -> liftIO (print e) >> throwE e + (rq, qUri, tSess, sessId) <- newRcvQueue_ c nm userId connId srvWithAuth vr qd (isJust ntfServer_) subMode nonce_ e2eKeys `catchAllErrors` \e -> liftIO (print e) >> throwE e atomically $ incSMPServerStat c userId srv connCreated rq' <- withStore c $ \db -> updateNewConnRcv db connId rq lift . when (subMode == SMSubscribe) $ addNewQueueSubscription c rq' tSess sessId @@ -1351,7 +1351,7 @@ subscribeClientService' = undefined -- requesting messages sequentially, to reduce memory usage getConnectionMessages' :: AgentClient -> NonEmpty ConnMsgReq -> AM' (NonEmpty (Either AgentErrorType (Maybe SMPMsgMeta))) -getConnectionMessages' c = mapM $ tryAgentError' . getConnectionMessage +getConnectionMessages' c = mapM $ tryAllErrors' . getConnectionMessage where getConnectionMessage :: ConnMsgReq -> AM (Maybe SMPMsgMeta) getConnectionMessage (ConnMsgReq connId dbQueueId msgTs_) = do @@ -1363,7 +1363,7 @@ getConnectionMessages' c = mapM $ tryAgentError' . getConnectionMessage ContactConnection _ rq -> pure rq SndConnection _ _ -> throwE $ CONN SIMPLEX "getConnectionMessage" NewConnection _ -> throwE $ CMD PROHIBITED "getConnectionMessage: NewConnection" - msg_ <- getQueueMessage c rq `catchAgentError` \e -> atomically (releaseGetLock c rq) >> throwError e + msg_ <- getQueueMessage c rq `catchAllErrors` \e -> atomically (releaseGetLock c rq) >> throwError e when (isNothing msg_) $ do atomically $ releaseGetLock c rq forM_ msgTs_ $ \msgTs -> withStore' c $ \db -> setLastBrokerTs db connId (DBEntityId dbQueueId) msgTs @@ -1534,7 +1534,7 @@ runCommandProcessing c@AgentClient {subQ} connId server_ Worker {doWork} = do pure CCCompleted -- duplex connection is matched to handle SKEY retries DuplexConnection cData _ (sq :| _) -> do - tryAgentError (mapM_ (connectReplyQueues c cData ownConnInfo (Just sq)) (L.nonEmpty $ smpReplyQueues senderConf)) >>= \case + tryAllErrors (mapM_ (connectReplyQueues c cData ownConnInfo (Just sq)) (L.nonEmpty $ smpReplyQueues senderConf)) >>= \case Right () -> pure CCCompleted Left e | temporaryOrHostError e && Just server /= server_ -> do @@ -1621,7 +1621,7 @@ runCommandProcessing c@AgentClient {subQ} connId server_ Worker {doWork} = do tryMoveableCommand action = withRetryInterval ri $ \_ loop -> do liftIO $ waitWhileSuspended c liftIO $ waitForUserNetwork c - tryAgentError action >>= \case + tryAllErrors action >>= \case Left e | temporaryOrHostError e -> retrySndOp c loop | otherwise -> cmdError e @@ -2065,7 +2065,7 @@ synchronizeRatchet' c connId pqSupport' force = withConnLock c connId "synchroni ackQueueMessage :: AgentClient -> RcvQueue -> SMP.MsgId -> AM () ackQueueMessage c rq@RcvQueue {userId, connId, server} srvMsgId = do atomically $ incSMPServerStat c userId server ackAttempts - tryAgentError (sendAck c rq srvMsgId) >>= \case + tryAllErrors (sendAck c rq srvMsgId) >>= \case Right _ -> sendMsgNtf ackMsgs Left (SMP _ SMP.NO_MSG) -> sendMsgNtf ackNoMsgErrs Left e -> do @@ -2076,7 +2076,7 @@ ackQueueMessage c rq@RcvQueue {userId, connId, server} srvMsgId = do atomically $ incSMPServerStat c userId server stat whenM (liftIO $ hasGetLock c rq) $ do atomically $ releaseGetLock c rq - brokerTs_ <- eitherToMaybe <$> tryAgentError (withStore c $ \db -> getRcvMsgBrokerTs db connId srvMsgId) + brokerTs_ <- eitherToMaybe <$> tryAllErrors (withStore c $ \db -> getRcvMsgBrokerTs db connId srvMsgId) atomically $ writeTBQueue (subQ c) ("", connId, AEvt SAEConn $ MSGNTF srvMsgId brokerTs_) -- | Suspend SMP agent connection (OFF command) in Reader monad @@ -2307,7 +2307,7 @@ registerNtfToken' c nm suppliedDeviceToken suppliedNtfMode = replaceToken :: NtfTokenId -> AM NtfTknStatus replaceToken tknId = do ns <- asks ntfSupervisor - tryReplace ns `catchAgentError` \e -> + tryReplace ns `catchAllErrors` \e -> if temporaryOrHostError e then throwE e else do @@ -2564,7 +2564,7 @@ cleanupManager c@AgentClient {subQ} = do where run :: forall e. AEntityI e => (AgentErrorType -> AEvent e) -> AM () -> AM' () run err a = do - waitActive . runExceptT $ a `catchAgentError` (notify "" . err) + waitActive . runExceptT $ a `catchAllErrors` (notify "" . err) step <- asks $ cleanupStepInterval . config liftIO $ threadDelay step -- we are catching it to avoid CRITICAL errors in tests when this is the only remaining handle to active @@ -2578,33 +2578,33 @@ cleanupManager c@AgentClient {subQ} = do deleteRcvFilesExpired = do rcvFilesTTL <- asks $ rcvFilesTTL . config rcvExpired <- withStore' c (`getRcvFilesExpired` rcvFilesTTL) - forM_ rcvExpired $ \(dbId, entId, p) -> flip catchAgentError (notify entId . RFERR) $ do + forM_ rcvExpired $ \(dbId, entId, p) -> flip catchAllErrors (notify entId . RFERR) $ do lift $ removePath =<< toFSFilePath p withStore' c (`deleteRcvFile'` dbId) deleteRcvFilesDeleted = do rcvDeleted <- withStore' c getCleanupRcvFilesDeleted - forM_ rcvDeleted $ \(dbId, entId, p) -> flip catchAgentError (notify entId . RFERR) $ do + forM_ rcvDeleted $ \(dbId, entId, p) -> flip catchAllErrors (notify entId . RFERR) $ do lift $ removePath =<< toFSFilePath p withStore' c (`deleteRcvFile'` dbId) deleteRcvFilesTmpPaths = do rcvTmpPaths <- withStore' c getCleanupRcvFilesTmpPaths - forM_ rcvTmpPaths $ \(dbId, entId, p) -> flip catchAgentError (notify entId . RFERR) $ do + forM_ rcvTmpPaths $ \(dbId, entId, p) -> flip catchAllErrors (notify entId . RFERR) $ do lift $ removePath =<< toFSFilePath p withStore' c (`updateRcvFileNoTmpPath` dbId) deleteSndFilesExpired = do sndFilesTTL <- asks $ sndFilesTTL . config sndExpired <- withStore' c (`getSndFilesExpired` sndFilesTTL) - forM_ sndExpired $ \(dbId, entId, p) -> flip catchAgentError (notify entId . SFERR) $ do + forM_ sndExpired $ \(dbId, entId, p) -> flip catchAllErrors (notify entId . SFERR) $ do lift . forM_ p $ removePath <=< toFSFilePath withStore' c (`deleteSndFile'` dbId) deleteSndFilesDeleted = do sndDeleted <- withStore' c getCleanupSndFilesDeleted - forM_ sndDeleted $ \(dbId, entId, p) -> flip catchAgentError (notify entId . SFERR) $ do + forM_ sndDeleted $ \(dbId, entId, p) -> flip catchAllErrors (notify entId . SFERR) $ do lift . forM_ p $ removePath <=< toFSFilePath withStore' c (`deleteSndFile'` dbId) deleteSndFilesPrefixPaths = do sndPrefixPaths <- withStore' c getCleanupSndFilesPrefixPaths - forM_ sndPrefixPaths $ \(dbId, entId, p) -> flip catchAgentError (notify entId . SFERR) $ do + forM_ sndPrefixPaths $ \(dbId, entId, p) -> flip catchAllErrors (notify entId . SFERR) $ do lift $ removePath =<< toFSFilePath p withStore' c (`updateSndFileNoPrefixPath` dbId) deleteExpiredReplicasForDeletion = do @@ -2652,10 +2652,10 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId where withRcvConn :: SMP.RecipientId -> (forall c. RcvQueue -> Connection c -> AM ()) -> AM' () withRcvConn rId a = do - tryAgentError' (withStore c $ \db -> getRcvConn db srv rId) >>= \case + tryAllErrors' (withStore c $ \db -> getRcvConn db srv rId) >>= \case Left e -> notify' "" (ERR e) Right (rq@RcvQueue {connId}, SomeConn _ conn) -> - tryAgentError' (a rq conn) >>= \case + tryAllErrors' (a rq conn) >>= \case Left e -> notify' connId (ERR e) Right () -> pure () processSubOk :: RcvQueue -> TVar [ConnId] -> AM () @@ -2739,7 +2739,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId _ -> pure () let encryptedMsgHash = C.sha256Hash encAgentMessage g <- asks random - tryAgentError (agentClientMsg g encryptedMsgHash) >>= \case + tryAllErrors (agentClientMsg g encryptedMsgHash) >>= \case Right (Just (msgId, msgMeta, aMessage, rcPrev)) -> do conn'' <- resetRatchetSync case aMessage of @@ -2848,7 +2848,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId ackDel :: InternalId -> AM ACKd ackDel aId = enqueueCmd (ICAckDel rId srvMsgId aId) $> ACKd handleNotifyAck :: AM ACKd -> AM ACKd - handleNotifyAck m = m `catchAgentError` \e -> notify (ERR e) >> ack + handleNotifyAck m = m `catchAllErrors` \e -> notify (ERR e) >> ack SMP.END -> atomically (ifM (activeClientSession c tSess sessId) (removeSubscription c connId $> True) (pure False)) >>= notifyEnd @@ -3006,7 +3006,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId messagesRcvd :: NonEmpty AMessageReceipt -> MsgMeta -> Connection 'CDuplex -> AM ACKd messagesRcvd rcpts msgMeta@MsgMeta {broker = (srvMsgId, _)} _ = do logServer "<--" c srv rId $ "MSG :" <> logSecret' srvMsgId - rs <- forM rcpts $ \rcpt -> clientReceipt rcpt `catchAgentError` \e -> notify (ERR e) $> Nothing + rs <- forM rcpts $ \rcpt -> clientReceipt rcpt `catchAllErrors` \e -> notify (ERR e) $> Nothing case L.nonEmpty . catMaybes $ L.toList rs of Just rs' -> notify (RCVD msgMeta rs') $> ACKPending Nothing -> ack diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 39b3534c0..1b3d8c7ef 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -130,6 +130,7 @@ module Simplex.Messaging.Agent.Client hasWorkToDo, hasWorkToDo', withWork, + withWork_, withWorkItems, agentOperations, agentOperationBracket, @@ -371,12 +372,12 @@ data SMPConnectedClient = SMPConnectedClient type ProxiedRelayVar = SessionVar (Either AgentErrorType ProxiedRelay) -getAgentWorker :: (Ord k, Show k) => String -> Bool -> AgentClient -> k -> TMap k Worker -> (Worker -> AM ()) -> AM' Worker +getAgentWorker :: (Ord k, Show k, AnyError e, MonadUnliftIO m) => String -> Bool -> AgentClient -> k -> TMap k Worker -> (Worker -> ExceptT e m ()) -> m Worker getAgentWorker = getAgentWorker' id pure {-# INLINE getAgentWorker #-} -getAgentWorker' :: forall a k. (Ord k, Show k) => (a -> Worker) -> (Worker -> STM a) -> String -> Bool -> AgentClient -> k -> TMap k a -> (a -> AM ()) -> AM' a -getAgentWorker' toW fromW name hasWork c key ws work = do +getAgentWorker' :: forall a k e m. (Ord k, Show k, AnyError e, MonadUnliftIO m) => (a -> Worker) -> (Worker -> STM a) -> String -> Bool -> AgentClient -> k -> TMap k a -> (a -> ExceptT e m ()) -> m a +getAgentWorker' toW fromW name hasWork c@AgentClient {agentEnv} key ws work = do atomically (getWorker >>= maybe createWorker whenExists) >>= \w -> runWorker w $> w where getWorker = TM.lookup key ws @@ -389,12 +390,12 @@ getAgentWorker' toW fromW name hasWork c key ws work = do | otherwise = pure w runWorker w = runWorkerAsync (toW w) runWork where - runWork :: AM' () - runWork = tryAgentError' (work w) >>= restartOrDelete - restartOrDelete :: Either AgentErrorType () -> AM' () + runWork :: m () + runWork = tryAllErrors' (work w) >>= restartOrDelete + restartOrDelete :: Either e () -> m () restartOrDelete e_ = do t <- liftIO getSystemTime - maxRestarts <- asks $ maxWorkerRestartsPerMin . config + let maxRestarts = maxWorkerRestartsPerMin $ config agentEnv -- worker may terminate because it was deleted from the map (getWorker returns Nothing), then it won't restart restart <- atomically $ getWorker >>= maybe (pure False) (shouldRestart e_ (toW w) t maxRestarts) when restart runWork @@ -431,7 +432,7 @@ newWorker c = do restarts <- newTVar $ RestartCount 0 0 pure Worker {workerId, doWork, action, restarts} -runWorkerAsync :: Worker -> AM' () -> AM' () +runWorkerAsync :: MonadUnliftIO m => Worker -> m () -> m () runWorkerAsync Worker {action} work = E.bracket (atomically $ takeTMVar action) -- get current action, locking to avoid race conditions @@ -665,7 +666,7 @@ getSMPProxyClient c@AgentClient {active, smpClients, smpProxiedRelays, workerSeq pure (clnt, sess) newProxiedRelay :: SMPConnectedClient -> Maybe SMP.BasicAuth -> ProxiedRelayVar -> AM (Either AgentErrorType ProxiedRelay) newProxiedRelay (SMPConnectedClient smp prs) proxyAuth rv = - tryAgentError (liftClient SMP (clientServer smp) $ connectSMPProxiedRelay smp nm destSrv proxyAuth) >>= \case + tryAllErrors (liftClient SMP (clientServer smp) $ connectSMPProxiedRelay smp nm destSrv proxyAuth) >>= \case Right sess -> do atomically $ putTMVar (sessionVar rv) (Right sess) pure $ Right sess @@ -688,7 +689,7 @@ getSMPProxyClient c@AgentClient {active, smpClients, smpProxiedRelays, workerSeq smpConnectClient :: AgentClient -> NetworkRequestMode -> SMPTransportSession -> TMap SMPServer ProxiedRelayVar -> SMPClientVar -> AM SMPConnectedClient smpConnectClient c@AgentClient {smpClients, msgQ, proxySessTs, presetDomains} nm tSess@(_, srv, _) prs v = newProtocolClient c tSess smpClients connectClient v - `catchAgentError` \e -> lift (resubscribeSMPSession c tSess) >> throwE e + `catchAllErrors` \e -> lift (resubscribeSMPSession c tSess) >> throwE e where connectClient :: SMPClientVar -> AM SMPConnectedClient connectClient v' = do @@ -866,7 +867,7 @@ newProtocolClient :: ClientVar msg -> AM (Client msg) newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient v = - tryAgentError (connectClient v) >>= \case + tryAllErrors (connectClient v) >>= \case Right client -> do logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv <> " (user " <> bshow userId <> maybe "" (" for entity " <>) entityId_ <> ")" atomically $ putTMVar (sessionVar v) (Right client) @@ -1027,7 +1028,7 @@ getMapLock locks key = TM.lookup key locks >>= maybe newLock pure withClient_ :: forall a v err msg. ProtocolServerClient v err msg => AgentClient -> NetworkRequestMode -> TransportSession msg -> (Client msg -> AM a) -> AM a withClient_ c nm tSess@(_, srv, _) action = do cl <- getProtocolServerClient c nm tSess - action cl `catchAgentError` logServerError + action cl `catchAllErrors` logServerError where logServerError :: AgentErrorType -> AM a logServerError e = do @@ -1040,7 +1041,7 @@ withProxySession c nm proxySrv_ destSess@(_, destSrv, _) entId cmdStr action = d logServer ("--> " <> proxySrv cl <> " >") c destSrv entId cmdStr case sess_ of Right sess -> do - r <- action (cl, sess) `catchAgentError` logServerError cl + r <- action (cl, sess) `catchAllErrors` logServerError cl logServer ("<-- " <> proxySrv cl <> " <") c destSrv entId "OK" pure r Left e -> logServerError cl e @@ -1117,7 +1118,7 @@ sendOrProxySMPCommand c nm userId destSrv@ProtocolServer {host = destHosts} conn unknownServer = liftIO $ maybe True (\srvs -> all (`S.notMember` knownHosts srvs) destHosts) <$> TM.lookupIO userId (smpServers c) sendViaProxy :: Maybe SMPServerWithAuth -> SMPTransportSession -> AM (Maybe SMPServer, a) sendViaProxy proxySrv_ destSess@(_, _, connId_) = do - r <- tryAgentError . withProxySession c nm proxySrv_ destSess entId ("PFWD " <> cmdStr) $ \(SMPConnectedClient smp _, proxySess@ProxiedRelay {prBasicAuth}) -> do + r <- tryAllErrors . withProxySession c nm proxySrv_ destSess entId ("PFWD " <> cmdStr) $ \(SMPConnectedClient smp _, proxySess@ProxiedRelay {prBasicAuth}) -> do r' <- liftClient SMP (clientServer smp) $ sendCmdViaProxy smp proxySess let proxySrv = protocolClientServer' smp case r' of @@ -1164,7 +1165,7 @@ sendOrProxySMPCommand c nm userId destSrv@ProtocolServer {host = destHosts} conn | otherwise -> throwE e sendDirectly tSess = withLogClient_ c nm tSess (unEntityId entId) ("SEND " <> cmdStr) $ \(SMPConnectedClient smp _) -> do - tryAgentError (liftClient SMP (clientServer smp) $ sendCmdDirectly smp) >>= \case + tryAllErrors (liftClient SMP (clientServer smp) $ sendCmdDirectly smp) >>= \case Right r -> r <$ atomically (incSMPServerStat c userId destSrv sentDirect) Left e -> throwE e @@ -1562,7 +1563,7 @@ sendTSessionBatches statCmd toRQ action c nm qs = in M.alter (Just . maybe [q] (q <|)) tSess m sendClientBatch :: (SMPTransportSession, NonEmpty q) -> AM' (BatchResponses q AgentErrorType r) sendClientBatch (tSess@(_, srv, _), qs') = - tryAgentError' (getSMPServerClient c nm tSess) >>= \case + tryAllErrors' (getSMPServerClient c nm tSess) >>= \case Left e -> pure $ L.map (,Left e) qs' Right (SMPConnectedClient smp _) -> liftIO $ do logServer' "-->" c srv (bshow (length qs') <> " queues") statCmd @@ -1867,7 +1868,7 @@ withNtfBatch :: AM' (NonEmpty (Either AgentErrorType r)) withNtfBatch cmdStr action c NtfToken {ntfServer, ntfPrivKey} subs = do let tSess = (0, ntfServer, Nothing) - tryAgentError' (getNtfServerClient c NRMBackground tSess) >>= \case + tryAllErrors' (getNtfServerClient c NRMBackground tSess) >>= \case Left e -> pure $ L.map (\_ -> Left e) subs Right ntf -> liftIO $ do logServer' "-->" c ntfServer (bshow (length subs) <> " subscriptions") cmdStr @@ -1968,8 +1969,12 @@ waitForWork = void . atomically . readTMVar {-# INLINE waitForWork #-} withWork :: AgentClient -> TMVar () -> (DB.Connection -> IO (Either StoreError (Maybe a))) -> (a -> AM ()) -> AM () -withWork c doWork getWork action = - withStore' c getWork >>= \case +withWork c doWork = withWork_ c doWork . withStore' c +{-# INLINE withWork #-} + +withWork_ :: MonadIO m => AgentClient -> TMVar () -> ExceptT e m (Either StoreError (Maybe a)) -> (a -> ExceptT e m ()) -> ExceptT e m () +withWork_ c doWork getWork action = + getWork >>= \case Right (Just r) -> action r Right Nothing -> noWork -- worker is stopped here (noWork) because the next iteration is likely to produce the same result @@ -1979,9 +1984,9 @@ withWork c doWork getWork action = noWork = liftIO $ noWorkToDo doWork notifyErr err e = atomically $ writeTBQueue (subQ c) ("", "", AEvt SAEConn $ ERR $ err $ show e) -withWorkItems :: AgentClient -> TMVar () -> (DB.Connection -> IO (Either StoreError [Either StoreError a])) -> (NonEmpty a -> AM ()) -> AM () +withWorkItems :: MonadIO m => AgentClient -> TMVar () -> ExceptT e m (Either StoreError [Either StoreError a]) -> (NonEmpty a -> ExceptT e m ()) -> ExceptT e m () withWorkItems c doWork getWork action = do - withStore' c getWork >>= \case + getWork >>= \case Right [] -> noWork Right rs -> do let (errs, items) = partitionEithers rs diff --git a/src/Simplex/Messaging/Agent/Env/SQLite.hs b/src/Simplex/Messaging/Agent/Env/SQLite.hs index 0c10d8cd4..393f07a93 100644 --- a/src/Simplex/Messaging/Agent/Env/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Env/SQLite.hs @@ -27,11 +27,6 @@ module Simplex.Messaging.Agent.Env.SQLite serverHosts, defaultAgentConfig, defaultReconnectInterval, - tryAgentError, - tryAgentError', - catchAgentError, - catchAgentError', - agentFinally, Env (..), newSMPAgentEnv, createAgentStore, @@ -45,7 +40,6 @@ module Simplex.Messaging.Agent.Env.SQLite where import Control.Concurrent (ThreadId) -import Control.Exception (BlockedIndefinitelyOnSTM (..), SomeException, fromException) import Control.Monad.Except import Control.Monad.IO.Unlift import Control.Monad.Reader @@ -83,7 +77,6 @@ import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (SMPVersion) import Simplex.Messaging.Transport.Client (TransportHost) -import Simplex.Messaging.Util (allFinally, catchAllErrors, catchAllErrors', tryAllErrors, tryAllErrors') import System.Mem.Weak (Weak) import System.Random (StdGen, newStdGen) import UnliftIO.STM @@ -312,33 +305,6 @@ newXFTPAgent = do xftpDelWorkers <- TM.emptyIO pure XFTPAgent {xftpWorkDir, xftpRcvWorkers, xftpSndWorkers, xftpDelWorkers} -tryAgentError :: AM a -> AM (Either AgentErrorType a) -tryAgentError = tryAllErrors mkInternal -{-# INLINE tryAgentError #-} - --- unlike runExceptT, this ensures we catch IO exceptions as well -tryAgentError' :: AM a -> AM' (Either AgentErrorType a) -tryAgentError' = tryAllErrors' mkInternal -{-# INLINE tryAgentError' #-} - -catchAgentError :: AM a -> (AgentErrorType -> AM a) -> AM a -catchAgentError = catchAllErrors mkInternal -{-# INLINE catchAgentError #-} - -catchAgentError' :: AM a -> (AgentErrorType -> AM' a) -> AM' a -catchAgentError' = catchAllErrors' mkInternal -{-# INLINE catchAgentError' #-} - -agentFinally :: AM a -> AM b -> AM a -agentFinally = allFinally mkInternal -{-# INLINE agentFinally #-} - -mkInternal :: SomeException -> AgentErrorType -mkInternal e = case fromException e of - Just BlockedIndefinitelyOnSTM -> CRITICAL True "Thread blocked indefinitely in STM transaction" - _ -> INTERNAL $ show e -{-# INLINE mkInternal #-} - data Worker = Worker { workerId :: Int, doWork :: TMVar (), diff --git a/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs b/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs index 7546c03dd..85fa45c49 100644 --- a/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs +++ b/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs @@ -52,7 +52,7 @@ import Simplex.Messaging.Notifications.Protocol import Simplex.Messaging.Notifications.Types import Simplex.Messaging.Protocol (NtfServer, sameSrvAddr) import qualified Simplex.Messaging.Protocol as SMP -import Simplex.Messaging.Util (diffToMicroseconds, threadDelay', tshow, whenM) +import Simplex.Messaging.Util (catchAllErrors, diffToMicroseconds, threadDelay', tryAllErrors, tshow, whenM) import System.Random (randomR) import UnliftIO import UnliftIO.Concurrent (forkIO) @@ -217,7 +217,7 @@ runNtfWorker c srv Worker {doWork} = runNtfOperation :: AM () runNtfOperation = do ntfBatchSize <- asks $ ntfBatchSize . config - withWorkItems c doWork (\db -> getNextNtfSubNTFActions db srv ntfBatchSize) $ \nextSubs -> do + withWorkItems c doWork (withStore' c $ \db -> getNextNtfSubNTFActions db srv ntfBatchSize) $ \nextSubs -> do logInfo $ "runNtfWorker - length nextSubs = " <> tshow (length nextSubs) currTs <- liftIO getCurrentTime let (creates, checks, deletes, rotates) = splitActions currTs nextSubs @@ -357,7 +357,7 @@ runNtfWorker c srv Worker {doWork} = runCatching :: (NtfSubscription -> AM (Maybe NtfSubscription)) -> NtfSubscription -> AM' (Maybe NtfSubscription) runCatching action sub@NtfSubscription {connId} = fromRight Nothing - <$> runExceptT (action sub `catchAgentError` \e -> workerInternalError c connId (show e) $> Nothing) + <$> runExceptT (action sub `catchAllErrors` \e -> workerInternalError c connId (show e) $> Nothing) -- deleteNtfSub is only used in NSADelete and NSARotate, so also deprecated deleteNtfSub :: NtfSubscription -> AM () -> AM (Maybe NtfSubscription) deleteNtfSub sub@NtfSubscription {userId, ntfSubId} continue = case ntfSubId of @@ -365,7 +365,7 @@ runNtfWorker c srv Worker {doWork} = lift getNtfToken >>= \case Just tkn@NtfToken {ntfServer} -> do atomically $ incNtfServerStat c userId ntfServer ntfDelAttempts - tryAgentError (agentNtfDeleteSubscription c nSubId tkn) >>= \case + tryAllErrors (agentNtfDeleteSubscription c nSubId tkn) >>= \case Right _ -> do atomically $ incNtfServerStat c userId ntfServer ntfDeleted continue' @@ -385,7 +385,7 @@ runNtfSMPWorker c srv Worker {doWork} = forever $ do runNtfSMPOperation :: AM () runNtfSMPOperation = do ntfBatchSize <- asks $ ntfBatchSize . config - withWorkItems c doWork (\db -> getNextNtfSubSMPActions db srv ntfBatchSize) $ \nextSubs -> do + withWorkItems c doWork (withStore' c $ \db -> getNextNtfSubSMPActions db srv ntfBatchSize) $ \nextSubs -> do logInfo $ "runNtfSMPWorker - length nextSubs = " <> tshow (length nextSubs) let (creates, deletes) = splitActions nextSubs retrySubActions c creates createNotifierKeys @@ -567,7 +567,7 @@ runNtfTknDelWorker c srv Worker {doWork} = withRetryInterval ri $ \_ loop -> do liftIO $ waitWhileSuspended c liftIO $ waitForUserNetwork c - processTknToDelete nextTknToDelete `catchAgentError` retryTmpError loop nextTknToDelete + processTknToDelete nextTknToDelete `catchAllErrors` retryTmpError loop nextTknToDelete retryTmpError :: AM () -> NtfTokenToDelete -> AgentErrorType -> AM () retryTmpError loop (tknDbId, _, _) e = do logError $ "ntf tkn del error: " <> tshow e diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index f1d5f2ec8..d4d302df7 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -173,6 +173,7 @@ module Simplex.Messaging.Agent.Protocol where import Control.Applicative (optional, (<|>)) +import Control.Exception (BlockedIndefinitelyOnSTM (..), fromException) import Data.Aeson (FromJSON (..), ToJSON (..), Value (..), (.:), (.:?)) import qualified Data.Aeson as J' import qualified Data.Aeson.Encoding as JE @@ -1866,6 +1867,12 @@ data AgentErrorType INACTIVE deriving (Eq, Show, Exception) +instance AnyError AgentErrorType where + fromSomeException e = case fromException e of + Just BlockedIndefinitelyOnSTM -> CRITICAL True "Thread blocked indefinitely in STM transaction" + _ -> INTERNAL $ show e + {-# INLINE fromSomeException #-} + -- | SMP agent protocol command or response error. data CommandErrorType = -- | command is prohibited in this context diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index 71e35e7ec..8dee07037 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -29,6 +29,7 @@ import Data.Time (UTCTime) import Data.Type.Equality import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Agent.RetryInterval (RI2State) +import Simplex.Messaging.Agent.Store.Entity import Simplex.Messaging.Agent.Store.Common import Simplex.Messaging.Agent.Store.Interface (createDBStore) import Simplex.Messaging.Agent.Store.Migrations.App (appMigrations) @@ -52,7 +53,7 @@ import Simplex.Messaging.Protocol VersionSMPC, ) import qualified Simplex.Messaging.Protocol as SMP -import Simplex.Messaging.Agent.Store.Entity +import Simplex.Messaging.Util (AnyError (..), bshow) createStore :: DBOpts -> MigrationConfirmation -> IO (Either MigrationError DBStore) createStore dbOpts = createDBStore dbOpts appMigrations @@ -696,3 +697,6 @@ data StoreError | -- | Servers stats not found. SEServersStatsNotFound deriving (Eq, Show, Exception) + +instance AnyError StoreError where + fromSomeException = SEInternal . bshow diff --git a/src/Simplex/Messaging/Agent/Store/AgentStore.hs b/src/Simplex/Messaging/Agent/Store/AgentStore.hs index e10f48c8f..59175e942 100644 --- a/src/Simplex/Messaging/Agent/Store/AgentStore.hs +++ b/src/Simplex/Messaging/Agent/Store/AgentStore.hs @@ -975,13 +975,10 @@ getWorkItems itemName getIds getItem markFailed = runExceptT $ handleWrkErr itemName "getIds" getIds >>= mapM (tryE . tryGetItem itemName getItem markFailed) tryGetItem :: Show i => ByteString -> (i -> IO (Either StoreError a)) -> (i -> IO ()) -> i -> ExceptT StoreError IO a -tryGetItem itemName getItem markFailed itemId = ExceptT (getItem itemId) `catchStoreError` \e -> mark >> throwE e +tryGetItem itemName getItem markFailed itemId = ExceptT (getItem itemId) `catchAllErrors` \e -> mark >> throwE e where mark = handleWrkErr itemName ("markFailed ID " <> bshow itemId) $ markFailed itemId -catchStoreError :: ExceptT StoreError IO a -> (StoreError -> ExceptT StoreError IO a) -> ExceptT StoreError IO a -catchStoreError = catchAllErrors (SEInternal . bshow) - -- Errors caught by this function will suspend worker as if there is no more work, handleWrkErr :: ByteString -> ByteString -> IO a -> ExceptT StoreError IO a handleWrkErr itemName opName action = ExceptT $ first mkError <$> E.try action diff --git a/src/Simplex/Messaging/Util.hs b/src/Simplex/Messaging/Util.hs index 1a7dedef5..f93119b3c 100644 --- a/src/Simplex/Messaging/Util.hs +++ b/src/Simplex/Messaging/Util.hs @@ -171,28 +171,30 @@ catchAll_ :: IO a -> IO a -> IO a catchAll_ a = catchAll a . const {-# INLINE catchAll_ #-} -tryAllErrors :: MonadUnliftIO m => (E.SomeException -> e) -> ExceptT e m a -> ExceptT e m (Either e a) -tryAllErrors err action = ExceptT $ Right <$> runExceptT action `UE.catch` (pure . Left . err) +class Show e => AnyError e where fromSomeException :: E.SomeException -> e + +tryAllErrors :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> ExceptT e m (Either e a) +tryAllErrors action = ExceptT $ Right <$> runExceptT action `UE.catch` (pure . Left . fromSomeException) {-# INLINE tryAllErrors #-} -tryAllErrors' :: MonadUnliftIO m => (E.SomeException -> e) -> ExceptT e m a -> m (Either e a) -tryAllErrors' err action = runExceptT action `UE.catch` (pure . Left . err) +tryAllErrors' :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> m (Either e a) +tryAllErrors' action = runExceptT action `UE.catch` (pure . Left . fromSomeException) {-# INLINE tryAllErrors' #-} -catchAllErrors :: MonadUnliftIO m => (E.SomeException -> e) -> ExceptT e m a -> (e -> ExceptT e m a) -> ExceptT e m a -catchAllErrors err action handler = tryAllErrors err action >>= either handler pure +catchAllErrors :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> (e -> ExceptT e m a) -> ExceptT e m a +catchAllErrors action handler = tryAllErrors action >>= either handler pure {-# INLINE catchAllErrors #-} -catchAllErrors' :: MonadUnliftIO m => (E.SomeException -> e) -> ExceptT e m a -> (e -> m a) -> m a -catchAllErrors' err action handler = tryAllErrors' err action >>= either handler pure +catchAllErrors' :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> (e -> m a) -> m a +catchAllErrors' action handler = tryAllErrors' action >>= either handler pure {-# INLINE catchAllErrors' #-} -catchThrow :: MonadUnliftIO m => ExceptT e m a -> (E.SomeException -> e) -> ExceptT e m a -catchThrow action err = catchAllErrors err action throwE +catchThrow :: MonadUnliftIO m => ExceptT e m a -> (SomeException -> e) -> ExceptT e m a +action `catchThrow` err = ExceptT $ runExceptT action `UE.catch` (pure . Left . err) {-# INLINE catchThrow #-} -allFinally :: MonadUnliftIO m => (E.SomeException -> e) -> ExceptT e m a -> ExceptT e m b -> ExceptT e m a -allFinally err action final = tryAllErrors err action >>= \r -> final >> except r +allFinally :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> ExceptT e m b -> ExceptT e m a +allFinally action final = tryAllErrors action >>= \r -> final >> except r {-# INLINE allFinally #-} eitherToMaybe :: Either a b -> Maybe b diff --git a/src/Simplex/RemoteControl/Client.hs b/src/Simplex/RemoteControl/Client.hs index bde72fb23..a9970c273 100644 --- a/src/Simplex/RemoteControl/Client.hs +++ b/src/Simplex/RemoteControl/Client.hs @@ -306,14 +306,8 @@ connectRCCtrl_ drg pairing'@RCCtrlPairing {caKey, caCert} inv@RCInvitation {ca, atomically $ takeTMVar endSession logDebug "Session ended" -catchRCError :: ExceptT RCErrorType IO a -> (RCErrorType -> ExceptT RCErrorType IO a) -> ExceptT RCErrorType IO a -catchRCError = catchAllErrors $ \e -> case fromException e of - Just (TLS.Terminated _ _ (TLS.Error_Protocol _ TLS.UnknownCa)) -> RCEIdentity - _ -> RCEException $ show e -{-# INLINE catchRCError #-} - putRCError :: ExceptT RCErrorType IO a -> TMVar (Either RCErrorType b) -> ExceptT RCErrorType IO a -a `putRCError` r = a `catchRCError` \e -> atomically (tryPutTMVar r $ Left e) >> throwE e +a `putRCError` r = a `catchAllErrors` \e -> atomically (tryPutTMVar r $ Left e) >> throwE e sendRCPacket :: Encoding a => TLS p -> a -> ExceptT RCErrorType IO () sendRCPacket tls pkt = do @@ -395,7 +389,7 @@ discoverRCCtrl subscribers pairings = pure r where loop :: ExceptT RCErrorType IO a -> ExceptT RCErrorType IO a - loop action = action `catchRCError` \e -> logError (tshow e) >> loop action + loop action = action `catchAllErrors` \e -> logError (tshow e) >> loop action findRCCtrlPairing :: NonEmpty RCCtrlPairing -> RCEncInvitation -> ExceptT RCErrorType IO (RCCtrlPairing, RCVerifiedInvitation) findRCCtrlPairing pairings RCEncInvitation {dhPubKey, nonce, encInvitation} = do diff --git a/src/Simplex/RemoteControl/Types.hs b/src/Simplex/RemoteControl/Types.hs index 7b8638e67..93f0c92c7 100644 --- a/src/Simplex/RemoteControl/Types.hs +++ b/src/Simplex/RemoteControl/Types.hs @@ -19,6 +19,7 @@ import qualified Data.Text as T import Data.Text.Encoding (encodeUtf8) import Data.Word (Word16) import qualified Data.X509 as X +import qualified Network.TLS as TLS import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.SNTRUP761.Bindings import Simplex.Messaging.Encoding @@ -26,7 +27,7 @@ import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (defaultJSON, dropPrefix, sumTypeJSON) import Simplex.Messaging.Transport (TLS, TSbChainKeys, TransportPeer (..)) import Simplex.Messaging.Transport.Client (TransportHost) -import Simplex.Messaging.Util (safeDecodeUtf8) +import Simplex.Messaging.Util (AnyError (..), safeDecodeUtf8) import Simplex.Messaging.Version (VersionRange, VersionScope, mkVersionRange) import Simplex.Messaging.Version.Internal import UnliftIO @@ -50,6 +51,12 @@ data RCErrorType | RCESyntax {syntaxErr :: String} deriving (Eq, Show, Exception) +instance AnyError RCErrorType where + fromSomeException e = case fromException e of + Just (TLS.Terminated _ _ (TLS.Error_Protocol _ TLS.UnknownCa)) -> RCEIdentity + _ -> RCEException $ show e + {-# INLINE fromSomeException #-} + instance StrEncoding RCErrorType where strEncode = \case RCEInternal err -> "INTERNAL" <> text err diff --git a/tests/CoreTests/UtilTests.hs b/tests/CoreTests/UtilTests.hs index 4159f25e1..946902358 100644 --- a/tests/CoreTests/UtilTests.hs +++ b/tests/CoreTests/UtilTests.hs @@ -45,56 +45,32 @@ utilTests = do runExceptT (throwTestException `catchError` handleCatch) `shouldThrow` (\(e :: IOError) -> show e == "user error (error)") describe "tryAllErrors" $ do it "should return ExceptT error as Left" $ - runExceptT (tryAllErrors testErr throwTestError) `shouldReturn` Right (Left (TestError "error")) + runExceptT (tryAllErrors throwTestError) `shouldReturn` Right (Left (TestError "error")) it "should return SomeException as Left" $ - runExceptT (tryAllErrors testErr throwTestException) `shouldReturn` Right (Left (TestException "user error (error)")) + runExceptT (tryAllErrors throwTestException) `shouldReturn` Right (Left (TestException "user error (error)")) it "should return no errors as Right" $ - runExceptT (tryAllErrors testErr noErrors) `shouldReturn` Right (Right "no errors") - describe "tryAllErrors specialized as tryTestError" $ do - let tryTestError = tryAllErrors testErr - it "should return ExceptT error as Left" $ - runExceptT (tryTestError throwTestError) `shouldReturn` Right (Left (TestError "error")) - it "should return SomeException as Left" $ - runExceptT (tryTestError throwTestException) `shouldReturn` Right (Left (TestException "user error (error)")) - it "should return no errors as Right" $ - runExceptT (tryTestError noErrors) `shouldReturn` Right (Right "no errors") + runExceptT (tryAllErrors noErrors) `shouldReturn` Right (Right "no errors") describe "catchAllErrors" $ do it "should catch ExceptT error" $ - runExceptT (catchAllErrors testErr throwTestError handleCatch) `shouldReturn` Right "caught TestError \"error\"" + runExceptT (throwTestError `catchAllErrors` handleCatch) `shouldReturn` Right "caught TestError \"error\"" it "should catch SomeException" $ - runExceptT (catchAllErrors testErr throwTestException handleCatch) `shouldReturn` Right "caught TestException \"user error (error)\"" + runExceptT (throwTestException `catchAllErrors` handleCatch) `shouldReturn` Right "caught TestException \"user error (error)\"" it "should not throw if there are no errors" $ - runExceptT (catchAllErrors testErr noErrors throwError) `shouldReturn` Right "no errors" - describe "catchAllErrors specialized as catchTestError" $ do - let catchTestError = catchAllErrors testErr - it "should catch ExceptT error" $ - runExceptT (throwTestError `catchTestError` handleCatch) `shouldReturn` Right "caught TestError \"error\"" - it "should catch SomeException" $ - runExceptT (throwTestException `catchTestError` handleCatch) `shouldReturn` Right "caught TestException \"user error (error)\"" - it "should not throw if there are no errors" $ - runExceptT (noErrors `catchTestError` throwError) `shouldReturn` Right "no errors" + runExceptT (noErrors `catchAllErrors` throwError) `shouldReturn` Right "no errors" describe "catchThrow" $ do it "should re-throw ExceptT error" $ - runExceptT (throwTestError `catchThrow` testErr) `shouldReturn` Left (TestError "error") + runExceptT (throwTestError `catchThrow` fromSomeException) `shouldReturn` Left (TestError "error") it "should catch SomeException and throw as ExceptT error" $ - runExceptT (throwTestException `catchThrow` testErr) `shouldReturn` Left (TestException "user error (error)") + runExceptT (throwTestException `catchThrow` fromSomeException) `shouldReturn` Left (TestException "user error (error)") it "should not throw if there are no exceptions" $ - runExceptT (noErrors `catchThrow` testErr) `shouldReturn` Right "no errors" + runExceptT (noErrors `catchThrow` fromSomeException) `shouldReturn` Right "no errors" describe "allFinally should run final action" $ do it "then throw ExceptT error" $ withFinal $ \final -> - runExceptT (allFinally testErr throwTestError final) `shouldReturn` Left (TestError "error") + runExceptT (throwTestError `allFinally` final) `shouldReturn` Left (TestError "error") it "then throw SomeException as ExceptT error" $ withFinal $ \final -> - runExceptT (allFinally testErr throwTestException final) `shouldReturn` Left (TestException "user error (error)") + runExceptT (throwTestException `allFinally` final) `shouldReturn` Left (TestException "user error (error)") it "and should not throw if there are no exceptions" $ withFinal $ \final -> - runExceptT (allFinally testErr noErrors final) `shouldReturn` Right "no errors" - describe "allFinally specialized as testFinally should run final action" $ do - let testFinally = allFinally testErr - it "then throw ExceptT error" $ withFinal $ \final -> - runExceptT (throwTestError `testFinally` final) `shouldReturn` Left (TestError "error") - it "then throw SomeException as ExceptT error" $ withFinal $ \final -> - runExceptT (throwTestException `testFinally` final) `shouldReturn` Left (TestException "user error (error)") - it "and should not throw if there are no exceptions" $ withFinal $ \final -> - runExceptT (noErrors `testFinally` final) `shouldReturn` Right "no errors" + runExceptT (noErrors `allFinally` final) `shouldReturn` Right "no errors" where throwTestError :: ExceptT TestError IO String throwTestError = throwError $ TestError "error" @@ -102,8 +78,6 @@ utilTests = do throwTestException = liftIO $ throwIO $ userError "error" noErrors :: ExceptT TestError IO String noErrors = pure "no errors" - testErr :: SomeException -> TestError - testErr = TestException . show handleCatch :: TestError -> ExceptT TestError IO String handleCatch e = pure $ "caught " <> show e handleException :: SomeException -> ExceptT TestError IO String @@ -119,3 +93,6 @@ data TestError = TestError String | TestException String deriving (Eq, Show) instance Exception TestError + +instance AnyError TestError where + fromSomeException = TestException . show