mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-05-10 21:26:57 +00:00
parameterize protocol by error type (#644)
This commit is contained in:
committed by
GitHub
parent
2ae3100bed
commit
2ddfb044fc
@@ -4,6 +4,7 @@
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE FunctionalDependencies #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
@@ -19,11 +20,10 @@
|
||||
{-# LANGUAGE TypeFamilyDependencies #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
{-# HLINT ignore "Use newtype instead of data" #-}
|
||||
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
|
||||
{-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-}
|
||||
|
||||
{-# HLINT ignore "Use newtype instead of data" #-}
|
||||
|
||||
-- |
|
||||
-- Module : Simplex.Messaging.ProtocolEncoding
|
||||
-- Copyright : (c) simplex.chat
|
||||
@@ -52,6 +52,7 @@ module Simplex.Messaging.Protocol
|
||||
SParty (..),
|
||||
PartyI (..),
|
||||
QueueIdsKeys (..),
|
||||
ProtocolErrorType (..),
|
||||
ErrorType (..),
|
||||
CommandError (..),
|
||||
Transmission,
|
||||
@@ -224,7 +225,7 @@ deriving instance Show Cmd
|
||||
type Transmission c = (CorrId, EntityId, c)
|
||||
|
||||
-- | signed parsed transmission, with original raw bytes and parsing error.
|
||||
type SignedTransmission c = (Maybe C.ASignature, Signed, Transmission (Either ErrorType c))
|
||||
type SignedTransmission e c = (Maybe C.ASignature, Signed, Transmission (Either e c))
|
||||
|
||||
type Signed = ByteString
|
||||
|
||||
@@ -874,6 +875,8 @@ type MsgId = ByteString
|
||||
-- | SMP message body.
|
||||
type MsgBody = ByteString
|
||||
|
||||
data ProtocolErrorType = PECmdSyntax | PECmdUnknown | PESession | PEBlock
|
||||
|
||||
-- | Type for protocol errors.
|
||||
data ErrorType
|
||||
= -- | incorrect block format, encoding or signature size
|
||||
@@ -944,16 +947,16 @@ transmissionP = do
|
||||
command <- A.takeByteString
|
||||
pure RawTransmission {signature, signed, sessId, corrId, entityId, command}
|
||||
|
||||
class (ProtocolEncoding msg, ProtocolEncoding (ProtoCommand msg), Show msg) => Protocol msg where
|
||||
class (ProtocolEncoding err msg, ProtocolEncoding err (ProtoCommand msg), Show err, Show msg) => Protocol err msg | msg -> err where
|
||||
type ProtoCommand msg = cmd | cmd -> msg
|
||||
type ProtoType msg = (sch :: ProtocolType) | sch -> msg
|
||||
protocolClientHandshake :: forall c. Transport c => c -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
|
||||
protocolPing :: ProtoCommand msg
|
||||
protocolError :: msg -> Maybe ErrorType
|
||||
protocolError :: msg -> Maybe err
|
||||
|
||||
type ProtoServer msg = ProtocolServer (ProtoType msg)
|
||||
|
||||
instance Protocol BrokerMsg where
|
||||
instance Protocol ErrorType BrokerMsg where
|
||||
type ProtoCommand BrokerMsg = Cmd
|
||||
type ProtoType BrokerMsg = 'PSMP
|
||||
protocolClientHandshake = smpClientHandshake
|
||||
@@ -962,13 +965,14 @@ instance Protocol BrokerMsg where
|
||||
ERR e -> Just e
|
||||
_ -> Nothing
|
||||
|
||||
class ProtocolMsgTag (Tag msg) => ProtocolEncoding msg where
|
||||
class ProtocolMsgTag (Tag msg) => ProtocolEncoding err msg | msg -> err where
|
||||
type Tag msg
|
||||
encodeProtocol :: Version -> msg -> ByteString
|
||||
protocolP :: Version -> Tag msg -> Parser msg
|
||||
checkCredentials :: SignedRawTransmission -> msg -> Either ErrorType msg
|
||||
fromProtocolError :: ProtocolErrorType -> err
|
||||
checkCredentials :: SignedRawTransmission -> msg -> Either err msg
|
||||
|
||||
instance PartyI p => ProtocolEncoding (Command p) where
|
||||
instance PartyI p => ProtocolEncoding ErrorType (Command p) where
|
||||
type Tag (Command p) = CommandTag p
|
||||
encodeProtocol v = \case
|
||||
NEW rKey dhKey auth_ -> case auth_ of
|
||||
@@ -999,6 +1003,9 @@ instance PartyI p => ProtocolEncoding (Command p) where
|
||||
|
||||
protocolP v tag = (\(Cmd _ c) -> checkParty c) <$?> protocolP v (CT (sParty @p) tag)
|
||||
|
||||
fromProtocolError = fromProtocolError @ErrorType @BrokerMsg
|
||||
{-# INLINE fromProtocolError #-}
|
||||
|
||||
checkCredentials (sig, _, queueId, _) cmd = case cmd of
|
||||
-- NEW must have signature but NOT queue ID
|
||||
NEW {}
|
||||
@@ -1018,7 +1025,7 @@ instance PartyI p => ProtocolEncoding (Command p) where
|
||||
| isNothing sig || B.null queueId -> Left $ CMD NO_AUTH
|
||||
| otherwise -> Right cmd
|
||||
|
||||
instance ProtocolEncoding Cmd where
|
||||
instance ProtocolEncoding ErrorType Cmd where
|
||||
type Tag Cmd = CmdTag
|
||||
encodeProtocol v (Cmd _ c) = encodeProtocol v c
|
||||
|
||||
@@ -1048,9 +1055,12 @@ instance ProtocolEncoding Cmd where
|
||||
PING_ -> pure PING
|
||||
CT SNotifier NSUB_ -> pure $ Cmd SNotifier NSUB
|
||||
|
||||
fromProtocolError = fromProtocolError @ErrorType @BrokerMsg
|
||||
{-# INLINE fromProtocolError #-}
|
||||
|
||||
checkCredentials t (Cmd p c) = Cmd p <$> checkCredentials t c
|
||||
|
||||
instance ProtocolEncoding BrokerMsg where
|
||||
instance ProtocolEncoding ErrorType BrokerMsg where
|
||||
type Tag BrokerMsg = BrokerMsgTag
|
||||
encodeProtocol v = \case
|
||||
IDS (QIK rcvId sndId srvDh) -> e (IDS_, ' ', rcvId, sndId, srvDh)
|
||||
@@ -1085,6 +1095,13 @@ instance ProtocolEncoding BrokerMsg where
|
||||
ERR_ -> ERR <$> _smpP
|
||||
PONG_ -> pure PONG
|
||||
|
||||
fromProtocolError = \case
|
||||
PECmdSyntax -> CMD SYNTAX
|
||||
PECmdUnknown -> CMD UNKNOWN
|
||||
PESession -> SESSION
|
||||
PEBlock -> BLOCK
|
||||
{-# INLINE fromProtocolError #-}
|
||||
|
||||
checkCredentials (_, _, queueId, _) cmd = case cmd of
|
||||
-- IDS response should not have queue ID
|
||||
IDS _ -> Right cmd
|
||||
@@ -1103,12 +1120,12 @@ _smpP :: Encoding a => Parser a
|
||||
_smpP = A.space *> smpP
|
||||
|
||||
-- | Parse SMP protocol commands and broker messages
|
||||
parseProtocol :: ProtocolEncoding msg => Version -> ByteString -> Either ErrorType msg
|
||||
parseProtocol :: forall err msg. ProtocolEncoding err msg => Version -> ByteString -> Either err msg
|
||||
parseProtocol v s =
|
||||
let (tag, params) = B.break (== ' ') s
|
||||
in case decodeTag tag of
|
||||
Just cmd -> parse (protocolP v cmd) (CMD SYNTAX) params
|
||||
Nothing -> Left $ CMD UNKNOWN
|
||||
Just cmd -> parse (protocolP v cmd) (fromProtocolError @err @msg $ PECmdSyntax) params
|
||||
Nothing -> Left $ fromProtocolError @err @msg $ PECmdUnknown
|
||||
|
||||
checkParty :: forall t p p'. (PartyI p, PartyI p') => t p' -> Either String (t p)
|
||||
checkParty c = case testEquality (sParty @p) (sParty @p') of
|
||||
@@ -1203,7 +1220,7 @@ tEncode (sig, t) = smpEncode (C.signatureBytes sig) <> t
|
||||
tEncodeBatch :: Int -> ByteString -> ByteString
|
||||
tEncodeBatch n s = lenEncode n `B.cons` s
|
||||
|
||||
encodeTransmission :: ProtocolEncoding c => Version -> ByteString -> Transmission c -> ByteString
|
||||
encodeTransmission :: ProtocolEncoding e c => Version -> ByteString -> Transmission c -> ByteString
|
||||
encodeTransmission v sessionId (CorrId corrId, queueId, command) =
|
||||
smpEncode (sessionId, corrId, queueId) <> encodeProtocol v command
|
||||
|
||||
@@ -1223,22 +1240,22 @@ eitherList :: (a -> NonEmpty (Either e b)) -> Either e a -> NonEmpty (Either e b
|
||||
eitherList = either (\e -> [Left e])
|
||||
|
||||
-- | Receive client and server transmissions (determined by `cmd` type).
|
||||
tGet :: forall cmd c. (ProtocolEncoding cmd, Transport c) => THandle c -> IO (NonEmpty (SignedTransmission cmd))
|
||||
tGet :: forall err cmd c. (ProtocolEncoding err cmd, Transport c) => THandle c -> IO (NonEmpty (SignedTransmission err cmd))
|
||||
tGet th@THandle {sessionId, thVersion = v} = L.map (tDecodeParseValidate sessionId v) <$> tGetParse th
|
||||
|
||||
tDecodeParseValidate :: forall cmd. ProtocolEncoding cmd => SessionId -> Version -> Either TransportError RawTransmission -> SignedTransmission cmd
|
||||
tDecodeParseValidate :: forall err cmd. ProtocolEncoding err cmd => SessionId -> Version -> Either TransportError RawTransmission -> SignedTransmission err cmd
|
||||
tDecodeParseValidate sessionId v = \case
|
||||
Right RawTransmission {signature, signed, sessId, corrId, entityId, command}
|
||||
| sessId == sessionId ->
|
||||
let decodedTransmission = (,corrId,entityId,command) <$> C.decodeSignature signature
|
||||
in either (const $ tError corrId) (tParseValidate signed) decodedTransmission
|
||||
| otherwise -> (Nothing, "", (CorrId corrId, "", Left SESSION))
|
||||
| otherwise -> (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @err @cmd PESession))
|
||||
Left _ -> tError ""
|
||||
where
|
||||
tError :: ByteString -> SignedTransmission cmd
|
||||
tError corrId = (Nothing, "", (CorrId corrId, "", Left BLOCK))
|
||||
tError :: ByteString -> SignedTransmission err cmd
|
||||
tError corrId = (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @err @cmd PEBlock))
|
||||
|
||||
tParseValidate :: ByteString -> SignedRawTransmission -> SignedTransmission cmd
|
||||
tParseValidate :: ByteString -> SignedRawTransmission -> SignedTransmission err cmd
|
||||
tParseValidate signed t@(sig, corrId, entityId, command) =
|
||||
let cmd = parseProtocol v command >>= checkCredentials t
|
||||
let cmd = parseProtocol @err @cmd v command >>= checkCredentials t
|
||||
in (sig, signed, (CorrId corrId, entityId, cmd))
|
||||
|
||||
Reference in New Issue
Block a user