remove IORefs

This commit is contained in:
IC Rainbow
2023-12-18 17:13:47 +02:00
parent cb89b963bf
commit ea714c731c
3 changed files with 45 additions and 64 deletions
+38 -55
View File
@@ -120,6 +120,7 @@ import Data.Bifunctor (bimap, first, second)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Composition ((.:), (.:.), (.::), (.::.))
import Data.Either (rights)
import Data.Foldable (foldl')
import Data.Functor (($>))
import Data.List (find)
@@ -127,7 +128,7 @@ import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as L
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing)
import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing, mapMaybe)
import Data.Text (Text)
import qualified Data.Text as T
import Data.Time.Clock
@@ -167,7 +168,6 @@ import Simplex.RemoteControl.Invitation
import Simplex.RemoteControl.Types
import UnliftIO.Async (async, race_)
import UnliftIO.Concurrent (forkFinally, forkIO, threadDelay)
import UnliftIO.IORef
import UnliftIO.STM
-- import GHC.Conc (unsafeIOToSTM)
@@ -873,38 +873,30 @@ getNotificationMessage' c nonce encNtfInfo = do
Just SMP.NMsgMeta {msgId = msgId', msgTs = msgTs'} -> msgId == msgId' || msgTs > msgTs'
Nothing -> SMP.notification msgFlags
type EIORef a = IORef (Either AgentErrorType a)
-- | Send message to the connection (SEND command) in Reader monad
sendMessage' :: forall m. AgentMonad m => AgentClient -> ConnId -> MsgFlags -> MsgBody -> m AgentMsgId
sendMessage' c connId msgFlags msg =
oneResult $ \r -> sendMessagesB c [(r, (connId, msgFlags, msg))]
sendMessage' c connId msgFlags msg = oneResult $ sendMessagesB c [Right (connId, msgFlags, msg)]
-- | Send multiple messages to different connections (SEND command) in Reader monad
sendMessages' :: forall m. AgentMonad' m => AgentClient -> [MsgReq] -> m [Either AgentErrorType AgentMsgId]
sendMessages' c msgReqs = do
rs <- replicateM (length msgReqs) (newIORef $ Left $ INTERNAL "skipped in batch")
sendMessagesB c $ zip rs msgReqs
mapM readIORef rs
sendMessages' c msgReqs = sendMessagesB c $ Right <$> msgReqs
sendMessagesB :: forall m. AgentMonad' m => AgentClient -> [(EIORef AgentMsgId, MsgReq)] -> m ()
sendMessagesB :: forall m. AgentMonad' m => AgentClient -> [Either AgentErrorType MsgReq] -> m [Either AgentErrorType AgentMsgId]
sendMessagesB c reqs = withConnLocks c connIds "sendMessages" $ do
reqs' <- zip reqs <$> withStoreBatch c (\db -> map (getConn db) connIds)
reqs'' <- catMaybes <$> mapM prepareConn reqs'
reqs' <- zipWith (liftA2 (,)) reqs <$> withStoreBatch c (\db -> map (first storeError <$$> getConn db) connIds)
reqs'' <- mapME prepareConn reqs'
enqueueMessagesB c reqs''
where
prepareConn :: ((EIORef AgentMsgId, MsgReq), Either AgentErrorType SomeConn) -> m (Maybe (EIORef AgentMsgId, (ConnData, NonEmpty SndQueue, MsgFlags, AMessage)))
prepareConn (req@(r, _), conn_) = case conn_ of
Left e -> Nothing <$ writeIORef r (Left e)
Right (SomeConn _ conn) -> case conn of
DuplexConnection cData _ sqs -> enqueueMsgs cData sqs req
SndConnection cData sq -> enqueueMsgs cData [sq] req
_ -> Nothing <$ writeIORef r (Left $ CONN SIMPLEX)
enqueueMsgs :: ConnData -> NonEmpty SndQueue -> (EIORef AgentMsgId, MsgReq) -> m (Maybe (EIORef AgentMsgId, (ConnData, NonEmpty SndQueue, MsgFlags, AMessage)))
enqueueMsgs cData sqs (r, (_, msgFlags, msg))
| ratchetSyncSendProhibited cData = Nothing <$ writeIORef r (Left $ CMD PROHIBITED)
| otherwise = pure $ Just (r, (cData, sqs, msgFlags, A_MSG msg))
connIds = map (\(_, (connId, _, _)) -> connId) reqs
prepareConn :: (MsgReq, SomeConn) -> m (Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage))
prepareConn (req, SomeConn _ conn) = case conn of
DuplexConnection cData _ sqs -> enqueueMsgs cData sqs req
SndConnection cData sq -> enqueueMsgs cData [sq] req
_ -> pure . Left $ CONN SIMPLEX
enqueueMsgs :: ConnData -> NonEmpty SndQueue -> MsgReq -> m (Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage))
enqueueMsgs cData sqs (_, msgFlags, msg)
| ratchetSyncSendProhibited cData = pure . Left $ CMD PROHIBITED
| otherwise = pure $ Right (cData, sqs, msgFlags, A_MSG msg)
connIds = map (either (const []) $ \(connId, _, _) -> connId) reqs
-- / async command processing v v v
@@ -1084,36 +1076,36 @@ enqueueMessages c cData sqs msgFlags aMessage = do
enqueueMessages' :: AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> MsgFlags -> AMessage -> m AgentMsgId
enqueueMessages' c cData sqs msgFlags aMessage =
oneResult $ \r -> enqueueMessagesB c [(r, (cData, sqs, msgFlags, aMessage))]
oneResult $ enqueueMessagesB c [Right (cData, sqs, msgFlags, aMessage)]
enqueueMessagesB :: AgentMonad' m => AgentClient -> [(EIORef AgentMsgId, (ConnData, NonEmpty SndQueue, MsgFlags, AMessage))] -> m ()
enqueueMessagesB _ [] = pure ()
enqueueMessagesB c reqs = enqueueMessageB c reqs >>= enqueueSavedMessageB c
enqueueMessagesB :: AgentMonad' m => AgentClient -> [Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage)] -> m [Either AgentErrorType AgentMsgId]
enqueueMessagesB c reqs = do
reqs' <- enqueueMessageB c reqs
enqueueSavedMessageB c $ mapMaybe snd $ rights reqs'
pure $ fst <$$> reqs'
isActiveSndQ :: SndQueue -> Bool
isActiveSndQ SndQueue {status} = status == Secured || status == Active
enqueueMessage :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> MsgFlags -> AMessage -> m AgentMsgId
enqueueMessage c cData sq msgFlags aMessage =
oneResult $ \r -> enqueueMessageB c [(r, (cData, [sq], msgFlags, aMessage))]
enqueueMessage c cData sq msgFlags aMessage = fst <$> oneResult (enqueueMessageB c [Right (cData, [sq], msgFlags, aMessage)])
-- this function is used only for sending messages in batch, it returns the list of successes to enqueue additional deliveries
enqueueMessageB :: forall m. AgentMonad' m => AgentClient -> [Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage))] -> m [Either AgentErrorType (ConnData, [SndQueue], AgentMsgId)]
enqueueMessageB :: forall m. AgentMonad' m => AgentClient -> [Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage)] -> m [Either AgentErrorType (AgentMsgId, Maybe (ConnData, [SndQueue], AgentMsgId))]
enqueueMessageB c reqs = do
forME_ reqs $ \(_, (cData, sq :| _, _, _)) ->
void . forME reqs $ \(cData, sq :| _, _, _) ->
runExceptT $ resumeMsgDelivery c cData sq
aVRange <- asks $ smpAgentVRange . config
aVRange <- asks $ maxVersion . smpAgentVRange . config
mIds <- withStoreBatch c $ \db ->
map (mapE $ storeSentMsg db $ maxVersion aVRange) reqs
forME mIds $ \mId -> do
map (mapE (first storeError <$$> storeSentMsg db aVRange)) reqs
forME (zipWith (liftA2 (,)) reqs mIds) $ \((cData, sq :| sqs, _, _), mId) -> do
let InternalId msgId = mId
queuePendingMsgs c sq [mId]
let sqs' = filter isActiveSndQ sqs
pure $ Right (cData, sqs', msgId)
-- catMaybes <$> mapM processResults (zip reqs mIds)
pure $ Right (msgId, if null sqs' then Nothing else Just (cData, sqs', msgId))
where
storeSentMsg :: DB.Connection -> Version -> (EIORef AgentMsgId, (ConnData, NonEmpty SndQueue, MsgFlags, AMessage)) -> IO (Either StoreError InternalId)
storeSentMsg db agentVersion (_, (ConnData {connId}, sq :| _, msgFlags, aMessage)) = runExceptT $ do
storeSentMsg :: DB.Connection -> Version -> (ConnData, NonEmpty SndQueue, MsgFlags, AMessage) -> IO (Either StoreError InternalId)
storeSentMsg db agentVersion (ConnData {connId}, sq :| _, msgFlags, aMessage) = runExceptT $ do
internalTs <- liftIO getCurrentTime
(internalId, internalSndId, prevMsgHash) <- liftIO $ updateSndIds db connId
let privHeader = APrivHeader (unSndId internalSndId) prevMsgHash
@@ -1127,15 +1119,6 @@ enqueueMessageB c reqs = do
liftIO $ createSndMsg db connId msgData
liftIO $ createSndMsgDelivery db connId sq internalId
pure internalId
processResults :: ((EIORef AgentMsgId, (ConnData, NonEmpty SndQueue, MsgFlags, AMessage)), Either AgentErrorType InternalId) -> m (Maybe (ConnData, [SndQueue], AgentMsgId))
processResults ((r, (cData, sq :| sqs, _, _)), mId_) = case mId_ of
Left e -> Nothing <$ writeIORef r (Left e)
Right mId -> do
let InternalId msgId = mId
writeIORef r $ Right msgId
queuePendingMsgs c sq [mId]
let sqs' = filter isActiveSndQ sqs
pure $ if null sqs' then Nothing else Just (cData, sqs', msgId)
enqueueSavedMessage :: AgentMonad' m => AgentClient -> ConnData -> AgentMsgId -> SndQueue -> m ()
enqueueSavedMessage c cData msgId sq = enqueueSavedMessageB c [(cData, [sq], msgId)]
@@ -1155,11 +1138,11 @@ enqueueSavedMessageB c reqs = do
let mId = InternalId msgId
in map (\sq -> createSndMsgDelivery db connId sq mId) sqs
oneResult :: AgentMonad m => (EIORef a -> m b) -> m a
oneResult action = do
r <- newIORef $ Left $ INTERNAL "skipped in batch of one"
_ <- action r
readIORef r >>= liftEither
oneResult :: AgentMonad m => m [Either AgentErrorType a] -> m a
oneResult action = action >>= \case
[Right res] -> pure res
[Left err] -> throwError err
_ -> throwError $ INTERNAL "non-singleton result"
resumeMsgDelivery :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> m ()
resumeMsgDelivery c cData@ConnData {connId} sq@SndQueue {server, sndId} = do
@@ -1953,7 +1936,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, s
conn
cData@ConnData {userId, connId, duplexHandshake, connAgentVersion, ratchetSyncState = rss} =
withConnLock c connId "processSMP" $ case cmd of
SMP.MSG msg@SMP.RcvMessage {msgId = srvMsgId} ->
SMP.MSG msg@SMP.RcvMessage {msgId = srvMsgId} ->
handleNotifyAck $ do
msg' <- decryptSMPMessage v rq msg
handleNotifyAck $ case msg' of
+4 -6
View File
@@ -1303,16 +1303,14 @@ withStoreCtx_ ctx_ c action = do
handleInternal :: String -> E.SomeException -> IO (Either StoreError a)
handleInternal ctxStr e = pure . Left . SEInternal . B.pack $ show e <> ctxStr
withStoreBatch :: AgentMonad' m => AgentClient -> (DB.Connection -> [IO (Either StoreError a)]) -> m [Either AgentErrorType a]
withStoreBatch :: AgentMonad' m => AgentClient -> (DB.Connection -> [IO (Either AgentErrorType a)]) -> m [Either AgentErrorType a]
withStoreBatch c actions = do
st <- asks store
rs <-
liftIO $ agentOperationBracket c AODatabase (\_ -> pure ()) $
liftIO $ agentOperationBracket c AODatabase (\_ -> pure ()) $
withTransaction st $ mapM (`E.catch` handleInternal) . actions
pure $ map (first storeError) rs
where
handleInternal :: E.SomeException -> IO (Either StoreError a)
handleInternal = pure . Left . SEInternal . B.pack . show
handleInternal :: E.SomeException -> IO (Either AgentErrorType a)
handleInternal = pure . Left . INTERNAL . show
withStoreBatch' :: AgentMonad' m => AgentClient -> (DB.Connection -> [IO a]) -> m [Either AgentErrorType a]
withStoreBatch' c actions = withStoreBatch c $ map (Right <$>) . actions
+3 -3
View File
@@ -94,15 +94,15 @@ mapME_ f = mapM_ (mapE f)
{-# INLINE mapME_ #-}
mapE :: Monad m => (a -> m (Either e b)) -> Either e a -> m (Either e b)
mapE f = either (pure . Left) f
mapE = either (pure . Left)
{-# INLINE mapE #-}
forME :: Monad m => [Either e a] -> (a -> m (Either e b)) -> m [Either e b]
forME = flip mapME
{-# INLINE forME #-}
forME_ :: Monad m => [Either e a] -> (a -> m (Either e b)) -> m [Either e b]
forME_ = void . flip mapME_
forME_ :: Monad m => [Either e a] -> (a -> m (Either e b)) -> m ()
forME_ f = void . forME f
{-# INLINE forME_ #-}
catchAll :: IO a -> (E.SomeException -> IO a) -> IO a