mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-05-25 03:45:23 +00:00
binary SMP protocol encoding, split Command type to two types (#245)
* binary SMP protocol encoding (server tests fail) * use 1 byte for bytestring length when encoding/decoding * Encoding class, binary tags * update server tests * negotiate SMP version in client/server handshake * add version columns to queues and connections * split parsing SMP client commands and server responses to different functions * check uniqueness of protocol tags * split client commands and server responses/messages to separate types * update types in SMP client * remove pattern synonyms for SMP errors * simplify getHandshake * update SMP protocol encoding in protocol spec * encode time as a number of seconds (64-bit integer) since epoch
This commit is contained in:
committed by
GitHub
parent
5e3f66a4cb
commit
5e29e3698e
@@ -74,6 +74,7 @@ import Data.Maybe (isJust)
|
||||
import qualified Data.Text as T
|
||||
import Data.Text.Encoding (decodeUtf8)
|
||||
import Data.Time.Clock
|
||||
import Data.Time.Clock.System (systemToUTCTime)
|
||||
import Database.SQLite.Simple (SQLError)
|
||||
import Simplex.Messaging.Agent.Client
|
||||
import Simplex.Messaging.Agent.Env.SQLite
|
||||
@@ -83,10 +84,11 @@ import Simplex.Messaging.Agent.Store
|
||||
import Simplex.Messaging.Agent.Store.SQLite (SQLiteStore)
|
||||
import Simplex.Messaging.Client (SMPServerTransmission)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Parsers (parse)
|
||||
import Simplex.Messaging.Protocol (MsgBody)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Transport (ATransport (..), TProxy, Transport (..), currentSMPVersionStr, loadTLSServerParams, runTransportServer)
|
||||
import Simplex.Messaging.Transport (ATransport (..), TProxy, Transport (..), loadTLSServerParams, runTransportServer, simplexMQVersion)
|
||||
import Simplex.Messaging.Util (bshow, tryError, unlessM)
|
||||
import System.Random (randomR)
|
||||
import UnliftIO.Async (async, race_)
|
||||
@@ -114,7 +116,7 @@ runSMPAgentBlocking (ATransport t) started cfg@AgentConfig {tcpPort, caCertifica
|
||||
-- tlsServerParams not in env to avoid breaking functional api w/t key and certificate generation
|
||||
tlsServerParams <- liftIO $ loadTLSServerParams caCertificateFile agentCertificateFile agentPrivateKeyFile
|
||||
runTransportServer started tcpPort tlsServerParams $ \(h :: c) -> do
|
||||
liftIO . putLn h $ "Welcome to SMP agent v" <> currentSMPVersionStr
|
||||
liftIO . putLn h $ "Welcome to SMP agent v" <> B.pack simplexMQVersion
|
||||
c <- getAgentClient
|
||||
logConnection c True
|
||||
race_ (connectClient h c) (runAgentClient c)
|
||||
@@ -526,7 +528,7 @@ processSMPTransmission c@AgentClient {subQ} (srv, rId, cmd) = do
|
||||
-- TODO deduplicate with previously received
|
||||
msgBody <- agentCbDecrypt rcvDhSecret (C.cbNonce srvMsgId) msgBody'
|
||||
encMessage@SMP.EncMessage {emHeader = SMP.PubHeader v e2ePubKey} <-
|
||||
liftEither $ parse SMP.encMessageP (AGENT A_MESSAGE) msgBody
|
||||
liftEither $ parse smpP (AGENT A_MESSAGE) msgBody
|
||||
case e2eShared of
|
||||
Nothing -> do
|
||||
let e2eDhSecret = C.dh' e2ePubKey e2ePrivKey
|
||||
@@ -551,7 +553,7 @@ processSMPTransmission c@AgentClient {subQ} (srv, rId, cmd) = do
|
||||
-- note that there is no ACK sent here, it is sent with agent's user ACK command
|
||||
-- TODO add hash to other messages
|
||||
let msgHash = C.sha256Hash msg
|
||||
agentClientMsg prevMsgHash sndMsgId (srvMsgId, srvTs) body msgHash
|
||||
agentClientMsg prevMsgHash sndMsgId (srvMsgId, systemToUTCTime srvTs) body msgHash
|
||||
_ -> prohibited >> ack
|
||||
SMP.END -> do
|
||||
removeSubscription c connId
|
||||
@@ -577,7 +579,7 @@ processSMPTransmission c@AgentClient {subQ} (srv, rId, cmd) = do
|
||||
decryptAgentMessage e2eDhSecret SMP.EncMessage {emNonce, emBody} = do
|
||||
msg <- agentCbDecrypt e2eDhSecret emNonce emBody
|
||||
agentMessage <-
|
||||
liftEither $ clientToAgentMsg =<< parse SMP.clientMessageP (AGENT A_MESSAGE) msg
|
||||
liftEither $ clientToAgentMsg =<< parse smpP (AGENT A_MESSAGE) msg
|
||||
pure (msg, agentMessage)
|
||||
|
||||
smpConfirmation :: SMPConfirmation -> m ()
|
||||
|
||||
@@ -57,9 +57,11 @@ import Simplex.Messaging.Agent.RetryInterval
|
||||
import Simplex.Messaging.Agent.Store
|
||||
import Simplex.Messaging.Client
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Protocol (ErrorType (AUTH), MsgBody, QueueId, QueueIdsKeys (..), SndPublicVerifyKey)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Util (bshow, liftEitherError, liftError)
|
||||
import Simplex.Messaging.Version
|
||||
import UnliftIO.Exception (IOException)
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.STM
|
||||
@@ -371,8 +373,8 @@ agentCbEncrypt SndQueue {e2ePubKey, e2eDhSecret} msg = do
|
||||
liftEither . first cryptoError $
|
||||
C.cbEncrypt e2eDhSecret emNonce msg SMP.e2eEncMessageLength
|
||||
-- TODO per-queue client version
|
||||
let emHeader = SMP.PubHeader SMP.clientVersion e2ePubKey
|
||||
pure $ SMP.serializeEncMessage SMP.EncMessage {emHeader, emNonce, emBody}
|
||||
let emHeader = SMP.PubHeader (maxVersion SMP.smpClientVersion) e2ePubKey
|
||||
pure $ smpEncode SMP.EncMessage {emHeader, emNonce, emBody}
|
||||
|
||||
agentCbEncryptOnce :: AgentMonad m => C.PublicKeyX25519 -> ByteString -> m ByteString
|
||||
agentCbEncryptOnce dhRcvPubKey msg = do
|
||||
@@ -383,8 +385,8 @@ agentCbEncryptOnce dhRcvPubKey msg = do
|
||||
liftEither . first cryptoError $
|
||||
C.cbEncrypt e2eDhSecret emNonce msg SMP.e2eEncMessageLength
|
||||
-- TODO per-queue client version
|
||||
let emHeader = SMP.PubHeader SMP.clientVersion dhSndPubKey
|
||||
pure $ SMP.serializeEncMessage SMP.EncMessage {emHeader, emNonce, emBody}
|
||||
let emHeader = SMP.PubHeader (maxVersion SMP.smpClientVersion) dhSndPubKey
|
||||
pure $ smpEncode SMP.EncMessage {emHeader, emNonce, emBody}
|
||||
|
||||
agentCbDecrypt :: AgentMonad m => C.DhSecretX25519 -> C.CbNonce -> ByteString -> m ByteString
|
||||
agentCbDecrypt dhSecret nonce msg =
|
||||
|
||||
@@ -84,6 +84,7 @@ module Simplex.Messaging.Agent.Protocol
|
||||
serializeConnReq,
|
||||
serializeConnReq',
|
||||
serializeAgentError,
|
||||
serializeSmpErrorType,
|
||||
commandP,
|
||||
smpServerP,
|
||||
smpQueueUriP,
|
||||
@@ -92,6 +93,7 @@ module Simplex.Messaging.Agent.Protocol
|
||||
connReqP',
|
||||
msgIntegrityP,
|
||||
agentErrorTypeP,
|
||||
smpErrorTypeP,
|
||||
serializeQueueStatus,
|
||||
queueStatusT,
|
||||
|
||||
@@ -128,6 +130,7 @@ import Generic.Random (genericArbitraryU)
|
||||
import Network.HTTP.Types (parseSimpleQuery, renderSimpleQuery)
|
||||
import Network.Socket (HostName, ServiceName)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Parsers
|
||||
import Simplex.Messaging.Protocol
|
||||
( ClientMessage (..),
|
||||
@@ -136,7 +139,6 @@ import Simplex.Messaging.Protocol
|
||||
MsgId,
|
||||
PrivHeader (..),
|
||||
SndPublicVerifyKey,
|
||||
serializeClientMessage,
|
||||
)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Transport (Transport (..), TransportError, serializeTransportError, transportErrorP)
|
||||
@@ -317,7 +319,7 @@ data AMessage
|
||||
deriving (Show)
|
||||
|
||||
serializeAgentMessage :: AgentMessage -> ByteString
|
||||
serializeAgentMessage = serializeClientMessage . agentToClientMsg
|
||||
serializeAgentMessage = smpEncode . agentToClientMsg
|
||||
|
||||
agentToClientMsg :: AgentMessage -> ClientMessage
|
||||
agentToClientMsg = \case
|
||||
@@ -754,8 +756,8 @@ serializeMsgIntegrity = \case
|
||||
-- | SMP agent protocol error parser.
|
||||
agentErrorTypeP :: Parser AgentErrorType
|
||||
agentErrorTypeP =
|
||||
"SMP " *> (SMP <$> SMP.errorTypeP)
|
||||
<|> "BROKER RESPONSE " *> (BROKER . RESPONSE <$> SMP.errorTypeP)
|
||||
"SMP " *> (SMP <$> smpErrorTypeP)
|
||||
<|> "BROKER RESPONSE " *> (BROKER . RESPONSE <$> smpErrorTypeP)
|
||||
<|> "BROKER TRANSPORT " *> (BROKER . TRANSPORT <$> transportErrorP)
|
||||
<|> "INTERNAL " *> (INTERNAL <$> parseRead A.takeByteString)
|
||||
<|> parseRead2
|
||||
@@ -763,11 +765,19 @@ agentErrorTypeP =
|
||||
-- | Serialize SMP agent protocol error.
|
||||
serializeAgentError :: AgentErrorType -> ByteString
|
||||
serializeAgentError = \case
|
||||
SMP e -> "SMP " <> SMP.serializeErrorType e
|
||||
BROKER (RESPONSE e) -> "BROKER RESPONSE " <> SMP.serializeErrorType e
|
||||
SMP e -> "SMP " <> serializeSmpErrorType e
|
||||
BROKER (RESPONSE e) -> "BROKER RESPONSE " <> serializeSmpErrorType e
|
||||
BROKER (TRANSPORT e) -> "BROKER TRANSPORT " <> serializeTransportError e
|
||||
e -> bshow e
|
||||
|
||||
-- | SMP error parser.
|
||||
smpErrorTypeP :: Parser ErrorType
|
||||
smpErrorTypeP = "CMD " *> (SMP.CMD <$> parseRead1) <|> parseRead1
|
||||
|
||||
-- | Serialize SMP error.
|
||||
serializeSmpErrorType :: ErrorType -> ByteString
|
||||
serializeSmpErrorType = bshow
|
||||
|
||||
serializeBinary :: ByteString -> ByteString
|
||||
serializeBinary body = bshow (B.length body) <> "\n" <> body
|
||||
|
||||
|
||||
@@ -84,12 +84,12 @@ data SMPClient = SMPClient
|
||||
clientCorrId :: TVar Natural,
|
||||
sentCommands :: TVar (Map CorrId Request),
|
||||
sndQ :: TBQueue SentRawTransmission,
|
||||
rcvQ :: TBQueue (SignedTransmission (Command 'Broker)),
|
||||
rcvQ :: TBQueue (SignedTransmission BrokerMsg),
|
||||
msgQ :: TBQueue SMPServerTransmission
|
||||
}
|
||||
|
||||
-- | Type synonym for transmission from some SPM server queue.
|
||||
type SMPServerTransmission = (SMPServer, RecipientId, Command 'Broker)
|
||||
type SMPServerTransmission = (SMPServer, RecipientId, BrokerMsg)
|
||||
|
||||
-- | SMP client configuration.
|
||||
data SMPClientConfig = SMPClientConfig
|
||||
@@ -118,7 +118,7 @@ data Request = Request
|
||||
responseVar :: TMVar Response
|
||||
}
|
||||
|
||||
type Response = Either SMPClientError (Command 'Broker)
|
||||
type Response = Either SMPClientError BrokerMsg
|
||||
|
||||
-- | Connects to 'SMPServer' using passed client configuration
|
||||
-- and queue for messages and notifications.
|
||||
@@ -185,12 +185,12 @@ getSMPClient smpServer cfg@SMPClientConfig {qSize, tcpTimeout, smpPing} msgQ dis
|
||||
send SMPClient {sndQ} h = forever $ atomically (readTBQueue sndQ) >>= tPut h
|
||||
|
||||
receive :: Transport c => SMPClient -> THandle c -> IO ()
|
||||
receive SMPClient {rcvQ} h = forever $ tGet fromServer h >>= atomically . writeTBQueue rcvQ
|
||||
receive SMPClient {rcvQ} h = forever $ tGet h >>= atomically . writeTBQueue rcvQ
|
||||
|
||||
ping :: SMPClient -> IO ()
|
||||
ping c = forever $ do
|
||||
threadDelay smpPing
|
||||
runExceptT $ sendSMPCommand c Nothing "" (ClientCmd SSender PING)
|
||||
runExceptT $ sendSMPCommand c Nothing "" PING
|
||||
|
||||
process :: SMPClient -> IO ()
|
||||
process SMPClient {rcvQ, sentCommands} = forever $ do
|
||||
@@ -211,7 +211,7 @@ getSMPClient smpServer cfg@SMPClientConfig {qSize, tcpTimeout, smpPing} msgQ dis
|
||||
Right r -> Right r
|
||||
else Left SMPUnexpectedResponse
|
||||
|
||||
sendMsg :: QueueId -> Either ErrorType (Command 'Broker) -> IO ()
|
||||
sendMsg :: QueueId -> Either ErrorType BrokerMsg -> IO ()
|
||||
sendMsg qId = \case
|
||||
Right cmd -> atomically $ writeTBQueue msgQ (smpServer, qId, cmd)
|
||||
-- TODO send everything else to errQ and log in agent
|
||||
@@ -257,7 +257,7 @@ createSMPQueue ::
|
||||
RcvPublicDhKey ->
|
||||
ExceptT SMPClientError IO QueueIdsKeys
|
||||
createSMPQueue c rpKey rKey dhKey =
|
||||
sendSMPCommand c (Just rpKey) "" (ClientCmd SRecipient $ NEW rKey dhKey) >>= \case
|
||||
sendSMPCommand c (Just rpKey) "" (NEW rKey dhKey) >>= \case
|
||||
IDS qik -> pure qik
|
||||
_ -> throwE SMPUnexpectedResponse
|
||||
|
||||
@@ -266,7 +266,7 @@ createSMPQueue c rpKey rKey dhKey =
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#subscribe-to-queue
|
||||
subscribeSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT SMPClientError IO ()
|
||||
subscribeSMPQueue c@SMPClient {smpServer, msgQ} rpKey rId =
|
||||
sendSMPCommand c (Just rpKey) rId (ClientCmd SRecipient SUB) >>= \case
|
||||
sendSMPCommand c (Just rpKey) rId SUB >>= \case
|
||||
OK -> return ()
|
||||
cmd@MSG {} ->
|
||||
lift . atomically $ writeTBQueue msgQ (smpServer, rId, cmd)
|
||||
@@ -276,20 +276,20 @@ subscribeSMPQueue c@SMPClient {smpServer, msgQ} rpKey rId =
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#subscribe-to-queue-notifications
|
||||
subscribeSMPQueueNotifications :: SMPClient -> NtfPrivateSignKey -> NotifierId -> ExceptT SMPClientError IO ()
|
||||
subscribeSMPQueueNotifications = okSMPCommand $ ClientCmd SNotifier NSUB
|
||||
subscribeSMPQueueNotifications = okSMPCommand NSUB
|
||||
|
||||
-- | Secure the SMP queue by adding a sender public key.
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#secure-queue-command
|
||||
secureSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> SndPublicVerifyKey -> ExceptT SMPClientError IO ()
|
||||
secureSMPQueue c rpKey rId senderKey = okSMPCommand (ClientCmd SRecipient $ KEY senderKey) c rpKey rId
|
||||
secureSMPQueue c rpKey rId senderKey = okSMPCommand (KEY senderKey) c rpKey rId
|
||||
|
||||
-- | Enable notifications for the queue for push notifications server.
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#enable-notifications-command
|
||||
enableSMPQueueNotifications :: SMPClient -> RcvPrivateSignKey -> RecipientId -> NtfPublicVerifyKey -> ExceptT SMPClientError IO NotifierId
|
||||
enableSMPQueueNotifications c rpKey rId notifierKey =
|
||||
sendSMPCommand c (Just rpKey) rId (ClientCmd SRecipient $ NKEY notifierKey) >>= \case
|
||||
sendSMPCommand c (Just rpKey) rId (NKEY notifierKey) >>= \case
|
||||
NID nId -> pure nId
|
||||
_ -> throwE SMPUnexpectedResponse
|
||||
|
||||
@@ -298,7 +298,7 @@ enableSMPQueueNotifications c rpKey rId notifierKey =
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#send-message
|
||||
sendSMPMessage :: SMPClient -> Maybe SndPrivateSignKey -> SenderId -> MsgBody -> ExceptT SMPClientError IO ()
|
||||
sendSMPMessage c spKey sId msg =
|
||||
sendSMPCommand c spKey sId (ClientCmd SSender $ SEND msg) >>= \case
|
||||
sendSMPCommand c spKey sId (SEND msg) >>= \case
|
||||
OK -> pure ()
|
||||
_ -> throwE SMPUnexpectedResponse
|
||||
|
||||
@@ -307,7 +307,7 @@ sendSMPMessage c spKey sId msg =
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#acknowledge-message-delivery
|
||||
ackSMPMessage :: SMPClient -> RcvPrivateSignKey -> QueueId -> ExceptT SMPClientError IO ()
|
||||
ackSMPMessage c@SMPClient {smpServer, msgQ} rpKey rId =
|
||||
sendSMPCommand c (Just rpKey) rId (ClientCmd SRecipient ACK) >>= \case
|
||||
sendSMPCommand c (Just rpKey) rId ACK >>= \case
|
||||
OK -> return ()
|
||||
cmd@MSG {} ->
|
||||
lift . atomically $ writeTBQueue msgQ (smpServer, rId, cmd)
|
||||
@@ -318,26 +318,26 @@ ackSMPMessage c@SMPClient {smpServer, msgQ} rpKey rId =
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#suspend-queue
|
||||
suspendSMPQueue :: SMPClient -> RcvPrivateSignKey -> QueueId -> ExceptT SMPClientError IO ()
|
||||
suspendSMPQueue = okSMPCommand $ ClientCmd SRecipient OFF
|
||||
suspendSMPQueue = okSMPCommand OFF
|
||||
|
||||
-- | Irreversibly delete SMP queue and all messages in it.
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#delete-queue
|
||||
deleteSMPQueue :: SMPClient -> RcvPrivateSignKey -> QueueId -> ExceptT SMPClientError IO ()
|
||||
deleteSMPQueue = okSMPCommand $ ClientCmd SRecipient DEL
|
||||
deleteSMPQueue = okSMPCommand DEL
|
||||
|
||||
okSMPCommand :: ClientCmd -> SMPClient -> C.APrivateSignKey -> QueueId -> ExceptT SMPClientError IO ()
|
||||
okSMPCommand :: PartyI p => Command p -> SMPClient -> C.APrivateSignKey -> QueueId -> ExceptT SMPClientError IO ()
|
||||
okSMPCommand cmd c pKey qId =
|
||||
sendSMPCommand c (Just pKey) qId cmd >>= \case
|
||||
OK -> return ()
|
||||
_ -> throwE SMPUnexpectedResponse
|
||||
|
||||
-- | Send any SMP command ('ClientCmd' type).
|
||||
-- | Send SMP command
|
||||
-- TODO sign all requests (SEND of SMP confirmation would be signed with the same key that is passed to the recipient)
|
||||
sendSMPCommand :: SMPClient -> Maybe C.APrivateSignKey -> QueueId -> ClientCmd -> ExceptT SMPClientError IO (Command 'Broker)
|
||||
sendSMPCommand :: PartyI p => SMPClient -> Maybe C.APrivateSignKey -> QueueId -> Command p -> ExceptT SMPClientError IO BrokerMsg
|
||||
sendSMPCommand SMPClient {sndQ, sentCommands, clientCorrId, sessionId, tcpTimeout} pKey qId cmd = do
|
||||
corrId <- lift_ getNextCorrId
|
||||
t <- signTransmission $ serializeTransmission sessionId (corrId, qId, cmd)
|
||||
t <- signTransmission $ encodeTransmission sessionId (corrId, qId, cmd)
|
||||
ExceptT $ sendRecv corrId t
|
||||
where
|
||||
lift_ :: STM a -> ExceptT SMPClientError IO a
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
{-# LANGUAGE AllowAmbiguousTypes #-}
|
||||
{-# LANGUAGE ConstraintKinds #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DeriveAnyClass #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
@@ -66,9 +67,6 @@ module Simplex.Messaging.Crypto
|
||||
serializePubKeyUri',
|
||||
strPubKeyP,
|
||||
strPubKeyUriP,
|
||||
encodeLenKey',
|
||||
encodeLenKey,
|
||||
binaryLenKeyP,
|
||||
encodePubKey,
|
||||
encodePubKey',
|
||||
binaryPubKeyP,
|
||||
@@ -116,7 +114,6 @@ module Simplex.Messaging.Crypto
|
||||
cbDecrypt,
|
||||
cbNonce,
|
||||
randomCbNonce,
|
||||
cbNonceP,
|
||||
|
||||
-- * SHA256 hash
|
||||
sha256Hash,
|
||||
@@ -168,7 +165,8 @@ import Database.SQLite.Simple.FromField (FromField (..))
|
||||
import Database.SQLite.Simple.ToField (ToField (..))
|
||||
import GHC.TypeLits (ErrorMessage (..), TypeError)
|
||||
import Network.Transport.Internal (decodeWord16, encodeWord16)
|
||||
import Simplex.Messaging.Parsers (base64P, base64UriP, blobFieldParser, parseAll, parseString, word16P)
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Parsers (base64P, base64UriP, blobFieldParser, parseAll, parseString)
|
||||
import Simplex.Messaging.Util ((<$?>))
|
||||
|
||||
type E2EEncryptionVersion = Word16
|
||||
@@ -428,37 +426,45 @@ dhSecret' (ADhSecret a s) = case testEquality a $ sAlgorithm @a of
|
||||
Just Refl -> Right s
|
||||
_ -> Left "bad DH secret algorithm"
|
||||
|
||||
-- | Class for all key types
|
||||
-- | Class for public key types
|
||||
class CryptoPublicKey k where
|
||||
toPubKey :: (forall a. AlgorithmI a => PublicKey a -> b) -> k -> b
|
||||
pubKey :: APublicKey -> Either String k
|
||||
|
||||
-- | X509 encoding of any public key.
|
||||
instance CryptoPublicKey APublicKey where
|
||||
toPubKey f (APublicKey _ k) = f k
|
||||
pubKey = Right
|
||||
|
||||
-- | X509 encoding of signature public key.
|
||||
instance CryptoPublicKey APublicVerifyKey where
|
||||
toPubKey f (APublicVerifyKey _ k) = f k
|
||||
pubKey (APublicKey a k) = case signatureAlgorithm a of
|
||||
Just Dict -> Right $ APublicVerifyKey a k
|
||||
_ -> Left "key does not support signature algorithms"
|
||||
|
||||
-- | X509 encoding of DH public key.
|
||||
instance CryptoPublicKey APublicDhKey where
|
||||
toPubKey f (APublicDhKey _ k) = f k
|
||||
pubKey (APublicKey a k) = case dhAlgorithm a of
|
||||
Just Dict -> Right $ APublicDhKey a k
|
||||
_ -> Left "key does not support DH algorithms"
|
||||
|
||||
-- | X509 encoding of 'PublicKey'.
|
||||
instance AlgorithmI a => CryptoPublicKey (PublicKey a) where
|
||||
toPubKey = id
|
||||
pubKey (APublicKey a k) = case testEquality a $ sAlgorithm @a of
|
||||
Just Refl -> Right k
|
||||
_ -> Left "bad key algorithm"
|
||||
|
||||
instance Encoding APublicVerifyKey where
|
||||
smpEncode k = smpEncode $ encodePubKey k
|
||||
smpP = parseAll binaryPubKeyP <$?> smpP
|
||||
|
||||
instance Encoding APublicDhKey where
|
||||
smpEncode k = smpEncode $ encodePubKey k
|
||||
smpP = parseAll binaryPubKeyP <$?> smpP
|
||||
|
||||
instance AlgorithmI a => Encoding (PublicKey a) where
|
||||
smpEncode k = smpEncode $ encodePubKey' k
|
||||
smpP = parseAll binaryPubKeyP <$?> smpP
|
||||
|
||||
-- | base64 X509 key encoding with algorithm prefix
|
||||
serializePubKey :: CryptoPublicKey k => k -> ByteString
|
||||
serializePubKey = toPubKey serializePubKey'
|
||||
@@ -499,24 +505,6 @@ strPublicKeyP_ b64P = do
|
||||
Just Refl -> pure k
|
||||
_ -> fail $ "public key algorithm " <> show a <> " does not match prefix"
|
||||
|
||||
encodeLenKey :: CryptoPublicKey k => k -> ByteString
|
||||
encodeLenKey = toPubKey encodeLenKey'
|
||||
{-# INLINE encodeLenKey #-}
|
||||
|
||||
-- | binary X509 key encoding with 2-bytes length prefix
|
||||
encodeLenKey' :: PublicKey a -> ByteString
|
||||
encodeLenKey' k =
|
||||
let s = encodePubKey' k
|
||||
len = fromIntegral $ B.length s
|
||||
in encodeWord16 len <> s
|
||||
{-# INLINE encodeLenKey' #-}
|
||||
|
||||
-- | binary X509 key parser with 2-bytes length prefix
|
||||
binaryLenKeyP :: CryptoPublicKey k => Parser k
|
||||
binaryLenKeyP = do
|
||||
len <- fromIntegral <$> word16P
|
||||
parseAll binaryPubKeyP <$?> A.take len
|
||||
|
||||
encodePubKey :: CryptoPublicKey pk => pk -> ByteString
|
||||
encodePubKey = toPubKey encodePubKey'
|
||||
{-# INLINE encodePubKey #-}
|
||||
@@ -926,8 +914,9 @@ cbNonce s
|
||||
randomCbNonce :: IO CbNonce
|
||||
randomCbNonce = CbNonce <$> getRandomBytes 24
|
||||
|
||||
cbNonceP :: Parser CbNonce
|
||||
cbNonceP = CbNonce <$> A.take 24
|
||||
instance Encoding CbNonce where
|
||||
smpEncode = unCbNonce
|
||||
smpP = CbNonce <$> A.take 24
|
||||
|
||||
xSalsa20 :: DhSecret X25519 -> ByteString -> ByteString -> (ByteString, ByteString)
|
||||
xSalsa20 (DhSecretX25519 shared) nonce msg = (rs, msg')
|
||||
|
||||
@@ -23,9 +23,9 @@ import Data.Map.Strict (Map)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Data.Word (Word32)
|
||||
import Network.Transport.Internal (encodeWord16, encodeWord32)
|
||||
import Simplex.Messaging.Crypto
|
||||
import Simplex.Messaging.Parsers (parseE, parseE', word16P, word32P)
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Parsers (parseE, parseE')
|
||||
import Simplex.Messaging.Util (tryE)
|
||||
|
||||
data Ratchet a = Ratchet
|
||||
@@ -143,22 +143,16 @@ paddedHeaderLen = 128
|
||||
fullHeaderLen :: Int
|
||||
fullHeaderLen = paddedHeaderLen + authTagSize + ivSize @AES256
|
||||
|
||||
serializeMsgHeader' :: AlgorithmI a => MsgHeader a -> ByteString
|
||||
serializeMsgHeader' MsgHeader {msgVersion, msgLatestVersion, msgDHRs, msgPN, msgNs} =
|
||||
encodeWord16 msgVersion
|
||||
<> encodeWord16 msgLatestVersion
|
||||
<> encodeLenKey msgDHRs
|
||||
<> encodeWord32 msgPN
|
||||
<> encodeWord32 msgNs
|
||||
|
||||
msgHeaderP' :: AlgorithmI a => Parser (MsgHeader a)
|
||||
msgHeaderP' = do
|
||||
msgVersion <- word16P
|
||||
msgLatestVersion <- word16P
|
||||
msgDHRs <- binaryLenKeyP
|
||||
msgPN <- word32P
|
||||
msgNs <- word32P
|
||||
pure MsgHeader {msgVersion, msgLatestVersion, msgDHRs, msgPN, msgNs}
|
||||
instance AlgorithmI a => Encoding (MsgHeader a) where
|
||||
smpEncode MsgHeader {msgVersion, msgLatestVersion, msgDHRs, msgPN, msgNs} =
|
||||
smpEncode (msgVersion, msgLatestVersion, msgDHRs, msgPN, msgNs)
|
||||
smpP = do
|
||||
msgVersion <- smpP
|
||||
msgLatestVersion <- smpP
|
||||
msgDHRs <- smpP
|
||||
msgPN <- smpP
|
||||
msgNs <- smpP
|
||||
pure MsgHeader {msgVersion, msgLatestVersion, msgDHRs, msgPN, msgNs}
|
||||
|
||||
data EncHeader = EncHeader
|
||||
{ ehBody :: ByteString,
|
||||
@@ -213,7 +207,7 @@ rcEncrypt' rc@Ratchet {rcSnd = Just sr@SndRatchet {rcCKs, rcHKs}, rcNs, rcAD} pa
|
||||
where
|
||||
-- header = HEADER(state.DHRs, state.PN, state.Ns)
|
||||
msgHeader =
|
||||
serializeMsgHeader'
|
||||
smpEncode
|
||||
MsgHeader
|
||||
{ msgVersion = rcVersion rc,
|
||||
msgLatestVersion = currentE2EVersion,
|
||||
@@ -352,7 +346,7 @@ rcDecrypt' rc@Ratchet {rcRcv, rcMKSkipped, rcAD} msg' = do
|
||||
decryptNextHeader hdr = (AdvanceRatchet,) <$> decryptHeader (rcNHKr rc) hdr
|
||||
decryptHeader k EncHeader {ehBody, ehAuthTag, ehIV} = do
|
||||
header <- decryptAEAD k ehIV rcAD ehBody ehAuthTag `catchE` \_ -> throwE CERatchetHeader
|
||||
parseE' CryptoHeaderError msgHeaderP' header
|
||||
parseE' CryptoHeaderError smpP header
|
||||
decryptMessage :: MessageKey -> EncMessage -> ExceptT CryptoError IO (Either CryptoError ByteString)
|
||||
decryptMessage (MessageKey mk iv) EncMessage {emHeader, emBody, emAuthTag} =
|
||||
-- DECRYPT(mk, ciphertext, CONCAT(AD, enc_header))
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
|
||||
module Simplex.Messaging.Encoding (Encoding (..), Tail (..)) where
|
||||
|
||||
import Data.Attoparsec.ByteString.Char8 (Parser)
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Bits (shiftL, shiftR, (.|.))
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.ByteString.Internal (c2w, w2c)
|
||||
import Data.Int (Int64)
|
||||
import Data.Time.Clock.System (SystemTime (..))
|
||||
import Data.Word (Word16, Word32)
|
||||
import Network.Transport.Internal (decodeWord16, decodeWord32, encodeWord16, encodeWord32)
|
||||
|
||||
class Encoding a where
|
||||
smpEncode :: a -> ByteString
|
||||
smpP :: Parser a
|
||||
|
||||
instance Encoding Char where
|
||||
smpEncode = B.singleton
|
||||
smpP = A.anyChar
|
||||
|
||||
instance Encoding Word16 where
|
||||
smpEncode = encodeWord16
|
||||
smpP = decodeWord16 <$> A.take 2
|
||||
|
||||
instance Encoding Word32 where
|
||||
smpEncode = encodeWord32
|
||||
smpP = decodeWord32 <$> A.take 4
|
||||
|
||||
instance Encoding Int64 where
|
||||
smpEncode i = w32 (i `shiftR` 32) <> w32 i
|
||||
smpP = do
|
||||
l <- w32P
|
||||
r <- w32P
|
||||
pure $ (l `shiftL` 32) .|. r
|
||||
|
||||
w32 :: Int64 -> ByteString
|
||||
w32 = smpEncode @Word32 . fromIntegral
|
||||
|
||||
w32P :: Parser Int64
|
||||
w32P = fromIntegral <$> smpP @Word32
|
||||
|
||||
-- ByteStrings are assumed no longer than 255 bytes
|
||||
instance Encoding ByteString where
|
||||
smpEncode s = B.cons (w2c len) s where len = fromIntegral $ B.length s
|
||||
smpP = A.take . fromIntegral . c2w =<< A.anyChar
|
||||
|
||||
newtype Tail = Tail {unTail :: ByteString}
|
||||
|
||||
instance Encoding Tail where
|
||||
smpEncode = unTail
|
||||
smpP = Tail <$> A.takeByteString
|
||||
|
||||
instance Encoding SystemTime where
|
||||
smpEncode = smpEncode . systemSeconds
|
||||
smpP = MkSystemTime <$> smpP <*> pure 0
|
||||
|
||||
instance (Encoding a, Encoding b) => Encoding (a, b) where
|
||||
smpEncode (a, b) = smpEncode a <> smpEncode b
|
||||
smpP = (,) <$> smpP <*> smpP
|
||||
|
||||
instance (Encoding a, Encoding b, Encoding c) => Encoding (a, b, c) where
|
||||
smpEncode (a, b, c) = smpEncode a <> smpEncode b <> smpEncode c
|
||||
smpP = (,,) <$> smpP <*> smpP <*> smpP
|
||||
|
||||
instance (Encoding a, Encoding b, Encoding c, Encoding d) => Encoding (a, b, c, d) where
|
||||
smpEncode (a, b, c, d) = smpEncode a <> smpEncode b <> smpEncode c <> smpEncode d
|
||||
smpP = (,,,) <$> smpP <*> smpP <*> smpP <*> smpP
|
||||
|
||||
instance (Encoding a, Encoding b, Encoding c, Encoding d, Encoding e) => Encoding (a, b, c, d, e) where
|
||||
smpEncode (a, b, c, d, e) = smpEncode a <> smpEncode b <> smpEncode c <> smpEncode d <> smpEncode e
|
||||
smpP = (,,,,) <$> smpP <*> smpP <*> smpP <*> smpP <*> smpP
|
||||
@@ -15,12 +15,10 @@ import Data.Char (isAlphaNum)
|
||||
import Data.Time.Clock (UTCTime)
|
||||
import Data.Time.ISO8601 (parseISO8601)
|
||||
import Data.Typeable (Typeable)
|
||||
import Data.Word (Word16, Word32)
|
||||
import Database.SQLite.Simple (ResultError (..), SQLData (..))
|
||||
import Database.SQLite.Simple.FromField (FieldParser, returnError)
|
||||
import Database.SQLite.Simple.Internal (Field (..))
|
||||
import Database.SQLite.Simple.Ok (Ok (Ok))
|
||||
import Network.Transport.Internal (decodeWord16, decodeWord32)
|
||||
import Simplex.Messaging.Util ((<$?>))
|
||||
import Text.Read (readMaybe)
|
||||
|
||||
@@ -50,12 +48,6 @@ rawBase64UriP = A.takeWhile1 (\c -> isAlphaNum c || c == '-' || c == '_')
|
||||
tsISO8601P :: Parser UTCTime
|
||||
tsISO8601P = maybe (fail "timestamp") pure . parseISO8601 . B.unpack =<< A.takeTill wordEnd
|
||||
|
||||
word16P :: Parser Word16
|
||||
word16P = decodeWord16 <$> A.take 2
|
||||
|
||||
word32P :: Parser Word32
|
||||
word32P = decodeWord32 <$> A.take 4
|
||||
|
||||
parse :: Parser a -> e -> (ByteString -> Either e a)
|
||||
parse parser err = first (const err) . parseAll parser
|
||||
|
||||
|
||||
+341
-255
@@ -1,16 +1,18 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PolyKinds #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
{-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-}
|
||||
|
||||
@@ -28,23 +30,22 @@
|
||||
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md
|
||||
module Simplex.Messaging.Protocol
|
||||
( -- * SMP protocol parameters
|
||||
clientVersion,
|
||||
smpClientVersion,
|
||||
maxMessageLength,
|
||||
e2eEncMessageLength,
|
||||
|
||||
-- * SMP protocol types
|
||||
Protocol,
|
||||
Command (..),
|
||||
CommandI (..),
|
||||
Party (..),
|
||||
ClientParty (..),
|
||||
Cmd (..),
|
||||
ClientCmd (..),
|
||||
BrokerMsg (..),
|
||||
SParty (..),
|
||||
PartyI (..),
|
||||
QueueIdsKeys (..),
|
||||
ErrorType (..),
|
||||
CommandError (..),
|
||||
Transmission,
|
||||
BrokerTransmission,
|
||||
SignedTransmission,
|
||||
SentRawTransmission,
|
||||
SignedRawTransmission,
|
||||
@@ -65,58 +66,48 @@ module Simplex.Messaging.Protocol
|
||||
SndPublicVerifyKey,
|
||||
NtfPrivateSignKey,
|
||||
NtfPublicVerifyKey,
|
||||
Encoded,
|
||||
MsgId,
|
||||
MsgBody,
|
||||
|
||||
-- * Parse and serialize
|
||||
serializeTransmission,
|
||||
serializeErrorType,
|
||||
encodeTransmission,
|
||||
transmissionP,
|
||||
errorTypeP,
|
||||
serializeEncMessage,
|
||||
encMessageP,
|
||||
serializeClientMessage,
|
||||
clientMessageP,
|
||||
encodeProtocol,
|
||||
|
||||
-- * TCP transport functions
|
||||
tPut,
|
||||
tGet,
|
||||
fromClient,
|
||||
fromServer,
|
||||
|
||||
-- * exports for tests
|
||||
CommandTag (..),
|
||||
BrokerMsgTag (..),
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Applicative ((<|>))
|
||||
import Control.Monad
|
||||
import Control.Applicative (optional)
|
||||
import Control.Monad.Except
|
||||
import Data.Attoparsec.ByteString.Char8 (Parser)
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Bifunctor (first)
|
||||
import Data.ByteString.Base64
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Constraint (Dict (..))
|
||||
import Data.Functor (($>))
|
||||
import Data.Kind
|
||||
import Data.Maybe (isNothing)
|
||||
import Data.String
|
||||
import Data.Time.Clock
|
||||
import Data.Time.ISO8601
|
||||
import Data.Time.Clock.System (SystemTime)
|
||||
import Data.Type.Equality
|
||||
import Data.Word (Word16)
|
||||
import GHC.Generics (Generic)
|
||||
import GHC.TypeLits (ErrorMessage (..), TypeError)
|
||||
import Generic.Random (genericArbitraryU)
|
||||
import Network.Transport.Internal (encodeWord16)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Parsers
|
||||
import Simplex.Messaging.Transport (THandle (..), Transport, TransportError (..), tGetBlock, tPutBlock)
|
||||
import Simplex.Messaging.Util
|
||||
import Simplex.Messaging.Util ((<$?>))
|
||||
import Simplex.Messaging.Version
|
||||
import Test.QuickCheck (Arbitrary (..))
|
||||
|
||||
clientVersion :: Word16
|
||||
clientVersion = 1
|
||||
smpClientVersion :: VersionRange
|
||||
smpClientVersion = mkVersionRange 1 1
|
||||
|
||||
maxMessageLength :: Int
|
||||
maxMessageLength = 15968
|
||||
@@ -124,19 +115,17 @@ maxMessageLength = 15968
|
||||
e2eEncMessageLength :: Int
|
||||
e2eEncMessageLength = 15842
|
||||
|
||||
-- | SMP protocol participants.
|
||||
data Party = Broker | Recipient | Sender | Notifier
|
||||
-- | SMP protocol clients
|
||||
data Party = Recipient | Sender | Notifier
|
||||
deriving (Show)
|
||||
|
||||
-- | Singleton types for SMP protocol participants.
|
||||
-- | Singleton types for SMP protocol clients
|
||||
data SParty :: Party -> Type where
|
||||
SBroker :: SParty Broker
|
||||
SRecipient :: SParty Recipient
|
||||
SSender :: SParty Sender
|
||||
SNotifier :: SParty Notifier
|
||||
|
||||
instance TestEquality SParty where
|
||||
testEquality SBroker SBroker = Just Refl
|
||||
testEquality SRecipient SRecipient = Just Refl
|
||||
testEquality SSender SSender = Just Refl
|
||||
testEquality SNotifier SNotifier = Just Refl
|
||||
@@ -146,35 +135,20 @@ deriving instance Show (SParty p)
|
||||
|
||||
class PartyI (p :: Party) where sParty :: SParty p
|
||||
|
||||
instance PartyI Broker where sParty = SBroker
|
||||
|
||||
instance PartyI Recipient where sParty = SRecipient
|
||||
|
||||
instance PartyI Sender where sParty = SSender
|
||||
|
||||
instance PartyI Notifier where sParty = SNotifier
|
||||
|
||||
data ClientParty = forall p. IsClient p => CP (SParty p)
|
||||
|
||||
deriving instance Show ClientParty
|
||||
|
||||
-- | Type for command or response of any participant.
|
||||
-- | Type for client command of any participant.
|
||||
data Cmd = forall p. PartyI p => Cmd (SParty p) (Command p)
|
||||
|
||||
deriving instance Show Cmd
|
||||
|
||||
-- | Type for command or response of any participant.
|
||||
data ClientCmd = forall p. (PartyI p, IsClient p) => ClientCmd (SParty p) (Command p)
|
||||
|
||||
class CommandI c where
|
||||
serializeCommand :: c -> ByteString
|
||||
commandP :: Parser c
|
||||
|
||||
-- | Parsed SMP transmission without signature, size and session ID.
|
||||
type Transmission c = (CorrId, QueueId, c)
|
||||
|
||||
type BrokerTransmission = Transmission (Command Broker)
|
||||
|
||||
-- | signed parsed transmission, with original raw bytes and parsing error.
|
||||
type SignedTransmission c = (Maybe C.ASignature, Signed, Transmission (Either ErrorType c))
|
||||
|
||||
@@ -206,10 +180,10 @@ type SenderId = QueueId
|
||||
type NotifierId = QueueId
|
||||
|
||||
-- | SMP queue ID on the server.
|
||||
type QueueId = Encoded
|
||||
type QueueId = ByteString
|
||||
|
||||
-- | Parameterized type for SMP protocol commands from all participants.
|
||||
data Command (a :: Party) where
|
||||
-- | Parameterized type for SMP protocol commands from all clients.
|
||||
data Command (p :: Party) where
|
||||
-- SMP recipient commands
|
||||
NEW :: RcvPublicVerifyKey -> RcvPublicDhKey -> Command Recipient
|
||||
SUB :: Command Recipient
|
||||
@@ -223,33 +197,120 @@ data Command (a :: Party) where
|
||||
PING :: Command Sender
|
||||
-- SMP notification subscriber commands
|
||||
NSUB :: Command Notifier
|
||||
-- SMP broker commands (responses, messages, notifications)
|
||||
IDS :: QueueIdsKeys -> Command Broker
|
||||
MSG :: MsgId -> UTCTime -> MsgBody -> Command Broker
|
||||
NID :: NotifierId -> Command Broker
|
||||
NMSG :: Command Broker
|
||||
END :: Command Broker
|
||||
OK :: Command Broker
|
||||
ERR :: ErrorType -> Command Broker
|
||||
PONG :: Command Broker
|
||||
|
||||
deriving instance Show (Command a)
|
||||
deriving instance Show (Command p)
|
||||
|
||||
deriving instance Eq (Command a)
|
||||
deriving instance Eq (Command p)
|
||||
|
||||
type family IsClient p :: Constraint where
|
||||
IsClient Recipient = ()
|
||||
IsClient Sender = ()
|
||||
IsClient Notifier = ()
|
||||
IsClient p =
|
||||
(Int ~ Bool, TypeError (Text "Party " :<>: ShowType p :<>: Text " is not a Client"))
|
||||
data BrokerMsg where
|
||||
-- SMP broker messages (responses, client messages, notifications)
|
||||
IDS :: QueueIdsKeys -> BrokerMsg
|
||||
MSG :: MsgId -> SystemTime -> MsgBody -> BrokerMsg
|
||||
NID :: NotifierId -> BrokerMsg
|
||||
NMSG :: BrokerMsg
|
||||
END :: BrokerMsg
|
||||
OK :: BrokerMsg
|
||||
ERR :: ErrorType -> BrokerMsg
|
||||
PONG :: BrokerMsg
|
||||
deriving (Eq, Show)
|
||||
|
||||
isClient :: SParty p -> Maybe (Dict (IsClient p))
|
||||
isClient = \case
|
||||
SRecipient -> Just Dict
|
||||
SSender -> Just Dict
|
||||
SNotifier -> Just Dict
|
||||
_ -> Nothing
|
||||
-- * SMP command tags
|
||||
|
||||
data CommandTag (p :: Party) where
|
||||
NEW_ :: CommandTag Recipient
|
||||
SUB_ :: CommandTag Recipient
|
||||
KEY_ :: CommandTag Recipient
|
||||
NKEY_ :: CommandTag Recipient
|
||||
ACK_ :: CommandTag Recipient
|
||||
OFF_ :: CommandTag Recipient
|
||||
DEL_ :: CommandTag Recipient
|
||||
SEND_ :: CommandTag Sender
|
||||
PING_ :: CommandTag Sender
|
||||
NSUB_ :: CommandTag Notifier
|
||||
|
||||
data CmdTag = forall p. PartyI p => CT (SParty p) (CommandTag p)
|
||||
|
||||
deriving instance Show (CommandTag p)
|
||||
|
||||
deriving instance Show CmdTag
|
||||
|
||||
data BrokerMsgTag
|
||||
= IDS_
|
||||
| MSG_
|
||||
| NID_
|
||||
| NMSG_
|
||||
| END_
|
||||
| OK_
|
||||
| ERR_
|
||||
| PONG_
|
||||
deriving (Show)
|
||||
|
||||
class ProtocolMsgTag t where
|
||||
decodeTag :: ByteString -> Maybe t
|
||||
|
||||
messageTagP :: ProtocolMsgTag t => Parser t
|
||||
messageTagP =
|
||||
maybe (fail "bad command") pure . decodeTag
|
||||
=<< (A.takeTill (== ' ') <* optional A.space)
|
||||
|
||||
instance PartyI p => Encoding (CommandTag p) where
|
||||
smpEncode = \case
|
||||
NEW_ -> "NEW"
|
||||
SUB_ -> "SUB"
|
||||
KEY_ -> "KEY"
|
||||
NKEY_ -> "NKEY"
|
||||
ACK_ -> "ACK"
|
||||
OFF_ -> "OFF"
|
||||
DEL_ -> "DEL"
|
||||
SEND_ -> "SEND"
|
||||
PING_ -> "PING"
|
||||
NSUB_ -> "NSUB"
|
||||
smpP = messageTagP
|
||||
|
||||
instance ProtocolMsgTag CmdTag where
|
||||
decodeTag = \case
|
||||
"NEW" -> Just $ CT SRecipient NEW_
|
||||
"SUB" -> Just $ CT SRecipient SUB_
|
||||
"KEY" -> Just $ CT SRecipient KEY_
|
||||
"NKEY" -> Just $ CT SRecipient NKEY_
|
||||
"ACK" -> Just $ CT SRecipient ACK_
|
||||
"OFF" -> Just $ CT SRecipient OFF_
|
||||
"DEL" -> Just $ CT SRecipient DEL_
|
||||
"SEND" -> Just $ CT SSender SEND_
|
||||
"PING" -> Just $ CT SSender PING_
|
||||
"NSUB" -> Just $ CT SNotifier NSUB_
|
||||
_ -> Nothing
|
||||
|
||||
instance Encoding CmdTag where
|
||||
smpEncode (CT _ t) = smpEncode t
|
||||
smpP = messageTagP
|
||||
|
||||
instance PartyI p => ProtocolMsgTag (CommandTag p) where
|
||||
decodeTag s = decodeTag s >>= (\(CT _ t) -> checkParty' t)
|
||||
|
||||
instance Encoding BrokerMsgTag where
|
||||
smpEncode = \case
|
||||
IDS_ -> "IDS"
|
||||
MSG_ -> "MSG"
|
||||
NID_ -> "NID"
|
||||
NMSG_ -> "NMSG"
|
||||
END_ -> "END"
|
||||
OK_ -> "OK"
|
||||
ERR_ -> "ERR"
|
||||
PONG_ -> "PONG"
|
||||
smpP = messageTagP
|
||||
|
||||
instance ProtocolMsgTag BrokerMsgTag where
|
||||
decodeTag = \case
|
||||
"IDS" -> Just IDS_
|
||||
"MSG" -> Just MSG_
|
||||
"NID" -> Just NID_
|
||||
"NMSG" -> Just NMSG_
|
||||
"END" -> Just END_
|
||||
"OK" -> Just OK_
|
||||
"ERR" -> Just ERR_
|
||||
"PONG" -> Just PONG_
|
||||
_ -> Nothing
|
||||
|
||||
-- | SMP message body format
|
||||
data EncMessage = EncMessage
|
||||
@@ -263,22 +324,18 @@ data PubHeader = PubHeader
|
||||
phE2ePubDhKey :: C.PublicKeyX25519
|
||||
}
|
||||
|
||||
serializePubHeader :: PubHeader -> ByteString
|
||||
serializePubHeader (PubHeader v k) = encodeWord16 v <> C.encodeLenKey' k
|
||||
instance Encoding PubHeader where
|
||||
smpEncode (PubHeader v k) = smpEncode (v, k)
|
||||
smpP = PubHeader <$> smpP <*> smpP
|
||||
|
||||
pubHeaderP :: Parser PubHeader
|
||||
pubHeaderP = PubHeader <$> word16P <*> C.binaryLenKeyP
|
||||
|
||||
serializeEncMessage :: EncMessage -> ByteString
|
||||
serializeEncMessage EncMessage {emHeader, emNonce, emBody} =
|
||||
serializePubHeader emHeader <> C.unCbNonce emNonce <> emBody
|
||||
|
||||
encMessageP :: Parser EncMessage
|
||||
encMessageP = do
|
||||
emHeader <- pubHeaderP
|
||||
emNonce <- C.cbNonceP
|
||||
emBody <- A.takeByteString
|
||||
pure EncMessage {emHeader, emNonce, emBody}
|
||||
instance Encoding EncMessage where
|
||||
smpEncode EncMessage {emHeader, emNonce, emBody} =
|
||||
smpEncode emHeader <> smpEncode emNonce <> emBody
|
||||
smpP = do
|
||||
emHeader <- smpP
|
||||
emNonce <- smpP
|
||||
emBody <- A.takeByteString
|
||||
pure EncMessage {emHeader, emNonce, emBody}
|
||||
|
||||
data ClientMessage = ClientMessage PrivHeader ByteString
|
||||
|
||||
@@ -286,26 +343,19 @@ data PrivHeader
|
||||
= PHConfirmation C.APublicVerifyKey
|
||||
| PHEmpty
|
||||
|
||||
serializePrivHeader :: PrivHeader -> ByteString
|
||||
serializePrivHeader = \case
|
||||
PHConfirmation k -> "K" <> C.encodeLenKey k
|
||||
PHEmpty -> " "
|
||||
instance Encoding PrivHeader where
|
||||
smpEncode = \case
|
||||
PHConfirmation k -> "K" <> smpEncode k
|
||||
PHEmpty -> " "
|
||||
smpP =
|
||||
A.anyChar >>= \case
|
||||
'K' -> PHConfirmation <$> smpP
|
||||
' ' -> pure PHEmpty
|
||||
_ -> fail "invalid PrivHeader"
|
||||
|
||||
privHeaderP :: Parser PrivHeader
|
||||
privHeaderP =
|
||||
A.anyChar >>= \case
|
||||
'K' -> PHConfirmation <$> C.binaryLenKeyP
|
||||
' ' -> pure PHEmpty
|
||||
_ -> fail "invalid PrivHeader"
|
||||
|
||||
serializeClientMessage :: ClientMessage -> ByteString
|
||||
serializeClientMessage (ClientMessage h msg) = serializePrivHeader h <> msg
|
||||
|
||||
clientMessageP :: Parser ClientMessage
|
||||
clientMessageP = ClientMessage <$> privHeaderP <*> A.takeByteString
|
||||
|
||||
-- | Base-64 encoded string.
|
||||
type Encoded = ByteString
|
||||
instance Encoding ClientMessage where
|
||||
smpEncode (ClientMessage h msg) = smpEncode h <> msg
|
||||
smpP = ClientMessage <$> smpP <*> A.takeByteString
|
||||
|
||||
-- | Transmission correlation ID.
|
||||
newtype CorrId = CorrId {bs :: ByteString} deriving (Eq, Ord, Show)
|
||||
@@ -350,7 +400,7 @@ type NtfPrivateSignKey = C.APrivateSignKey
|
||||
type NtfPublicVerifyKey = C.APublicVerifyKey
|
||||
|
||||
-- | SMP message server ID.
|
||||
type MsgId = Encoded
|
||||
type MsgId = ByteString
|
||||
|
||||
-- | SMP message body.
|
||||
type MsgBody = ByteString
|
||||
@@ -379,8 +429,8 @@ data ErrorType
|
||||
|
||||
-- | SMP command error type.
|
||||
data CommandError
|
||||
= -- | server response sent from client or vice versa
|
||||
PROHIBITED
|
||||
= -- | unknown command
|
||||
UNKNOWN
|
||||
| -- | error parsing command
|
||||
SYNTAX
|
||||
| -- | transmission has no required credentials (signature or queue ID)
|
||||
@@ -398,141 +448,209 @@ instance Arbitrary CommandError where arbitrary = genericArbitraryU
|
||||
-- | SMP transmission parser.
|
||||
transmissionP :: Parser RawTransmission
|
||||
transmissionP = do
|
||||
signature <- segment
|
||||
signature <- smpP
|
||||
signed <- A.takeByteString
|
||||
either fail pure $ parseAll (trn signature signed) signed
|
||||
where
|
||||
segment = A.takeTill (== ' ') <* A.space
|
||||
trn signature signed = do
|
||||
sessId <- segment
|
||||
corrId <- segment
|
||||
queueId <- segment
|
||||
sessId <- smpP
|
||||
corrId <- smpP
|
||||
queueId <- smpP
|
||||
command <- A.takeByteString
|
||||
pure RawTransmission {signature, signed, sessId, corrId, queueId, command}
|
||||
|
||||
instance CommandI Cmd where
|
||||
serializeCommand (Cmd _ cmd) = serializeCommand cmd
|
||||
commandP =
|
||||
"NEW " *> newCmd
|
||||
<|> "IDS " *> idsResp
|
||||
<|> "SUB" $> Cmd SRecipient SUB
|
||||
<|> "KEY " *> keyCmd
|
||||
<|> "NKEY " *> nKeyCmd
|
||||
<|> "NID " *> nIdsResp
|
||||
<|> "ACK" $> Cmd SRecipient ACK
|
||||
<|> "OFF" $> Cmd SRecipient OFF
|
||||
<|> "DEL" $> Cmd SRecipient DEL
|
||||
<|> "SEND " *> sendCmd
|
||||
<|> "PING" $> Cmd SSender PING
|
||||
<|> "NSUB" $> Cmd SNotifier NSUB
|
||||
<|> "MSG " *> message
|
||||
<|> "NMSG" $> Cmd SBroker NMSG
|
||||
<|> "END" $> Cmd SBroker END
|
||||
<|> "OK" $> Cmd SBroker OK
|
||||
<|> "ERR " *> serverError
|
||||
<|> "PONG" $> Cmd SBroker PONG
|
||||
class Protocol msg where
|
||||
type Tag msg
|
||||
encodeProtocol :: msg -> ByteString
|
||||
protocolP :: Tag msg -> Parser msg
|
||||
checkCredentials :: SignedRawTransmission -> msg -> Either ErrorType msg
|
||||
|
||||
instance PartyI p => Protocol (Command p) where
|
||||
type Tag (Command p) = CommandTag p
|
||||
encodeProtocol = \case
|
||||
NEW rKey dhKey -> e (NEW_, ' ', rKey, dhKey)
|
||||
SUB -> e SUB_
|
||||
KEY k -> e (KEY_, ' ', k)
|
||||
NKEY k -> e (NKEY_, ' ', k)
|
||||
ACK -> e ACK_
|
||||
OFF -> e OFF_
|
||||
DEL -> e DEL_
|
||||
SEND msg -> e (SEND_, ' ', Tail msg)
|
||||
PING -> e PING_
|
||||
NSUB -> e NSUB_
|
||||
where
|
||||
newCmd = Cmd SRecipient <$> (NEW <$> C.strPubKeyP <* A.space <*> C.strPubKeyP)
|
||||
idsResp = Cmd SBroker . IDS <$> qik
|
||||
qik = QIK <$> base64P <* A.space <*> base64P <* A.space <*> C.strPubKeyP
|
||||
nIdsResp = Cmd SBroker . NID <$> base64P
|
||||
keyCmd = Cmd SRecipient . KEY <$> C.strPubKeyP
|
||||
nKeyCmd = Cmd SRecipient . NKEY <$> C.strPubKeyP
|
||||
sendCmd = Cmd SSender . SEND <$> A.takeByteString
|
||||
message = do
|
||||
msgId <- base64P <* A.space
|
||||
ts <- tsISO8601P <* A.space
|
||||
Cmd SBroker . MSG msgId ts <$> A.takeByteString
|
||||
serverError = Cmd SBroker . ERR <$> errorTypeP
|
||||
e :: Encoding a => a -> ByteString
|
||||
e = smpEncode
|
||||
|
||||
instance CommandI ClientCmd where
|
||||
serializeCommand (ClientCmd _ cmd) = serializeCommand cmd
|
||||
commandP = clientCmd <$?> commandP
|
||||
protocolP tag = (\(Cmd _ c) -> checkParty c) <$?> protocolP (CT (sParty @p) tag)
|
||||
|
||||
checkCredentials (sig, _, queueId, _) cmd = case cmd of
|
||||
-- NEW must have signature but NOT queue ID
|
||||
NEW {}
|
||||
| isNothing sig -> Left $ CMD NO_AUTH
|
||||
| not (B.null queueId) -> Left $ CMD HAS_AUTH
|
||||
| otherwise -> Right cmd
|
||||
-- SEND must have queue ID, signature is not always required
|
||||
SEND _
|
||||
| B.null queueId -> Left $ CMD NO_QUEUE
|
||||
| otherwise -> Right cmd
|
||||
-- PING must not have queue ID or signature
|
||||
PING
|
||||
| isNothing sig && B.null queueId -> Right cmd
|
||||
| otherwise -> Left $ CMD HAS_AUTH
|
||||
-- other client commands must have both signature and queue ID
|
||||
_
|
||||
| isNothing sig || B.null queueId -> Left $ CMD NO_AUTH
|
||||
| otherwise -> Right cmd
|
||||
|
||||
instance Protocol Cmd where
|
||||
type Tag Cmd = CmdTag
|
||||
encodeProtocol (Cmd _ c) = encodeProtocol c
|
||||
|
||||
protocolP = \case
|
||||
CT SRecipient tag ->
|
||||
Cmd SRecipient <$> case tag of
|
||||
NEW_ -> NEW <$> _smpP <*> smpP
|
||||
SUB_ -> pure SUB
|
||||
KEY_ -> KEY <$> _smpP
|
||||
NKEY_ -> NKEY <$> _smpP
|
||||
ACK_ -> pure ACK
|
||||
OFF_ -> pure OFF
|
||||
DEL_ -> pure DEL
|
||||
CT SSender tag ->
|
||||
Cmd SSender <$> case tag of
|
||||
SEND_ -> SEND . unTail <$> _smpP
|
||||
PING_ -> pure PING
|
||||
CT SNotifier NSUB_ -> pure $ Cmd SNotifier NSUB
|
||||
|
||||
checkCredentials t (Cmd p c) = Cmd p <$> checkCredentials t c
|
||||
|
||||
instance Protocol BrokerMsg where
|
||||
type Tag BrokerMsg = BrokerMsgTag
|
||||
encodeProtocol = \case
|
||||
IDS (QIK rcvId sndId srvDh) -> e (IDS_, ' ', rcvId, sndId, srvDh)
|
||||
MSG msgId ts msgBody -> e (MSG_, ' ', msgId, ts, Tail msgBody)
|
||||
NID nId -> e (NID_, ' ', nId)
|
||||
NMSG -> e NMSG_
|
||||
END -> e END_
|
||||
OK -> e OK_
|
||||
ERR err -> e (ERR_, ' ', err)
|
||||
PONG -> e PONG_
|
||||
where
|
||||
clientCmd :: Cmd -> Either String ClientCmd
|
||||
clientCmd (Cmd p cmd) = case isClient p of
|
||||
Just Dict -> Right (ClientCmd p cmd)
|
||||
_ -> Left "not a client command"
|
||||
e :: Encoding a => a -> ByteString
|
||||
e = smpEncode
|
||||
|
||||
-- | Parse SMP command.
|
||||
parseCommand :: ByteString -> Either ErrorType Cmd
|
||||
parseCommand = parse commandP $ CMD SYNTAX
|
||||
protocolP = \case
|
||||
MSG_ -> MSG <$> _smpP <*> smpP <*> (unTail <$> smpP)
|
||||
IDS_ -> IDS <$> (QIK <$> _smpP <*> smpP <*> smpP)
|
||||
NID_ -> NID <$> _smpP
|
||||
NMSG_ -> pure NMSG
|
||||
END_ -> pure END
|
||||
OK_ -> pure OK
|
||||
ERR_ -> ERR <$> _smpP
|
||||
PONG_ -> pure PONG
|
||||
|
||||
instance PartyI p => CommandI (Command p) where
|
||||
commandP = command' <$?> commandP
|
||||
where
|
||||
command' :: Cmd -> Either String (Command p)
|
||||
command' (Cmd p cmd) = case testEquality p $ sParty @p of
|
||||
Just Refl -> Right cmd
|
||||
_ -> Left "bad command party"
|
||||
serializeCommand = \case
|
||||
NEW rKey dhKey -> B.unwords ["NEW", C.serializePubKey rKey, C.serializePubKey' dhKey]
|
||||
KEY sKey -> "KEY " <> C.serializePubKey sKey
|
||||
NKEY nKey -> "NKEY " <> C.serializePubKey nKey
|
||||
SUB -> "SUB"
|
||||
ACK -> "ACK"
|
||||
OFF -> "OFF"
|
||||
DEL -> "DEL"
|
||||
SEND msgBody -> "SEND " <> msgBody
|
||||
PING -> "PING"
|
||||
NSUB -> "NSUB"
|
||||
MSG msgId ts msgBody ->
|
||||
B.unwords ["MSG", encode msgId, B.pack $ formatISO8601Millis ts, msgBody]
|
||||
IDS (QIK rcvId sndId srvDh) ->
|
||||
B.unwords ["IDS", encode rcvId, encode sndId, C.serializePubKey' srvDh]
|
||||
NID nId -> "NID " <> encode nId
|
||||
ERR err -> "ERR " <> serializeErrorType err
|
||||
NMSG -> "NMSG"
|
||||
END -> "END"
|
||||
OK -> "OK"
|
||||
PONG -> "PONG"
|
||||
checkCredentials (_, _, queueId, _) cmd = case cmd of
|
||||
-- IDS response must not have queue ID
|
||||
IDS _ -> Right cmd
|
||||
-- ERR response does not always have queue ID
|
||||
ERR _ -> Right cmd
|
||||
-- PONG response must not have queue ID
|
||||
PONG
|
||||
| B.null queueId -> Right cmd
|
||||
| otherwise -> Left $ CMD HAS_AUTH
|
||||
-- other broker responses must have queue ID
|
||||
_
|
||||
| B.null queueId -> Left $ CMD NO_QUEUE
|
||||
| otherwise -> Right cmd
|
||||
|
||||
-- | SMP error parser.
|
||||
errorTypeP :: Parser ErrorType
|
||||
errorTypeP = "CMD " *> (CMD <$> parseRead1) <|> parseRead1
|
||||
_smpP :: Encoding a => Parser a
|
||||
_smpP = A.space *> smpP
|
||||
|
||||
-- | Serialize SMP error.
|
||||
serializeErrorType :: ErrorType -> ByteString
|
||||
serializeErrorType = bshow
|
||||
-- | Parse SMP protocol commands and broker messages
|
||||
parseProtocol :: (Protocol msg, ProtocolMsgTag (Tag msg)) => ByteString -> Either ErrorType msg
|
||||
parseProtocol s =
|
||||
let (tag, params) = B.break (== ' ') s
|
||||
in case decodeTag tag of
|
||||
Just cmd -> parse (protocolP cmd) (CMD SYNTAX) params
|
||||
Nothing -> Left $ CMD UNKNOWN
|
||||
|
||||
checkParty :: forall t p p'. (PartyI p, PartyI p') => t p' -> Either String (t p)
|
||||
checkParty c = case testEquality (sParty @p) (sParty @p') of
|
||||
Just Refl -> Right c
|
||||
Nothing -> Left "bad command party"
|
||||
|
||||
checkParty' :: forall t p p'. (PartyI p, PartyI p') => t p' -> Maybe (t p)
|
||||
checkParty' c = case testEquality (sParty @p) (sParty @p') of
|
||||
Just Refl -> Just c
|
||||
_ -> Nothing
|
||||
|
||||
instance Encoding ErrorType where
|
||||
smpEncode = \case
|
||||
BLOCK -> "BLOCK"
|
||||
SESSION -> "SESSION"
|
||||
CMD err -> "CMD " <> smpEncode err
|
||||
AUTH -> "AUTH"
|
||||
QUOTA -> "QUOTA"
|
||||
NO_MSG -> "NO_MSG"
|
||||
LARGE_MSG -> "LARGE_MSG"
|
||||
INTERNAL -> "INTERNAL"
|
||||
DUPLICATE_ -> "DUPLICATE_"
|
||||
|
||||
smpP =
|
||||
A.takeTill (== ' ') >>= \case
|
||||
"BLOCK" -> pure BLOCK
|
||||
"SESSION" -> pure SESSION
|
||||
"CMD" -> CMD <$> _smpP
|
||||
"AUTH" -> pure AUTH
|
||||
"QUOTA" -> pure QUOTA
|
||||
"NO_MSG" -> pure NO_MSG
|
||||
"LARGE_MSG" -> pure LARGE_MSG
|
||||
"INTERNAL" -> pure INTERNAL
|
||||
"DUPLICATE_" -> pure DUPLICATE_
|
||||
_ -> fail "bad error type"
|
||||
|
||||
instance Encoding CommandError where
|
||||
smpEncode e = case e of
|
||||
UNKNOWN -> "UNKNOWN"
|
||||
SYNTAX -> "SYNTAX"
|
||||
NO_AUTH -> "NO_AUTH"
|
||||
HAS_AUTH -> "HAS_AUTH"
|
||||
NO_QUEUE -> "NO_QUEUE"
|
||||
smpP =
|
||||
A.takeTill (== ' ') >>= \case
|
||||
"UNKNOWN" -> pure UNKNOWN
|
||||
"SYNTAX" -> pure SYNTAX
|
||||
"NO_AUTH" -> pure NO_AUTH
|
||||
"HAS_AUTH" -> pure HAS_AUTH
|
||||
"NO_QUEUE" -> pure NO_QUEUE
|
||||
_ -> fail "bad command error type"
|
||||
|
||||
-- | Send signed SMP transmission to TCP transport.
|
||||
tPut :: Transport c => THandle c -> SentRawTransmission -> IO (Either TransportError ())
|
||||
tPut th (sig, t) = tPutBlock th $ C.serializeSignature sig <> " " <> t
|
||||
tPut th (sig, t) = tPutBlock th $ smpEncode (C.signatureBytes sig) <> t
|
||||
|
||||
serializeTransmission :: CommandI c => ByteString -> Transmission c -> ByteString
|
||||
serializeTransmission sessionId (CorrId corrId, queueId, command) =
|
||||
B.unwords [sessionId, corrId, encode queueId, serializeCommand command]
|
||||
|
||||
-- | Validate that it is an SMP client command, used with 'tGet' by 'Simplex.Messaging.Server'.
|
||||
fromClient :: Cmd -> Either ErrorType ClientCmd
|
||||
fromClient (Cmd p cmd) = case isClient p of
|
||||
Just Dict -> Right $ ClientCmd p cmd
|
||||
Nothing -> Left $ CMD PROHIBITED
|
||||
|
||||
-- | Validate that it is an SMP server command, used with 'tGet' by 'Simplex.Messaging.Client'.
|
||||
fromServer :: Cmd -> Either ErrorType (Command Broker)
|
||||
fromServer = \case
|
||||
Cmd SBroker cmd -> Right cmd
|
||||
_ -> Left $ CMD PROHIBITED
|
||||
encodeTransmission :: Protocol c => ByteString -> Transmission c -> ByteString
|
||||
encodeTransmission sessionId (CorrId corrId, queueId, command) =
|
||||
smpEncode (sessionId, corrId, queueId) <> encodeProtocol command
|
||||
|
||||
-- | Receive and parse transmission from the TCP transport (ignoring any trailing padding).
|
||||
tGetParse :: Transport c => THandle c -> IO (Either TransportError RawTransmission)
|
||||
tGetParse th = (parseTransmission =<<) <$> tGetBlock th
|
||||
where
|
||||
parseTransmission = first (const TEBadBlock) . A.parseOnly transmissionP
|
||||
tGetParse th = (parse transmissionP TEBadBlock =<<) <$> tGetBlock th
|
||||
|
||||
-- | Receive client and server transmissions.
|
||||
--
|
||||
-- The first argument is used to limit allowed senders.
|
||||
-- 'fromClient' or 'fromServer' should be used here.
|
||||
tGet :: forall c m cmd. (Transport c, MonadIO m) => (Cmd -> Either ErrorType cmd) -> THandle c -> m (SignedTransmission cmd)
|
||||
tGet fromParty th@THandle {sessionId} = liftIO (tGetParse th) >>= decodeParseValidate
|
||||
-- | Receive client and server transmissions (determined by `cmd` type).
|
||||
tGet ::
|
||||
forall cmd c m.
|
||||
(Protocol cmd, ProtocolMsgTag (Tag cmd), Transport c, MonadIO m) =>
|
||||
THandle c ->
|
||||
m (SignedTransmission cmd)
|
||||
tGet th@THandle {sessionId} = liftIO (tGetParse th) >>= decodeParseValidate
|
||||
where
|
||||
decodeParseValidate :: Either TransportError RawTransmission -> m (SignedTransmission cmd)
|
||||
decodeParseValidate = \case
|
||||
Right RawTransmission {signature, signed, sessId, corrId, queueId, command}
|
||||
| sessId == sessionId ->
|
||||
let decodedTransmission = liftM2 (,corrId,,command) (C.decodeSignature =<< decode signature) (decode queueId)
|
||||
let decodedTransmission = (,corrId,queueId,command) <$> C.decodeSignature signature
|
||||
in either (const $ tError corrId) (tParseValidate signed) decodedTransmission
|
||||
| otherwise -> pure (Nothing, "", (CorrId corrId, "", Left SESSION))
|
||||
Left _ -> tError ""
|
||||
@@ -542,37 +660,5 @@ tGet fromParty th@THandle {sessionId} = liftIO (tGetParse th) >>= decodeParseVal
|
||||
|
||||
tParseValidate :: ByteString -> SignedRawTransmission -> m (SignedTransmission cmd)
|
||||
tParseValidate signed t@(sig, corrId, queueId, command) = do
|
||||
let cmd = parseCommand command >>= tCredentials t >>= fromParty
|
||||
return (sig, signed, (CorrId corrId, queueId, cmd))
|
||||
|
||||
tCredentials :: SignedRawTransmission -> Cmd -> Either ErrorType Cmd
|
||||
tCredentials (sig, _, queueId, _) cmd = case cmd of
|
||||
-- IDS response must not have queue ID
|
||||
Cmd SBroker (IDS _) -> Right cmd
|
||||
-- ERR response does not always have queue ID
|
||||
Cmd SBroker (ERR _) -> Right cmd
|
||||
-- PONG response must not have queue ID
|
||||
Cmd SBroker PONG
|
||||
| B.null queueId -> Right cmd
|
||||
| otherwise -> Left $ CMD HAS_AUTH
|
||||
-- other responses must have queue ID
|
||||
Cmd SBroker _
|
||||
| B.null queueId -> Left $ CMD NO_QUEUE
|
||||
| otherwise -> Right cmd
|
||||
-- NEW must have signature but NOT queue ID
|
||||
Cmd SRecipient NEW {}
|
||||
| isNothing sig -> Left $ CMD NO_AUTH
|
||||
| not (B.null queueId) -> Left $ CMD HAS_AUTH
|
||||
| otherwise -> Right cmd
|
||||
-- SEND must have queue ID, signature is not always required
|
||||
Cmd SSender (SEND _)
|
||||
| B.null queueId -> Left $ CMD NO_QUEUE
|
||||
| otherwise -> Right cmd
|
||||
-- PING must not have queue ID or signature
|
||||
Cmd SSender PING
|
||||
| isNothing sig && B.null queueId -> Right cmd
|
||||
| otherwise -> Left $ CMD HAS_AUTH
|
||||
-- other client commands must have both signature and queue ID
|
||||
Cmd _ _
|
||||
| isNothing sig || B.null queueId -> Left $ CMD NO_AUTH
|
||||
| otherwise -> Right cmd
|
||||
let cmd = parseProtocol command >>= checkCredentials t
|
||||
pure (sig, signed, (CorrId corrId, queueId, cmd))
|
||||
|
||||
@@ -36,7 +36,7 @@ import qualified Data.ByteString.Char8 as B
|
||||
import Data.Functor (($>))
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Maybe (isNothing)
|
||||
import Data.Time.Clock
|
||||
import Data.Time.Clock.System (getSystemTime)
|
||||
import Data.Type.Equality
|
||||
import Network.Socket (ServiceName)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
@@ -154,7 +154,7 @@ cancelSub = \case
|
||||
|
||||
receive :: (Transport c, MonadUnliftIO m, MonadReader Env m) => THandle c -> Client -> m ()
|
||||
receive th Client {rcvQ, sndQ} = forever $ do
|
||||
(sig, signed, (corrId, queueId, cmdOrError)) <- tGet fromClient th
|
||||
(sig, signed, (corrId, queueId, cmdOrError)) <- tGet th
|
||||
case cmdOrError of
|
||||
Left e -> write sndQ (corrId, queueId, ERR e)
|
||||
Right cmd -> do
|
||||
@@ -168,19 +168,19 @@ receive th Client {rcvQ, sndQ} = forever $ do
|
||||
send :: (Transport c, MonadUnliftIO m) => THandle c -> Client -> m ()
|
||||
send h Client {sndQ, sessionId} = forever $ do
|
||||
t <- atomically $ readTBQueue sndQ
|
||||
liftIO $ tPut h (Nothing, serializeTransmission sessionId t)
|
||||
liftIO $ tPut h (Nothing, encodeTransmission sessionId t)
|
||||
|
||||
verifyTransmission ::
|
||||
forall m. (MonadUnliftIO m, MonadReader Env m) => Maybe C.ASignature -> ByteString -> QueueId -> ClientCmd -> m Bool
|
||||
forall m. (MonadUnliftIO m, MonadReader Env m) => Maybe C.ASignature -> ByteString -> QueueId -> Cmd -> m Bool
|
||||
verifyTransmission sig_ signed queueId cmd = do
|
||||
case cmd of
|
||||
ClientCmd SRecipient (NEW k _) -> pure $ verifySignature k
|
||||
ClientCmd SRecipient _ -> verifyCmd (CP SRecipient) $ verifySignature . recipientKey
|
||||
ClientCmd SSender (SEND _) -> verifyCmd (CP SSender) $ verifyMaybe . senderKey
|
||||
ClientCmd SSender PING -> pure True
|
||||
ClientCmd SNotifier NSUB -> verifyCmd (CP SNotifier) $ verifyMaybe . fmap snd . notifier
|
||||
Cmd SRecipient (NEW k _) -> pure $ verifySignature k
|
||||
Cmd SRecipient _ -> verifyCmd SRecipient $ verifySignature . recipientKey
|
||||
Cmd SSender (SEND _) -> verifyCmd SSender $ verifyMaybe . senderKey
|
||||
Cmd SSender PING -> pure True
|
||||
Cmd SNotifier NSUB -> verifyCmd SNotifier $ verifyMaybe . fmap snd . notifier
|
||||
where
|
||||
verifyCmd :: ClientParty -> (QueueRec -> Bool) -> m Bool
|
||||
verifyCmd :: SParty p -> (QueueRec -> Bool) -> m Bool
|
||||
verifyCmd party f = do
|
||||
st <- asks queueStore
|
||||
q <- atomically $ getQueue st party queueId
|
||||
@@ -217,16 +217,16 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscri
|
||||
>>= processCommand
|
||||
>>= atomically . writeTBQueue sndQ
|
||||
where
|
||||
processCommand :: Transmission ClientCmd -> m BrokerTransmission
|
||||
processCommand :: Transmission Cmd -> m (Transmission BrokerMsg)
|
||||
processCommand (corrId, queueId, cmd) = do
|
||||
st <- asks queueStore
|
||||
case cmd of
|
||||
ClientCmd SSender command ->
|
||||
Cmd SSender command ->
|
||||
case command of
|
||||
SEND msgBody -> sendMessage st msgBody
|
||||
PING -> pure (corrId, "", PONG)
|
||||
ClientCmd SNotifier NSUB -> subscribeNotifications
|
||||
ClientCmd SRecipient command ->
|
||||
Cmd SNotifier NSUB -> subscribeNotifications
|
||||
Cmd SRecipient command ->
|
||||
case command of
|
||||
NEW rKey dhKey -> createQueue st rKey dhKey
|
||||
SUB -> subscribeQueue queueId
|
||||
@@ -236,7 +236,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscri
|
||||
OFF -> suspendQueue_ st
|
||||
DEL -> delQueueAndMsgs st
|
||||
where
|
||||
createQueue :: QueueStore -> RcvPublicVerifyKey -> RcvPublicDhKey -> m BrokerTransmission
|
||||
createQueue :: QueueStore -> RcvPublicVerifyKey -> RcvPublicDhKey -> m (Transmission BrokerMsg)
|
||||
createQueue st recipientKey dhKey = do
|
||||
(rcvPublicDhKey, privDhKey) <- liftIO C.generateKeyPair'
|
||||
let rcvDhSecret = C.dh' dhKey privDhKey
|
||||
@@ -254,7 +254,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscri
|
||||
(corrId,queueId,) <$> addQueueRetry 3 qik qRec
|
||||
where
|
||||
addQueueRetry ::
|
||||
Int -> ((RecipientId, SenderId) -> QueueIdsKeys) -> ((RecipientId, SenderId) -> QueueRec) -> m (Command 'Broker)
|
||||
Int -> ((RecipientId, SenderId) -> QueueIdsKeys) -> ((RecipientId, SenderId) -> QueueRec) -> m BrokerMsg
|
||||
addQueueRetry 0 _ _ = pure $ ERR INTERNAL
|
||||
addQueueRetry n qik qRec = do
|
||||
ids@(rId, _) <- getIds
|
||||
@@ -268,7 +268,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscri
|
||||
|
||||
logCreateById :: StoreLog 'WriteMode -> RecipientId -> IO ()
|
||||
logCreateById s rId =
|
||||
atomically (getQueue st (CP SRecipient) rId) >>= \case
|
||||
atomically (getQueue st SRecipient rId) >>= \case
|
||||
Right q -> logCreateQueue s q
|
||||
_ -> pure ()
|
||||
|
||||
@@ -277,15 +277,15 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscri
|
||||
n <- asks $ queueIdBytes . config
|
||||
liftM2 (,) (randomId n) (randomId n)
|
||||
|
||||
secureQueue_ :: QueueStore -> SndPublicVerifyKey -> m BrokerTransmission
|
||||
secureQueue_ :: QueueStore -> SndPublicVerifyKey -> m (Transmission BrokerMsg)
|
||||
secureQueue_ st sKey = do
|
||||
withLog $ \s -> logSecureQueue s queueId sKey
|
||||
atomically $ (corrId,queueId,) . either ERR (const OK) <$> secureQueue st queueId sKey
|
||||
|
||||
addQueueNotifier_ :: QueueStore -> NtfPublicVerifyKey -> m BrokerTransmission
|
||||
addQueueNotifier_ :: QueueStore -> NtfPublicVerifyKey -> m (Transmission BrokerMsg)
|
||||
addQueueNotifier_ st nKey = (corrId,queueId,) <$> addNotifierRetry 3
|
||||
where
|
||||
addNotifierRetry :: Int -> m (Command 'Broker)
|
||||
addNotifierRetry :: Int -> m BrokerMsg
|
||||
addNotifierRetry 0 = pure $ ERR INTERNAL
|
||||
addNotifierRetry n = do
|
||||
nId <- randomId =<< asks (queueIdBytes . config)
|
||||
@@ -296,12 +296,12 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscri
|
||||
withLog $ \s -> logAddNotifier s queueId nId nKey
|
||||
pure $ NID nId
|
||||
|
||||
suspendQueue_ :: QueueStore -> m BrokerTransmission
|
||||
suspendQueue_ :: QueueStore -> m (Transmission BrokerMsg)
|
||||
suspendQueue_ st = do
|
||||
withLog (`logDeleteQueue` queueId)
|
||||
okResp <$> atomically (suspendQueue st queueId)
|
||||
|
||||
subscribeQueue :: RecipientId -> m BrokerTransmission
|
||||
subscribeQueue :: RecipientId -> m (Transmission BrokerMsg)
|
||||
subscribeQueue rId =
|
||||
atomically (getSubscription rId) >>= deliverMessage tryPeekMsg rId
|
||||
|
||||
@@ -316,7 +316,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscri
|
||||
writeTVar subscriptions $ M.insert rId s subs
|
||||
return s
|
||||
|
||||
subscribeNotifications :: m BrokerTransmission
|
||||
subscribeNotifications :: m (Transmission BrokerMsg)
|
||||
subscribeNotifications = atomically $ do
|
||||
subs <- readTVar ntfSubscriptions
|
||||
when (isNothing $ M.lookup queueId subs) $ do
|
||||
@@ -324,7 +324,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscri
|
||||
writeTVar ntfSubscriptions $ M.insert queueId () subs
|
||||
pure ok
|
||||
|
||||
acknowledgeMsg :: m BrokerTransmission
|
||||
acknowledgeMsg :: m (Transmission BrokerMsg)
|
||||
acknowledgeMsg =
|
||||
atomically (withSub queueId $ \s -> const s <$$> tryTakeTMVar (delivered s))
|
||||
>>= \case
|
||||
@@ -334,14 +334,14 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscri
|
||||
withSub :: RecipientId -> (Sub -> STM a) -> STM (Maybe a)
|
||||
withSub rId f = readTVar subscriptions >>= mapM f . M.lookup rId
|
||||
|
||||
sendMessage :: QueueStore -> MsgBody -> m BrokerTransmission
|
||||
sendMessage :: QueueStore -> MsgBody -> m (Transmission BrokerMsg)
|
||||
sendMessage st msgBody
|
||||
| B.length msgBody > maxMessageLength = pure $ err LARGE_MSG
|
||||
| otherwise = do
|
||||
qr <- atomically $ getQueue st (CP SSender) queueId
|
||||
qr <- atomically $ getQueue st SSender queueId
|
||||
either (return . err) storeMessage qr
|
||||
where
|
||||
storeMessage :: QueueRec -> m BrokerTransmission
|
||||
storeMessage :: QueueRec -> m (Transmission BrokerMsg)
|
||||
storeMessage qr = case status qr of
|
||||
QueueOff -> return $ err AUTH
|
||||
QueueActive ->
|
||||
@@ -360,7 +360,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscri
|
||||
mkMessage :: m (Either C.CryptoError Message)
|
||||
mkMessage = do
|
||||
msgId <- randomId =<< asks (msgIdBytes . config)
|
||||
ts <- liftIO getCurrentTime
|
||||
ts <- liftIO getSystemTime
|
||||
let c = C.cbEncrypt (rcvDhSecret qr) (C.cbNonce msgId) msgBody (maxMessageLength + 2)
|
||||
pure $ Message msgId ts <$> c
|
||||
|
||||
@@ -374,7 +374,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscri
|
||||
unlessM (isFullTBQueue sndQ) $
|
||||
writeTBQueue q (CorrId "", nId, NMSG)
|
||||
|
||||
deliverMessage :: (MsgQueue -> STM (Maybe Message)) -> RecipientId -> Sub -> m BrokerTransmission
|
||||
deliverMessage :: (MsgQueue -> STM (Maybe Message)) -> RecipientId -> Sub -> m (Transmission BrokerMsg)
|
||||
deliverMessage tryPeek rId = \case
|
||||
Sub {subThread = NoSub} -> do
|
||||
ms <- asks msgStore
|
||||
@@ -406,10 +406,10 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscri
|
||||
setDelivered :: STM (Maybe Bool)
|
||||
setDelivered = withSub rId $ \s -> tryPutTMVar (delivered s) ()
|
||||
|
||||
msgCmd :: Message -> Command 'Broker
|
||||
msgCmd :: Message -> BrokerMsg
|
||||
msgCmd Message {msgId, ts, msgBody} = MSG msgId ts msgBody
|
||||
|
||||
delQueueAndMsgs :: QueueStore -> m BrokerTransmission
|
||||
delQueueAndMsgs :: QueueStore -> m (Transmission BrokerMsg)
|
||||
delQueueAndMsgs st = do
|
||||
withLog (`logDeleteQueue` queueId)
|
||||
ms <- asks msgStore
|
||||
@@ -418,13 +418,13 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscri
|
||||
Left e -> pure $ err e
|
||||
Right _ -> delMsgQueue ms queueId $> ok
|
||||
|
||||
ok :: BrokerTransmission
|
||||
ok :: Transmission BrokerMsg
|
||||
ok = (corrId, queueId, OK)
|
||||
|
||||
err :: ErrorType -> BrokerTransmission
|
||||
err :: ErrorType -> Transmission BrokerMsg
|
||||
err e = (corrId, queueId, ERR e)
|
||||
|
||||
okResp :: Either ErrorType () -> BrokerTransmission
|
||||
okResp :: Either ErrorType () -> Transmission BrokerMsg
|
||||
okResp = either err $ const ok
|
||||
|
||||
withLog :: (MonadUnliftIO m, MonadReader Env m) => (StoreLog 'WriteMode -> IO a) -> m ()
|
||||
@@ -432,7 +432,7 @@ withLog action = do
|
||||
env <- ask
|
||||
liftIO . mapM_ action $ storeLog (env :: Env)
|
||||
|
||||
randomId :: (MonadUnliftIO m, MonadReader Env m) => Int -> m Encoded
|
||||
randomId :: (MonadUnliftIO m, MonadReader Env m) => Int -> m ByteString
|
||||
randomId n = do
|
||||
gVar <- asks idsDrg
|
||||
atomically (randomBytes n gVar)
|
||||
|
||||
@@ -56,8 +56,8 @@ data Server = Server
|
||||
data Client = Client
|
||||
{ subscriptions :: TVar (Map RecipientId Sub),
|
||||
ntfSubscriptions :: TVar (Map NotifierId ()),
|
||||
rcvQ :: TBQueue (Transmission ClientCmd),
|
||||
sndQ :: TBQueue BrokerTransmission,
|
||||
rcvQ :: TBQueue (Transmission Cmd),
|
||||
sndQ :: TBQueue (Transmission BrokerMsg),
|
||||
sessionId :: ByteString,
|
||||
connected :: TVar Bool
|
||||
}
|
||||
|
||||
@@ -2,13 +2,13 @@
|
||||
|
||||
module Simplex.Messaging.Server.MsgStore where
|
||||
|
||||
import Data.Time.Clock
|
||||
import Data.Time.Clock.System (SystemTime)
|
||||
import Numeric.Natural
|
||||
import Simplex.Messaging.Protocol (Encoded, MsgBody, RecipientId)
|
||||
import Simplex.Messaging.Protocol (MsgBody, MsgId, RecipientId)
|
||||
|
||||
data Message = Message
|
||||
{ msgId :: Encoded,
|
||||
ts :: UTCTime,
|
||||
{ msgId :: MsgId,
|
||||
ts :: SystemTime,
|
||||
msgBody :: MsgBody
|
||||
}
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ data QueueStatus = QueueActive | QueueOff deriving (Eq, Show)
|
||||
|
||||
class MonadQueueStore s m where
|
||||
addQueue :: s -> QueueRec -> m (Either ErrorType ())
|
||||
getQueue :: s -> ClientParty -> QueueId -> m (Either ErrorType QueueRec)
|
||||
getQueue :: s -> SParty p -> QueueId -> m (Either ErrorType QueueRec)
|
||||
secureQueue :: s -> RecipientId -> SndPublicVerifyKey -> m (Either ErrorType QueueRec)
|
||||
addQueueNotifier :: s -> RecipientId -> NotifierId -> NtfPublicVerifyKey -> m (Either ErrorType QueueRec)
|
||||
suspendQueue :: s -> RecipientId -> m (Either ErrorType ())
|
||||
|
||||
@@ -42,8 +42,8 @@ instance MonadQueueStore QueueStore STM where
|
||||
}
|
||||
return $ Right ()
|
||||
|
||||
getQueue :: QueueStore -> ClientParty -> QueueId -> STM (Either ErrorType QueueRec)
|
||||
getQueue st (CP party) qId = do
|
||||
getQueue :: QueueStore -> SParty p -> QueueId -> STM (Either ErrorType QueueRec)
|
||||
getQueue st party qId = do
|
||||
cs <- readTVar st
|
||||
pure $ case party of
|
||||
SRecipient -> getRcpQueue cs qId
|
||||
|
||||
@@ -27,6 +27,8 @@
|
||||
module Simplex.Messaging.Transport
|
||||
( -- * SMP transport parameters
|
||||
smpBlockSize,
|
||||
supportedSMPVersions,
|
||||
simplexMQVersion,
|
||||
|
||||
-- * Transport connection class
|
||||
Transport (..),
|
||||
@@ -55,7 +57,6 @@ module Simplex.Messaging.Transport
|
||||
tGetBlock,
|
||||
serializeTransportError,
|
||||
transportErrorP,
|
||||
currentSMPVersionStr,
|
||||
|
||||
-- * Trim trailing CR
|
||||
trimCR,
|
||||
@@ -68,7 +69,6 @@ import Control.Monad.IO.Unlift
|
||||
import Control.Monad.Trans.Except (throwE)
|
||||
import qualified Crypto.Store.X509 as SX
|
||||
import Data.Attoparsec.ByteString.Char8 (Parser)
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Bifunctor (first)
|
||||
import Data.Bitraversable (bimapM)
|
||||
import Data.ByteString.Base64
|
||||
@@ -79,7 +79,7 @@ import Data.Default (def)
|
||||
import Data.Functor (($>))
|
||||
import Data.Set (Set)
|
||||
import qualified Data.Set as S
|
||||
import Data.String
|
||||
import Data.Word (Word16)
|
||||
import qualified Data.X509 as X
|
||||
import qualified Data.X509.CertificateStore as XS
|
||||
import qualified Data.X509.Validation as XV
|
||||
@@ -91,8 +91,10 @@ import Network.Socket
|
||||
import qualified Network.TLS as T
|
||||
import qualified Network.TLS.Extra as TE
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Parsers (parseAll, parseRead1, parseString)
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Parsers (parse, parseRead1)
|
||||
import Simplex.Messaging.Util (bshow)
|
||||
import Simplex.Messaging.Version
|
||||
import System.Exit (exitFailure)
|
||||
import System.IO.Error
|
||||
import Test.QuickCheck (Arbitrary (..))
|
||||
@@ -101,9 +103,17 @@ import UnliftIO.Exception (Exception, IOException)
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.STM
|
||||
|
||||
-- * Transport parameters
|
||||
|
||||
smpBlockSize :: Int
|
||||
smpBlockSize = 16384
|
||||
|
||||
supportedSMPVersions :: VersionRange
|
||||
supportedSMPVersions = mkVersionRange 1 1
|
||||
|
||||
simplexMQVersion :: String
|
||||
simplexMQVersion = "0.5.1"
|
||||
|
||||
-- * Transport connection class
|
||||
|
||||
class Transport c where
|
||||
@@ -379,43 +389,32 @@ trimCR s = if B.last s == '\r' then B.init s else s
|
||||
|
||||
-- * SMP transport
|
||||
|
||||
data SMPVersion = SMPVersion Int Int Int
|
||||
deriving (Eq, Ord)
|
||||
|
||||
instance IsString SMPVersion where
|
||||
fromString = parseString $ parseAll smpVersionP
|
||||
|
||||
currentSMPVersion :: SMPVersion
|
||||
currentSMPVersion = "0.5.1"
|
||||
|
||||
currentSMPVersionStr :: ByteString
|
||||
currentSMPVersionStr = serializeSMPVersion currentSMPVersion
|
||||
|
||||
serializeSMPVersion :: SMPVersion -> ByteString
|
||||
serializeSMPVersion (SMPVersion a b c) = B.intercalate "." [bshow a, bshow b, bshow c]
|
||||
|
||||
smpVersionP :: Parser SMPVersion
|
||||
smpVersionP =
|
||||
let ver = A.decimal <* A.char '.'
|
||||
in SMPVersion <$> ver <*> ver <*> A.decimal
|
||||
|
||||
-- | The handle for SMP encrypted transport connection over Transport .
|
||||
data THandle c = THandle
|
||||
{ connection :: c,
|
||||
sessionId :: ByteString,
|
||||
-- | agreed SMP server protocol version
|
||||
smpVersion :: Word16
|
||||
}
|
||||
|
||||
data ServerHandshake = ServerHandshake
|
||||
{ smpVersionRange :: VersionRange,
|
||||
sessionId :: ByteString
|
||||
}
|
||||
|
||||
data Handshake = Handshake
|
||||
{ sessionId :: ByteString,
|
||||
smpVersion :: SMPVersion
|
||||
newtype ClientHandshake = ClientHandshake
|
||||
{ -- | agreed SMP server protocol version
|
||||
smpVersion :: Word16
|
||||
}
|
||||
|
||||
serializeHandshake :: Handshake -> ByteString
|
||||
serializeHandshake Handshake {sessionId, smpVersion} =
|
||||
sessionId <> " " <> serializeSMPVersion smpVersion <> " "
|
||||
instance Encoding ClientHandshake where
|
||||
smpEncode ClientHandshake {smpVersion} = smpEncode smpVersion
|
||||
smpP = ClientHandshake <$> smpP
|
||||
|
||||
handshakeP :: Parser Handshake
|
||||
handshakeP = Handshake <$> A.takeWhile (/= ' ') <* A.space <*> smpVersionP <* A.space
|
||||
instance Encoding ServerHandshake where
|
||||
smpEncode ServerHandshake {smpVersionRange, sessionId} =
|
||||
smpEncode (smpVersionRange, sessionId)
|
||||
smpP = ServerHandshake <$> smpP <*> smpP
|
||||
|
||||
-- | Error of SMP encrypted transport over TCP.
|
||||
data TransportError
|
||||
@@ -473,12 +472,14 @@ tGetBlock THandle {connection = c} =
|
||||
-- | Server SMP transport handshake.
|
||||
--
|
||||
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a
|
||||
serverHandshake :: Transport c => c -> ExceptT TransportError IO (THandle c)
|
||||
serverHandshake :: forall c. Transport c => c -> ExceptT TransportError IO (THandle c)
|
||||
serverHandshake c = do
|
||||
let th@THandle {sessionId} = tHandle c
|
||||
_ <- getPeerHello th
|
||||
sendHelloToPeer th sessionId
|
||||
pure th
|
||||
sendHandshake th $ ServerHandshake {sessionId, smpVersionRange = supportedSMPVersions}
|
||||
ClientHandshake smpVersion <- getHandshake th
|
||||
if smpVersion `isCompatible` supportedSMPVersions
|
||||
then pure (th :: THandle c) {smpVersion}
|
||||
else throwE $ TEHandshake VERSION
|
||||
|
||||
-- | Client SMP transport handshake.
|
||||
--
|
||||
@@ -486,23 +487,21 @@ serverHandshake c = do
|
||||
clientHandshake :: forall c. Transport c => c -> ExceptT TransportError IO (THandle c)
|
||||
clientHandshake c = do
|
||||
let th@THandle {sessionId} = tHandle c
|
||||
sendHelloToPeer th ""
|
||||
Handshake {sessionId = sessId} <- getPeerHello th
|
||||
ServerHandshake {sessionId = sessId, smpVersionRange} <- getHandshake th
|
||||
if sessionId == sessId
|
||||
then pure th
|
||||
then case smpVersionRange `compatibleVersion` supportedSMPVersions of
|
||||
Just smpVersion -> do
|
||||
sendHandshake th $ ClientHandshake smpVersion
|
||||
pure (th :: THandle c) {smpVersion}
|
||||
Nothing -> throwE $ TEHandshake VERSION
|
||||
else throwE TEBadSession
|
||||
|
||||
sendHelloToPeer :: Transport c => THandle c -> ByteString -> ExceptT TransportError IO ()
|
||||
sendHelloToPeer th sessionId =
|
||||
let handshake = Handshake {sessionId, smpVersion = currentSMPVersion}
|
||||
in ExceptT . tPutBlock th $ serializeHandshake handshake
|
||||
sendHandshake :: (Transport c, Encoding smp) => THandle c -> smp -> ExceptT TransportError IO ()
|
||||
sendHandshake th = ExceptT . tPutBlock th . smpEncode
|
||||
|
||||
getPeerHello :: Transport c => THandle c -> ExceptT TransportError IO Handshake
|
||||
getPeerHello th = ExceptT $ (parseHandshake =<<) <$> tGetBlock th
|
||||
where
|
||||
parseHandshake :: ByteString -> Either TransportError Handshake
|
||||
parseHandshake = first (const $ TEHandshake PARSE) . A.parseOnly handshakeP
|
||||
getHandshake :: (Transport c, Encoding smp) => THandle c -> ExceptT TransportError IO smp
|
||||
getHandshake th = ExceptT $ (parse smpP (TEHandshake PARSE) =<<) <$> tGetBlock th
|
||||
|
||||
tHandle :: Transport c => c -> THandle c
|
||||
tHandle c =
|
||||
THandle {connection = c, sessionId = encode $ tlsUnique c}
|
||||
THandle {connection = c, sessionId = tlsUnique c, smpVersion = 0}
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
|
||||
module Simplex.Messaging.Version
|
||||
( VersionRange (minVersion, maxVersion),
|
||||
pattern VersionRange,
|
||||
mkVersionRange,
|
||||
versionRange,
|
||||
compatibleVersion,
|
||||
isCompatible,
|
||||
)
|
||||
where
|
||||
|
||||
import Data.Word (Word16)
|
||||
import Simplex.Messaging.Encoding
|
||||
|
||||
pattern VersionRange :: Word16 -> Word16 -> VersionRange
|
||||
pattern VersionRange v1 v2 <- VRange v1 v2
|
||||
|
||||
{-# COMPLETE VersionRange #-}
|
||||
|
||||
data VersionRange = VRange
|
||||
{ minVersion :: Word16,
|
||||
maxVersion :: Word16
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
|
||||
-- | construct valid version range, to be used in constants
|
||||
mkVersionRange :: Word16 -> Word16 -> VersionRange
|
||||
mkVersionRange v1 v2
|
||||
| v1 <= v2 = VRange v1 v2
|
||||
| otherwise = error "invalid version range"
|
||||
|
||||
versionRange :: Word16 -> Word16 -> Maybe VersionRange
|
||||
versionRange v1 v2
|
||||
| v1 <= v2 = Just $ VRange v1 v2
|
||||
| otherwise = Nothing
|
||||
|
||||
instance Encoding VersionRange where
|
||||
smpEncode (VRange v1 v2) = smpEncode (v1, v2)
|
||||
smpP =
|
||||
maybe (fail "invalid version range") pure
|
||||
=<< versionRange <$> smpP <*> smpP
|
||||
|
||||
compatibleVersion :: VersionRange -> VersionRange -> Maybe Word16
|
||||
compatibleVersion (VersionRange min1 max1) (VersionRange min2 max2)
|
||||
| min1 <= max2 && min2 <= max1 = Just $ min max1 max2
|
||||
| otherwise = Nothing
|
||||
|
||||
isCompatible :: Word16 -> VersionRange -> Bool
|
||||
isCompatible v (VersionRange v1 v2) = v1 <= v && v <= v2
|
||||
Reference in New Issue
Block a user