Files
simplexmq/src/Simplex/Messaging/Util.hs
Evgeny @ SimpleX Chat e5dbe97e1d spec references in code
2026-03-11 09:06:05 +00:00

432 lines
14 KiB
Haskell

{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MonadComprehensions #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
-- spec: spec/modules/Simplex/Messaging/Util.md
module Simplex.Messaging.Util
( AnyError (..),
(<$?>),
($>>),
(<$$),
(<$$>),
raceAny_,
bshow,
tshow,
maybeWord,
liftError,
liftError',
liftEitherWith,
ifM,
whenM,
unlessM,
anyM,
($>>=),
mapME,
bindRight,
forME,
mapAccumLM,
packZipWith,
tryWriteTBQueue,
catchAll,
catchAll_,
tryAllErrors,
tryAllErrors',
catchAllErrors,
catchAllErrors',
catchThrow,
allFinally,
isOwnException,
isAsyncCancellation,
catchOwn',
catchOwn,
tryAllOwnErrors,
tryAllOwnErrors',
catchAllOwnErrors,
catchAllOwnErrors',
eitherToMaybe,
listToEither,
firstRow,
maybeFirstRow,
maybeFirstRow',
firstRow',
groupOn,
groupOn',
eqOn,
groupAllOn,
toChunks,
safeDecodeUtf8,
timeoutThrow,
threadDelay',
diffToMicroseconds,
diffToMilliseconds,
labelMyThread,
atomicModifyIORef'_,
encodeJSON,
decodeJSON,
traverseWithKey_,
) where
import Control.Exception (AllocationLimitExceeded (..), AsyncException (..))
import qualified Control.Exception as E
import Control.Monad
import Control.Monad.Except
import Control.Monad.IO.Unlift
import Control.Monad.Trans.Except
import Control.Monad.Trans.State.Strict (StateT (..))
import Data.Aeson (FromJSON, ToJSON)
import qualified Data.Aeson as J
import Data.Bifunctor (first, second)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.ByteString.Internal (toForeignPtr, unsafeCreate)
import qualified Data.ByteString.Lazy.Char8 as LB
import Data.IORef
import Data.Int (Int64)
import Data.List (groupBy, sortOn)
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as L
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.Maybe (listToMaybe)
import Data.Text (Text)
import qualified Data.Text as T
import Data.Text.Encoding (decodeUtf8With, encodeUtf8)
import Data.Time (NominalDiffTime)
import Data.Tuple (swap)
import Data.Word (Word8)
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Storable (peekByteOff, pokeByteOff)
import GHC.Conc (labelThread, myThreadId, threadDelay)
import UnliftIO hiding (atomicModifyIORef')
import qualified UnliftIO.Exception as UE
raceAny_ :: MonadUnliftIO m => [m a] -> m ()
raceAny_ = r []
where
r as (m : ms) = withAsync m $ \a -> r (a : as) ms
r as [] = void $ waitAnyCancel as
infixl 4 <$$>, <$$, <$?>
(<$$>) :: (Functor f, Functor g) => (a -> b) -> f (g a) -> f (g b)
(<$$>) = fmap . fmap
{-# INLINE (<$$>) #-}
(<$$) :: (Functor f, Functor g) => b -> f (g a) -> f (g b)
(<$$) = fmap . fmap . const
{-# INLINE (<$$) #-}
(<$?>) :: MonadFail m => (a -> Either String b) -> m a -> m b
f <$?> m = either fail pure . f =<< m
{-# INLINE (<$?>) #-}
bshow :: Show a => a -> ByteString
bshow = B.pack . show
{-# INLINE bshow #-}
tshow :: Show a => a -> Text
tshow = T.pack . show
{-# INLINE tshow #-}
maybeWord :: (a -> ByteString) -> Maybe a -> ByteString
maybeWord f = maybe "" $ B.cons ' ' . f
{-# INLINE maybeWord #-}
liftError :: MonadIO m => (e -> e') -> ExceptT e IO a -> ExceptT e' m a
liftError f = liftError' f . runExceptT
{-# INLINE liftError #-}
liftError' :: MonadIO m => (e -> e') -> IO (Either e a) -> ExceptT e' m a
liftError' f = ExceptT . fmap (first f) . liftIO
{-# INLINE liftError' #-}
liftEitherWith :: MonadIO m => (e -> e') -> Either e a -> ExceptT e' m a
liftEitherWith f = liftEither . first f
{-# INLINE liftEitherWith #-}
ifM :: Monad m => m Bool -> m a -> m a -> m a
ifM ba t f = ba >>= \b -> if b then t else f
{-# INLINE ifM #-}
whenM :: Monad m => m Bool -> m () -> m ()
whenM b a = ifM b a $ pure ()
{-# INLINE whenM #-}
unlessM :: Monad m => m Bool -> m () -> m ()
unlessM b = ifM b $ pure ()
{-# INLINE unlessM #-}
anyM :: Monad m => [m Bool] -> m Bool
anyM = foldM (\r a -> if r then pure r else (r ||) <$!> a) False
{-# INLINE anyM #-}
infixl 1 $>>, $>>=
($>>=) :: (Monad m, Monad f, Traversable f) => m (f a) -> (a -> m (f b)) -> m (f b)
f $>>= g = f >>= fmap join . mapM g
{-# INLINE ($>>=) #-}
($>>) :: (Monad m, Monad f, Traversable f) => m (f a) -> m (f b) -> m (f b)
f $>> g = f $>>= \_ -> g
{-# INLINE ($>>) #-}
mapME :: (Monad m, Traversable t) => (a -> m (Either e b)) -> t (Either e a) -> m (t (Either e b))
mapME f = mapM (bindRight f)
{-# INLINE mapME #-}
bindRight :: Monad m => (a -> m (Either e b)) -> Either e a -> m (Either e b)
bindRight = either (pure . Left)
{-# INLINE bindRight #-}
forME :: (Monad m, Traversable t) => t (Either e a) -> (a -> m (Either e b)) -> m (t (Either e b))
forME = flip mapME
{-# INLINE forME #-}
-- | Monadic version of mapAccumL
-- Copied from ghc-9.6.3 package: https://hackage.haskell.org/package/ghc-9.12.1/docs/GHC-Utils-Monad.html#v:mapAccumLM
-- for backward compatibility with 8.10.7.
mapAccumLM ::
(Monad m, Traversable t) =>
-- | combining function
(acc -> x -> m (acc, y)) ->
-- | initial state
acc ->
-- | inputs
t x ->
-- | final state, outputs
m (acc, t y)
{-# INLINE [1] mapAccumLM #-}
-- INLINE pragma. mapAccumLM is called in inner loops. Like 'map',
-- we inline it so that we can take advantage of knowing 'f'.
-- This makes a few percent difference (in compiler allocations)
-- when compiling perf/compiler/T9675
mapAccumLM f s = fmap swap . flip runStateT s . traverse f'
where
f' = StateT . (fmap . fmap) swap . flip f
{-# RULES "mapAccumLM/List" mapAccumLM = mapAccumLM_List #-}
{-# RULES "mapAccumLM/NonEmpty" mapAccumLM = mapAccumLM_NonEmpty #-}
mapAccumLM_List ::
Monad m =>
(acc -> x -> m (acc, y)) ->
acc ->
[x] ->
m (acc, [y])
{-# INLINE mapAccumLM_List #-}
mapAccumLM_List f = go
where
go s (x : xs) = do
(s1, x') <- f s x
(s2, xs') <- go s1 xs
return (s2, x' : xs')
go s [] = return (s, [])
mapAccumLM_NonEmpty ::
Monad m =>
(acc -> x -> m (acc, y)) ->
acc ->
NonEmpty x ->
m (acc, NonEmpty y)
{-# INLINE mapAccumLM_NonEmpty #-}
mapAccumLM_NonEmpty f s (x :| xs) =
[(s2, x' :| xs') | (s1, x') <- f s x, (s2, xs') <- mapAccumLM_List f s1 xs]
-- | Optimized from bytestring package for GHC 8.10.7 compatibility
packZipWith :: (Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> ByteString
packZipWith f s1 s2 =
unsafeCreate len $ \r ->
withForeignPtr fp1 $ \p1 ->
withForeignPtr fp2 $ \p2 -> zipWith_ p1 p2 r
where
zipWith_ p1 p2 r = go 0
where
go :: Int -> IO ()
go !n
| n >= len = pure ()
| otherwise = do
x <- peekByteOff p1 (off1 + n)
y <- peekByteOff p2 (off2 + n)
pokeByteOff r n (f x y)
go (n + 1)
(fp1, off1, l1) = toForeignPtr s1
(fp2, off2, l2) = toForeignPtr s2
len = min l1 l2
tryWriteTBQueue :: TBQueue a -> a -> STM Bool
tryWriteTBQueue q a = do
full <- isFullTBQueue q
unless full $ writeTBQueue q a
pure $ not full
{-# INLINE tryWriteTBQueue #-}
catchAll :: IO a -> (E.SomeException -> IO a) -> IO a
catchAll = E.catch
{-# INLINE catchAll #-}
catchAll_ :: IO a -> IO a -> IO a
catchAll_ a = catchAll a . const
{-# INLINE catchAll_ #-}
class Show e => AnyError e where fromSomeException :: E.SomeException -> e
tryAllErrors :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> ExceptT e m (Either e a)
tryAllErrors action = ExceptT $ Right <$> runExceptT action `UE.catch` (pure . Left . fromSomeException)
{-# INLINE tryAllErrors #-}
tryAllErrors' :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> m (Either e a)
tryAllErrors' action = runExceptT action `UE.catch` (pure . Left . fromSomeException)
{-# INLINE tryAllErrors' #-}
catchAllErrors :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> (e -> ExceptT e m a) -> ExceptT e m a
catchAllErrors action handler = tryAllErrors action >>= either handler pure
{-# INLINE catchAllErrors #-}
catchAllErrors' :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> (e -> m a) -> m a
catchAllErrors' action handler = tryAllErrors' action >>= either handler pure
{-# INLINE catchAllErrors' #-}
catchThrow :: MonadUnliftIO m => ExceptT e m a -> (SomeException -> e) -> ExceptT e m a
action `catchThrow` err = ExceptT $ runExceptT action `UE.catch` (pure . Left . err)
{-# INLINE catchThrow #-}
allFinally :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> ExceptT e m b -> ExceptT e m a
allFinally action final = tryAllErrors action >>= \r -> final >> except r
{-# INLINE allFinally #-}
-- spec: spec/modules/Simplex/Messaging/Util.md#isOwnException
isOwnException :: E.SomeException -> Bool
isOwnException e = case E.fromException e of
Just StackOverflow -> True
Just HeapOverflow -> True
_ -> case E.fromException e of
Just AllocationLimitExceeded -> True
_ -> False
{-# INLINE isOwnException #-}
-- spec: spec/modules/Simplex/Messaging/Util.md#isAsyncCancellation
isAsyncCancellation :: E.SomeException -> Bool
isAsyncCancellation e = case E.fromException e of
Just (_ :: SomeAsyncException) -> not $ isOwnException e
Nothing -> False
{-# INLINE isAsyncCancellation #-}
-- spec: spec/modules/Simplex/Messaging/Util.md#catchOwn
-- Catches all exceptions EXCEPT async cancellations (name is misleading)
catchOwn' :: IO a -> (E.SomeException -> IO a) -> IO a
catchOwn' action handleInternal = action `E.catch` \e -> if isAsyncCancellation e then E.throwIO e else handleInternal e
{-# INLINE catchOwn' #-}
-- spec: spec/modules/Simplex/Messaging/Util.md#catchOwn
catchOwn :: MonadUnliftIO m => m a -> (E.SomeException -> m a) -> m a
catchOwn action handleInternal =
withRunInIO $ \run ->
run action `E.catch` \e -> if isAsyncCancellation e then E.throwIO e else run (handleInternal e)
{-# INLINE catchOwn #-}
tryAllOwnErrors :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> ExceptT e m (Either e a)
tryAllOwnErrors action = ExceptT $ Right <$> runExceptT action `catchOwn` (pure . Left . fromSomeException)
{-# INLINE tryAllOwnErrors #-}
tryAllOwnErrors' :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> m (Either e a)
tryAllOwnErrors' action = runExceptT action `catchOwn` (pure . Left . fromSomeException)
{-# INLINE tryAllOwnErrors' #-}
catchAllOwnErrors :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> (e -> ExceptT e m a) -> ExceptT e m a
catchAllOwnErrors action handler = tryAllOwnErrors action >>= either handler pure
{-# INLINE catchAllOwnErrors #-}
catchAllOwnErrors' :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> (e -> m a) -> m a
catchAllOwnErrors' action handler = tryAllOwnErrors' action >>= either handler pure
{-# INLINE catchAllOwnErrors' #-}
eitherToMaybe :: Either a b -> Maybe b
eitherToMaybe = either (const Nothing) Just
{-# INLINE eitherToMaybe #-}
listToEither :: e -> [a] -> Either e a
listToEither _ (x : _) = Right x
listToEither e _ = Left e
firstRow :: (a -> b) -> e -> IO [a] -> IO (Either e b)
firstRow f e a = second f . listToEither e <$> a
maybeFirstRow :: Functor f => (a -> b) -> f [a] -> f (Maybe b)
maybeFirstRow f q = fmap f . listToMaybe <$> q
maybeFirstRow' :: Functor f => b -> (a -> b) -> f [a] -> f b
maybeFirstRow' def f q = maybe def f . listToMaybe <$> q
firstRow' :: (a -> Either e b) -> e -> IO [a] -> IO (Either e b)
firstRow' f e a = (f <=< listToEither e) <$> a
groupOn :: Eq k => (a -> k) -> [a] -> [[a]]
groupOn = groupBy . eqOn
groupOn' :: Eq k => (a -> k) -> [a] -> [NonEmpty a]
groupOn' = L.groupBy . eqOn
-- it is equivalent to groupBy ((==) `on` f),
-- but it redefines `on` to avoid duplicate computation for most values.
-- source: https://hackage.haskell.org/package/extra-1.7.13/docs/src/Data.List.Extra.html#groupOn
-- the on2 in this package is specialized to only use `==` as the function, `eqOn f` is equivalent to `(==) `on` f`
eqOn :: Eq k => (a -> k) -> a -> a -> Bool
eqOn f x = let fx = f x in \y -> fx == f y
{-# INLINE eqOn #-}
groupAllOn :: Ord k => (a -> k) -> [a] -> [[a]]
groupAllOn f = groupOn f . sortOn f
-- n must be > 0
toChunks :: Int -> [a] -> [NonEmpty a]
toChunks _ [] = []
toChunks n xs =
let (ys, xs') = splitAt n xs
in maybe id (:) (L.nonEmpty ys) (toChunks n xs')
safeDecodeUtf8 :: ByteString -> Text
safeDecodeUtf8 = decodeUtf8With onError
where
onError _ _ = Just '?'
{-# INLINE safeDecodeUtf8 #-}
timeoutThrow :: MonadUnliftIO m => e -> Int -> ExceptT e m a -> ExceptT e m a
timeoutThrow e ms action = ExceptT (sequence <$> (ms `timeout` runExceptT action)) >>= maybe (throwE e) pure
threadDelay' :: Int64 -> IO ()
threadDelay' = loop
where
loop time
| time <= 0 = pure ()
| otherwise = do
let maxWait = min time $ fromIntegral (maxBound :: Int)
threadDelay $ fromIntegral maxWait
loop $ time - maxWait
diffToMicroseconds :: NominalDiffTime -> Int64
diffToMicroseconds diff = truncate $ diff * 1000000
{-# INLINE diffToMicroseconds #-}
diffToMilliseconds :: NominalDiffTime -> Int64
diffToMilliseconds diff = truncate $ diff * 1000
{-# INLINE diffToMilliseconds #-}
labelMyThread :: MonadIO m => String -> m ()
labelMyThread label = liftIO $ myThreadId >>= (`labelThread` label)
atomicModifyIORef'_ :: IORef a -> (a -> a) -> IO ()
atomicModifyIORef'_ r f = atomicModifyIORef' r (\v -> (f v, ()))
encodeJSON :: ToJSON a => a -> Text
encodeJSON = safeDecodeUtf8 . LB.toStrict . J.encode
{-# INLINE encodeJSON #-}
decodeJSON :: FromJSON a => Text -> Maybe a
decodeJSON = J.decodeStrict . encodeUtf8
{-# INLINE decodeJSON #-}
traverseWithKey_ :: Monad m => (k -> v -> m ()) -> Map k v -> m ()
traverseWithKey_ f = M.foldrWithKey (\k v -> (f k v >>)) (pure ())
{-# INLINE traverseWithKey_ #-}