diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 57d37d17b..a90ae48a3 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -121,8 +121,9 @@ import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Composition ((.:), (.:.), (.::), (.::.)) import Data.Either (rights) -import Data.Foldable (foldl') +import Data.Foldable (foldl', toList) import Data.Functor (($>)) +import Data.Functor.Identity import Data.List (find) import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L @@ -875,28 +876,29 @@ getNotificationMessage' c nonce encNtfInfo = do -- | Send message to the connection (SEND command) in Reader monad sendMessage' :: forall m. AgentMonad m => AgentClient -> ConnId -> MsgFlags -> MsgBody -> m AgentMsgId -sendMessage' c connId msgFlags msg = oneResult $ sendMessagesB c [Right (connId, msgFlags, msg)] +sendMessage' c connId msgFlags msg = liftEither . runIdentity =<< sendMessagesB c (Identity (Right (connId, msgFlags, msg))) -- | Send multiple messages to different connections (SEND command) in Reader monad sendMessages' :: forall m. AgentMonad' m => AgentClient -> [MsgReq] -> m [Either AgentErrorType AgentMsgId] -sendMessages' c msgReqs = sendMessagesB c $ Right <$> msgReqs +sendMessages' c = sendMessagesB c . map Right -sendMessagesB :: forall m. AgentMonad' m => AgentClient -> [Either AgentErrorType MsgReq] -> m [Either AgentErrorType AgentMsgId] +sendMessagesB :: forall m t. (AgentMonad' m, Traversable t) => AgentClient -> t (Either AgentErrorType MsgReq) -> m (t (Either AgentErrorType AgentMsgId)) sendMessagesB c reqs = withConnLocks c connIds "sendMessages" $ do - reqs' <- withStoreBatch c (\db -> map (mapE $ \req@(connId, _, _) -> bimap storeError (req,) <$> getConn db connId) reqs) - reqs'' <- mapME prepareConn reqs' + reqs' <- withStoreBatch c (\db -> fmap (mapE $ \req@(connId, _, _) -> bimap storeError (req,) <$> getConn db connId) reqs) + let reqs'' = fmap (>>= prepareConn) reqs' enqueueMessagesB c reqs'' where - prepareConn :: (MsgReq, SomeConn) -> m (Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage)) - prepareConn (req, SomeConn _ conn) = case conn of - DuplexConnection cData _ sqs -> enqueueMsgs cData sqs req - SndConnection cData sq -> enqueueMsgs cData [sq] req - _ -> pure . Left $ CONN SIMPLEX - enqueueMsgs :: ConnData -> NonEmpty SndQueue -> MsgReq -> m (Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage)) - enqueueMsgs cData sqs (_, msgFlags, msg) - | ratchetSyncSendProhibited cData = pure . Left $ CMD PROHIBITED - | otherwise = pure $ Right (cData, sqs, msgFlags, A_MSG msg) - connIds = map (either (const []) $ \(connId, _, _) -> connId) reqs + prepareConn :: (MsgReq, SomeConn) -> Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage) + prepareConn ((_, msgFlags, msg), SomeConn _ conn) = case conn of + DuplexConnection cData _ sqs -> prepareMsgs cData sqs + SndConnection cData sq -> prepareMsgs cData [sq] + _ -> Left $ CONN SIMPLEX + where + prepareMsgs :: ConnData -> NonEmpty SndQueue -> Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage) + prepareMsgs cData sqs + | ratchetSyncSendProhibited cData = Left $ CMD PROHIBITED + | otherwise = Right (cData, sqs, msgFlags, A_MSG msg) + connIds = map (\(connId, _, _) -> connId) $ rights $ toList reqs -- / async command processing v v v @@ -1076,27 +1078,28 @@ enqueueMessages c cData sqs msgFlags aMessage = do enqueueMessages' :: AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> MsgFlags -> AMessage -> m AgentMsgId enqueueMessages' c cData sqs msgFlags aMessage = - oneResult $ enqueueMessagesB c [Right (cData, sqs, msgFlags, aMessage)] + liftEither . runIdentity =<< enqueueMessagesB c (Identity (Right (cData, sqs, msgFlags, aMessage))) -enqueueMessagesB :: AgentMonad' m => AgentClient -> [Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage)] -> m [Either AgentErrorType AgentMsgId] +enqueueMessagesB :: (AgentMonad' m, Traversable t) => AgentClient -> t (Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage)) -> m (t (Either AgentErrorType AgentMsgId)) enqueueMessagesB c reqs = do reqs' <- enqueueMessageB c reqs - enqueueSavedMessageB c $ mapMaybe snd $ rights reqs' + enqueueSavedMessageB c $ mapMaybe snd $ rights $ toList reqs' pure $ fst <$$> reqs' isActiveSndQ :: SndQueue -> Bool isActiveSndQ SndQueue {status} = status == Secured || status == Active enqueueMessage :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> MsgFlags -> AMessage -> m AgentMsgId -enqueueMessage c cData sq msgFlags aMessage = fst <$> oneResult (enqueueMessageB c [Right (cData, [sq], msgFlags, aMessage)]) +enqueueMessage c cData sq msgFlags aMessage = + liftEither . fmap fst . runIdentity =<< enqueueMessageB c (Identity (Right (cData, [sq], msgFlags, aMessage))) -- this function is used only for sending messages in batch, it returns the list of successes to enqueue additional deliveries -enqueueMessageB :: forall m. AgentMonad' m => AgentClient -> [Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage)] -> m [Either AgentErrorType (AgentMsgId, Maybe (ConnData, [SndQueue], AgentMsgId))] +enqueueMessageB :: forall m t. (AgentMonad' m, Traversable t) => AgentClient -> t (Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage)) -> m (t (Either AgentErrorType (AgentMsgId, Maybe (ConnData, [SndQueue], AgentMsgId)))) enqueueMessageB c reqs = do void . forME reqs $ \(cData, sq :| _, _, _) -> runExceptT $ resumeMsgDelivery c cData sq aVRange <- asks $ maxVersion . smpAgentVRange . config - reqMids <- withStoreBatch c $ \db -> map (mapE $ storeSentMsg db aVRange) reqs + reqMids <- withStoreBatch c $ \db -> fmap (mapE $ storeSentMsg db aVRange) reqs forME reqMids $ \((cData, sq :| sqs, _, _), mId) -> do let InternalId msgId = mId queuePendingMsgs c sq [mId] @@ -1104,7 +1107,7 @@ enqueueMessageB c reqs = do pure $ Right (msgId, if null sqs' then Nothing else Just (cData, sqs', msgId)) where storeSentMsg :: DB.Connection -> Version -> (ConnData, NonEmpty SndQueue, MsgFlags, AMessage) -> IO (Either AgentErrorType ((ConnData, NonEmpty SndQueue, MsgFlags, AMessage), InternalId)) - storeSentMsg db agentVersion req@(ConnData {connId}, sq :| _, msgFlags, aMessage) = bimap storeError (req,) <$$> runExceptT $ do + storeSentMsg db agentVersion req@(ConnData {connId}, sq :| _, msgFlags, aMessage) = fmap (first storeError) $ runExceptT $ do internalTs <- liftIO getCurrentTime (internalId, internalSndId, prevMsgHash) <- liftIO $ updateSndIds db connId let privHeader = APrivHeader (unSndId internalSndId) prevMsgHash @@ -1117,12 +1120,12 @@ enqueueMessageB c reqs = do msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgFlags, msgBody, internalHash, prevMsgHash} liftIO $ createSndMsg db connId msgData liftIO $ createSndMsgDelivery db connId sq internalId - pure internalId + pure (req, internalId) enqueueSavedMessage :: AgentMonad' m => AgentClient -> ConnData -> AgentMsgId -> SndQueue -> m () -enqueueSavedMessage c cData msgId sq = enqueueSavedMessageB c [(cData, [sq], msgId)] +enqueueSavedMessage c cData msgId sq = enqueueSavedMessageB c $ Identity (cData, [sq], msgId) -enqueueSavedMessageB :: AgentMonad' m => AgentClient -> [(ConnData, [SndQueue], AgentMsgId)] -> m () +enqueueSavedMessageB :: (AgentMonad' m, Foldable t) => AgentClient -> t (ConnData, [SndQueue], AgentMsgId) -> m () enqueueSavedMessageB c reqs = do -- saving to the database moved to the start to avoid race conditions when delivery is read from queue before it is saved void $ withStoreBatch' c $ \db -> concatMap (storeDeliveries db) reqs @@ -1137,12 +1140,6 @@ enqueueSavedMessageB c reqs = do let mId = InternalId msgId in map (\sq -> createSndMsgDelivery db connId sq mId) sqs -oneResult :: AgentMonad m => m [Either AgentErrorType a] -> m a -oneResult action = action >>= \case - [Right res] -> pure res - [Left err] -> throwError err - _ -> throwError $ INTERNAL "non-singleton result" - resumeMsgDelivery :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> m () resumeMsgDelivery c cData@ConnData {connId} sq@SndQueue {server, sndId} = do let qKey = (server, sndId) diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index d5135410a..0ac65d466 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -1303,7 +1303,7 @@ withStoreCtx_ ctx_ c action = do handleInternal :: String -> E.SomeException -> IO (Either StoreError a) handleInternal ctxStr e = pure . Left . SEInternal . B.pack $ show e <> ctxStr -withStoreBatch :: AgentMonad' m => AgentClient -> (DB.Connection -> [IO (Either AgentErrorType a)]) -> m [Either AgentErrorType a] +withStoreBatch :: (AgentMonad' m, Traversable t) => AgentClient -> (DB.Connection -> t (IO (Either AgentErrorType a))) -> m (t (Either AgentErrorType a)) withStoreBatch c actions = do st <- asks store liftIO $ agentOperationBracket c AODatabase (\_ -> pure ()) $ @@ -1312,8 +1312,8 @@ withStoreBatch c actions = do handleInternal :: E.SomeException -> IO (Either AgentErrorType a) handleInternal = pure . Left . INTERNAL . show -withStoreBatch' :: AgentMonad' m => AgentClient -> (DB.Connection -> [IO a]) -> m [Either AgentErrorType a] -withStoreBatch' c actions = withStoreBatch c $ map (Right <$>) . actions +withStoreBatch' :: (AgentMonad' m, Traversable t) => AgentClient -> (DB.Connection -> t (IO a)) -> m (t (Either AgentErrorType a)) +withStoreBatch' c actions = withStoreBatch c $ \db -> fmap Right <$> actions db storeError :: StoreError -> AgentErrorType storeError = \case diff --git a/src/Simplex/Messaging/Util.hs b/src/Simplex/Messaging/Util.hs index 0e6f929fb..34c8cf152 100644 --- a/src/Simplex/Messaging/Util.hs +++ b/src/Simplex/Messaging/Util.hs @@ -85,11 +85,11 @@ unlessM b = ifM b $ pure () ($>>=) :: (Monad m, Monad f, Traversable f) => m (f a) -> (a -> m (f b)) -> m (f b) f $>>= g = f >>= fmap join . mapM g -mapME :: Monad m => (a -> m (Either e b)) -> [Either e a] -> m [Either e b] +mapME :: (Monad m, Traversable t) => (a -> m (Either e b)) -> t (Either e a) -> m (t (Either e b)) mapME f = mapM (mapE f) {-# INLINE mapME #-} -mapME_ :: Monad m => (a -> m (Either e b)) -> [Either e a] -> m () +mapME_ :: (Monad m, Traversable t) => (a -> m (Either e b)) -> t (Either e a) -> m () mapME_ f = mapM_ (mapE f) {-# INLINE mapME_ #-} @@ -97,11 +97,11 @@ mapE :: Monad m => (a -> m (Either e b)) -> Either e a -> m (Either e b) mapE = either (pure . Left) {-# INLINE mapE #-} -forME :: Monad m => [Either e a] -> (a -> m (Either e b)) -> m [Either e b] +forME :: (Monad m, Traversable t) => t (Either e a) -> (a -> m (Either e b)) -> m (t (Either e b)) forME = flip mapME {-# INLINE forME #-} -forME_ :: Monad m => [Either e a] -> (a -> m (Either e b)) -> m () +forME_ :: (Monad m, Traversable t) => t (Either e a) -> (a -> m (Either e b)) -> m () forME_ f = void . forME f {-# INLINE forME_ #-}