From d260a464d621f2db544d436a60753fdd170d41b3 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Sun, 3 Jan 2021 10:42:41 +0000 Subject: [PATCH] add error handling, function to process SMP responses --- apps/smp-agent/Main.hs | 4 +- src/Simplex/Messaging/Agent.hs | 105 +++++++++++++------- src/Simplex/Messaging/Agent/Env/SQLite.hs | 3 +- src/Simplex/Messaging/Agent/ServerClient.hs | 31 +++--- src/Simplex/Messaging/Agent/Store.hs | 4 +- src/Simplex/Messaging/Agent/Transmission.hs | 33 +++++- src/Simplex/Messaging/Transport.hs | 4 +- 7 files changed, 119 insertions(+), 65 deletions(-) diff --git a/apps/smp-agent/Main.hs b/apps/smp-agent/Main.hs index e5f33a219..15c076dc4 100644 --- a/apps/smp-agent/Main.hs +++ b/apps/smp-agent/Main.hs @@ -13,10 +13,10 @@ cfg = tbqSize = 16, connIdBytes = 12, dbFile = "smp-agent.db", + smpTcpPort = "5223", smpConfig = ServerClientConfig - { tcpPort = "5223", - tbqSize = 16, + { tbqSize = 16, corrIdBytes = 4 } } diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index dd0096147..d903f1946 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -9,13 +9,13 @@ module Simplex.Messaging.Agent (runSMPAgent) where +import Control.Monad.Except import Control.Monad.IO.Unlift import Control.Monad.Reader import Crypto.Random -import qualified Data.ByteString.Char8 as B -import Data.Int import qualified Data.Map as M -import Data.Maybe (fromMaybe) +import Data.Maybe +import Network.Socket import Simplex.Messaging.Agent.Env.SQLite import Simplex.Messaging.Agent.ServerClient (ServerClient (..), newServerClient) import Simplex.Messaging.Agent.Store @@ -25,16 +25,22 @@ import Simplex.Messaging.Server.Transmission (Cmd (..), CorrId (..), SParty (..) import qualified Simplex.Messaging.Server.Transmission as SMP import Simplex.Messaging.Transport import UnliftIO.Async -import UnliftIO.Exception +import UnliftIO.Exception (Exception, SomeException) +import qualified UnliftIO.Exception as E import UnliftIO.IO import UnliftIO.STM +instance (MonadUnliftIO m, Exception e) => MonadUnliftIO (ExceptT e m) where + withRunInIO inner = ExceptT . E.try $ + withRunInIO $ \run -> + inner (run . (either E.throwIO pure <=< runExceptT)) + runSMPAgent :: (MonadRandom m, MonadUnliftIO m) => AgentConfig -> m () runSMPAgent cfg@AgentConfig {tcpPort} = do env <- newEnv cfg runReaderT smpAgent env where - smpAgent :: (MonadUnliftIO m, MonadReader Env m) => m () + smpAgent :: (MonadUnliftIO m', MonadReader Env m') => m' () smpAgent = runTCPServer tcpPort $ \h -> do putLn h "Welcome to SMP Agent v0.1" q <- asks $ tbqSize . config @@ -57,38 +63,45 @@ receive h AgentClient {rcvQ, sndQ} = send :: MonadUnliftIO m => Handle -> AgentClient -> m () send h AgentClient {sndQ} = forever $ atomically (readTBQueue sndQ) >>= tPut h -client :: forall m. (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m () -client AgentClient {rcvQ, sndQ, respQ, servers, commands} = forever $ do +client :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m () +client c@AgentClient {rcvQ, sndQ} = forever $ do t@(corrId, cAlias, cmd) <- atomically $ readTBQueue rcvQ - processCommand t cmd >>= \case + runExceptT (processCommand c t cmd) >>= \case Left e -> atomically $ writeTBQueue sndQ (corrId, cAlias, ERR e) Right _ -> return () + +processCommand :: forall m. (MonadUnliftIO m, MonadReader Env m, MonadError ErrorType m) => AgentClient -> ATransmission 'Client -> ACommand 'Client -> m () +processCommand AgentClient {respQ, servers, commands} t = \case + NEW smpServer _ -> do + srv <- getSMPServer smpServer + smpT <- mkSmpNEW + atomically $ writeTBQueue (smpSndQ srv) smpT + return () + _ -> throwError PROHIBITED where - handler :: SomeException -> m (Either StoreError Int64) - handler e = do - liftIO (print e) - return $ Right 1 + replyError :: ErrorType -> SomeException -> m a + replyError err e = do + liftIO . putStrLn $ "Exception: " ++ show e -- TODO remove + throwError err - processCommand :: ATransmission 'Client -> ACommand 'Client -> m (Either ErrorType ()) - processCommand t = \case - NEW server@SMPServer {host, port, keyHash} (AckMode mode) -> do - cfg <- asks $ smpConfig . config - maybeServer <- atomically $ M.lookup (host, fromMaybe "5223" port) <$> readTVar servers - srv <- case maybeServer of - Nothing -> do - conn <- asks db - _serverId <- addServer conn server `catch` handler - newServerClient cfg respQ host port - Just s -> return s - _t <- mkSmpNEW t - atomically $ writeTBQueue (smpSndQ srv) _t - liftIO $ putStrLn "sending NEW to server" - liftIO $ print t - return $ Right () - _ -> return $ Left PROHIBITED + getSMPServer :: SMPServer -> m ServerClient + getSMPServer s@SMPServer {host, port} = do + defPort <- asks $ smpTcpPort . config + let p = fromMaybe defPort port + atomically (M.lookup (host, p) <$> readTVar servers) + >>= maybe (newSMPServer s host p) return - mkSmpNEW :: ATransmission 'Client -> m SMP.Transmission - mkSmpNEW t = do + newSMPServer :: SMPServer -> HostName -> ServiceName -> m ServerClient + newSMPServer s host port = do + cfg <- asks $ smpConfig . config + store <- asks db + _serverId <- addServer store s `E.catch` replyError INTERNAL + srv <- newServerClient cfg respQ host port `E.catch` replyError (BROKER smpErrTCPConnection) + atomically . modifyTVar servers $ M.insert (host, port) srv + return srv + + mkSmpNEW :: m SMP.Transmission + mkSmpNEW = do g <- asks idsDrg smpCorrId <- atomically $ CorrId <$> randomBytes 4 g recipientKey <- atomically $ randomBytes 16 g -- TODO replace with cryptographic key pair @@ -100,16 +113,32 @@ client AgentClient {rcvQ, sndQ, respQ, servers, commands} = forever $ do toSMP, state = NEWRequestState {recipientKey, recipientPrivateKey} } - atomically . modifyTVar commands $ M.insert smpCorrId req + atomically . modifyTVar commands $ M.insert smpCorrId req -- TODO check ID collision return toSMP -processSmp :: MonadUnliftIO m => AgentClient -> m () +processSmp :: forall m. (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m () processSmp AgentClient {respQ, sndQ, commands} = forever $ do (_, (smpCorrId, qId, cmdOrErr)) <- atomically $ readTBQueue respQ - liftIO $ putStrLn "received from server" + liftIO $ putStrLn "received from server" -- TODO remove liftIO $ print (smpCorrId, qId, cmdOrErr) req <- atomically $ M.lookup smpCorrId <$> readTVar commands - atomically $ case req of -- TODO empty correlation ID is ok - it can be a message - Nothing -> writeTBQueue sndQ ("", "", ERR $ SMP smpErrCorrelationId) - Just Request {fromClient = (corrId, cAlias, cmd), toSMP, state} -> do - writeTBQueue sndQ (corrId, cAlias, ERR UNKNOWN) + case req of -- TODO empty correlation ID is ok - it can be a message + Nothing -> atomically $ writeTBQueue sndQ ("", "", ERR $ BROKER smpErrCorrelationId) + Just r -> processResponse r cmdOrErr + where + processResponse :: Request -> Either SMP.ErrorType SMP.Cmd -> m () + processResponse Request {fromClient = (corrId, cAlias, cmd), toSMP = (_, (_, _, smpCmd)), state} cmdOrErr = do + case cmdOrErr of + Left e -> respond $ ERR (SMP e) + Right resp -> case resp of + Cmd SBroker (SMP.IDS recipientId senderId) -> case smpCmd of + Cmd SRecipient (SMP.NEW _) -> case (cmd, state) of + (NEW _ _, NEWRequestState {recipientKey, recipientPrivateKey}) -> do + -- TODO all good - process response + respond $ ERR UNKNOWN + _ -> respond $ ERR INTERNAL + _ -> respond $ ERR (BROKER smpUnexpectedResponse) + _ -> respond $ ERR UNSUPPORTED + where + respond :: ACommand 'Agent -> m () + respond c = atomically $ writeTBQueue sndQ (corrId, cAlias, c) diff --git a/src/Simplex/Messaging/Agent/Env/SQLite.hs b/src/Simplex/Messaging/Agent/Env/SQLite.hs index ffc922cdc..ad64aa181 100644 --- a/src/Simplex/Messaging/Agent/Env/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Env/SQLite.hs @@ -14,18 +14,19 @@ import Network.Socket (HostName, ServiceName) import Numeric.Natural import Simplex.Messaging.Agent.ServerClient import Simplex.Messaging.Agent.Store +import Simplex.Messaging.Agent.Store.SQLite import Simplex.Messaging.Agent.Store.SQLite.Schema import Simplex.Messaging.Agent.Transmission import Simplex.Messaging.Server.Transmission (PublicKey) import qualified Simplex.Messaging.Server.Transmission as SMP import UnliftIO.STM -import Simplex.Messaging.Agent.Store.SQLite data AgentConfig = AgentConfig { tcpPort :: ServiceName, tbqSize :: Natural, connIdBytes :: Int, dbFile :: String, + smpTcpPort :: ServiceName, smpConfig :: ServerClientConfig } diff --git a/src/Simplex/Messaging/Agent/ServerClient.hs b/src/Simplex/Messaging/Agent/ServerClient.hs index 13178776f..bf358a5a2 100644 --- a/src/Simplex/Messaging/Agent/ServerClient.hs +++ b/src/Simplex/Messaging/Agent/ServerClient.hs @@ -5,10 +5,8 @@ module Simplex.Messaging.Agent.ServerClient where import Control.Monad import Control.Monad.IO.Unlift -import Data.Maybe import Network.Socket (HostName, ServiceName) import Numeric.Natural -import Simplex.Messaging.Agent.Store import qualified Simplex.Messaging.Server.Transmission as SMP import Simplex.Messaging.Transport import UnliftIO.Async @@ -16,8 +14,7 @@ import UnliftIO.IO import UnliftIO.STM data ServerClientConfig = ServerClientConfig - { tcpPort :: ServiceName, - tbqSize :: Natural, + { tbqSize :: Natural, corrIdBytes :: Natural } @@ -33,26 +30,26 @@ newServerClient :: ServerClientConfig -> TBQueue SMP.TransmissionOrError -> HostName -> - Maybe ServiceName -> + ServiceName -> m ServerClient newServerClient cfg smpRcvQ host port = do smpSndQ <- atomically . newTBQueue $ tbqSize cfg let c = ServerClient {smpSndQ, smpRcvQ} - _srvA <- async $ runClient (fromMaybe (tcpPort cfg) port) c + _srvA <- async $ runTCPClient host p (client c) + -- TODO because exception can be thrown inside async it is not caught by newSMPServer + -- there possibly needs to be another channel to communicate with ServerClient if it fails + -- alternatively, there may be just timeout on sent commands - + -- in this case late responses should be just ignored rather than result in smpErrCorrelationId return c where - runClient :: ServiceName -> ServerClient -> m () - runClient p c = do - liftIO $ print (host, p) - runTCPClient host p $ \h -> do - liftIO $ putStrLn "SMP connected" - _line <- getLn h -- "Welcome to SMP" - liftIO $ print _line - -- TODO test connection failure - race_ (send h c) (receive h) + client :: ServerClient -> Handle -> m () + client c h = do + _line <- getLn h -- "Welcome to SMP" + -- TODO test connection failure + send c h `race_` receive h - send :: Handle -> ServerClient -> m () - send h ServerClient {smpSndQ} = forever $ atomically (readTBQueue smpSndQ) >>= SMP.tPut h + send :: ServerClient -> Handle -> m () + send ServerClient {smpSndQ} h = forever $ atomically (readTBQueue smpSndQ) >>= SMP.tPut h receive :: Handle -> m () receive h = forever $ SMP.tGet SMP.fromServer h >>= atomically . writeTBQueue smpRcvQ diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index 2eca846c3..67bfb5fa8 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -8,11 +8,11 @@ module Simplex.Messaging.Agent.Store where +import Data.Int (Int64) import Data.Kind import Data.Time.Clock (UTCTime) import Simplex.Messaging.Agent.Transmission import Simplex.Messaging.Server.Transmission (Encoded, PublicKey, QueueId) -import Data.Int (Int64) data ReceiveQueue = ReceiveQueue { server :: SMPServer, @@ -71,7 +71,7 @@ data DeliveryStatus type SMPServerId = Int64 -class MonadAgentStore s m where +class Monad m => MonadAgentStore s m where addServer :: s -> SMPServer -> m (Either StoreError SMPServerId) createRcvConn :: s -> Maybe ConnAlias -> ReceiveQueue -> m (Either StoreError (Connection CReceive)) createSndConn :: s -> Maybe ConnAlias -> SendQueue -> m (Either StoreError (Connection CSend)) diff --git a/src/Simplex/Messaging/Agent/Transmission.hs b/src/Simplex/Messaging/Agent/Transmission.hs index 84dcb226a..22176de75 100644 --- a/src/Simplex/Messaging/Agent/Transmission.hs +++ b/src/Simplex/Messaging/Agent/Transmission.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} @@ -23,10 +24,20 @@ import Data.Type.Equality import Data.Typeable () import Network.Socket import Numeric.Natural -import Simplex.Messaging.Server.Transmission (CorrId (..), Encoded, MsgBody, PublicKey, QueueId, errBadParameters, errMessageBody) +import Simplex.Messaging.Server.Transmission + ( CorrId (..), + Encoded, + MsgBody, + PublicKey, + QueueId, + errBadParameters, + errMessageBody, + ) +import qualified Simplex.Messaging.Server.Transmission as SMP import Simplex.Messaging.Transport import System.IO import Text.Read +import UnliftIO.Exception type ARawTransmission = (ByteString, ByteString, ByteString) @@ -123,8 +134,16 @@ data MsgStatus = MsgOk | MsgError MsgErrorType data MsgErrorType = MsgSkipped AgentMsgId AgentMsgId | MsgBadId AgentMsgId | MsgBadHash deriving (Show) -data ErrorType = UNKNOWN | PROHIBITED | SYNTAX Int | SMP Natural | SIZE -- etc. TODO SYNTAX Natural - deriving (Show) +data ErrorType + = UNKNOWN + | UNSUPPORTED -- TODO remove once all commands implemented + | PROHIBITED + | SYNTAX Int + | BROKER Natural + | SMP SMP.ErrorType + | SIZE + | INTERNAL -- etc. TODO SYNTAX Natural + deriving (Show, Exception) data AckStatus = AckOk | AckError AckErrorType deriving (Show) @@ -138,8 +157,14 @@ errBadInvitation = 10 errNoConnAlias :: Int errNoConnAlias = 11 +smpErrTCPConnection :: Natural +smpErrTCPConnection = 1 + smpErrCorrelationId :: Natural -smpErrCorrelationId = 1 +smpErrCorrelationId = 2 + +smpUnexpectedResponse :: Natural +smpUnexpectedResponse = 3 parseCommand :: ByteString -> Either ErrorType ACmd parseCommand command = case B.words command of diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index 5f7ffde0f..b4ef2ce09 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -69,7 +69,9 @@ startTCPClient host port = getSocketHandle sock runTCPClient :: MonadUnliftIO m => HostName -> ServiceName -> (Handle -> m a) -> m a -runTCPClient host port = E.bracket (startTCPClient host port) IO.hClose +runTCPClient host port client = do + h <- startTCPClient host port + client h `E.finally` IO.hClose h getSocketHandle :: MonadIO m => Socket -> m Handle getSocketHandle conn = liftIO $ do