diff --git a/package.yaml b/package.yaml index f2fd486ec..722d647dc 100644 --- a/package.yaml +++ b/package.yaml @@ -24,11 +24,13 @@ dependencies: - cryptonite == 0.26.* - directory == 1.3.* - filepath == 1.4.* + - generic-random == 1.3.* - iso8601-time == 0.1.* - memory == 0.15.* - mtl - network == 3.1.* - network-transport == 0.5.* + - QuickCheck == 2.13.* - simple-logger == 0.1.* - sqlite-simple == 0.4.* - stm @@ -82,6 +84,7 @@ tests: - hspec-core == 2.7.* - HUnit == 1.6.* - random == 1.1.* + - QuickCheck == 2.13.* ghc-options: # - -haddock diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 6f952fd35..59ca26015 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -114,10 +114,15 @@ withStore :: withStore action = do runExceptT (action `E.catch` handleInternal) >>= \case Right c -> return c - Left _ -> throwError STORE + Left e -> throwError $ storeError e where handleInternal :: (MonadError StoreError m') => SomeException -> m' a - handleInternal _ = throwError SEInternal + handleInternal e = throwError . SEInternal $ bshow e + storeError :: StoreError -> AgentErrorType + storeError = \case + SEConnNotFound -> CONN UNKNOWN + SEConnDuplicate -> CONN DUPLICATE + e -> INTERNAL $ show e processCommand :: forall m. AgentMonad m => AgentClient -> SQLiteStore -> ATransmission 'Client -> m () processCommand c@AgentClient {sndQ} st (corrId, connAlias, cmd) = @@ -156,9 +161,7 @@ processCommand c@AgentClient {sndQ} st (corrId, connAlias, cmd) = withStore (getConn st cAlias) >>= \case SomeConn _ (DuplexConnection _ rq _) -> subscribe rq SomeConn _ (RcvConnection _ rq) -> subscribe rq - -- TODO possibly there should be a separate error type trying - -- TODO to send the message to the connection without RcvQueue - _ -> throwError PROHIBITED + _ -> throwError $ CONN SIMPLEX where subscribe rq = subscribeQueue c rq cAlias >> respond' cAlias OK @@ -171,9 +174,7 @@ processCommand c@AgentClient {sndQ} st (corrId, connAlias, cmd) = withStore (getConn st connAlias) >>= \case SomeConn _ (DuplexConnection _ _ sq) -> sendMsg sq SomeConn _ (SndConnection _ sq) -> sendMsg sq - -- TODO possibly there should be a separate error type trying - -- TODO to send the message to the connection without SndQueue - _ -> throwError PROHIBITED -- NOT_READY ? + _ -> throwError $ CONN SIMPLEX where sendMsg sq = do senderTs <- liftIO getCurrentTime @@ -186,7 +187,7 @@ processCommand c@AgentClient {sndQ} st (corrId, connAlias, cmd) = withStore (getConn st connAlias) >>= \case SomeConn _ (DuplexConnection _ rq _) -> suspend rq SomeConn _ (RcvConnection _ rq) -> suspend rq - _ -> throwError PROHIBITED + _ -> throwError $ CONN SIMPLEX where suspend rq = suspendQueue c rq >> respond OK @@ -195,13 +196,13 @@ processCommand c@AgentClient {sndQ} st (corrId, connAlias, cmd) = withStore (getConn st connAlias) >>= \case SomeConn _ (DuplexConnection _ rq _) -> delete rq SomeConn _ (RcvConnection _ rq) -> delete rq - _ -> throwError PROHIBITED + _ -> delConn where + delConn = withStore (deleteConn st connAlias) >> respond OK delete rq = do deleteQueue c rq removeSubscription c connAlias - withStore (deleteConn st connAlias) - respond OK + delConn sendReplyQInfo :: SMPServer -> SndQueue -> m () sendReplyQInfo srv sq = do @@ -242,17 +243,14 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do -- TODO update sender key in the store? secureQueue c rq senderKey withStore $ setRcvQueueStatus st rq Secured - sendAck c rq - s -> - -- TODO maybe send notification to the user - liftIO . putStrLn $ "unexpected SMP confirmation, queue status " <> show s + _ -> notify connAlias . ERR $ AGENT A_PROHIBITED SMPMessage {agentMessage, senderMsgId, senderTimestamp} -> case agentMessage of HELLO _verifyKey _ -> do logServer "<--" c srv rId "MSG " - -- TODO send status update to the user? - withStore $ setRcvQueueStatus st rq Active - sendAck c rq + case status of + Active -> notify connAlias . ERR $ AGENT A_PROHIBITED + _ -> withStore $ setRcvQueueStatus st rq Active REPLY qInfo -> do logServer "<--" c srv rId "MSG " -- TODO move senderKey inside SndQueue @@ -260,29 +258,33 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do withStore $ upgradeRcvConnToDuplex st connAlias sq connectToSendQueue c st sq senderKey verifyKey notify connAlias CON - sendAck c rq A_MSG body -> do - logServer "<--" c srv rId "MSG " -- TODO check message status - recipientTs <- liftIO getCurrentTime - let m_sender = (senderMsgId, senderTimestamp) - let m_broker = (srvMsgId, srvTs) - recipientId <- withStore $ createRcvMsg st connAlias body recipientTs m_sender m_broker - notify connAlias $ - MSG - { m_status = MsgOk, - m_recipient = (unId recipientId, recipientTs), - m_sender, - m_broker, - m_body = body - } - sendAck c rq + logServer "<--" c srv rId "MSG " + case status of + Active -> do + recipientTs <- liftIO getCurrentTime + let m_sender = (senderMsgId, senderTimestamp) + let m_broker = (srvMsgId, srvTs) + recipientId <- withStore $ createRcvMsg st connAlias body recipientTs m_sender m_broker + notify connAlias $ + MSG + { m_status = MsgOk, + m_recipient = (unId recipientId, recipientTs), + m_sender, + m_broker, + m_body = body + } + _ -> notify connAlias . ERR $ AGENT A_PROHIBITED + sendAck c rq return () SMP.END -> do removeSubscription c connAlias logServer "<--" c srv rId "END" notify connAlias END - _ -> logServer "<--" c srv rId $ "unexpected:" <> bshow cmd + _ -> do + logServer "<--" c srv rId $ "unexpected: " <> bshow cmd + notify connAlias . ERR $ BROKER UNEXPECTED where notify :: ConnAlias -> ACommand 'Agent -> m () notify connAlias msg = atomically $ writeTBQueue sndQ ("", connAlias, msg) @@ -295,7 +297,7 @@ connectToSendQueue c st sq senderKey verifyKey = do withStore $ setSndQueueStatus st sq Active decryptMessage :: (MonadUnliftIO m, MonadError AgentErrorType m) => DecryptionKey -> ByteString -> m ByteString -decryptMessage decryptKey msg = liftError CRYPTO $ C.decrypt decryptKey msg +decryptMessage decryptKey msg = liftError cryptoError $ C.decrypt decryptKey msg newSendQueue :: (MonadUnliftIO m, MonadReader Env m) => SMPQueueInfo -> ConnAlias -> m (SndQueue, SenderPublicKey, VerificationKey) diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index a8dba40fa..8c431f888 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -24,6 +24,7 @@ module Simplex.Messaging.Agent.Client deleteQueue, logServer, removeSubscription, + cryptoError, ) where @@ -47,7 +48,7 @@ import Simplex.Messaging.Agent.Transmission import Simplex.Messaging.Client import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol (ErrorType (AUTH), MsgBody, QueueId, SenderPublicKey) -import Simplex.Messaging.Util (bshow, liftError) +import Simplex.Messaging.Util (bshow, liftEitherError, liftError) import UnliftIO.Concurrent import UnliftIO.Exception (IOException) import qualified UnliftIO.Exception as E @@ -86,15 +87,17 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv = newSMPClient = do smp <- connectClient logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv - -- TODO how can agent know client lost the connection? atomically . modifyTVar smpClients $ M.insert srv smp return smp connectClient :: m SMPClient connectClient = do cfg <- asks $ smpCfg . config - liftIO (getSMPClient srv cfg msgQ clientDisconnected) - `E.catch` \(_ :: IOException) -> throwError (BROKER smpErrTCPConnection) + liftEitherError smpClientError (getSMPClient srv cfg msgQ clientDisconnected) + `E.catch` internalError + where + internalError :: IOException -> m SMPClient + internalError = throwError . INTERNAL . show clientDisconnected :: IO () clientDisconnected = do @@ -125,12 +128,6 @@ withSMP c srv action = runAction :: SMPClient -> m a runAction smp = liftError smpClientError $ action smp - smpClientError :: SMPClientError -> AgentErrorType - smpClientError = \case - SMPServerError e -> SMP e - -- TODO handle other errors - _ -> INTERNAL - logServerError :: AgentErrorType -> m a logServerError e = do logServer "<--" c srv "" $ bshow e @@ -143,6 +140,16 @@ withLogSMP c srv qId cmdStr action = do logServer "<--" c srv qId "OK" return res +smpClientError :: SMPClientError -> AgentErrorType +smpClientError = \case + SMPServerError e -> SMP e + SMPResponseError e -> BROKER $ RESPONSE e + SMPUnexpectedResponse -> BROKER UNEXPECTED + SMPResponseTimeout -> BROKER TIMEOUT + SMPNetworkError -> BROKER NETWORK + SMPTransportError e -> BROKER $ TRANSPORT e + e -> INTERNAL $ show e + newReceiveQueue :: AgentMonad m => AgentClient -> SMPServer -> ConnAlias -> m (RcvQueue, SMPQueueInfo) newReceiveQueue c srv connAlias = do size <- asks $ rsaKeySize . config @@ -214,26 +221,26 @@ sendConfirmation c SndQueue {server, sndId, encryptKey} senderKey = do mkConfirmation = do let msg = serializeSMPMessage $ SMPConfirmation senderKey paddedSize <- asks paddedMsgSize - liftError CRYPTO $ C.encrypt encryptKey paddedSize msg + liftError cryptoError $ C.encrypt encryptKey paddedSize msg sendHello :: forall m. AgentMonad m => AgentClient -> SndQueue -> VerificationKey -> m () sendHello c SndQueue {server, sndId, sndPrivateKey, encryptKey} verifyKey = do msg <- mkHello $ AckMode On withLogSMP c server sndId "SEND (retrying)" $ - send 20 msg + send 8 100000 msg where mkHello :: AckMode -> m ByteString mkHello ackMode = do senderTs <- liftIO getCurrentTime mkAgentMessage encryptKey senderTs $ HELLO verifyKey ackMode - send :: Int -> ByteString -> SMPClient -> ExceptT SMPClientError IO () - send 0 _ _ = throwE SMPResponseTimeout -- TODO different error - send retry msg smp = + send :: Int -> Int -> ByteString -> SMPClient -> ExceptT SMPClientError IO () + send 0 _ _ _ = throwE $ SMPServerError AUTH + send retry delay msg smp = sendSMPMessage smp (Just sndPrivateKey) sndId msg `catchE` \case SMPServerError AUTH -> do - threadDelay 100000 - send (retry - 1) msg smp + threadDelay delay + send (retry - 1) (delay * 3 `div` 2) msg smp e -> throwE e secureQueue :: AgentMonad m => AgentClient -> RcvQueue -> SenderPublicKey -> m () @@ -273,4 +280,12 @@ mkAgentMessage encKey senderTs agentMessage = do agentMessage } paddedSize <- asks paddedMsgSize - liftError CRYPTO $ C.encrypt encKey paddedSize msg + liftError cryptoError $ C.encrypt encKey paddedSize msg + +cryptoError :: C.CryptoError -> AgentErrorType +cryptoError = \case + C.CryptoLargeMsgError -> CMD LARGE + C.RSADecryptError _ -> AGENT A_ENCRYPTION + C.CryptoHeaderError _ -> AGENT A_ENCRYPTION + C.AESDecryptError -> AGENT A_ENCRYPTION + e -> INTERNAL $ show e diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index 6708bb13d..7521b9fae 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -10,6 +10,7 @@ module Simplex.Messaging.Agent.Store where import Control.Exception (Exception) +import Data.ByteString.Char8 (ByteString) import Data.Int (Int64) import Data.Kind (Type) import Data.Time (UTCTime) @@ -219,13 +220,11 @@ type InternalTs = UTCTime -- * Store errors --- TODO revise data StoreError - = SEInternal - | SENotFound - | SEBadConn + = SEInternal ByteString + | SEConnNotFound + | SEConnDuplicate | SEBadConnType ConnType - | SEBadQueueStatus - | SEBadQueueDirection + | SEBadQueueStatus -- not used, planned to check strictly | SENotImplemented -- TODO remove deriving (Eq, Show, Exception) diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 0b83bbd77..1173e8215 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -95,7 +95,7 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto (Just rcvQ, Just sndQ) -> return $ SomeConn SCDuplex (DuplexConnection connAlias rcvQ sndQ) (Just rcvQ, Nothing) -> return $ SomeConn SCRcv (RcvConnection connAlias rcvQ) (Nothing, Just sndQ) -> return $ SomeConn SCSnd (SndConnection connAlias sndQ) - _ -> throwError SEBadConn + _ -> throwError SEConnNotFound getAllConnAliases :: SQLiteStore -> m [ConnAlias] getAllConnAliases SQLiteStore {dbConn} = @@ -109,7 +109,7 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto retrieveRcvQueue dbConn host port rcvId case rcvQueue of Just rcvQ -> return rcvQ - _ -> throwError SENotFound + _ -> throwError SEConnNotFound deleteConn :: SQLiteStore -> ConnAlias -> m () deleteConn SQLiteStore {dbConn} connAlias = @@ -415,7 +415,7 @@ updateRcvConnWithSndQueue dbConn connAlias sndQueue = return $ Right () (Nothing, Just _sndQ) -> return $ Left (SEBadConnType CSnd) (Just _rcvQ, Just _sndQ) -> return $ Left (SEBadConnType CDuplex) - _ -> return $ Left SEBadConn + _ -> return $ Left SEConnNotFound updateConnWithSndQueue_ :: DB.Connection -> ConnAlias -> SndQueue -> IO () updateConnWithSndQueue_ dbConn connAlias SndQueue {server, sndId} = do @@ -443,7 +443,7 @@ updateSndConnWithRcvQueue dbConn connAlias rcvQueue = return $ Right () (Just _rcvQ, Nothing) -> return $ Left (SEBadConnType CRcv) (Just _rcvQ, Just _sndQ) -> return $ Left (SEBadConnType CDuplex) - _ -> return $ Left SEBadConn + _ -> return $ Left SEConnNotFound updateConnWithRcvQueue_ :: DB.Connection -> ConnAlias -> RcvQueue -> IO () updateConnWithRcvQueue_ dbConn connAlias RcvQueue {server, rcvId} = do @@ -508,7 +508,7 @@ insertRcvMsg dbConn connAlias msgBody internalTs (externalSndId, externalSndTs) updateLastInternalIdsRcv_ dbConn connAlias internalId internalRcvId return $ Right internalId (Nothing, Just _sndQ) -> return $ Left (SEBadConnType CSnd) - _ -> return $ Left SEBadConn + _ -> return $ Left SEConnNotFound retrieveLastInternalIdsRcv_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalRcvId) retrieveLastInternalIdsRcv_ dbConn connAlias = do @@ -599,7 +599,7 @@ insertSndMsg dbConn connAlias msgBody internalTs = updateLastInternalIdsSnd_ dbConn connAlias internalId internalSndId return $ Right internalId (Just _rcvQ, Nothing) -> return $ Left (SEBadConnType CRcv) - _ -> return $ Left SEBadConn + _ -> return $ Left SEConnNotFound retrieveLastInternalIdsSnd_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalSndId) retrieveLastInternalIdsSnd_ dbConn connAlias = do diff --git a/src/Simplex/Messaging/Agent/Transmission.hs b/src/Simplex/Messaging/Agent/Transmission.hs index 0c9c8bfad..02706da30 100644 --- a/src/Simplex/Messaging/Agent/Transmission.hs +++ b/src/Simplex/Messaging/Agent/Transmission.hs @@ -1,5 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} @@ -26,8 +27,9 @@ import Data.Time.Clock (UTCTime) import Data.Time.ISO8601 import Data.Type.Equality import Data.Typeable () +import GHC.Generics (Generic) +import Generic.Random (genericArbitraryU) import Network.Socket -import Numeric.Natural import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Parsers import Simplex.Messaging.Protocol @@ -37,12 +39,12 @@ import Simplex.Messaging.Protocol MsgBody, MsgId, SenderPublicKey, - errMessageBody, ) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Transport import Simplex.Messaging.Util import System.IO +import Test.QuickCheck (Arbitrary (..)) import Text.Read import UnliftIO.Exception @@ -125,7 +127,7 @@ data AMessage where deriving (Show) parseSMPMessage :: ByteString -> Either AgentErrorType SMPMessage -parseSMPMessage = parse (smpMessageP <* A.endOfLine) $ SYNTAX errBadMessage +parseSMPMessage = parse (smpMessageP <* A.endOfLine) $ AGENT A_MESSAGE where smpMessageP :: Parser SMPMessage smpMessageP = @@ -186,7 +188,7 @@ smpServerP = SMPServer <$> server <*> port <*> kHash else pure Nothing parseAgentMessage :: ByteString -> Either AgentErrorType AMessage -parseAgentMessage = parse agentMessageP $ SYNTAX errBadMessage +parseAgentMessage = parse agentMessageP $ AGENT A_MESSAGE serializeAgentMessage :: AMessage -> ByteString serializeAgentMessage = \case @@ -245,50 +247,53 @@ data MsgStatus = MsgOk | MsgError MsgErrorType data MsgErrorType = MsgSkipped AgentMsgId AgentMsgId | MsgBadId AgentMsgId | MsgBadHash deriving (Eq, Show) +-- | error type used in errors sent to agent clients data AgentErrorType - = UNKNOWN - | PROHIBITED - | SYNTAX Int - | BROKER Natural - | SMP ErrorType - | CRYPTO C.CryptoError - | SIZE - | STORE - | INTERNAL -- etc. TODO SYNTAX Natural - deriving (Eq, Show, Exception) + = CMD CommandErrorType -- command errors + | CONN ConnectionErrorType -- connection state errors + | SMP ErrorType -- SMP protocol errors forwarded to agent clients + | BROKER BrokerErrorType -- SMP server errors + | AGENT SMPAgentError -- errors of other agents + | INTERNAL String -- agent implementation errors + deriving (Eq, Generic, Read, Show, Exception) -data AckStatus = AckOk | AckError AckErrorType - deriving (Show) +data CommandErrorType + = PROHIBITED -- command is prohibited + | SYNTAX -- command syntax is invalid + | NO_CONN -- connection alias is required with this command + | SIZE -- message size is not correct (no terminating space) + | LARGE -- message does not fit SMP block + deriving (Eq, Generic, Read, Show, Exception) -data AckErrorType = AckUnknown | AckProhibited | AckSyntax Int -- etc. - deriving (Show) +data ConnectionErrorType + = UNKNOWN -- connection alias not in database + | DUPLICATE -- connection alias already exists + | SIMPLEX -- connection is simplex, but operation requires another queue + deriving (Eq, Generic, Read, Show, Exception) -errBadEncoding :: Int -errBadEncoding = 10 +data BrokerErrorType + = RESPONSE ErrorType -- invalid server response (failed to parse) + | UNEXPECTED -- unexpected response + | NETWORK -- network error + | TRANSPORT TransportError -- handshake or other transport error + | TIMEOUT -- command response timeout + deriving (Eq, Generic, Read, Show, Exception) -errBadCommand :: Int -errBadCommand = 11 +data SMPAgentError + = A_MESSAGE -- possibly should include bytestring that failed to parse + | A_PROHIBITED -- possibly should include the prohibited SMP/agent message + | A_ENCRYPTION -- cannot RSA/AES-decrypt or parse decrypted header + deriving (Eq, Generic, Read, Show, Exception) -errBadInvitation :: Int -errBadInvitation = 12 +instance Arbitrary AgentErrorType where arbitrary = genericArbitraryU -errNoConnAlias :: Int -errNoConnAlias = 13 +instance Arbitrary CommandErrorType where arbitrary = genericArbitraryU -errBadMessage :: Int -errBadMessage = 14 +instance Arbitrary ConnectionErrorType where arbitrary = genericArbitraryU -errBadServer :: Int -errBadServer = 15 +instance Arbitrary BrokerErrorType where arbitrary = genericArbitraryU -smpErrTCPConnection :: Natural -smpErrTCPConnection = 1 - -smpErrCorrelationId :: Natural -smpErrCorrelationId = 2 - -smpUnexpectedResponse :: Natural -smpUnexpectedResponse = 3 +instance Arbitrary SMPAgentError where arbitrary = genericArbitraryU commandP :: Parser ACmd commandP = @@ -319,9 +324,6 @@ commandP = m_sender <- "S=" *> partyMeta A.decimal m_body <- A.takeByteString return $ ACmd SAgent MSG {m_recipient, m_broker, m_sender, m_status, m_body} - -- TODO other error types - agentError = ACmd SAgent . ERR <$> ("SMP " *> smpErrorType) - smpErrorType = "AUTH" $> SMP SMP.AUTH replyMode = " NO_REPLY" $> ReplyOff <|> A.space *> (ReplyVia <$> smpServerP) @@ -332,9 +334,10 @@ commandP = "ID " *> (MsgBadId <$> A.decimal) <|> "IDS " *> (MsgSkipped <$> A.decimal <* A.space <*> A.decimal) <|> "HASH" $> MsgBadHash + agentError = ACmd SAgent . ERR <$> agentErrorTypeP parseCommand :: ByteString -> Either AgentErrorType ACmd -parseCommand = parse commandP $ SYNTAX errBadCommand +parseCommand = parse commandP $ CMD SYNTAX serializeCommand :: ACommand p -> ByteString serializeCommand = \case @@ -358,7 +361,7 @@ serializeCommand = \case OFF -> "OFF" DEL -> "DEL" CON -> "CON" - ERR e -> "ERR " <> bshow e + ERR e -> "ERR " <> serializeAgentError e OK -> "OK" where replyMode :: ReplyMode -> ByteString @@ -378,7 +381,21 @@ serializeCommand = \case MsgBadId aMsgId -> "ID " <> bshow aMsgId MsgBadHash -> "HASH" --- TODO - save function as in the server Transmission - re-use? +agentErrorTypeP :: Parser AgentErrorType +agentErrorTypeP = + "SMP " *> (SMP <$> SMP.errorTypeP) + <|> "BROKER RESPONSE " *> (BROKER . RESPONSE <$> SMP.errorTypeP) + <|> "BROKER TRANSPORT " *> (BROKER . TRANSPORT <$> transportErrorP) + <|> "INTERNAL " *> (INTERNAL <$> parseRead A.takeByteString) + <|> parseRead2 + +serializeAgentError :: AgentErrorType -> ByteString +serializeAgentError = \case + SMP e -> "SMP " <> SMP.serializeErrorType e + BROKER (RESPONSE e) -> "BROKER RESPONSE " <> SMP.serializeErrorType e + BROKER (TRANSPORT e) -> "BROKER TRANSPORT " <> serializeTransportError e + e -> bshow e + serializeMsg :: ByteString -> ByteString serializeMsg body = bshow (B.length body) <> "\n" <> body @@ -408,7 +425,7 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody fromParty :: ACmd -> Either AgentErrorType (ACommand p) fromParty (ACmd (p :: p1) cmd) = case testEquality party p of Just Refl -> Right cmd - _ -> Left PROHIBITED + _ -> Left $ CMD PROHIBITED tConnAlias :: ARawTransmission -> ACommand p -> Either AgentErrorType (ACommand p) tConnAlias (_, connAlias, _) cmd = case cmd of @@ -419,7 +436,7 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody ERR _ -> Right cmd -- other responses must have connAlias _ - | B.null connAlias -> Left $ SYNTAX errNoConnAlias + | B.null connAlias -> Left $ CMD NO_CONN | otherwise -> Right cmd cmdWithMsgBody :: ACommand p -> m (Either AgentErrorType (ACommand p)) @@ -437,5 +454,5 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody Just size -> liftIO $ do body <- B.hGet h size s <- getLn h - return $ if B.null s then Right body else Left SIZE - Nothing -> return . Left $ SYNTAX errMessageBody + return $ if B.null s then Right body else Left $ CMD SIZE + Nothing -> return . Left $ CMD SYNTAX diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index ad7020e14..1efd16443 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -34,12 +34,10 @@ import Control.Exception import Control.Monad import Control.Monad.Trans.Class import Control.Monad.Trans.Except -import qualified Crypto.PubKey.RSA.Types as RSA import Data.ByteString.Char8 (ByteString) import Data.Map.Strict (Map) import qualified Data.Map.Strict as M import Data.Maybe -import GHC.IO.Exception (IOErrorType (..)) import Network.Socket (ServiceName) import Numeric.Natural import Simplex.Messaging.Agent.Transmission (SMPServer (..)) @@ -48,13 +46,13 @@ import Simplex.Messaging.Protocol import Simplex.Messaging.Transport import Simplex.Messaging.Util (bshow, liftEitherError, raceAny_) import System.IO -import System.IO.Error import System.Timeout data SMPClient = SMPClient { action :: Async (), connected :: TVar Bool, smpServer :: SMPServer, + tcpTimeout :: Int, clientCorrId :: TVar Natural, sentCommands :: TVar (Map CorrId Request), sndQ :: TBQueue SignedRawTransmission, @@ -78,7 +76,7 @@ smpDefaultConfig = SMPClientConfig { qSize = 16, defaultPort = "5223", - tcpTimeout = 2_000_000, + tcpTimeout = 4_000_000, smpPing = 30_000_000, blockSize = 8_192, -- 16_384, smpCommandSize = 256 @@ -86,28 +84,29 @@ smpDefaultConfig = data Request = Request { queueId :: QueueId, - responseVar :: TMVar (Either SMPClientError Cmd) + responseVar :: TMVar Response } -getSMPClient :: SMPServer -> SMPClientConfig -> TBQueue SMPServerTransmission -> IO () -> IO SMPClient +type Response = Either SMPClientError Cmd + +getSMPClient :: SMPServer -> SMPClientConfig -> TBQueue SMPServerTransmission -> IO () -> IO (Either SMPClientError SMPClient) getSMPClient smpServer@SMPServer {host, port, keyHash} SMPClientConfig {qSize, defaultPort, tcpTimeout, smpPing} msgQ disconnected = do c <- atomically mkSMPClient - started <- newEmptyTMVarIO + err <- newEmptyTMVarIO action <- async $ - runTCPClient host (fromMaybe defaultPort port) (client c started) - `finally` atomically (putTMVar started False) - tcpTimeout `timeout` atomically (takeTMVar started) >>= \case - Just True -> return c {action} - _ -> throwIO err -- TODO report handshake error too, not only connection timeout + runTCPClient host (fromMaybe defaultPort port) (client c err) + `finally` atomically (putTMVar err $ Just SMPNetworkError) + ok <- tcpTimeout `timeout` atomically (takeTMVar err) + pure $ case ok of + Just Nothing -> Right c {action} + Just (Just e) -> Left e + Nothing -> Left SMPNetworkError where - err :: IOException - err = mkIOError TimeExpired "connection timeout" Nothing Nothing - mkSMPClient :: STM SMPClient mkSMPClient = do connected <- newTVar False @@ -120,6 +119,7 @@ getSMPClient { action = undefined, connected, smpServer, + tcpTimeout, clientCorrId, sentCommands, sndQ, @@ -127,18 +127,17 @@ getSMPClient msgQ } - client :: SMPClient -> TMVar Bool -> Handle -> IO () - client c started h = + client :: SMPClient -> TMVar (Maybe SMPClientError) -> Handle -> IO () + client c err h = runExceptT (clientHandshake h keyHash) >>= \case - Right th -> clientTransport c started th - -- TODO report error instead of True/False - Left _ -> atomically $ putTMVar started False + Right th -> clientTransport c err th + Left e -> atomically . putTMVar err . Just $ SMPTransportError e - clientTransport :: SMPClient -> TMVar Bool -> THandle -> IO () - clientTransport c started th = do + clientTransport :: SMPClient -> TMVar (Maybe SMPClientError) -> THandle -> IO () + clientTransport c err th = do atomically $ do writeTVar (connected c) True - putTMVar started True + putTMVar err Nothing raceAny_ [send c th, process c, receive c th, ping c] `finally` disconnected @@ -171,7 +170,7 @@ getSMPClient Left e -> Left $ SMPResponseError e Right (Cmd _ (ERR e)) -> Left $ SMPServerError e Right r -> Right r - else Left SMPQueueIdError + else Left SMPUnexpectedResponse closeSMPClient :: SMPClient -> IO () closeSMPClient = uninterruptibleCancel . action @@ -179,11 +178,11 @@ closeSMPClient = uninterruptibleCancel . action data SMPClientError = SMPServerError ErrorType | SMPResponseError ErrorType - | SMPQueueIdError | SMPUnexpectedResponse | SMPResponseTimeout - | SMPCryptoError RSA.Error - | SMPClientError + | SMPNetworkError + | SMPTransportError TransportError + | SMPSignatureError C.CryptoError deriving (Eq, Show, Exception) createSMPQueue :: @@ -235,7 +234,7 @@ okSMPCommand cmd c pKey qId = _ -> throwE SMPUnexpectedResponse sendSMPCommand :: SMPClient -> Maybe C.SafePrivateKey -> QueueId -> Cmd -> ExceptT SMPClientError IO Cmd -sendSMPCommand SMPClient {sndQ, sentCommands, clientCorrId} pKey qId cmd = do +sendSMPCommand SMPClient {sndQ, sentCommands, clientCorrId, tcpTimeout} pKey qId cmd = do corrId <- lift_ getNextCorrId t <- signTransmission $ serializeTransmission (corrId, qId, cmd) ExceptT $ sendRecv corrId t @@ -253,14 +252,16 @@ sendSMPCommand SMPClient {sndQ, sentCommands, clientCorrId} pKey qId cmd = do signTransmission t = case pKey of Nothing -> return ("", t) Just pk -> do - sig <- liftEitherError SMPCryptoError $ C.sign pk t + sig <- liftEitherError SMPSignatureError $ C.sign pk t return (sig, t) -- two separate "atomically" needed to avoid blocking - sendRecv :: CorrId -> SignedRawTransmission -> IO (Either SMPClientError Cmd) - sendRecv corrId t = atomically (send corrId t) >>= atomically . takeTMVar + sendRecv :: CorrId -> SignedRawTransmission -> IO Response + sendRecv corrId t = atomically (send corrId t) >>= withTimeout . atomically . takeTMVar + where + withTimeout a = fromMaybe (Left SMPResponseTimeout) <$> timeout tcpTimeout a - send :: CorrId -> SignedRawTransmission -> STM (TMVar (Either SMPClientError Cmd)) + send :: CorrId -> SignedRawTransmission -> STM (TMVar Response) send corrId t = do r <- newEmptyTMVar modifyTVar sentCommands . M.insert corrId $ Request qId r diff --git a/src/Simplex/Messaging/Crypto.hs b/src/Simplex/Messaging/Crypto.hs index 5b8b85b53..766c46632 100644 --- a/src/Simplex/Messaging/Crypto.hs +++ b/src/Simplex/Messaging/Crypto.hs @@ -68,7 +68,7 @@ import Data.ASN1.Encoding import Data.ASN1.Types import Data.Attoparsec.ByteString.Char8 (Parser) import qualified Data.Attoparsec.ByteString.Char8 as A -import Data.Bifunctor (first) +import Data.Bifunctor (bimap, first) import qualified Data.ByteArray as BA import Data.ByteString.Base64 import Data.ByteString.Char8 (ByteString) @@ -85,7 +85,7 @@ import Database.SQLite.Simple.Ok (Ok (Ok)) import Database.SQLite.Simple.ToField (ToField (..)) import Network.Transport.Internal (decodeWord32, encodeWord32) import Simplex.Messaging.Parsers (base64P, base64StringP, parseAll) -import Simplex.Messaging.Util (liftEitherError, (<$$>)) +import Simplex.Messaging.Util (liftEitherError) newtype PublicKey = PublicKey {rsaPublicKey :: R.PublicKey} deriving (Eq, Show) @@ -148,10 +148,12 @@ instance IsString Signature where newtype Verified = Verified ByteString deriving (Show) data CryptoError - = CryptoRSAError R.Error - | CryptoCipherError CE.CryptoError + = RSAEncryptError R.Error + | RSADecryptError R.Error + | RSASignError R.Error + | AESCipherError CE.CryptoError | CryptoIVError - | CryptoDecryptError + | AESDecryptError | CryptoLargeMsgError | CryptoHeaderError String deriving (Eq, Show, Exception) @@ -276,7 +278,7 @@ encryptAES aesKey ivBytes paddedSize msg = do decryptAES :: Key -> IV -> ByteString -> AES.AuthTag -> ExceptT CryptoError IO ByteString decryptAES aesKey ivBytes msg authTag = do aead <- initAEAD @AES256 aesKey ivBytes - maybeError CryptoDecryptError $ AES.aeadSimpleDecrypt aead B.empty msg authTag + maybeError AESDecryptError $ AES.aeadSimpleDecrypt aead B.empty msg authTag initAEAD :: forall c. AES.BlockCipher c => Key -> IV -> ExceptT CryptoError IO (AES.AEAD c) initAEAD (Key aesKey) (IV ivBytes) = do @@ -307,26 +309,26 @@ bsToAuthTag :: ByteString -> AES.AuthTag bsToAuthTag = AES.AuthTag . BA.pack . map c2w . B.unpack cryptoFailable :: CE.CryptoFailable a -> ExceptT CryptoError IO a -cryptoFailable = liftEither . first CryptoCipherError . CE.eitherCryptoError +cryptoFailable = liftEither . first AESCipherError . CE.eitherCryptoError oaepParams :: OAEP.OAEPParams SHA256 ByteString ByteString oaepParams = OAEP.defaultOAEPParams SHA256 encryptOAEP :: PublicKey -> ByteString -> ExceptT CryptoError IO ByteString encryptOAEP (PublicKey k) aesKey = - liftEitherError CryptoRSAError $ + liftEitherError RSAEncryptError $ OAEP.encrypt oaepParams k aesKey decryptOAEP :: PrivateKey k => k -> ByteString -> ExceptT CryptoError IO ByteString decryptOAEP pk encKey = - liftEitherError CryptoRSAError $ + liftEitherError RSADecryptError $ OAEP.decryptSafer oaepParams (rsaPrivateKey pk) encKey pssParams :: PSS.PSSParams SHA256 ByteString ByteString pssParams = PSS.defaultPSSParams SHA256 -sign :: PrivateKey k => k -> ByteString -> IO (Either R.Error Signature) -sign pk msg = Signature <$$> PSS.signSafer pssParams (rsaPrivateKey pk) msg +sign :: PrivateKey k => k -> ByteString -> IO (Either CryptoError Signature) +sign pk msg = bimap RSASignError Signature <$> PSS.signSafer pssParams (rsaPrivateKey pk) msg verify :: PublicKey -> Signature -> ByteString -> Bool verify (PublicKey k) (Signature sig) msg = PSS.verify pssParams k msg sig diff --git a/src/Simplex/Messaging/Parsers.hs b/src/Simplex/Messaging/Parsers.hs index 74a802435..4f96bae0b 100644 --- a/src/Simplex/Messaging/Parsers.hs +++ b/src/Simplex/Messaging/Parsers.hs @@ -1,3 +1,5 @@ +{-# LANGUAGE OverloadedStrings #-} + module Simplex.Messaging.Parsers where import Data.Attoparsec.ByteString.Char8 (Parser) @@ -9,6 +11,7 @@ import qualified Data.ByteString.Char8 as B import Data.Char (isAlphaNum) import Data.Time.Clock (UTCTime) import Data.Time.ISO8601 (parseISO8601) +import Text.Read (readMaybe) base64P :: Parser ByteString base64P = either fail pure . decode =<< base64StringP @@ -27,3 +30,15 @@ parse parser err = first (const err) . parseAll parser parseAll :: Parser a -> (ByteString -> Either String a) parseAll parser = A.parseOnly (parser <* A.endOfInput) + +parseRead :: Read a => Parser ByteString -> Parser a +parseRead = (>>= maybe (fail "cannot read") pure . readMaybe . B.unpack) + +parseRead1 :: Read a => Parser a +parseRead1 = parseRead $ A.takeTill (== ' ') + +parseRead2 :: Read a => Parser a +parseRead2 = parseRead $ do + w1 <- A.takeTill (== ' ') <* A.char ' ' + w2 <- A.takeTill (== ' ') + pure $ w1 <> " " <> w2 diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 2a4ea433e..d7f139c18 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} @@ -24,10 +25,13 @@ import Data.Kind import Data.String import Data.Time.Clock import Data.Time.ISO8601 +import GHC.Generics (Generic) +import Generic.Random (genericArbitraryU) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Parsers import Simplex.Messaging.Transport import Simplex.Messaging.Util +import Test.QuickCheck (Arbitrary (..)) data Party = Broker | Recipient | Sender deriving (Show) @@ -106,25 +110,26 @@ type MsgId = Encoded type MsgBody = ByteString -data ErrorType = PROHIBITED | SYNTAX Int | AUTH | INTERNAL | DUPLICATE deriving (Show, Eq) +data ErrorType + = BLOCK + | CMD CommandError + | AUTH + | NO_MSG + | INTERNAL + | DUPLICATE_ -- TODO remove, not part of SMP protocol + deriving (Eq, Generic, Read, Show) -errBadTransmission :: Int -errBadTransmission = 1 +data CommandError + = PROHIBITED + | SYNTAX + | NO_AUTH + | HAS_AUTH + | NO_QUEUE + deriving (Eq, Generic, Read, Show) -errBadSMPCommand :: Int -errBadSMPCommand = 2 +instance Arbitrary ErrorType where arbitrary = genericArbitraryU -errNoCredentials :: Int -errNoCredentials = 3 - -errHasCredentials :: Int -errHasCredentials = 4 - -errNoQueueId :: Int -errNoQueueId = 5 - -errMessageBody :: Int -errMessageBody = 6 +instance Arbitrary CommandError where arbitrary = genericArbitraryU transmissionP :: Parser RawTransmission transmissionP = do @@ -164,16 +169,11 @@ commandP = ts <- tsISO8601P <* A.space size <- A.decimal <* A.space Cmd SBroker . MSG msgId ts <$> A.take size <* A.space - serverError = Cmd SBroker . ERR <$> errorType - errorType = - "PROHIBITED" $> PROHIBITED - <|> "SYNTAX " *> (SYNTAX <$> A.decimal) - <|> "AUTH" $> AUTH - <|> "INTERNAL" $> INTERNAL + serverError = Cmd SBroker . ERR <$> errorTypeP -- TODO ignore the end of block, no need to parse it parseCommand :: ByteString -> Either ErrorType Cmd -parseCommand = parse (commandP <* " " <* A.takeByteString) $ SYNTAX errBadSMPCommand +parseCommand = parse (commandP <* " " <* A.takeByteString) $ CMD SYNTAX serializeCommand :: Cmd -> ByteString serializeCommand = \case @@ -185,11 +185,17 @@ serializeCommand = \case Cmd SBroker (MSG msgId ts msgBody) -> B.unwords ["MSG", encode msgId, B.pack $ formatISO8601Millis ts, serializeMsg msgBody] Cmd SBroker (IDS rId sId) -> B.unwords ["IDS", encode rId, encode sId] - Cmd SBroker (ERR err) -> "ERR " <> bshow err + Cmd SBroker (ERR err) -> "ERR " <> serializeErrorType err Cmd SBroker resp -> bshow resp where serializeMsg msgBody = bshow (B.length msgBody) <> " " <> msgBody <> " " +errorTypeP :: Parser ErrorType +errorTypeP = "CMD " *> (CMD <$> parseRead1) <|> parseRead1 + +serializeErrorType :: ErrorType -> ByteString +serializeErrorType = bshow + tPut :: THandle -> SignedRawTransmission -> IO (Either TransportError ()) tPut th (C.Signature sig, t) = tPutEncrypted th $ encode sig <> " " <> t <> " " @@ -200,16 +206,16 @@ serializeTransmission (CorrId corrId, queueId, command) = fromClient :: Cmd -> Either ErrorType Cmd fromClient = \case - Cmd SBroker _ -> Left PROHIBITED + Cmd SBroker _ -> Left $ CMD PROHIBITED cmd -> Right cmd fromServer :: Cmd -> Either ErrorType Cmd fromServer = \case cmd@(Cmd SBroker _) -> Right cmd - _ -> Left PROHIBITED + _ -> Left $ CMD PROHIBITED tGetParse :: THandle -> IO (Either TransportError RawTransmission) -tGetParse th = (>>= parse transmissionP TransportParsingError) <$> tGetEncrypted th +tGetParse th = (>>= parse transmissionP TEBadBlock) <$> tGetEncrypted th -- | get client and server transmissions -- `fromParty` is used to limit allowed senders - `fromClient` or `fromServer` should be used @@ -224,7 +230,7 @@ tGet fromParty th = liftIO (tGetParse th) >>= decodeParseValidate Left _ -> tError "" tError :: ByteString -> m SignedTransmissionOrError - tError corrId = return (C.Signature B.empty, (CorrId corrId, B.empty, Left $ SYNTAX errBadTransmission)) + tError corrId = return (C.Signature B.empty, (CorrId corrId, B.empty, Left BLOCK)) tParseValidate :: RawTransmission -> m SignedTransmissionOrError tParseValidate t@(sig, corrId, queueId, command) = do @@ -233,32 +239,32 @@ tGet fromParty th = liftIO (tGetParse th) >>= decodeParseValidate tCredentials :: RawTransmission -> Cmd -> Either ErrorType Cmd tCredentials (signature, _, queueId, _) cmd = case cmd of - -- IDS response should not have queue ID + -- IDS response must not have queue ID Cmd SBroker (IDS _ _) -> Right cmd -- ERR response does not always have queue ID Cmd SBroker (ERR _) -> Right cmd - -- PONG response should not have queue ID + -- PONG response must not have queue ID Cmd SBroker PONG | B.null queueId -> Right cmd - | otherwise -> Left $ SYNTAX errHasCredentials + | otherwise -> Left $ CMD HAS_AUTH -- other responses must have queue ID Cmd SBroker _ - | B.null queueId -> Left $ SYNTAX errNoQueueId + | B.null queueId -> Left $ CMD NO_QUEUE | otherwise -> Right cmd -- NEW must NOT have signature or queue ID Cmd SRecipient (NEW _) - | B.null signature -> Left $ SYNTAX errNoCredentials - | not (B.null queueId) -> Left $ SYNTAX errHasCredentials + | B.null signature -> Left $ CMD NO_AUTH + | not (B.null queueId) -> Left $ CMD HAS_AUTH | otherwise -> Right cmd -- SEND must have queue ID, signature is not always required Cmd SSender (SEND _) - | B.null queueId -> Left $ SYNTAX errNoQueueId + | B.null queueId -> Left $ CMD NO_QUEUE | otherwise -> Right cmd -- PING must not have queue ID or signature Cmd SSender PING | B.null queueId && B.null signature -> Right cmd - | otherwise -> Left $ SYNTAX errHasCredentials + | otherwise -> Left $ CMD HAS_AUTH -- other client commands must have both signature and queue ID Cmd SRecipient _ - | B.null signature || B.null queueId -> Left $ SYNTAX errNoCredentials + | B.null signature || B.null queueId -> Left $ CMD NO_AUTH | otherwise -> Right cmd diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index c78eac3c1..3b6d6d8a3 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -165,7 +165,7 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} = addQueueRetry n = do ids <- getIds atomically (addQueue st rKey ids) >>= \case - Left DUPLICATE -> addQueueRetry $ n - 1 + Left DUPLICATE_ -> addQueueRetry $ n - 1 Left e -> return $ Left e Right _ -> return $ Right ids @@ -200,7 +200,7 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} = atomically (withSub queueId $ \s -> const s <$$> tryTakeTMVar (delivered s)) >>= \case Just (Just s) -> deliverMessage tryDelPeekMsg queueId s - _ -> return $ err PROHIBITED + _ -> return $ err NO_MSG withSub :: RecipientId -> (Sub -> STM a) -> STM (Maybe a) withSub rId f = readTVar subscriptions >>= mapM f . M.lookup rId diff --git a/src/Simplex/Messaging/Server/QueueStore/STM.hs b/src/Simplex/Messaging/Server/QueueStore/STM.hs index dca596e82..86caff78f 100644 --- a/src/Simplex/Messaging/Server/QueueStore/STM.hs +++ b/src/Simplex/Messaging/Server/QueueStore/STM.hs @@ -32,7 +32,7 @@ instance MonadQueueStore QueueStore STM where addQueue store rKey ids@(rId, sId) = do cs@QueueStoreData {queues, senders} <- readTVar store if M.member rId queues || M.member sId senders - then return $ Left DUPLICATE + then return $ Left DUPLICATE_ else do writeTVar store $ cs diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index 4d5d63e32..d0468886e 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -1,6 +1,7 @@ {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE BlockArguments #-} {-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} @@ -11,6 +12,7 @@ module Simplex.Messaging.Transport where +import Control.Applicative ((<|>)) import Control.Monad.Except import Control.Monad.IO.Unlift import Control.Monad.Trans.Except (throwE) @@ -21,18 +23,22 @@ import Data.Bifunctor (first) import Data.ByteArray (xor) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B +import Data.Functor (($>)) import Data.Set (Set) import qualified Data.Set as S import Data.Word (Word32) +import GHC.Generics (Generic) import GHC.IO.Exception (IOErrorType (..)) import GHC.IO.Handle.Internals (ioe_EOF) +import Generic.Random (genericArbitraryU) import Network.Socket import Network.Transport.Internal (encodeWord32) import qualified Simplex.Messaging.Crypto as C -import Simplex.Messaging.Parsers (parse, parseAll) +import Simplex.Messaging.Parsers (parse, parseAll, parseRead1) import Simplex.Messaging.Util (bshow, liftError) import System.IO import System.IO.Error +import Test.QuickCheck (Arbitrary (..)) import UnliftIO.Concurrent import UnliftIO.Exception (Exception, IOException) import qualified UnliftIO.Exception as E @@ -50,7 +56,10 @@ runTCPServer started port server = do atomically . modifyTVar clients $ S.insert tid where closeServer :: TVar (Set ThreadId) -> Socket -> IO () - closeServer clients sock = readTVarIO clients >>= mapM_ killThread >> close sock + closeServer clients sock = do + readTVarIO clients >>= mapM_ killThread + close sock + void . atomically $ tryPutTMVar started False startTCPServer :: TMVar Bool -> ServiceName -> IO Socket startTCPServer started port = withSocketsDo $ resolve >>= open >>= setStarted @@ -76,9 +85,7 @@ runTCPClient host port client = do client h `E.finally` IO.hClose h startTCPClient :: HostName -> ServiceName -> IO Handle -startTCPClient host port = - withSocketsDo $ - resolve >>= foldM tryOpen (Left err) >>= either E.throwIO return -- replace fold with recursion +startTCPClient host port = withSocketsDo $ resolve >>= tryOpen err where err :: IOException err = mkIOError NoSuchThing "no address" Nothing Nothing @@ -88,9 +95,10 @@ startTCPClient host port = let hints = defaultHints {addrSocketType = Stream} in getAddrInfo (Just hints) (Just host) (Just port) - tryOpen :: Exception e => Either e Handle -> AddrInfo -> IO (Either e Handle) - tryOpen (Left _) addr = E.try $ open addr - tryOpen h _ = return h + tryOpen :: IOException -> [AddrInfo] -> IO Handle + tryOpen e [] = E.throwIO e + tryOpen _ (addr : as) = + E.try (open addr) >>= either (`tryOpen` as) pure open :: AddrInfo -> IO Handle open addr = do @@ -153,21 +161,51 @@ data HandshakeKeys = HandshakeKeys } data TransportError - = TransportCryptoError C.CryptoError - | TransportParsingError - | TransportHandshakeError String - deriving (Eq, Show, Exception) + = TEBadBlock + | TEEncrypt + | TEDecrypt + | TEHandshake HandshakeError + deriving (Eq, Generic, Read, Show, Exception) + +data HandshakeError + = ENCRYPT + | DECRYPT + | VERSION + | RSA_KEY + | AES_KEYS + | BAD_HASH + | MAJOR_VERSION + | TERMINATED + deriving (Eq, Generic, Read, Show, Exception) + +instance Arbitrary TransportError where arbitrary = genericArbitraryU + +instance Arbitrary HandshakeError where arbitrary = genericArbitraryU + +transportErrorP :: Parser TransportError +transportErrorP = + "BLOCK" $> TEBadBlock + <|> "AES_ENCRYPT" $> TEEncrypt + <|> "AES_DECRYPT" $> TEDecrypt + <|> TEHandshake <$> parseRead1 + +serializeTransportError :: TransportError -> ByteString +serializeTransportError = \case + TEEncrypt -> "AES_ENCRYPT" + TEDecrypt -> "AES_DECRYPT" + TEBadBlock -> "BLOCK" + TEHandshake e -> bshow e tPutEncrypted :: THandle -> ByteString -> IO (Either TransportError ()) tPutEncrypted THandle {handle = h, sndKey, blockSize} block = encryptBlock sndKey (blockSize - C.authTagSize) block >>= \case - Left e -> return . Left $ TransportCryptoError e + Left _ -> pure $ Left TEEncrypt Right (authTag, msg) -> Right <$> B.hPut h (C.authTagToBS authTag <> msg) tGetEncrypted :: THandle -> IO (Either TransportError ByteString) tGetEncrypted THandle {handle = h, rcvKey, blockSize} = B.hGet h blockSize >>= decryptBlock rcvKey >>= \case - Left e -> pure . Left $ TransportCryptoError e + Left _ -> pure $ Left TEDecrypt Right "" -> ioe_EOF Right msg -> pure $ Right msg @@ -207,11 +245,11 @@ serverHandshake h (k, pk) = do receiveEncryptedKeys_4 :: ExceptT TransportError IO ByteString receiveEncryptedKeys_4 = liftIO (B.hGet h $ C.publicKeySize k) >>= \case - "" -> throwE $ TransportHandshakeError "EOF" + "" -> throwE $ TEHandshake TERMINATED ks -> pure ks decryptParseKeys_5 :: ByteString -> ExceptT TransportError IO HandshakeKeys decryptParseKeys_5 encKeys = - liftError TransportCryptoError (C.decryptOAEP pk encKeys) + liftError (const $ TEHandshake DECRYPT) (C.decryptOAEP pk encKeys) >>= liftEither . parseHandshakeKeys sendWelcome_6 :: THandle -> ExceptT TransportError IO () sendWelcome_6 th = ExceptT . tPutEncrypted th $ serializeSMPVersion currentSMPVersion <> " " @@ -233,11 +271,11 @@ clientHandshake h keyHash = do maybe (pure ()) (validateKeyHash_2 s) keyHash liftEither $ parseKey s parseKey :: ByteString -> Either TransportError C.PublicKey - parseKey = first TransportHandshakeError . parseAll C.pubKeyP + parseKey = first (const $ TEHandshake RSA_KEY) . parseAll C.pubKeyP validateKeyHash_2 :: ByteString -> C.KeyHash -> ExceptT TransportError IO () validateKeyHash_2 k kHash | C.getKeyHash k == kHash = pure () - | otherwise = throwE $ TransportHandshakeError "wrong key hash" + | otherwise = throwE $ TEHandshake BAD_HASH generateKeys_3 :: IO HandshakeKeys generateKeys_3 = HandshakeKeys <$> generateKey <*> generateKey generateKey :: IO SessionKey @@ -247,16 +285,16 @@ clientHandshake h keyHash = do pure SessionKey {aesKey, baseIV, counter = undefined} sendEncryptedKeys_4 :: C.PublicKey -> HandshakeKeys -> ExceptT TransportError IO () sendEncryptedKeys_4 k keys = - liftError TransportCryptoError (C.encryptOAEP k $ serializeHandshakeKeys keys) + liftError (const $ TEHandshake ENCRYPT) (C.encryptOAEP k $ serializeHandshakeKeys keys) >>= liftIO . B.hPut h getWelcome_6 :: THandle -> ExceptT TransportError IO SMPVersion getWelcome_6 th = ExceptT $ (>>= parseSMPVersion) <$> tGetEncrypted th parseSMPVersion :: ByteString -> Either TransportError SMPVersion - parseSMPVersion = first TransportHandshakeError . A.parseOnly (smpVersionP <* A.space) + parseSMPVersion = first (const $ TEHandshake VERSION) . A.parseOnly (smpVersionP <* A.space) checkVersion :: SMPVersion -> ExceptT TransportError IO () checkVersion smpVersion = when (major smpVersion > major currentSMPVersion) . throwE $ - TransportHandshakeError "SMP server version" + TEHandshake MAJOR_VERSION serializeHandshakeKeys :: HandshakeKeys -> ByteString serializeHandshakeKeys HandshakeKeys {sndKey, rcvKey} = @@ -275,7 +313,7 @@ handshakeKeysP = HandshakeKeys <$> keyP <*> keyP pure SessionKey {aesKey, baseIV, counter = undefined} parseHandshakeKeys :: ByteString -> Either TransportError HandshakeKeys -parseHandshakeKeys = parse handshakeKeysP $ TransportHandshakeError "parsing keys" +parseHandshakeKeys = parse handshakeKeysP $ TEHandshake AES_KEYS transportHandle :: Handle -> SessionKey -> SessionKey -> IO THandle transportHandle h sk rk = do diff --git a/src/Simplex/Messaging/errors.md b/src/Simplex/Messaging/errors.md new file mode 100644 index 000000000..6e83e6d61 --- /dev/null +++ b/src/Simplex/Messaging/errors.md @@ -0,0 +1,97 @@ +# Errors + +## Problems + +- using numbers and strings to indicate errors (in protocol and in code) - ErrorType, AgentErrorType, TransportError +- re-using the same type in multiple contexts (with some constructors not applicable to all contexts) - ErrorType + +## Error types + +### ErrorType (Protocol.hs) + +- BLOCK - incorrect block format or encoding +- CMD error - command is unknown or has invalid syntax, where `error` can be: + - PROHIBITED - server response sent from client or vice versa + - SYNTAX - error parsing command + - NO_AUTH - transmission has no required credentials (signature or queue ID) + - HAS_AUTH - transmission has not allowed credentials + - NO_QUEUE - transmission has not queue ID +- AUTH - command is not authorised (queue does not exist or signature verification failed). +- NO_MSG - acknowledging (ACK) the message without message +- INTERNAL - internal server error. +- DUPLICATE_ - it is used internally to signal that the queue ID is already used. This is NOT used in the protocol, instead INTERNAL is sent to the client. It has to be removed. + +### AgentErrorType (Agent/Transmission.hs) + +Some of these errors are not correctly serialized/parsed - see line 322 in Agent/Transmission.hs + +- CMD e - command or response error + - PROHIBITED - server response sent as client command (and vice versa) + - SYNTAX - command is unknown or has invalid syntax. + - NO_CONN - connection is required in the command (and absent) + - SIZE - incorrect message size of messages (when parsing SEND and MSG) + - LARGE -- message does not fit SMP block +- CONN e - connection errors + - UNKNOWN - connection alias not in database + - DUPLICATE - connection alias already exists + - SIMPLEX - connection is simplex, but operation requires another queue +- SMP ErrorType - forwarding SMP errors (SMPServerError) to the agent client +- BROKER e - SMP server errors + - RESPONSE ErrorType - invalid SMP server response + - UNEXPECTED - unexpected response + - NETWORK - network TCP connection error + - TRANSPORT TransportError -- handshake or other transport error + - TIMEOUT - command response timeout +- AGENT e - errors of other agents + - A_MESSAGE - SMP message failed to parse + - A_PROHIBITED - SMP message is prohibited with the current queue status + - A_ENCRYPTION - cannot RSA/AES-decrypt or parse decrypted header +- INTERNAL ByteString - agent implementation or dependency error + +### SMPClientError (Client.hs) + +- SMPServerError ErrorType - this is correctly parsed server ERR response. This error is forwarded to the agent client as `ERR SMP err` +- SMPResponseError ErrorType - this is invalid server response that failed to parse - forwarded to the client as `ERR BROKER RESPONSE`. +- SMPUnexpectedResponse - different response from what is expected to a given command, e.g. server should respond `IDS` or `ERR` to `NEW` command, other responses would result in this error - forwarded to the client as `ERR BROKER UNEXPECTED`. +- SMPResponseTimeout - used for TCP connection and command response timeouts -> `ERR BROKER TIMEOUT`. +- SMPNetworkError - fails to establish TCP connection -> `ERR BROKER NETWORK` +- SMPTransportError e - fails connection handshake or some other transport error -> `ERR BROKER TRANSPORT e` +- SMPSignatureError C.CryptoError - error when cryptographically "signing" the command. + +### StoreError (Agent/Store.hs) + +- SEInternal ByteString - signals exceptions in store actions. +- SEConnNotFound - connection alias not found (or both queues absent). +- SEConnDuplicate - connection alias already used. +- SEBadConnType ConnType - wrong connection type, e.g. "send" connection when "receive" or "duplex" is expected, or vice versa. `updateRcvConnWithSndQueue` and `updateSndConnWithRcvQueue` do not allow duplex connections - they would also return this error. +- SEBadQueueStatus - the intention was to pass current expected queue status in methods, as we always know what it should be at any stage of the protocol, and in case it does not match use this error. **Currently not used**. +- SENotImplemented - used in `getMsg` that is not implemented/used. + +### CryptoError (Crypto.hs) + +- RSAEncryptError R.Error - RSA encryption error +- RSADecryptError R.Error - RSA decryption error +- RSASignError R.Error - RSA signature error +- AESCipherError CE.CryptoError - AES initialization error +- CryptoIVError - IV generation error +- AESDecryptError - AES decryption error +- CryptoLargeMsgError - message does not fit in SMP block +- CryptoHeaderError String - failure parsing RSA-encrypted message header + +### TransportError (Transport.hs) + + - TEBadBlock - error parsing block + - TEEncrypt - block encryption error + - TEDecrypt - block decryption error + - TEHandshake HandshakeError + +### HandshakeError (Transport.hs) + + - ENCRYPT - encryption error + - DECRYPT - decryption error + - VERSION - error parsing protocol version + - RSA_KEY - error parsing RSA key + - AES_KEYS - error parsing AES keys + - BAD_HASH - not matching RSA key hash + - MAJOR_VERSION - lower agent version than protocol version + - TERMINATED - transport terminated diff --git a/tests/AgentTests.hs b/tests/AgentTests.hs index ea33e6f2f..ac8686439 100644 --- a/tests/AgentTests.hs +++ b/tests/AgentTests.hs @@ -132,7 +132,7 @@ samplePublicKey = "256,ppr3DCweAD3RTVFhU2j0u+DnYdqJl1qCdKLHIKsPl1xBzfmnzK0o9GEDl syntaxTests :: Spec syntaxTests = do - it "unknown command" $ ("1", "5678", "HELLO") >#> ("1", "5678", "ERR SYNTAX 11") + it "unknown command" $ ("1", "5678", "HELLO") >#> ("1", "5678", "ERR CMD SYNTAX") describe "NEW" do describe "valid" do -- TODO: ERROR no connection alias in the response (it does not generate it yet if not provided) @@ -143,9 +143,9 @@ syntaxTests = do it "with port and keyHash" $ ("214", "", "NEW localhost:5000#" <> teshKeyHashStr) >#>= \case ("214", "", "INV" : _) -> True; _ -> False describe "invalid" do -- TODO: add tests with defined connection alias - it "no parameters" $ ("221", "", "NEW") >#> ("221", "", "ERR SYNTAX 11") - it "many parameters" $ ("222", "", "NEW localhost:5000 hi") >#> ("222", "", "ERR SYNTAX 11") - it "invalid server keyHash" $ ("223", "", "NEW localhost:5000#1") >#> ("223", "", "ERR SYNTAX 11") + it "no parameters" $ ("221", "", "NEW") >#> ("221", "", "ERR CMD SYNTAX") + it "many parameters" $ ("222", "", "NEW localhost:5000 hi") >#> ("222", "", "ERR CMD SYNTAX") + it "invalid server keyHash" $ ("223", "", "NEW localhost:5000#1") >#> ("223", "", "ERR CMD SYNTAX") describe "JOIN" do describe "valid" do @@ -155,4 +155,4 @@ syntaxTests = do ("311", "", "JOIN smp::localhost:5000::1234::" <> samplePublicKey) >#> ("311", "", "ERR SMP AUTH") describe "invalid" do -- TODO: JOIN is not merged yet - to be added - it "no parameters" $ ("321", "", "JOIN") >#> ("321", "", "ERR SYNTAX 11") + it "no parameters" $ ("321", "", "JOIN") >#> ("321", "", "ERR CMD SYNTAX") diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index ef9f59e14..8f33d5965 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -66,14 +66,14 @@ storeTests = withStore do describe "setRcvQueueStatus" testSetRcvQueueStatus describe "setSndQueueStatus" testSetSndQueueStatus describe "DuplexConnection" testSetQueueStatusDuplex - xdescribe "RcvQueue doesn't exist" testSetRcvQueueStatusNoQueue - xdescribe "SndQueue doesn't exist" testSetSndQueueStatusNoQueue + xdescribe "RcvQueue does not exist" testSetRcvQueueStatusNoQueue + xdescribe "SndQueue does not exist" testSetSndQueueStatusNoQueue describe "createRcvMsg" do describe "RcvQueue exists" testCreateRcvMsg - describe "RcvQueue doesn't exist" testCreateRcvMsgNoQueue + describe "RcvQueue does not exist" testCreateRcvMsgNoQueue describe "createSndMsg" do describe "SndQueue exists" testCreateSndMsg - describe "SndQueue doesn't exist" testCreateSndMsgNoQueue + describe "SndQueue does not exist" testCreateSndMsgNoQueue testCompiledThreadsafe :: SpecWith SQLiteStore testCompiledThreadsafe = do @@ -175,7 +175,7 @@ testDeleteRcvConn = do `returnsResult` () -- TODO check queues are deleted as well getConn store "conn1" - `throwsError` SEBadConn + `throwsError` SEConnNotFound testDeleteSndConn :: SpecWith SQLiteStore testDeleteSndConn = do @@ -188,7 +188,7 @@ testDeleteSndConn = do `returnsResult` () -- TODO check queues are deleted as well getConn store "conn1" - `throwsError` SEBadConn + `throwsError` SEConnNotFound testDeleteDuplexConn :: SpecWith SQLiteStore testDeleteDuplexConn = do @@ -203,7 +203,7 @@ testDeleteDuplexConn = do `returnsResult` () -- TODO check queues are deleted as well getConn store "conn1" - `throwsError` SEBadConn + `throwsError` SEConnNotFound testUpgradeRcvConnToDuplex :: SpecWith SQLiteStore testUpgradeRcvConnToDuplex = do @@ -298,15 +298,15 @@ testSetQueueStatusDuplex = do testSetRcvQueueStatusNoQueue :: SpecWith SQLiteStore testSetRcvQueueStatusNoQueue = do - it "should throw error on attempt to update status of nonexistent RcvQueue" $ \store -> do + it "should throw error on attempt to update status of non-existent RcvQueue" $ \store -> do setRcvQueueStatus store rcvQueue1 Confirmed - `throwsError` SEInternal + `throwsError` SEInternal "" testSetSndQueueStatusNoQueue :: SpecWith SQLiteStore testSetSndQueueStatusNoQueue = do - it "should throw error on attempt to update status of nonexistent SndQueue" $ \store -> do + it "should throw error on attempt to update status of non-existent SndQueue" $ \store -> do setSndQueueStatus store sndQueue1 Confirmed - `throwsError` SEInternal + `throwsError` SEInternal "" testCreateRcvMsg :: SpecWith SQLiteStore testCreateRcvMsg = do @@ -323,7 +323,7 @@ testCreateRcvMsgNoQueue = do it "should throw error on attempt to create a RcvMsg w/t a RcvQueue" $ \store -> do let ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0) createRcvMsg store "conn1" (encodeUtf8 "Hello world!") ts (1, ts) ("1", ts) - `throwsError` SEBadConn + `throwsError` SEConnNotFound createSndConn store sndQueue1 `returnsResult` () createRcvMsg store "conn1" (encodeUtf8 "Hello world!") ts (1, ts) ("1", ts) @@ -344,7 +344,7 @@ testCreateSndMsgNoQueue = do it "should throw error on attempt to create a SndMsg w/t a SndQueue" $ \store -> do let ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0) createSndMsg store "conn1" (encodeUtf8 "Hello world!") ts - `throwsError` SEBadConn + `throwsError` SEConnNotFound createRcvConn store rcvQueue1 `returnsResult` () createSndMsg store "conn1" (encodeUtf8 "Hello world!") ts diff --git a/tests/ProtocolErrorTests.hs b/tests/ProtocolErrorTests.hs new file mode 100644 index 000000000..82e4afbf7 --- /dev/null +++ b/tests/ProtocolErrorTests.hs @@ -0,0 +1,18 @@ +module ProtocolErrorTests where + +import Simplex.Messaging.Agent.Transmission (AgentErrorType, agentErrorTypeP, serializeAgentError) +import Simplex.Messaging.Parsers (parseAll) +import Simplex.Messaging.Protocol (ErrorType, errorTypeP, serializeErrorType) +import Test.Hspec +import Test.Hspec.QuickCheck (modifyMaxSuccess) +import Test.QuickCheck + +protocolErrorTests :: Spec +protocolErrorTests = modifyMaxSuccess (const 1000) $ do + describe "errors parsing / serializing" $ do + it "should parse SMP protocol errors" . property $ \err -> + parseAll errorTypeP (serializeErrorType err) + == Right (err :: ErrorType) + it "should parse SMP agent errors" . property $ \err -> + parseAll agentErrorTypeP (serializeAgentError err) + == Right (err :: AgentErrorType) diff --git a/tests/SMPAgentClient.hs b/tests/SMPAgentClient.hs index 5f7175871..9bf0bf187 100644 --- a/tests/SMPAgentClient.hs +++ b/tests/SMPAgentClient.hs @@ -6,23 +6,19 @@ module SMPAgentClient where -import Control.Monad import Control.Monad.IO.Unlift import Crypto.Random import Network.Socket (HostName, ServiceName) -import SMPClient (testPort, withSmpServer, withSmpServerThreadOn) +import SMPClient (serverBracket, testPort, withSmpServer, withSmpServerThreadOn) import Simplex.Messaging.Agent (runSMPAgentBlocking) import Simplex.Messaging.Agent.Env.SQLite import Simplex.Messaging.Agent.Transmission import Simplex.Messaging.Client (SMPClientConfig (..), smpDefaultConfig) import Simplex.Messaging.Transport -import System.Timeout (timeout) import Test.Hspec import UnliftIO.Concurrent import UnliftIO.Directory -import qualified UnliftIO.Exception as E import UnliftIO.IO -import UnliftIO.STM (atomically, newEmptyTMVarIO, takeTMVar) agentTestHost :: HostName agentTestHost = "localhost" @@ -125,12 +121,10 @@ cfg = } withSmpAgentThreadOn :: (MonadUnliftIO m, MonadRandom m) => (ServiceName, String) -> (ThreadId -> m a) -> m a -withSmpAgentThreadOn (port', db') f = do - started <- newEmptyTMVarIO - E.bracket - (forkIOWithUnmask ($ runSMPAgentBlocking started cfg {tcpPort = port', dbFile = db'})) - (liftIO . killThread >=> const (removeFile db')) - \x -> liftIO (5_000_000 `timeout` atomically (takeTMVar started)) >> f x +withSmpAgentThreadOn (port', db') = + serverBracket + (\started -> runSMPAgentBlocking started cfg {tcpPort = port', dbFile = db'}) + (removeFile db') withSmpAgentOn :: (MonadUnliftIO m, MonadRandom m) => (ServiceName, String) -> m a -> m a withSmpAgentOn (port', db') = withSmpAgentThreadOn (port', db') . const diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index 0e24a53f7..c6e6af28e 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -18,11 +18,11 @@ import Simplex.Messaging.Protocol import Simplex.Messaging.Server (runSMPServerBlocking) import Simplex.Messaging.Server.Env.STM import Simplex.Messaging.Transport -import System.Timeout (timeout) import Test.Hspec import UnliftIO.Concurrent import qualified UnliftIO.Exception as E -import UnliftIO.STM (atomically, newEmptyTMVarIO, takeTMVar) +import UnliftIO.STM (TMVar, atomically, newEmptyTMVarIO, takeTMVar) +import UnliftIO.Timeout (timeout) testHost :: HostName testHost = "localhost" @@ -83,12 +83,23 @@ cfg = } withSmpServerThreadOn :: (MonadUnliftIO m, MonadRandom m) => ServiceName -> (ThreadId -> m a) -> m a -withSmpServerThreadOn port f = do +withSmpServerThreadOn port = + serverBracket + (\started -> runSMPServerBlocking started cfg {tcpPort = port}) + (pure ()) + +serverBracket :: MonadUnliftIO m => (TMVar Bool -> m ()) -> m () -> (ThreadId -> m a) -> m a +serverBracket process afterProcess f = do started <- newEmptyTMVarIO E.bracket - (forkIOWithUnmask ($ runSMPServerBlocking started cfg {tcpPort = port})) - (liftIO . killThread) - \x -> liftIO (5_000_000 `timeout` atomically (takeTMVar started)) >> f x + (forkIOWithUnmask ($ process started)) + (\t -> killThread t >> afterProcess >> waitFor started "stop") + (\t -> waitFor started "start" >> f t) + where + waitFor started s = + 5_000_000 `timeout` atomically (takeTMVar started) >>= \case + Nothing -> error $ "server did not " <> s + _ -> pure () withSmpServerOn :: (MonadUnliftIO m, MonadRandom m) => ServiceName -> m a -> m a withSmpServerOn port = withSmpServerThreadOn port . const diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index 24c6af342..50d3759cf 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -72,7 +72,7 @@ testCreateSecure = (ok4, OK) #== "replies OK when message acknowledged if no more messages" Resp "dabc" _ err6 <- signSendRecv h rKey ("dabc", rId, "ACK") - (err6, ERR PROHIBITED) #== "replies ERR when message acknowledged without messages" + (err6, ERR NO_MSG) #== "replies ERR when message acknowledged without messages" (sPub, sKey) <- C.generateKeyPair rsaKeySize Resp "abcd" sId2 err1 <- signSendRecv h sKey ("abcd", sId, "SEND 5 hello ") @@ -252,7 +252,7 @@ testSwitchSub = (msg3, "test3") #== "delivered to the 2nd TCP connection" Resp "abcd" _ err <- signSendRecv rh1 rKey ("abcd", rId, "ACK") - (err, ERR PROHIBITED) #== "rejects ACK from the 1st TCP connection" + (err, ERR NO_MSG) #== "rejects ACK from the 1st TCP connection" Resp "bcda" _ ok3 <- signSendRecv rh2 rKey ("bcda", rId, "ACK") (ok3, OK) #== "accepts ACK from the 2nd TCP connection" @@ -263,18 +263,18 @@ testSwitchSub = syntaxTests :: Spec syntaxTests = do - it "unknown command" $ ("", "abcd", "1234", "HELLO") >#> ("", "abcd", "1234", "ERR SYNTAX 2") + it "unknown command" $ ("", "abcd", "1234", "HELLO") >#> ("", "abcd", "1234", "ERR CMD SYNTAX") describe "NEW" do - it "no parameters" $ ("1234", "bcda", "", "NEW") >#> ("", "bcda", "", "ERR SYNTAX 2") - it "many parameters" $ ("1234", "cdab", "", "NEW 1 2") >#> ("", "cdab", "", "ERR SYNTAX 2") - it "no signature" $ ("", "dabc", "", "NEW 3,1234,1234") >#> ("", "dabc", "", "ERR SYNTAX 3") - it "queue ID" $ ("1234", "abcd", "12345678", "NEW 3,1234,1234") >#> ("", "abcd", "12345678", "ERR SYNTAX 4") + it "no parameters" $ ("1234", "bcda", "", "NEW") >#> ("", "bcda", "", "ERR CMD SYNTAX") + it "many parameters" $ ("1234", "cdab", "", "NEW 1 2") >#> ("", "cdab", "", "ERR CMD SYNTAX") + it "no signature" $ ("", "dabc", "", "NEW 3,1234,1234") >#> ("", "dabc", "", "ERR CMD NO_AUTH") + it "queue ID" $ ("1234", "abcd", "12345678", "NEW 3,1234,1234") >#> ("", "abcd", "12345678", "ERR CMD HAS_AUTH") describe "KEY" do it "valid syntax" $ ("1234", "bcda", "12345678", "KEY 3,4567,4567") >#> ("", "bcda", "12345678", "ERR AUTH") - it "no parameters" $ ("1234", "cdab", "12345678", "KEY") >#> ("", "cdab", "12345678", "ERR SYNTAX 2") - it "many parameters" $ ("1234", "dabc", "12345678", "KEY 1 2") >#> ("", "dabc", "12345678", "ERR SYNTAX 2") - it "no signature" $ ("", "abcd", "12345678", "KEY 3,4567,4567") >#> ("", "abcd", "12345678", "ERR SYNTAX 3") - it "no queue ID" $ ("1234", "bcda", "", "KEY 3,4567,4567") >#> ("", "bcda", "", "ERR SYNTAX 3") + it "no parameters" $ ("1234", "cdab", "12345678", "KEY") >#> ("", "cdab", "12345678", "ERR CMD SYNTAX") + it "many parameters" $ ("1234", "dabc", "12345678", "KEY 1 2") >#> ("", "dabc", "12345678", "ERR CMD SYNTAX") + it "no signature" $ ("", "abcd", "12345678", "KEY 3,4567,4567") >#> ("", "abcd", "12345678", "ERR CMD NO_AUTH") + it "no queue ID" $ ("1234", "bcda", "", "KEY 3,4567,4567") >#> ("", "bcda", "", "ERR CMD NO_AUTH") noParamsSyntaxTest "SUB" noParamsSyntaxTest "ACK" noParamsSyntaxTest "OFF" @@ -282,19 +282,19 @@ syntaxTests = do describe "SEND" do it "valid syntax 1" $ ("1234", "cdab", "12345678", "SEND 5 hello ") >#> ("", "cdab", "12345678", "ERR AUTH") it "valid syntax 2" $ ("1234", "dabc", "12345678", "SEND 11 hello there ") >#> ("", "dabc", "12345678", "ERR AUTH") - it "no parameters" $ ("1234", "abcd", "12345678", "SEND") >#> ("", "abcd", "12345678", "ERR SYNTAX 2") - it "no queue ID" $ ("1234", "bcda", "", "SEND 5 hello ") >#> ("", "bcda", "", "ERR SYNTAX 5") - it "bad message body 1" $ ("1234", "cdab", "12345678", "SEND 11 hello ") >#> ("", "cdab", "12345678", "ERR SYNTAX 2") - it "bad message body 2" $ ("1234", "dabc", "12345678", "SEND hello ") >#> ("", "dabc", "12345678", "ERR SYNTAX 2") - it "bigger body" $ ("1234", "abcd", "12345678", "SEND 4 hello ") >#> ("", "abcd", "12345678", "ERR SYNTAX 2") + it "no parameters" $ ("1234", "abcd", "12345678", "SEND") >#> ("", "abcd", "12345678", "ERR CMD SYNTAX") + it "no queue ID" $ ("1234", "bcda", "", "SEND 5 hello ") >#> ("", "bcda", "", "ERR CMD NO_QUEUE") + it "bad message body 1" $ ("1234", "cdab", "12345678", "SEND 11 hello ") >#> ("", "cdab", "12345678", "ERR CMD SYNTAX") + it "bad message body 2" $ ("1234", "dabc", "12345678", "SEND hello ") >#> ("", "dabc", "12345678", "ERR CMD SYNTAX") + it "bigger body" $ ("1234", "abcd", "12345678", "SEND 4 hello ") >#> ("", "abcd", "12345678", "ERR CMD SYNTAX") describe "PING" do it "valid syntax" $ ("", "abcd", "", "PING") >#> ("", "abcd", "", "PONG") describe "broker response not allowed" do - it "OK" $ ("1234", "bcda", "12345678", "OK") >#> ("", "bcda", "12345678", "ERR PROHIBITED") + it "OK" $ ("1234", "bcda", "12345678", "OK") >#> ("", "bcda", "12345678", "ERR CMD PROHIBITED") where noParamsSyntaxTest :: ByteString -> Spec noParamsSyntaxTest cmd = describe (B.unpack cmd) do it "valid syntax" $ ("1234", "abcd", "12345678", cmd) >#> ("", "abcd", "12345678", "ERR AUTH") - it "wrong terminator" $ ("1234", "bcda", "12345678", cmd <> "=") >#> ("", "bcda", "12345678", "ERR SYNTAX 2") - it "no signature" $ ("", "cdab", "12345678", cmd) >#> ("", "cdab", "12345678", "ERR SYNTAX 3") - it "no queue ID" $ ("1234", "dabc", "", cmd) >#> ("", "dabc", "", "ERR SYNTAX 3") + it "wrong terminator" $ ("1234", "bcda", "12345678", cmd <> "=") >#> ("", "bcda", "12345678", "ERR CMD SYNTAX") + it "no signature" $ ("", "cdab", "12345678", cmd) >#> ("", "cdab", "12345678", "ERR CMD NO_AUTH") + it "no queue ID" $ ("1234", "dabc", "", cmd) >#> ("", "dabc", "", "ERR CMD NO_AUTH") diff --git a/tests/Test.hs b/tests/Test.hs index 06f495d99..988b734f6 100644 --- a/tests/Test.hs +++ b/tests/Test.hs @@ -1,5 +1,6 @@ import AgentTests import MarkdownTests +import ProtocolErrorTests import ServerTests import System.Directory (createDirectoryIfMissing, removeDirectoryRecursive) import Test.Hspec @@ -9,6 +10,7 @@ main = do createDirectoryIfMissing False "tests/tmp" hspec $ do describe "SimpleX markdown" markdownTests + describe "Protocol errors" protocolErrorTests describe "SMP server" serverTests describe "SMP client agent" agentTests removeDirectoryRecursive "tests/tmp"