parameterize protocol by error type (#644)

This commit is contained in:
Evgeny Poberezkin
2023-02-17 20:46:01 +00:00
committed by GitHub
parent 2ae3100bed
commit 2ddfb044fc
12 changed files with 216 additions and 176 deletions
+39 -22
View File
@@ -4,6 +4,7 @@
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
@@ -19,11 +20,10 @@
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# HLINT ignore "Use newtype instead of data" #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-}
{-# HLINT ignore "Use newtype instead of data" #-}
-- |
-- Module : Simplex.Messaging.ProtocolEncoding
-- Copyright : (c) simplex.chat
@@ -52,6 +52,7 @@ module Simplex.Messaging.Protocol
SParty (..),
PartyI (..),
QueueIdsKeys (..),
ProtocolErrorType (..),
ErrorType (..),
CommandError (..),
Transmission,
@@ -224,7 +225,7 @@ deriving instance Show Cmd
type Transmission c = (CorrId, EntityId, c)
-- | signed parsed transmission, with original raw bytes and parsing error.
type SignedTransmission c = (Maybe C.ASignature, Signed, Transmission (Either ErrorType c))
type SignedTransmission e c = (Maybe C.ASignature, Signed, Transmission (Either e c))
type Signed = ByteString
@@ -874,6 +875,8 @@ type MsgId = ByteString
-- | SMP message body.
type MsgBody = ByteString
data ProtocolErrorType = PECmdSyntax | PECmdUnknown | PESession | PEBlock
-- | Type for protocol errors.
data ErrorType
= -- | incorrect block format, encoding or signature size
@@ -944,16 +947,16 @@ transmissionP = do
command <- A.takeByteString
pure RawTransmission {signature, signed, sessId, corrId, entityId, command}
class (ProtocolEncoding msg, ProtocolEncoding (ProtoCommand msg), Show msg) => Protocol msg where
class (ProtocolEncoding err msg, ProtocolEncoding err (ProtoCommand msg), Show err, Show msg) => Protocol err msg | msg -> err where
type ProtoCommand msg = cmd | cmd -> msg
type ProtoType msg = (sch :: ProtocolType) | sch -> msg
protocolClientHandshake :: forall c. Transport c => c -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
protocolPing :: ProtoCommand msg
protocolError :: msg -> Maybe ErrorType
protocolError :: msg -> Maybe err
type ProtoServer msg = ProtocolServer (ProtoType msg)
instance Protocol BrokerMsg where
instance Protocol ErrorType BrokerMsg where
type ProtoCommand BrokerMsg = Cmd
type ProtoType BrokerMsg = 'PSMP
protocolClientHandshake = smpClientHandshake
@@ -962,13 +965,14 @@ instance Protocol BrokerMsg where
ERR e -> Just e
_ -> Nothing
class ProtocolMsgTag (Tag msg) => ProtocolEncoding msg where
class ProtocolMsgTag (Tag msg) => ProtocolEncoding err msg | msg -> err where
type Tag msg
encodeProtocol :: Version -> msg -> ByteString
protocolP :: Version -> Tag msg -> Parser msg
checkCredentials :: SignedRawTransmission -> msg -> Either ErrorType msg
fromProtocolError :: ProtocolErrorType -> err
checkCredentials :: SignedRawTransmission -> msg -> Either err msg
instance PartyI p => ProtocolEncoding (Command p) where
instance PartyI p => ProtocolEncoding ErrorType (Command p) where
type Tag (Command p) = CommandTag p
encodeProtocol v = \case
NEW rKey dhKey auth_ -> case auth_ of
@@ -999,6 +1003,9 @@ instance PartyI p => ProtocolEncoding (Command p) where
protocolP v tag = (\(Cmd _ c) -> checkParty c) <$?> protocolP v (CT (sParty @p) tag)
fromProtocolError = fromProtocolError @ErrorType @BrokerMsg
{-# INLINE fromProtocolError #-}
checkCredentials (sig, _, queueId, _) cmd = case cmd of
-- NEW must have signature but NOT queue ID
NEW {}
@@ -1018,7 +1025,7 @@ instance PartyI p => ProtocolEncoding (Command p) where
| isNothing sig || B.null queueId -> Left $ CMD NO_AUTH
| otherwise -> Right cmd
instance ProtocolEncoding Cmd where
instance ProtocolEncoding ErrorType Cmd where
type Tag Cmd = CmdTag
encodeProtocol v (Cmd _ c) = encodeProtocol v c
@@ -1048,9 +1055,12 @@ instance ProtocolEncoding Cmd where
PING_ -> pure PING
CT SNotifier NSUB_ -> pure $ Cmd SNotifier NSUB
fromProtocolError = fromProtocolError @ErrorType @BrokerMsg
{-# INLINE fromProtocolError #-}
checkCredentials t (Cmd p c) = Cmd p <$> checkCredentials t c
instance ProtocolEncoding BrokerMsg where
instance ProtocolEncoding ErrorType BrokerMsg where
type Tag BrokerMsg = BrokerMsgTag
encodeProtocol v = \case
IDS (QIK rcvId sndId srvDh) -> e (IDS_, ' ', rcvId, sndId, srvDh)
@@ -1085,6 +1095,13 @@ instance ProtocolEncoding BrokerMsg where
ERR_ -> ERR <$> _smpP
PONG_ -> pure PONG
fromProtocolError = \case
PECmdSyntax -> CMD SYNTAX
PECmdUnknown -> CMD UNKNOWN
PESession -> SESSION
PEBlock -> BLOCK
{-# INLINE fromProtocolError #-}
checkCredentials (_, _, queueId, _) cmd = case cmd of
-- IDS response should not have queue ID
IDS _ -> Right cmd
@@ -1103,12 +1120,12 @@ _smpP :: Encoding a => Parser a
_smpP = A.space *> smpP
-- | Parse SMP protocol commands and broker messages
parseProtocol :: ProtocolEncoding msg => Version -> ByteString -> Either ErrorType msg
parseProtocol :: forall err msg. ProtocolEncoding err msg => Version -> ByteString -> Either err msg
parseProtocol v s =
let (tag, params) = B.break (== ' ') s
in case decodeTag tag of
Just cmd -> parse (protocolP v cmd) (CMD SYNTAX) params
Nothing -> Left $ CMD UNKNOWN
Just cmd -> parse (protocolP v cmd) (fromProtocolError @err @msg $ PECmdSyntax) params
Nothing -> Left $ fromProtocolError @err @msg $ PECmdUnknown
checkParty :: forall t p p'. (PartyI p, PartyI p') => t p' -> Either String (t p)
checkParty c = case testEquality (sParty @p) (sParty @p') of
@@ -1203,7 +1220,7 @@ tEncode (sig, t) = smpEncode (C.signatureBytes sig) <> t
tEncodeBatch :: Int -> ByteString -> ByteString
tEncodeBatch n s = lenEncode n `B.cons` s
encodeTransmission :: ProtocolEncoding c => Version -> ByteString -> Transmission c -> ByteString
encodeTransmission :: ProtocolEncoding e c => Version -> ByteString -> Transmission c -> ByteString
encodeTransmission v sessionId (CorrId corrId, queueId, command) =
smpEncode (sessionId, corrId, queueId) <> encodeProtocol v command
@@ -1223,22 +1240,22 @@ eitherList :: (a -> NonEmpty (Either e b)) -> Either e a -> NonEmpty (Either e b
eitherList = either (\e -> [Left e])
-- | Receive client and server transmissions (determined by `cmd` type).
tGet :: forall cmd c. (ProtocolEncoding cmd, Transport c) => THandle c -> IO (NonEmpty (SignedTransmission cmd))
tGet :: forall err cmd c. (ProtocolEncoding err cmd, Transport c) => THandle c -> IO (NonEmpty (SignedTransmission err cmd))
tGet th@THandle {sessionId, thVersion = v} = L.map (tDecodeParseValidate sessionId v) <$> tGetParse th
tDecodeParseValidate :: forall cmd. ProtocolEncoding cmd => SessionId -> Version -> Either TransportError RawTransmission -> SignedTransmission cmd
tDecodeParseValidate :: forall err cmd. ProtocolEncoding err cmd => SessionId -> Version -> Either TransportError RawTransmission -> SignedTransmission err cmd
tDecodeParseValidate sessionId v = \case
Right RawTransmission {signature, signed, sessId, corrId, entityId, command}
| sessId == sessionId ->
let decodedTransmission = (,corrId,entityId,command) <$> C.decodeSignature signature
in either (const $ tError corrId) (tParseValidate signed) decodedTransmission
| otherwise -> (Nothing, "", (CorrId corrId, "", Left SESSION))
| otherwise -> (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @err @cmd PESession))
Left _ -> tError ""
where
tError :: ByteString -> SignedTransmission cmd
tError corrId = (Nothing, "", (CorrId corrId, "", Left BLOCK))
tError :: ByteString -> SignedTransmission err cmd
tError corrId = (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @err @cmd PEBlock))
tParseValidate :: ByteString -> SignedRawTransmission -> SignedTransmission cmd
tParseValidate :: ByteString -> SignedRawTransmission -> SignedTransmission err cmd
tParseValidate signed t@(sig, corrId, entityId, command) =
let cmd = parseProtocol v command >>= checkCredentials t
let cmd = parseProtocol @err @cmd v command >>= checkCredentials t
in (sig, signed, (CorrId corrId, entityId, cmd))