mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-03-31 07:36:00 +00:00
432 lines
14 KiB
Haskell
432 lines
14 KiB
Haskell
{-# LANGUAGE BangPatterns #-}
|
|
{-# LANGUAGE MonadComprehensions #-}
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
{-# LANGUAGE ScopedTypeVariables #-}
|
|
|
|
-- spec: spec/modules/Simplex/Messaging/Util.md
|
|
module Simplex.Messaging.Util
|
|
( AnyError (..),
|
|
(<$?>),
|
|
($>>),
|
|
(<$$),
|
|
(<$$>),
|
|
raceAny_,
|
|
bshow,
|
|
tshow,
|
|
maybeWord,
|
|
liftError,
|
|
liftError',
|
|
liftEitherWith,
|
|
ifM,
|
|
whenM,
|
|
unlessM,
|
|
anyM,
|
|
($>>=),
|
|
mapME,
|
|
bindRight,
|
|
forME,
|
|
mapAccumLM,
|
|
packZipWith,
|
|
tryWriteTBQueue,
|
|
catchAll,
|
|
catchAll_,
|
|
tryAllErrors,
|
|
tryAllErrors',
|
|
catchAllErrors,
|
|
catchAllErrors',
|
|
catchThrow,
|
|
allFinally,
|
|
isOwnException,
|
|
isAsyncCancellation,
|
|
catchOwn',
|
|
catchOwn,
|
|
tryAllOwnErrors,
|
|
tryAllOwnErrors',
|
|
catchAllOwnErrors,
|
|
catchAllOwnErrors',
|
|
eitherToMaybe,
|
|
listToEither,
|
|
firstRow,
|
|
maybeFirstRow,
|
|
maybeFirstRow',
|
|
firstRow',
|
|
groupOn,
|
|
groupOn',
|
|
eqOn,
|
|
groupAllOn,
|
|
toChunks,
|
|
safeDecodeUtf8,
|
|
timeoutThrow,
|
|
threadDelay',
|
|
diffToMicroseconds,
|
|
diffToMilliseconds,
|
|
labelMyThread,
|
|
atomicModifyIORef'_,
|
|
encodeJSON,
|
|
decodeJSON,
|
|
traverseWithKey_,
|
|
) where
|
|
|
|
import Control.Exception (AllocationLimitExceeded (..), AsyncException (..))
|
|
import qualified Control.Exception as E
|
|
import Control.Monad
|
|
import Control.Monad.Except
|
|
import Control.Monad.IO.Unlift
|
|
import Control.Monad.Trans.Except
|
|
import Control.Monad.Trans.State.Strict (StateT (..))
|
|
import Data.Aeson (FromJSON, ToJSON)
|
|
import qualified Data.Aeson as J
|
|
import Data.Bifunctor (first, second)
|
|
import Data.ByteString.Char8 (ByteString)
|
|
import qualified Data.ByteString.Char8 as B
|
|
import Data.ByteString.Internal (toForeignPtr, unsafeCreate)
|
|
import qualified Data.ByteString.Lazy.Char8 as LB
|
|
import Data.IORef
|
|
import Data.Int (Int64)
|
|
import Data.List (groupBy, sortOn)
|
|
import Data.List.NonEmpty (NonEmpty (..))
|
|
import qualified Data.List.NonEmpty as L
|
|
import Data.Map.Strict (Map)
|
|
import qualified Data.Map.Strict as M
|
|
import Data.Maybe (listToMaybe)
|
|
import Data.Text (Text)
|
|
import qualified Data.Text as T
|
|
import Data.Text.Encoding (decodeUtf8With, encodeUtf8)
|
|
import Data.Time (NominalDiffTime)
|
|
import Data.Tuple (swap)
|
|
import Data.Word (Word8)
|
|
import Foreign.ForeignPtr (withForeignPtr)
|
|
import Foreign.Storable (peekByteOff, pokeByteOff)
|
|
import GHC.Conc (labelThread, myThreadId, threadDelay)
|
|
import UnliftIO hiding (atomicModifyIORef')
|
|
import qualified UnliftIO.Exception as UE
|
|
|
|
raceAny_ :: MonadUnliftIO m => [m a] -> m ()
|
|
raceAny_ = r []
|
|
where
|
|
r as (m : ms) = withAsync m $ \a -> r (a : as) ms
|
|
r as [] = void $ waitAnyCancel as
|
|
|
|
infixl 4 <$$>, <$$, <$?>
|
|
|
|
(<$$>) :: (Functor f, Functor g) => (a -> b) -> f (g a) -> f (g b)
|
|
(<$$>) = fmap . fmap
|
|
{-# INLINE (<$$>) #-}
|
|
|
|
(<$$) :: (Functor f, Functor g) => b -> f (g a) -> f (g b)
|
|
(<$$) = fmap . fmap . const
|
|
{-# INLINE (<$$) #-}
|
|
|
|
(<$?>) :: MonadFail m => (a -> Either String b) -> m a -> m b
|
|
f <$?> m = either fail pure . f =<< m
|
|
{-# INLINE (<$?>) #-}
|
|
|
|
bshow :: Show a => a -> ByteString
|
|
bshow = B.pack . show
|
|
{-# INLINE bshow #-}
|
|
|
|
tshow :: Show a => a -> Text
|
|
tshow = T.pack . show
|
|
{-# INLINE tshow #-}
|
|
|
|
maybeWord :: (a -> ByteString) -> Maybe a -> ByteString
|
|
maybeWord f = maybe "" $ B.cons ' ' . f
|
|
{-# INLINE maybeWord #-}
|
|
|
|
liftError :: MonadIO m => (e -> e') -> ExceptT e IO a -> ExceptT e' m a
|
|
liftError f = liftError' f . runExceptT
|
|
{-# INLINE liftError #-}
|
|
|
|
liftError' :: MonadIO m => (e -> e') -> IO (Either e a) -> ExceptT e' m a
|
|
liftError' f = ExceptT . fmap (first f) . liftIO
|
|
{-# INLINE liftError' #-}
|
|
|
|
liftEitherWith :: MonadIO m => (e -> e') -> Either e a -> ExceptT e' m a
|
|
liftEitherWith f = liftEither . first f
|
|
{-# INLINE liftEitherWith #-}
|
|
|
|
ifM :: Monad m => m Bool -> m a -> m a -> m a
|
|
ifM ba t f = ba >>= \b -> if b then t else f
|
|
{-# INLINE ifM #-}
|
|
|
|
whenM :: Monad m => m Bool -> m () -> m ()
|
|
whenM b a = ifM b a $ pure ()
|
|
{-# INLINE whenM #-}
|
|
|
|
unlessM :: Monad m => m Bool -> m () -> m ()
|
|
unlessM b = ifM b $ pure ()
|
|
{-# INLINE unlessM #-}
|
|
|
|
anyM :: Monad m => [m Bool] -> m Bool
|
|
anyM = foldM (\r a -> if r then pure r else (r ||) <$!> a) False
|
|
{-# INLINE anyM #-}
|
|
|
|
infixl 1 $>>, $>>=
|
|
|
|
($>>=) :: (Monad m, Monad f, Traversable f) => m (f a) -> (a -> m (f b)) -> m (f b)
|
|
f $>>= g = f >>= fmap join . mapM g
|
|
{-# INLINE ($>>=) #-}
|
|
|
|
($>>) :: (Monad m, Monad f, Traversable f) => m (f a) -> m (f b) -> m (f b)
|
|
f $>> g = f $>>= \_ -> g
|
|
{-# INLINE ($>>) #-}
|
|
|
|
mapME :: (Monad m, Traversable t) => (a -> m (Either e b)) -> t (Either e a) -> m (t (Either e b))
|
|
mapME f = mapM (bindRight f)
|
|
{-# INLINE mapME #-}
|
|
|
|
bindRight :: Monad m => (a -> m (Either e b)) -> Either e a -> m (Either e b)
|
|
bindRight = either (pure . Left)
|
|
{-# INLINE bindRight #-}
|
|
|
|
forME :: (Monad m, Traversable t) => t (Either e a) -> (a -> m (Either e b)) -> m (t (Either e b))
|
|
forME = flip mapME
|
|
{-# INLINE forME #-}
|
|
|
|
-- | Monadic version of mapAccumL
|
|
-- Copied from ghc-9.6.3 package: https://hackage.haskell.org/package/ghc-9.12.1/docs/GHC-Utils-Monad.html#v:mapAccumLM
|
|
-- for backward compatibility with 8.10.7.
|
|
mapAccumLM ::
|
|
(Monad m, Traversable t) =>
|
|
-- | combining function
|
|
(acc -> x -> m (acc, y)) ->
|
|
-- | initial state
|
|
acc ->
|
|
-- | inputs
|
|
t x ->
|
|
-- | final state, outputs
|
|
m (acc, t y)
|
|
{-# INLINE [1] mapAccumLM #-}
|
|
-- INLINE pragma. mapAccumLM is called in inner loops. Like 'map',
|
|
-- we inline it so that we can take advantage of knowing 'f'.
|
|
-- This makes a few percent difference (in compiler allocations)
|
|
-- when compiling perf/compiler/T9675
|
|
mapAccumLM f s = fmap swap . flip runStateT s . traverse f'
|
|
where
|
|
f' = StateT . (fmap . fmap) swap . flip f
|
|
|
|
{-# RULES "mapAccumLM/List" mapAccumLM = mapAccumLM_List #-}
|
|
{-# RULES "mapAccumLM/NonEmpty" mapAccumLM = mapAccumLM_NonEmpty #-}
|
|
|
|
mapAccumLM_List ::
|
|
Monad m =>
|
|
(acc -> x -> m (acc, y)) ->
|
|
acc ->
|
|
[x] ->
|
|
m (acc, [y])
|
|
{-# INLINE mapAccumLM_List #-}
|
|
mapAccumLM_List f = go
|
|
where
|
|
go s (x : xs) = do
|
|
(s1, x') <- f s x
|
|
(s2, xs') <- go s1 xs
|
|
return (s2, x' : xs')
|
|
go s [] = return (s, [])
|
|
|
|
mapAccumLM_NonEmpty ::
|
|
Monad m =>
|
|
(acc -> x -> m (acc, y)) ->
|
|
acc ->
|
|
NonEmpty x ->
|
|
m (acc, NonEmpty y)
|
|
{-# INLINE mapAccumLM_NonEmpty #-}
|
|
mapAccumLM_NonEmpty f s (x :| xs) =
|
|
[(s2, x' :| xs') | (s1, x') <- f s x, (s2, xs') <- mapAccumLM_List f s1 xs]
|
|
|
|
-- | Optimized from bytestring package for GHC 8.10.7 compatibility
|
|
packZipWith :: (Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> ByteString
|
|
packZipWith f s1 s2 =
|
|
unsafeCreate len $ \r ->
|
|
withForeignPtr fp1 $ \p1 ->
|
|
withForeignPtr fp2 $ \p2 -> zipWith_ p1 p2 r
|
|
where
|
|
zipWith_ p1 p2 r = go 0
|
|
where
|
|
go :: Int -> IO ()
|
|
go !n
|
|
| n >= len = pure ()
|
|
| otherwise = do
|
|
x <- peekByteOff p1 (off1 + n)
|
|
y <- peekByteOff p2 (off2 + n)
|
|
pokeByteOff r n (f x y)
|
|
go (n + 1)
|
|
(fp1, off1, l1) = toForeignPtr s1
|
|
(fp2, off2, l2) = toForeignPtr s2
|
|
len = min l1 l2
|
|
|
|
tryWriteTBQueue :: TBQueue a -> a -> STM Bool
|
|
tryWriteTBQueue q a = do
|
|
full <- isFullTBQueue q
|
|
unless full $ writeTBQueue q a
|
|
pure $ not full
|
|
{-# INLINE tryWriteTBQueue #-}
|
|
|
|
catchAll :: IO a -> (E.SomeException -> IO a) -> IO a
|
|
catchAll = E.catch
|
|
{-# INLINE catchAll #-}
|
|
|
|
catchAll_ :: IO a -> IO a -> IO a
|
|
catchAll_ a = catchAll a . const
|
|
{-# INLINE catchAll_ #-}
|
|
|
|
class Show e => AnyError e where fromSomeException :: E.SomeException -> e
|
|
|
|
tryAllErrors :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> ExceptT e m (Either e a)
|
|
tryAllErrors action = ExceptT $ Right <$> runExceptT action `UE.catch` (pure . Left . fromSomeException)
|
|
{-# INLINE tryAllErrors #-}
|
|
|
|
tryAllErrors' :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> m (Either e a)
|
|
tryAllErrors' action = runExceptT action `UE.catch` (pure . Left . fromSomeException)
|
|
{-# INLINE tryAllErrors' #-}
|
|
|
|
catchAllErrors :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> (e -> ExceptT e m a) -> ExceptT e m a
|
|
catchAllErrors action handler = tryAllErrors action >>= either handler pure
|
|
{-# INLINE catchAllErrors #-}
|
|
|
|
catchAllErrors' :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> (e -> m a) -> m a
|
|
catchAllErrors' action handler = tryAllErrors' action >>= either handler pure
|
|
{-# INLINE catchAllErrors' #-}
|
|
|
|
catchThrow :: MonadUnliftIO m => ExceptT e m a -> (SomeException -> e) -> ExceptT e m a
|
|
action `catchThrow` err = ExceptT $ runExceptT action `UE.catch` (pure . Left . err)
|
|
{-# INLINE catchThrow #-}
|
|
|
|
allFinally :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> ExceptT e m b -> ExceptT e m a
|
|
allFinally action final = tryAllErrors action >>= \r -> final >> except r
|
|
{-# INLINE allFinally #-}
|
|
|
|
-- spec: spec/modules/Simplex/Messaging/Util.md#isOwnException
|
|
isOwnException :: E.SomeException -> Bool
|
|
isOwnException e = case E.fromException e of
|
|
Just StackOverflow -> True
|
|
Just HeapOverflow -> True
|
|
_ -> case E.fromException e of
|
|
Just AllocationLimitExceeded -> True
|
|
_ -> False
|
|
{-# INLINE isOwnException #-}
|
|
|
|
-- spec: spec/modules/Simplex/Messaging/Util.md#isAsyncCancellation
|
|
isAsyncCancellation :: E.SomeException -> Bool
|
|
isAsyncCancellation e = case E.fromException e of
|
|
Just (_ :: SomeAsyncException) -> not $ isOwnException e
|
|
Nothing -> False
|
|
{-# INLINE isAsyncCancellation #-}
|
|
|
|
-- spec: spec/modules/Simplex/Messaging/Util.md#catchOwn
|
|
-- Catches all exceptions EXCEPT async cancellations (name is misleading)
|
|
catchOwn' :: IO a -> (E.SomeException -> IO a) -> IO a
|
|
catchOwn' action handleInternal = action `E.catch` \e -> if isAsyncCancellation e then E.throwIO e else handleInternal e
|
|
{-# INLINE catchOwn' #-}
|
|
|
|
-- spec: spec/modules/Simplex/Messaging/Util.md#catchOwn
|
|
catchOwn :: MonadUnliftIO m => m a -> (E.SomeException -> m a) -> m a
|
|
catchOwn action handleInternal =
|
|
withRunInIO $ \run ->
|
|
run action `E.catch` \e -> if isAsyncCancellation e then E.throwIO e else run (handleInternal e)
|
|
{-# INLINE catchOwn #-}
|
|
|
|
tryAllOwnErrors :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> ExceptT e m (Either e a)
|
|
tryAllOwnErrors action = ExceptT $ Right <$> runExceptT action `catchOwn` (pure . Left . fromSomeException)
|
|
{-# INLINE tryAllOwnErrors #-}
|
|
|
|
tryAllOwnErrors' :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> m (Either e a)
|
|
tryAllOwnErrors' action = runExceptT action `catchOwn` (pure . Left . fromSomeException)
|
|
{-# INLINE tryAllOwnErrors' #-}
|
|
|
|
catchAllOwnErrors :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> (e -> ExceptT e m a) -> ExceptT e m a
|
|
catchAllOwnErrors action handler = tryAllOwnErrors action >>= either handler pure
|
|
{-# INLINE catchAllOwnErrors #-}
|
|
|
|
catchAllOwnErrors' :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> (e -> m a) -> m a
|
|
catchAllOwnErrors' action handler = tryAllOwnErrors' action >>= either handler pure
|
|
{-# INLINE catchAllOwnErrors' #-}
|
|
|
|
eitherToMaybe :: Either a b -> Maybe b
|
|
eitherToMaybe = either (const Nothing) Just
|
|
{-# INLINE eitherToMaybe #-}
|
|
|
|
listToEither :: e -> [a] -> Either e a
|
|
listToEither _ (x : _) = Right x
|
|
listToEither e _ = Left e
|
|
|
|
firstRow :: (a -> b) -> e -> IO [a] -> IO (Either e b)
|
|
firstRow f e a = second f . listToEither e <$> a
|
|
|
|
maybeFirstRow :: Functor f => (a -> b) -> f [a] -> f (Maybe b)
|
|
maybeFirstRow f q = fmap f . listToMaybe <$> q
|
|
|
|
maybeFirstRow' :: Functor f => b -> (a -> b) -> f [a] -> f b
|
|
maybeFirstRow' def f q = maybe def f . listToMaybe <$> q
|
|
|
|
firstRow' :: (a -> Either e b) -> e -> IO [a] -> IO (Either e b)
|
|
firstRow' f e a = (f <=< listToEither e) <$> a
|
|
|
|
groupOn :: Eq k => (a -> k) -> [a] -> [[a]]
|
|
groupOn = groupBy . eqOn
|
|
|
|
groupOn' :: Eq k => (a -> k) -> [a] -> [NonEmpty a]
|
|
groupOn' = L.groupBy . eqOn
|
|
|
|
-- it is equivalent to groupBy ((==) `on` f),
|
|
-- but it redefines `on` to avoid duplicate computation for most values.
|
|
-- source: https://hackage.haskell.org/package/extra-1.7.13/docs/src/Data.List.Extra.html#groupOn
|
|
-- the on2 in this package is specialized to only use `==` as the function, `eqOn f` is equivalent to `(==) `on` f`
|
|
eqOn :: Eq k => (a -> k) -> a -> a -> Bool
|
|
eqOn f x = let fx = f x in \y -> fx == f y
|
|
{-# INLINE eqOn #-}
|
|
|
|
groupAllOn :: Ord k => (a -> k) -> [a] -> [[a]]
|
|
groupAllOn f = groupOn f . sortOn f
|
|
|
|
-- n must be > 0
|
|
toChunks :: Int -> [a] -> [NonEmpty a]
|
|
toChunks _ [] = []
|
|
toChunks n xs =
|
|
let (ys, xs') = splitAt n xs
|
|
in maybe id (:) (L.nonEmpty ys) (toChunks n xs')
|
|
|
|
safeDecodeUtf8 :: ByteString -> Text
|
|
safeDecodeUtf8 = decodeUtf8With onError
|
|
where
|
|
onError _ _ = Just '?'
|
|
{-# INLINE safeDecodeUtf8 #-}
|
|
|
|
timeoutThrow :: MonadUnliftIO m => e -> Int -> ExceptT e m a -> ExceptT e m a
|
|
timeoutThrow e ms action = ExceptT (sequence <$> (ms `timeout` runExceptT action)) >>= maybe (throwE e) pure
|
|
|
|
threadDelay' :: Int64 -> IO ()
|
|
threadDelay' = loop
|
|
where
|
|
loop time
|
|
| time <= 0 = pure ()
|
|
| otherwise = do
|
|
let maxWait = min time $ fromIntegral (maxBound :: Int)
|
|
threadDelay $ fromIntegral maxWait
|
|
loop $ time - maxWait
|
|
|
|
diffToMicroseconds :: NominalDiffTime -> Int64
|
|
diffToMicroseconds diff = truncate $ diff * 1000000
|
|
{-# INLINE diffToMicroseconds #-}
|
|
|
|
diffToMilliseconds :: NominalDiffTime -> Int64
|
|
diffToMilliseconds diff = truncate $ diff * 1000
|
|
{-# INLINE diffToMilliseconds #-}
|
|
|
|
labelMyThread :: MonadIO m => String -> m ()
|
|
labelMyThread label = liftIO $ myThreadId >>= (`labelThread` label)
|
|
|
|
atomicModifyIORef'_ :: IORef a -> (a -> a) -> IO ()
|
|
atomicModifyIORef'_ r f = atomicModifyIORef' r (\v -> (f v, ()))
|
|
|
|
encodeJSON :: ToJSON a => a -> Text
|
|
encodeJSON = safeDecodeUtf8 . LB.toStrict . J.encode
|
|
{-# INLINE encodeJSON #-}
|
|
|
|
decodeJSON :: FromJSON a => Text -> Maybe a
|
|
decodeJSON = J.decodeStrict . encodeUtf8
|
|
{-# INLINE decodeJSON #-}
|
|
|
|
traverseWithKey_ :: Monad m => (k -> v -> m ()) -> Map k v -> m ()
|
|
traverseWithKey_ f = M.foldrWithKey (\k v -> (f k v >>)) (pure ())
|
|
{-# INLINE traverseWithKey_ #-}
|