From ea714c731cc731962778400bb546a48098b8944d Mon Sep 17 00:00:00 2001 From: IC Rainbow Date: Mon, 18 Dec 2023 17:13:47 +0200 Subject: [PATCH] remove IORefs --- src/Simplex/Messaging/Agent.hs | 93 +++++++++++---------------- src/Simplex/Messaging/Agent/Client.hs | 10 ++- src/Simplex/Messaging/Util.hs | 6 +- 3 files changed, 45 insertions(+), 64 deletions(-) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index a61ae26d5..8c2043a81 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -120,6 +120,7 @@ import Data.Bifunctor (bimap, first, second) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Composition ((.:), (.:.), (.::), (.::.)) +import Data.Either (rights) import Data.Foldable (foldl') import Data.Functor (($>)) import Data.List (find) @@ -127,7 +128,7 @@ import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L import Data.Map.Strict (Map) import qualified Data.Map.Strict as M -import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing) +import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing, mapMaybe) import Data.Text (Text) import qualified Data.Text as T import Data.Time.Clock @@ -167,7 +168,6 @@ import Simplex.RemoteControl.Invitation import Simplex.RemoteControl.Types import UnliftIO.Async (async, race_) import UnliftIO.Concurrent (forkFinally, forkIO, threadDelay) -import UnliftIO.IORef import UnliftIO.STM -- import GHC.Conc (unsafeIOToSTM) @@ -873,38 +873,30 @@ getNotificationMessage' c nonce encNtfInfo = do Just SMP.NMsgMeta {msgId = msgId', msgTs = msgTs'} -> msgId == msgId' || msgTs > msgTs' Nothing -> SMP.notification msgFlags -type EIORef a = IORef (Either AgentErrorType a) - -- | Send message to the connection (SEND command) in Reader monad sendMessage' :: forall m. AgentMonad m => AgentClient -> ConnId -> MsgFlags -> MsgBody -> m AgentMsgId -sendMessage' c connId msgFlags msg = - oneResult $ \r -> sendMessagesB c [(r, (connId, msgFlags, msg))] +sendMessage' c connId msgFlags msg = oneResult $ sendMessagesB c [Right (connId, msgFlags, msg)] -- | Send multiple messages to different connections (SEND command) in Reader monad sendMessages' :: forall m. AgentMonad' m => AgentClient -> [MsgReq] -> m [Either AgentErrorType AgentMsgId] -sendMessages' c msgReqs = do - rs <- replicateM (length msgReqs) (newIORef $ Left $ INTERNAL "skipped in batch") - sendMessagesB c $ zip rs msgReqs - mapM readIORef rs +sendMessages' c msgReqs = sendMessagesB c $ Right <$> msgReqs -sendMessagesB :: forall m. AgentMonad' m => AgentClient -> [(EIORef AgentMsgId, MsgReq)] -> m () +sendMessagesB :: forall m. AgentMonad' m => AgentClient -> [Either AgentErrorType MsgReq] -> m [Either AgentErrorType AgentMsgId] sendMessagesB c reqs = withConnLocks c connIds "sendMessages" $ do - reqs' <- zip reqs <$> withStoreBatch c (\db -> map (getConn db) connIds) - reqs'' <- catMaybes <$> mapM prepareConn reqs' + reqs' <- zipWith (liftA2 (,)) reqs <$> withStoreBatch c (\db -> map (first storeError <$$> getConn db) connIds) + reqs'' <- mapME prepareConn reqs' enqueueMessagesB c reqs'' where - prepareConn :: ((EIORef AgentMsgId, MsgReq), Either AgentErrorType SomeConn) -> m (Maybe (EIORef AgentMsgId, (ConnData, NonEmpty SndQueue, MsgFlags, AMessage))) - prepareConn (req@(r, _), conn_) = case conn_ of - Left e -> Nothing <$ writeIORef r (Left e) - Right (SomeConn _ conn) -> case conn of - DuplexConnection cData _ sqs -> enqueueMsgs cData sqs req - SndConnection cData sq -> enqueueMsgs cData [sq] req - _ -> Nothing <$ writeIORef r (Left $ CONN SIMPLEX) - enqueueMsgs :: ConnData -> NonEmpty SndQueue -> (EIORef AgentMsgId, MsgReq) -> m (Maybe (EIORef AgentMsgId, (ConnData, NonEmpty SndQueue, MsgFlags, AMessage))) - enqueueMsgs cData sqs (r, (_, msgFlags, msg)) - | ratchetSyncSendProhibited cData = Nothing <$ writeIORef r (Left $ CMD PROHIBITED) - | otherwise = pure $ Just (r, (cData, sqs, msgFlags, A_MSG msg)) - connIds = map (\(_, (connId, _, _)) -> connId) reqs + prepareConn :: (MsgReq, SomeConn) -> m (Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage)) + prepareConn (req, SomeConn _ conn) = case conn of + DuplexConnection cData _ sqs -> enqueueMsgs cData sqs req + SndConnection cData sq -> enqueueMsgs cData [sq] req + _ -> pure . Left $ CONN SIMPLEX + enqueueMsgs :: ConnData -> NonEmpty SndQueue -> MsgReq -> m (Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage)) + enqueueMsgs cData sqs (_, msgFlags, msg) + | ratchetSyncSendProhibited cData = pure . Left $ CMD PROHIBITED + | otherwise = pure $ Right (cData, sqs, msgFlags, A_MSG msg) + connIds = map (either (const []) $ \(connId, _, _) -> connId) reqs -- / async command processing v v v @@ -1084,36 +1076,36 @@ enqueueMessages c cData sqs msgFlags aMessage = do enqueueMessages' :: AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> MsgFlags -> AMessage -> m AgentMsgId enqueueMessages' c cData sqs msgFlags aMessage = - oneResult $ \r -> enqueueMessagesB c [(r, (cData, sqs, msgFlags, aMessage))] + oneResult $ enqueueMessagesB c [Right (cData, sqs, msgFlags, aMessage)] -enqueueMessagesB :: AgentMonad' m => AgentClient -> [(EIORef AgentMsgId, (ConnData, NonEmpty SndQueue, MsgFlags, AMessage))] -> m () -enqueueMessagesB _ [] = pure () -enqueueMessagesB c reqs = enqueueMessageB c reqs >>= enqueueSavedMessageB c +enqueueMessagesB :: AgentMonad' m => AgentClient -> [Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage)] -> m [Either AgentErrorType AgentMsgId] +enqueueMessagesB c reqs = do + reqs' <- enqueueMessageB c reqs + enqueueSavedMessageB c $ mapMaybe snd $ rights reqs' + pure $ fst <$$> reqs' isActiveSndQ :: SndQueue -> Bool isActiveSndQ SndQueue {status} = status == Secured || status == Active enqueueMessage :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> MsgFlags -> AMessage -> m AgentMsgId -enqueueMessage c cData sq msgFlags aMessage = - oneResult $ \r -> enqueueMessageB c [(r, (cData, [sq], msgFlags, aMessage))] +enqueueMessage c cData sq msgFlags aMessage = fst <$> oneResult (enqueueMessageB c [Right (cData, [sq], msgFlags, aMessage)]) -- this function is used only for sending messages in batch, it returns the list of successes to enqueue additional deliveries -enqueueMessageB :: forall m. AgentMonad' m => AgentClient -> [Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage))] -> m [Either AgentErrorType (ConnData, [SndQueue], AgentMsgId)] +enqueueMessageB :: forall m. AgentMonad' m => AgentClient -> [Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage)] -> m [Either AgentErrorType (AgentMsgId, Maybe (ConnData, [SndQueue], AgentMsgId))] enqueueMessageB c reqs = do - forME_ reqs $ \(_, (cData, sq :| _, _, _)) -> + void . forME reqs $ \(cData, sq :| _, _, _) -> runExceptT $ resumeMsgDelivery c cData sq - aVRange <- asks $ smpAgentVRange . config + aVRange <- asks $ maxVersion . smpAgentVRange . config mIds <- withStoreBatch c $ \db -> - map (mapE $ storeSentMsg db $ maxVersion aVRange) reqs - forME mIds $ \mId -> do + map (mapE (first storeError <$$> storeSentMsg db aVRange)) reqs + forME (zipWith (liftA2 (,)) reqs mIds) $ \((cData, sq :| sqs, _, _), mId) -> do let InternalId msgId = mId queuePendingMsgs c sq [mId] let sqs' = filter isActiveSndQ sqs - pure $ Right (cData, sqs', msgId) - -- catMaybes <$> mapM processResults (zip reqs mIds) + pure $ Right (msgId, if null sqs' then Nothing else Just (cData, sqs', msgId)) where - storeSentMsg :: DB.Connection -> Version -> (EIORef AgentMsgId, (ConnData, NonEmpty SndQueue, MsgFlags, AMessage)) -> IO (Either StoreError InternalId) - storeSentMsg db agentVersion (_, (ConnData {connId}, sq :| _, msgFlags, aMessage)) = runExceptT $ do + storeSentMsg :: DB.Connection -> Version -> (ConnData, NonEmpty SndQueue, MsgFlags, AMessage) -> IO (Either StoreError InternalId) + storeSentMsg db agentVersion (ConnData {connId}, sq :| _, msgFlags, aMessage) = runExceptT $ do internalTs <- liftIO getCurrentTime (internalId, internalSndId, prevMsgHash) <- liftIO $ updateSndIds db connId let privHeader = APrivHeader (unSndId internalSndId) prevMsgHash @@ -1127,15 +1119,6 @@ enqueueMessageB c reqs = do liftIO $ createSndMsg db connId msgData liftIO $ createSndMsgDelivery db connId sq internalId pure internalId - processResults :: ((EIORef AgentMsgId, (ConnData, NonEmpty SndQueue, MsgFlags, AMessage)), Either AgentErrorType InternalId) -> m (Maybe (ConnData, [SndQueue], AgentMsgId)) - processResults ((r, (cData, sq :| sqs, _, _)), mId_) = case mId_ of - Left e -> Nothing <$ writeIORef r (Left e) - Right mId -> do - let InternalId msgId = mId - writeIORef r $ Right msgId - queuePendingMsgs c sq [mId] - let sqs' = filter isActiveSndQ sqs - pure $ if null sqs' then Nothing else Just (cData, sqs', msgId) enqueueSavedMessage :: AgentMonad' m => AgentClient -> ConnData -> AgentMsgId -> SndQueue -> m () enqueueSavedMessage c cData msgId sq = enqueueSavedMessageB c [(cData, [sq], msgId)] @@ -1155,11 +1138,11 @@ enqueueSavedMessageB c reqs = do let mId = InternalId msgId in map (\sq -> createSndMsgDelivery db connId sq mId) sqs -oneResult :: AgentMonad m => (EIORef a -> m b) -> m a -oneResult action = do - r <- newIORef $ Left $ INTERNAL "skipped in batch of one" - _ <- action r - readIORef r >>= liftEither +oneResult :: AgentMonad m => m [Either AgentErrorType a] -> m a +oneResult action = action >>= \case + [Right res] -> pure res + [Left err] -> throwError err + _ -> throwError $ INTERNAL "non-singleton result" resumeMsgDelivery :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> m () resumeMsgDelivery c cData@ConnData {connId} sq@SndQueue {server, sndId} = do @@ -1953,7 +1936,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, s conn cData@ConnData {userId, connId, duplexHandshake, connAgentVersion, ratchetSyncState = rss} = withConnLock c connId "processSMP" $ case cmd of - SMP.MSG msg@SMP.RcvMessage {msgId = srvMsgId} -> + SMP.MSG msg@SMP.RcvMessage {msgId = srvMsgId} -> handleNotifyAck $ do msg' <- decryptSMPMessage v rq msg handleNotifyAck $ case msg' of diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 987bb52c3..d5135410a 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -1303,16 +1303,14 @@ withStoreCtx_ ctx_ c action = do handleInternal :: String -> E.SomeException -> IO (Either StoreError a) handleInternal ctxStr e = pure . Left . SEInternal . B.pack $ show e <> ctxStr -withStoreBatch :: AgentMonad' m => AgentClient -> (DB.Connection -> [IO (Either StoreError a)]) -> m [Either AgentErrorType a] +withStoreBatch :: AgentMonad' m => AgentClient -> (DB.Connection -> [IO (Either AgentErrorType a)]) -> m [Either AgentErrorType a] withStoreBatch c actions = do st <- asks store - rs <- - liftIO $ agentOperationBracket c AODatabase (\_ -> pure ()) $ + liftIO $ agentOperationBracket c AODatabase (\_ -> pure ()) $ withTransaction st $ mapM (`E.catch` handleInternal) . actions - pure $ map (first storeError) rs where - handleInternal :: E.SomeException -> IO (Either StoreError a) - handleInternal = pure . Left . SEInternal . B.pack . show + handleInternal :: E.SomeException -> IO (Either AgentErrorType a) + handleInternal = pure . Left . INTERNAL . show withStoreBatch' :: AgentMonad' m => AgentClient -> (DB.Connection -> [IO a]) -> m [Either AgentErrorType a] withStoreBatch' c actions = withStoreBatch c $ map (Right <$>) . actions diff --git a/src/Simplex/Messaging/Util.hs b/src/Simplex/Messaging/Util.hs index 830f16c50..0e6f929fb 100644 --- a/src/Simplex/Messaging/Util.hs +++ b/src/Simplex/Messaging/Util.hs @@ -94,15 +94,15 @@ mapME_ f = mapM_ (mapE f) {-# INLINE mapME_ #-} mapE :: Monad m => (a -> m (Either e b)) -> Either e a -> m (Either e b) -mapE f = either (pure . Left) f +mapE = either (pure . Left) {-# INLINE mapE #-} forME :: Monad m => [Either e a] -> (a -> m (Either e b)) -> m [Either e b] forME = flip mapME {-# INLINE forME #-} -forME_ :: Monad m => [Either e a] -> (a -> m (Either e b)) -> m [Either e b] -forME_ = void . flip mapME_ +forME_ :: Monad m => [Either e a] -> (a -> m (Either e b)) -> m () +forME_ f = void . forME f {-# INLINE forME_ #-} catchAll :: IO a -> (E.SomeException -> IO a) -> IO a