refactor store: m (Either e a) => ExceptT e m a

This commit is contained in:
Evgeny Poberezkin
2021-01-09 11:18:52 +00:00
parent 75f58f8ba4
commit 0d0a12f778
6 changed files with 175 additions and 145 deletions

View File

@@ -5,6 +5,7 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Simplex.Messaging.Agent (runSMPAgent) where
@@ -30,11 +31,6 @@ import qualified UnliftIO.Exception as E
import UnliftIO.IO
import UnliftIO.STM
instance (MonadUnliftIO m, Exception e) => MonadUnliftIO (ExceptT e m) where
withRunInIO inner = ExceptT . E.try $
withRunInIO $ \run ->
inner (run . (either E.throwIO pure <=< runExceptT))
runSMPAgent :: (MonadRandom m, MonadUnliftIO m) => AgentConfig -> m ()
runSMPAgent cfg@AgentConfig {tcpPort} = do
env <- newEnv cfg
@@ -70,6 +66,15 @@ client c@AgentClient {rcvQ, sndQ} = forever $ do
Left e -> atomically $ writeTBQueue sndQ (corrId, cAlias, ERR e)
Right _ -> return ()
withStore ::
(MonadUnliftIO m, MonadError ErrorType m) =>
(forall n. (MonadUnliftIO n, MonadError StoreError n) => n a) ->
m a
withStore action =
runExceptT action >>= \case
Left _ -> throwError INTERNAL
Right c -> return c
processCommand :: forall m. (MonadUnliftIO m, MonadReader Env m, MonadError ErrorType m) => AgentClient -> ATransmission 'Client -> ACommand 'Client -> m ()
processCommand AgentClient {respQ, servers, commands} t = \case
NEW smpServer _ -> do
@@ -95,7 +100,7 @@ processCommand AgentClient {respQ, servers, commands} t = \case
newSMPServer s host port = do
cfg <- asks $ smpConfig . config
store <- asks db
_serverId <- addServer store s `E.catch` replyError INTERNAL
_serverId <- withStore (addServer store s) `E.catch` replyError INTERNAL
srv <- newServerClient cfg respQ host port `E.catch` replyError (BROKER smpErrTCPConnection)
atomically . modifyTVar servers $ M.insert (host, port) srv
return srv

View File

@@ -1,5 +1,7 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
@@ -8,6 +10,7 @@
module Simplex.Messaging.Agent.Store where
import Control.Exception
import Data.Int (Int64)
import Data.Kind
import Data.Time.Clock (UTCTime)
@@ -94,20 +97,20 @@ data DeliveryStatus
type SMPServerId = Int64
class Monad m => MonadAgentStore s m where
addServer :: s -> SMPServer -> m (Either StoreError SMPServerId)
createRcvConn :: s -> ConnAlias -> ReceiveQueue -> m (Either StoreError ())
createSndConn :: s -> ConnAlias -> SendQueue -> m (Either StoreError ())
getConn :: s -> ConnAlias -> m (Either StoreError SomeConn)
deleteConn :: s -> ConnAlias -> m (Either StoreError ())
addSndQueue :: s -> ConnAlias -> SendQueue -> m (Either StoreError ())
addRcvQueue :: s -> ConnAlias -> ReceiveQueue -> m (Either StoreError ())
removeSndAuth :: s -> ConnAlias -> m (Either StoreError ())
updateQueueStatus :: s -> ConnAlias -> QueueDirection -> QueueStatus -> m (Either StoreError ())
createMsg :: s -> ConnAlias -> QueueDirection -> AMessage -> m (Either StoreError MessageDelivery)
getLastMsg :: s -> ConnAlias -> QueueDirection -> m (Either StoreError MessageDelivery)
getMsg :: s -> ConnAlias -> QueueDirection -> AgentMsgId -> m (Either StoreError MessageDelivery)
updateMsgStatus :: s -> ConnAlias -> QueueDirection -> AgentMsgId -> m (Either StoreError ())
deleteMsg :: s -> ConnAlias -> QueueDirection -> AgentMsgId -> m (Either StoreError ())
addServer :: s -> SMPServer -> m SMPServerId
createRcvConn :: s -> ConnAlias -> ReceiveQueue -> m ()
createSndConn :: s -> ConnAlias -> SendQueue -> m ()
getConn :: s -> ConnAlias -> m SomeConn
deleteConn :: s -> ConnAlias -> m ()
addSndQueue :: s -> ConnAlias -> SendQueue -> m ()
addRcvQueue :: s -> ConnAlias -> ReceiveQueue -> m ()
removeSndAuth :: s -> ConnAlias -> m ()
updateQueueStatus :: s -> ConnAlias -> QueueDirection -> QueueStatus -> m ()
createMsg :: s -> ConnAlias -> QueueDirection -> AMessage -> m MessageDelivery
getLastMsg :: s -> ConnAlias -> QueueDirection -> m MessageDelivery
getMsg :: s -> ConnAlias -> QueueDirection -> AgentMsgId -> m MessageDelivery
updateMsgStatus :: s -> ConnAlias -> QueueDirection -> AgentMsgId -> m ()
deleteMsg :: s -> ConnAlias -> QueueDirection -> AgentMsgId -> m ()
data StoreError
= SEInternal
@@ -115,4 +118,4 @@ data StoreError
| SEBadConn
| SEBadConnType ConnType
| SEBadQueueStatus
deriving (Eq, Show)
deriving (Eq, Show, Exception)

View File

@@ -1,5 +1,6 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
@@ -10,10 +11,12 @@
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE UndecidableInstances #-}
module Simplex.Messaging.Agent.Store.SQLite where
import Control.Monad
import Control.Monad.Except
import Control.Monad.IO.Unlift
import Data.Int (Int64)
import qualified Data.Text as T
@@ -108,37 +111,38 @@ instance ToRow SMPServer where
instance FromRow SMPServer where
fromRow = SMPServer <$> field <*> field <*> field
upsertServer :: MonadUnliftIO m => SQLiteStore -> SMPServer -> m (Either StoreError SMPServerId)
upsertServer SQLiteStore {conn} srv@SMPServer {host, port} = liftIO $ do
DB.execute
conn
[s|
upsertServer :: (MonadUnliftIO m, MonadError StoreError m) => SQLiteStore -> SMPServer -> m SMPServerId
upsertServer SQLiteStore {conn} srv@SMPServer {host, port} = do
r <- liftIO $ do
DB.execute
conn
[s|
INSERT INTO servers (host, port, key_hash) VALUES (?, ?, ?)
ON CONFLICT (host, port) DO UPDATE SET
host=excluded.host,
port=excluded.port,
key_hash=excluded.key_hash;
|]
srv
r <-
srv
DB.queryNamed
conn
"SELECT server_id FROM servers WHERE host = :host AND port = :port"
[":host" := host, ":port" := port]
return $ case r of
[Only serverId] -> Right serverId
_ -> Left SEInternal
case r of
[Only serverId] -> return serverId
_ -> throwError SEInternal
getServer :: MonadUnliftIO m => SQLiteStore -> SMPServerId -> m (Either StoreError SMPServer)
getServer SQLiteStore {conn} serverId = liftIO $ do
getServer :: (MonadUnliftIO m, MonadError StoreError m) => SQLiteStore -> SMPServerId -> m SMPServer
getServer SQLiteStore {conn} serverId = do
r <-
DB.queryNamed
conn
"SELECT host, port, key_hash FROM servers WHERE server_id = :server_id"
[":server_id" := serverId]
return $ case r of
[smpServer] -> Right smpServer
_ -> Left SENotFound
liftIO $
DB.queryNamed
conn
"SELECT host, port, key_hash FROM servers WHERE server_id = :server_id"
[":server_id" := serverId]
case r of
[smpServer] -> return smpServer
_ -> throwError SENotFound
instance ToField AckMode where toField (AckMode mode) = toField $ show mode
@@ -156,38 +160,40 @@ instance FromRow ReceiveQueue where
fromRow = ReceiveQueue undefined <$> field <*> field <*> field <*> field <*> field <*> field <*> field <*> field
-- TODO refactor into a single query with join
getRcvQueue :: MonadUnliftIO m => SQLiteStore -> QueueRowId -> m (Either StoreError ReceiveQueue)
getRcvQueue st@SQLiteStore {conn} queueRowId = liftIO $ do
getRcvQueue :: (MonadUnliftIO m, MonadError StoreError m) => SQLiteStore -> QueueRowId -> m ReceiveQueue
getRcvQueue st@SQLiteStore {conn} queueRowId = do
r <-
DB.queryNamed
conn
[s|
liftIO $
DB.queryNamed
conn
[s|
SELECT server_id, rcv_id, rcv_private_key, snd_id, snd_key, decrypt_key, verify_key, status, ack_mode
FROM receive_queues
WHERE receive_queue_id = :rowId;
|]
[":rowId" := queueRowId]
[":rowId" := queueRowId]
case r of
[Only serverId :. rcvQueue] ->
(\srv -> (rcvQueue {server = srv} :: ReceiveQueue)) <$$> getServer st serverId
_ -> return (Left SENotFound)
(\srv -> (rcvQueue {server = srv} :: ReceiveQueue)) <$> getServer st serverId
_ -> throwError SENotFound
-- TODO refactor into a single query with join
getSndQueue :: MonadUnliftIO m => SQLiteStore -> QueueRowId -> m (Either StoreError SendQueue)
getSndQueue st@SQLiteStore {conn} queueRowId = liftIO $ do
getSndQueue :: (MonadUnliftIO m, MonadError StoreError m) => SQLiteStore -> QueueRowId -> m SendQueue
getSndQueue st@SQLiteStore {conn} queueRowId = do
r <-
DB.queryNamed
conn
[s|
liftIO $
DB.queryNamed
conn
[s|
SELECT server_id, snd_id, snd_private_key, encrypt_key, sign_key, status, ack_mode
FROM send_queues
WHERE send_queue_id = :rowId;
|]
[":rowId" := queueRowId]
[":rowId" := queueRowId]
case r of
[Only serverId :. sndQueue] ->
(\srv -> (sndQueue {server = srv} :: SendQueue)) <$$> getServer st serverId
_ -> return (Left SENotFound)
(\srv -> (sndQueue {server = srv} :: SendQueue)) <$> getServer st serverId
_ -> throwError SENotFound
insertRcvQueue :: MonadUnliftIO m => SQLiteStore -> SMPServerId -> ReceiveQueue -> m QueueRowId
insertRcvQueue store serverId rcvQueue =
@@ -262,16 +268,17 @@ updateSndConnectionWithRcvQueue store connAlias rcvQueueId =
|]
(Only rcvQueueId :. Only connAlias)
getConnection :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> m (Either StoreError (Maybe QueueRowId, Maybe QueueRowId))
getConnection SQLiteStore {conn} connAlias = liftIO $ do
getConnection :: (MonadError StoreError m, MonadUnliftIO m) => SQLiteStore -> ConnAlias -> m (Maybe QueueRowId, Maybe QueueRowId)
getConnection SQLiteStore {conn} connAlias = do
r <-
DB.queryNamed
conn
"SELECT receive_queue_id, send_queue_id FROM connections WHERE conn_alias = :conn_alias"
[":conn_alias" := connAlias]
return $ case r of
[queueIds] -> Right queueIds
_ -> Left SEInternal
liftIO $
DB.queryNamed
conn
"SELECT receive_queue_id, send_queue_id FROM connections WHERE conn_alias = :conn_alias"
[":conn_alias" := connAlias]
case r of
[queueIds] -> return queueIds
_ -> throwError SEInternal
deleteRcvQueue :: MonadUnliftIO m => SQLiteStore -> QueueRowId -> m ()
deleteRcvQueue store rcvQueueId = do
@@ -297,23 +304,23 @@ deleteConnection store connAlias = do
"DELETE FROM connections WHERE conn_alias = ?"
(Only connAlias)
instance MonadUnliftIO m => MonadAgentStore SQLiteStore m where
instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteStore m where
addServer store smpServer = upsertServer store smpServer
createRcvConn :: SQLiteStore -> ConnAlias -> ReceiveQueue -> m (Either StoreError ())
createRcvConn :: SQLiteStore -> ConnAlias -> ReceiveQueue -> m ()
createRcvConn st connAlias rcvQueue =
-- TODO test for duplicate connAlias
upsertServer st (server (rcvQueue :: ReceiveQueue))
>>= either (return . Left) (fmap Right . addConnection)
>>= addConnection
where
addConnection serverId =
-- TODO test for duplicate connAlias
insertRcvQueue st serverId rcvQueue
>>= insertRcvConnection st connAlias
createSndConn :: SQLiteStore -> ConnAlias -> SendQueue -> m (Either StoreError ())
createSndConn :: SQLiteStore -> ConnAlias -> SendQueue -> m ()
createSndConn st connAlias sndQueue =
upsertServer st (server (sndQueue :: SendQueue))
>>= either (return . Left) (fmap Right . addConnection)
>>= addConnection
where
addConnection serverId =
-- TODO test for duplicate connAlias
@@ -321,53 +328,46 @@ instance MonadUnliftIO m => MonadAgentStore SQLiteStore m where
>>= insertSndConnection st connAlias
-- TODO refactor ito a single query with join, and parse as `Only connAlias :. rcvQueue :. sndQueue`
getConn :: SQLiteStore -> ConnAlias -> m (Either StoreError SomeConn)
getConn :: SQLiteStore -> ConnAlias -> m SomeConn
getConn st connAlias =
getConnection st connAlias >>= \case
Left e -> return $ Left e
Right (Just rcvQId, Just sndQId) -> do
(Just rcvQId, Just sndQId) -> do
rcvQ <- getRcvQueue st rcvQId
sndQ <- getSndQueue st sndQId
return $ SomeConn SCDuplex <$> (DuplexConnection connAlias <$> rcvQ <*> sndQ)
Right (Just rcvQId, _) ->
fmap (SomeConn SCReceive . ReceiveConnection connAlias) <$> getRcvQueue st rcvQId
Right (_, Just sndQId) ->
fmap (SomeConn SCSend . SendConnection connAlias) <$> getSndQueue st sndQId
Right (_, _) -> return $ Left SEBadConn
return $ SomeConn SCDuplex (DuplexConnection connAlias rcvQ sndQ)
(Just rcvQId, _) ->
SomeConn SCReceive . ReceiveConnection connAlias <$> getRcvQueue st rcvQId
(_, Just sndQId) ->
SomeConn SCSend . SendConnection connAlias <$> getSndQueue st sndQId
(_, _) -> throwError SEBadConn
-- TODO make transactional
addSndQueue :: SQLiteStore -> ConnAlias -> SendQueue -> m (Either StoreError ())
addSndQueue :: SQLiteStore -> ConnAlias -> SendQueue -> m ()
addSndQueue st connAlias sndQueue =
getConn st connAlias
>>= either (return . Left) checkUpdateConn
where
checkUpdateConn :: SomeConn -> m (Either StoreError ())
checkUpdateConn = \case
SomeConn SCDuplex _ -> return $ Left (SEBadConnType CDuplex)
SomeConn SCSend _ -> return $ Left (SEBadConnType CSend)
>>= \case
SomeConn SCDuplex _ -> throwError (SEBadConnType CDuplex)
SomeConn SCSend _ -> throwError (SEBadConnType CSend)
SomeConn SCReceive _ ->
upsertServer st (server (sndQueue :: SendQueue))
>>= either (return . Left) (fmap Right . updateConn)
>>= updateConn
where
updateConn :: SMPServerId -> m ()
updateConn servId =
insertSndQueue st servId sndQueue
>>= updateRcvConnectionWithSndQueue st connAlias
-- TODO make transactional
addRcvQueue :: SQLiteStore -> ConnAlias -> ReceiveQueue -> m (Either StoreError ())
addRcvQueue :: SQLiteStore -> ConnAlias -> ReceiveQueue -> m ()
addRcvQueue st connAlias rcvQueue =
getConn st connAlias
>>= either (return . Left) checkUpdateConn
where
checkUpdateConn :: SomeConn -> m (Either StoreError ())
checkUpdateConn = \case
SomeConn SCDuplex _ -> return $ Left (SEBadConnType CDuplex)
SomeConn SCReceive _ -> return $ Left (SEBadConnType CReceive)
>>= \case
SomeConn SCDuplex _ -> throwError (SEBadConnType CDuplex)
SomeConn SCReceive _ -> throwError (SEBadConnType CReceive)
SomeConn SCSend _ ->
upsertServer st (server (rcvQueue :: ReceiveQueue))
>>= either (return . Left) (fmap Right . updateConn)
>>= updateConn
where
updateConn :: SMPServerId -> m ()
updateConn servId =
insertRcvQueue st servId rcvQueue
@@ -380,18 +380,17 @@ instance MonadUnliftIO m => MonadAgentStore SQLiteStore m where
-- * Enables cascade deletes
-- ? See https://sqlite.org/foreignkeys.html#fk_deferred
-- - Keep as is and just wrap in transaction?
deleteConn :: SQLiteStore -> ConnAlias -> m (Either StoreError ())
deleteConn :: SQLiteStore -> ConnAlias -> m ()
deleteConn st connAlias =
getConnection st connAlias >>= \case
Left e -> return $ Left e
Right (Just rcvQId, Just sndQId) -> do
(Just rcvQId, Just sndQId) -> do
deleteRcvQueue st rcvQId
deleteSndQueue st sndQId
Right <$> deleteConnection st connAlias
Right (Just rcvQId, _) -> do
deleteConnection st connAlias
(Just rcvQId, _) -> do
deleteRcvQueue st rcvQId
Right <$> deleteConnection st connAlias
Right (_, Just sndQId) -> do
deleteConnection st connAlias
(_, Just sndQId) -> do
deleteSndQueue st sndQId
Right <$> deleteConnection st connAlias
Right (_, _) -> return $ Left SEBadConn
deleteConnection st connAlias
(_, _) -> throwError SEBadConn

View File

@@ -13,7 +13,6 @@ import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import GHC.IO.Exception (IOErrorType (..))
import Network.Socket
import Simplex.Messaging.Util
import System.IO
import System.IO.Error
import UnliftIO.Concurrent

View File

@@ -1,8 +1,15 @@
module Simplex.Messaging.Util where
import Control.Monad (void)
import Control.Monad.Except
import Control.Monad.IO.Unlift
import UnliftIO.Async
import UnliftIO.Exception (Exception)
import qualified UnliftIO.Exception as E
instance (MonadUnliftIO m, Exception e) => MonadUnliftIO (ExceptT e m) where
withRunInIO inner = ExceptT . E.try $
withRunInIO $ \run ->
inner (run . (either E.throwIO pure <=< runExceptT))
raceAny_ :: MonadUnliftIO m => [m a] -> m ()
raceAny_ = r []