From cf3d0dfdc376ba3eb182b6add10cb54174998695 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Wed, 15 Dec 2021 08:06:34 +0000 Subject: [PATCH] Transaction fields for size, session IDs, refactor (#222) * add SMP session IDs/tls-unique to transmission * refactor SMP transmissions: precise transmission types in server & client * use correct session IDs * remove TSession --- protocol/simplex-messaging.md | 3 +- src/Simplex/Messaging/Client.hs | 73 +++--- src/Simplex/Messaging/Protocol.hs | 310 ++++++++++++++---------- src/Simplex/Messaging/Server.hs | 138 +++++------ src/Simplex/Messaging/Server/Env/STM.hs | 13 +- src/Simplex/Messaging/Transport.hs | 14 ++ tests/SMPClient.hs | 8 +- tests/ServerTests.hs | 62 ++--- tests/Test.hs | 2 +- 9 files changed, 351 insertions(+), 272 deletions(-) diff --git a/protocol/simplex-messaging.md b/protocol/simplex-messaging.md index fb553964a..d4ea09208 100644 --- a/protocol/simplex-messaging.md +++ b/protocol/simplex-messaging.md @@ -353,7 +353,8 @@ Commands syntax below is provided using [ABNF][8] with [case-sensitive strings e Each transmission between the client and the server must have this format/syntax (after the decryption): ```abnf -transmission = [signature] SP signed SP pad ; pad to the fixed block size +transmission = [signature] SP signedSize SP signed SP pad ; pad to the fixed block size +signedSize = 1*DIGIT signed = sessionIdentifier SP [corrId] SP [queueId] SP cmd ; corrId is required in client commands and server responses, ; corrId is empty in server notifications. cmd = ping / recipientCmd / send / subscribeNotifications / serverMsg diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index 974d39f5f..34b4df6eb 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -64,8 +64,8 @@ import Numeric.Natural import Simplex.Messaging.Agent.Protocol (SMPServer (..)) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol -import Simplex.Messaging.Transport (ATransport (..), THandle (..), TLS, TProxy, Transport (..), TransportError, clientHandshake, runTransportClient) -import Simplex.Messaging.Transport.WebSockets (WS) +import Simplex.Messaging.Transport (ATransport (..), SessionId (..), THandle (..), TLS, TProxy, Transport (..), TransportError, clientHandshake, runTransportClient) +-- import Simplex.Messaging.Transport.WebSockets (WS) import Simplex.Messaging.Util (bshow, liftError, raceAny_) import System.Timeout (timeout) @@ -78,12 +78,14 @@ import System.Timeout (timeout) data SMPClient = SMPClient { action :: Async (), connected :: TVar Bool, + sndSessionId :: SessionId, + rcvSessionId :: SessionId, smpServer :: SMPServer, tcpTimeout :: Int, clientCorrId :: TVar Natural, sentCommands :: TVar (Map CorrId Request), sndQ :: TBQueue SentRawTransmission, - rcvQ :: TBQueue SignedTransmissionOrError, + rcvQ :: TBQueue (SignedTransmission (Command 'Broker)), msgQ :: TBQueue SMPServerTransmission, blockSize :: Int } @@ -126,7 +128,7 @@ data Request = Request responseVar :: TMVar Response } -type Response = Either SMPClientError Cmd +type Response = Either SMPClientError (Command 'Broker) -- | Connects to 'SMPServer' using passed client configuration -- and queue for messages and notifications. @@ -147,6 +149,8 @@ getSMPClient smpServer cfg@SMPClientConfig {qSize, tcpTimeout, smpPing, smpBlock return SMPClient { action = undefined, + sndSessionId = undefined, + rcvSessionId = undefined, blockSize = undefined, connected, smpServer, @@ -167,7 +171,8 @@ getSMPClient smpServer cfg@SMPClientConfig {qSize, tcpTimeout, smpPing, smpBlock `finally` atomically (putTMVar thVar $ Left SMPNetworkError) bSize <- tcpTimeout `timeout` atomically (takeTMVar thVar) pure $ case bSize of - Just (Right blockSize) -> Right c {action, blockSize} + Just (Right THandle {sndSessionId, rcvSessionId, blockSize}) -> + Right c {action, sndSessionId, rcvSessionId, blockSize} Just (Left e) -> Left e Nothing -> Left SMPNetworkError @@ -177,14 +182,14 @@ getSMPClient smpServer cfg@SMPClientConfig {qSize, tcpTimeout, smpPing, smpBlock -- Just "80" -> ("80", transport @WS) Just p -> (p, transport @TLS) - client :: forall c. Transport c => TProxy c -> SMPClient -> TMVar (Either SMPClientError Int) -> c -> IO () + client :: forall c. Transport c => TProxy c -> SMPClient -> TMVar (Either SMPClientError (THandle c)) -> c -> IO () client _ c thVar h = runExceptT (clientHandshake h smpBlockSize $ keyHash smpServer) >>= \case Left e -> atomically . putTMVar thVar . Left $ SMPTransportError e Right th -> do atomically $ do writeTVar (connected c) True - putTMVar thVar . Right $ blockSize (th :: THandle c) + putTMVar thVar $ Right th raceAny_ [send c th, process c, receive c th, ping c] `finally` disconnected @@ -197,11 +202,11 @@ getSMPClient smpServer cfg@SMPClientConfig {qSize, tcpTimeout, smpPing, smpBlock ping :: SMPClient -> IO () ping c = forever $ do threadDelay smpPing - runExceptT $ sendSMPCommand c Nothing "" (Cmd SSender PING) + runExceptT $ sendSMPCommand c Nothing "" (ClientCmd SSender PING) process :: SMPClient -> IO () process SMPClient {rcvQ, sentCommands} = forever $ do - (_, (corrId, qId, respOrErr)) <- atomically $ readTBQueue rcvQ + (_, _, (_, corrId, qId, respOrErr)) <- atomically $ readTBQueue rcvQ if B.null $ bs corrId then sendMsg qId respOrErr else do @@ -214,13 +219,13 @@ getSMPClient smpServer cfg@SMPClientConfig {qSize, tcpTimeout, smpPing, smpBlock if queueId == qId then case respOrErr of Left e -> Left $ SMPResponseError e - Right (Cmd _ (ERR e)) -> Left $ SMPServerError e + Right (ERR e) -> Left $ SMPServerError e Right r -> Right r else Left SMPUnexpectedResponse - sendMsg :: QueueId -> Either ErrorType Cmd -> IO () + sendMsg :: QueueId -> Either ErrorType (Command 'Broker) -> IO () sendMsg qId = \case - Right (Cmd SBroker cmd) -> atomically $ writeTBQueue msgQ (smpServer, qId, cmd) + Right cmd -> atomically $ writeTBQueue msgQ (smpServer, qId, cmd) -- TODO send everything else to errQ and log in agent _ -> return () @@ -265,8 +270,8 @@ createSMPQueue :: ExceptT SMPClientError IO QueueIdsKeys createSMPQueue c rpKey rKey dhKey = -- TODO add signing this request too - requires changes in the server - sendSMPCommand c (Just rpKey) "" (Cmd SRecipient $ NEW rKey dhKey) >>= \case - Cmd _ (IDS qik) -> pure qik + sendSMPCommand c (Just rpKey) "" (ClientCmd SRecipient $ NEW rKey dhKey) >>= \case + IDS qik -> pure qik _ -> throwE SMPUnexpectedResponse -- | Subscribe to the SMP queue. @@ -274,9 +279,9 @@ createSMPQueue c rpKey rKey dhKey = -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#subscribe-to-queue subscribeSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT SMPClientError IO () subscribeSMPQueue c@SMPClient {smpServer, msgQ} rpKey rId = - sendSMPCommand c (Just rpKey) rId (Cmd SRecipient SUB) >>= \case - Cmd _ OK -> return () - Cmd _ cmd@MSG {} -> + sendSMPCommand c (Just rpKey) rId (ClientCmd SRecipient SUB) >>= \case + OK -> return () + cmd@MSG {} -> lift . atomically $ writeTBQueue msgQ (smpServer, rId, cmd) _ -> throwE SMPUnexpectedResponse @@ -284,21 +289,21 @@ subscribeSMPQueue c@SMPClient {smpServer, msgQ} rpKey rId = -- -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#subscribe-to-queue-notifications subscribeSMPQueueNotifications :: SMPClient -> NtfPrivateSignKey -> NotifierId -> ExceptT SMPClientError IO () -subscribeSMPQueueNotifications = okSMPCommand $ Cmd SNotifier NSUB +subscribeSMPQueueNotifications = okSMPCommand $ ClientCmd SNotifier NSUB -- | Secure the SMP queue by adding a sender public key. -- -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#secure-queue-command secureSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> SndPublicVerifyKey -> ExceptT SMPClientError IO () -secureSMPQueue c rpKey rId senderKey = okSMPCommand (Cmd SRecipient $ KEY senderKey) c rpKey rId +secureSMPQueue c rpKey rId senderKey = okSMPCommand (ClientCmd SRecipient $ KEY senderKey) c rpKey rId -- | Enable notifications for the queue for push notifications server. -- -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#enable-notifications-command enableSMPQueueNotifications :: SMPClient -> RcvPrivateSignKey -> RecipientId -> NtfPublicVerifyKey -> ExceptT SMPClientError IO NotifierId enableSMPQueueNotifications c rpKey rId notifierKey = - sendSMPCommand c (Just rpKey) rId (Cmd SRecipient $ NKEY notifierKey) >>= \case - Cmd _ (NID nId) -> pure nId + sendSMPCommand c (Just rpKey) rId (ClientCmd SRecipient $ NKEY notifierKey) >>= \case + NID nId -> pure nId _ -> throwE SMPUnexpectedResponse -- | Send SMP message. @@ -306,8 +311,8 @@ enableSMPQueueNotifications c rpKey rId notifierKey = -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#send-message sendSMPMessage :: SMPClient -> Maybe SndPrivateSignKey -> SenderId -> MsgBody -> ExceptT SMPClientError IO () sendSMPMessage c spKey sId msg = - sendSMPCommand c spKey sId (Cmd SSender $ SEND msg) >>= \case - Cmd _ OK -> return () + sendSMPCommand c spKey sId (ClientCmd SSender $ SEND msg) >>= \case + OK -> pure () _ -> throwE SMPUnexpectedResponse -- | Acknowledge message delivery (server deletes the message). @@ -315,9 +320,9 @@ sendSMPMessage c spKey sId msg = -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#acknowledge-message-delivery ackSMPMessage :: SMPClient -> RcvPrivateSignKey -> QueueId -> ExceptT SMPClientError IO () ackSMPMessage c@SMPClient {smpServer, msgQ} rpKey rId = - sendSMPCommand c (Just rpKey) rId (Cmd SRecipient ACK) >>= \case - Cmd _ OK -> return () - Cmd _ cmd@MSG {} -> + sendSMPCommand c (Just rpKey) rId (ClientCmd SRecipient ACK) >>= \case + OK -> return () + cmd@MSG {} -> lift . atomically $ writeTBQueue msgQ (smpServer, rId, cmd) _ -> throwE SMPUnexpectedResponse @@ -326,25 +331,25 @@ ackSMPMessage c@SMPClient {smpServer, msgQ} rpKey rId = -- -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#suspend-queue suspendSMPQueue :: SMPClient -> RcvPrivateSignKey -> QueueId -> ExceptT SMPClientError IO () -suspendSMPQueue = okSMPCommand $ Cmd SRecipient OFF +suspendSMPQueue = okSMPCommand $ ClientCmd SRecipient OFF -- | Irreversibly delete SMP queue and all messages in it. -- -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#delete-queue deleteSMPQueue :: SMPClient -> RcvPrivateSignKey -> QueueId -> ExceptT SMPClientError IO () -deleteSMPQueue = okSMPCommand $ Cmd SRecipient DEL +deleteSMPQueue = okSMPCommand $ ClientCmd SRecipient DEL -okSMPCommand :: Cmd -> SMPClient -> C.APrivateSignKey -> QueueId -> ExceptT SMPClientError IO () +okSMPCommand :: ClientCmd -> SMPClient -> C.APrivateSignKey -> QueueId -> ExceptT SMPClientError IO () okSMPCommand cmd c pKey qId = sendSMPCommand c (Just pKey) qId cmd >>= \case - Cmd _ OK -> return () + OK -> return () _ -> throwE SMPUnexpectedResponse --- | Send any SMP command ('Cmd' type). -sendSMPCommand :: SMPClient -> Maybe C.APrivateSignKey -> QueueId -> Cmd -> ExceptT SMPClientError IO Cmd -sendSMPCommand SMPClient {sndQ, sentCommands, clientCorrId, tcpTimeout} pKey qId cmd = do +-- | Send any SMP command ('ClientCmd' type). +sendSMPCommand :: SMPClient -> Maybe C.APrivateSignKey -> QueueId -> ClientCmd -> ExceptT SMPClientError IO (Command 'Broker) +sendSMPCommand SMPClient {sndQ, sentCommands, clientCorrId, sndSessionId, tcpTimeout} pKey qId cmd = do corrId <- lift_ getNextCorrId - t <- signTransmission $ serializeTransmission (corrId, qId, cmd) + t <- signTransmission $ serializeTransmission (sndSessionId, corrId, qId, cmd) ExceptT $ sendRecv corrId t where lift_ :: STM a -> ExceptT SMPClientError IO a diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 6fad11f12..ab75e396c 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -1,5 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} @@ -7,7 +8,9 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-} @@ -26,16 +29,17 @@ module Simplex.Messaging.Protocol ( -- * SMP protocol types Command (..), + CommandI (..), Party (..), Cmd (..), + ClientCmd (..), SParty (..), QueueIdsKeys (..), ErrorType (..), CommandError (..), Transmission, + BrokerTransmission, SignedTransmission, - SignedTransmissionOrError, - RawTransmission, SentRawTransmission, SignedRawTransmission, CorrId (..), @@ -57,10 +61,8 @@ module Simplex.Messaging.Protocol -- * Parse and serialize serializeTransmission, - serializeCommand, serializeErrorType, transmissionP, - commandP, errorTypeP, -- * TCP transport functions @@ -76,20 +78,24 @@ import Control.Monad import Control.Monad.Except import Data.Attoparsec.ByteString.Char8 (Parser) import qualified Data.Attoparsec.ByteString.Char8 as A +import Data.Bifunctor (first) import Data.ByteString.Base64 import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B +import Data.Constraint (Dict (..)) import Data.Functor (($>)) import Data.Kind import Data.Maybe (isNothing) import Data.String import Data.Time.Clock import Data.Time.ISO8601 +import Data.Type.Equality import GHC.Generics (Generic) +import GHC.TypeLits (ErrorMessage (..), TypeError) import Generic.Random (genericArbitraryU) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Parsers -import Simplex.Messaging.Transport (THandle, Transport, TransportError (..), tGetEncrypted, tPutEncrypted) +import Simplex.Messaging.Transport (SessionId (..), THandle (..), Transport, TransportError (..), tGetEncrypted, tPutEncrypted) import Simplex.Messaging.Util import Test.QuickCheck (Arbitrary (..)) @@ -104,29 +110,57 @@ data SParty :: Party -> Type where SSender :: SParty Sender SNotifier :: SParty Notifier -deriving instance Show (SParty a) +instance TestEquality SParty where + testEquality SBroker SBroker = Just Refl + testEquality SRecipient SRecipient = Just Refl + testEquality SSender SSender = Just Refl + testEquality SNotifier SNotifier = Just Refl + testEquality _ _ = Nothing + +deriving instance Show (SParty p) + +class PartyI (p :: Party) where sParty :: SParty p + +instance PartyI Broker where sParty = SBroker + +instance PartyI Recipient where sParty = SRecipient + +instance PartyI Sender where sParty = SSender + +instance PartyI Notifier where sParty = SNotifier -- | Type for command or response of any participant. -data Cmd = forall a. Cmd (SParty a) (Command a) +data Cmd = forall p. PartyI p => Cmd (SParty p) (Command p) deriving instance Show Cmd +-- | Type for command or response of any participant. +data ClientCmd = forall p. (PartyI p, ClientParty p) => ClientCmd (SParty p) (Command p) + +class CommandI c where + serializeCommand :: c -> ByteString + commandP :: Parser c + -- | SMP transmission without signature. -type Transmission = (CorrId, QueueId, Cmd) +type Transmission c = (SessionId, CorrId, QueueId, c) --- | SMP transmission with signature. -type SignedTransmission = (Maybe C.ASignature, Transmission) +type BrokerTransmission = Transmission (Command Broker) -type TransmissionOrError = (CorrId, QueueId, Either ErrorType Cmd) - --- | signed parsed transmission, with parsing error. -type SignedTransmissionOrError = (Maybe C.ASignature, TransmissionOrError) +-- | signed parsed transmission, with original raw bytes and parsing error. +type SignedTransmission c = (Maybe C.ASignature, ByteString, Transmission (Either ErrorType c)) -- | unparsed SMP transmission with signature. -type RawTransmission = (ByteString, ByteString, ByteString, ByteString) +data RawTransmission = RawTransmission + { signature :: ByteString, + signed :: ByteString, + sessId :: ByteString, + corrId :: ByteString, + queueId :: ByteString, + command :: ByteString + } -- | unparsed sent SMP transmission with signature. -type SignedRawTransmission = (Maybe C.ASignature, ByteString, ByteString, ByteString) +type SignedRawTransmission = (Maybe C.ASignature, ByteString, ByteString, ByteString, ByteString) -- | unparsed sent SMP transmission with signature. type SentRawTransmission = (Maybe C.ASignature, ByteString) @@ -172,12 +206,24 @@ deriving instance Show (Command a) deriving instance Eq (Command a) +type family ClientParty p :: Constraint where + ClientParty Recipient = () + ClientParty Sender = () + ClientParty Notifier = () + ClientParty p = + (Int ~ Bool, TypeError (Text "Party " :<>: ShowType p :<>: Text " is not a Client")) + +clientParty :: SParty p -> Maybe (Dict (ClientParty p)) +clientParty = \case + SRecipient -> Just Dict + SSender -> Just Dict + SNotifier -> Just Dict + _ -> Nothing + -- | Base-64 encoded string. type Encoded = ByteString -- | Transmission correlation ID. --- --- A newtype to avoid accidentally changing order of transmission parts. newtype CorrId = CorrId {bs :: ByteString} deriving (Eq, Ord, Show) instance IsString CorrId where @@ -231,6 +277,8 @@ type MsgBody = ByteString data ErrorType = -- | incorrect block format, encoding or signature size BLOCK + | -- | incorrect SMP session ID (TLS Finished message / tls-unique binding RFC5929) + SESSION | -- | SMP command is unknown or has invalid syntax CMD CommandError | -- | command authorization error - bad signature or non-existing SMP queue @@ -268,89 +316,107 @@ instance Arbitrary CommandError where arbitrary = genericArbitraryU -- | SMP transmission parser. transmissionP :: Parser RawTransmission transmissionP = do - sig <- segment - corrId <- segment - queueId <- segment - command <- A.takeByteString - return (sig, corrId, queueId, command) + signature <- segment + len <- A.decimal <* A.space + signed <- A.take len <* A.space + either fail pure $ parseAll (trn signature signed) signed where - segment = A.takeTill (== ' ') <* " " + segment = A.takeTill (== ' ') <* A.space + trn signature signed = do + sessId <- segment + corrId <- segment + queueId <- segment + command <- A.takeByteString + pure RawTransmission {signature, signed, sessId, corrId, queueId, command} --- | SMP command parser. -commandP :: Parser Cmd -commandP = - "NEW " *> newCmd - <|> "IDS " *> idsResp - <|> "SUB" $> Cmd SRecipient SUB - <|> "KEY " *> keyCmd - <|> "NKEY " *> nKeyCmd - <|> "NID " *> nIdsResp - <|> "ACK" $> Cmd SRecipient ACK - <|> "OFF" $> Cmd SRecipient OFF - <|> "DEL" $> Cmd SRecipient DEL - <|> "SEND " *> sendCmd - <|> "PING" $> Cmd SSender PING - <|> "NSUB" $> Cmd SNotifier NSUB - <|> "MSG " *> message - <|> "NMSG" $> Cmd SBroker NMSG - <|> "END" $> Cmd SBroker END - <|> "OK" $> Cmd SBroker OK - <|> "ERR " *> serverError - <|> "PONG" $> Cmd SBroker PONG - where - newCmd = Cmd SRecipient <$> (NEW <$> C.strKeyP <* A.space <*> C.strKeyP) - idsResp = Cmd SBroker . IDS <$> qik - qik = do - rcvId <- base64P <* A.space - rcvSrvVerifyKey <- C.strKeyP <* A.space - rcvPublicDHKey <- C.strKeyP <* A.space - sndId <- base64P <* A.space - sndSrvVerifyKey <- C.strKeyP - pure QIK {rcvId, rcvSrvVerifyKey, rcvPublicDHKey, sndId, sndSrvVerifyKey} - nIdsResp = Cmd SBroker . NID <$> base64P - keyCmd = Cmd SRecipient . KEY <$> C.strKeyP - nKeyCmd = Cmd SRecipient . NKEY <$> C.strKeyP - sendCmd = do - size <- A.decimal <* A.space - Cmd SSender . SEND <$> A.take size <* A.space - message = do - msgId <- base64P <* A.space - ts <- tsISO8601P <* A.space - size <- A.decimal <* A.space - Cmd SBroker . MSG msgId ts <$> A.take size <* A.space - serverError = Cmd SBroker . ERR <$> errorTypeP +instance CommandI Cmd where + serializeCommand (Cmd _ cmd) = serializeCommand cmd + commandP = + "NEW " *> newCmd + <|> "IDS " *> idsResp + <|> "SUB" $> Cmd SRecipient SUB + <|> "KEY " *> keyCmd + <|> "NKEY " *> nKeyCmd + <|> "NID " *> nIdsResp + <|> "ACK" $> Cmd SRecipient ACK + <|> "OFF" $> Cmd SRecipient OFF + <|> "DEL" $> Cmd SRecipient DEL + <|> "SEND " *> sendCmd + <|> "PING" $> Cmd SSender PING + <|> "NSUB" $> Cmd SNotifier NSUB + <|> "MSG " *> message + <|> "NMSG" $> Cmd SBroker NMSG + <|> "END" $> Cmd SBroker END + <|> "OK" $> Cmd SBroker OK + <|> "ERR " *> serverError + <|> "PONG" $> Cmd SBroker PONG + where + newCmd = Cmd SRecipient <$> (NEW <$> C.strKeyP <* A.space <*> C.strKeyP) + idsResp = Cmd SBroker . IDS <$> qik + qik = do + rcvId <- base64P <* A.space + rcvSrvVerifyKey <- C.strKeyP <* A.space + rcvPublicDHKey <- C.strKeyP <* A.space + sndId <- base64P <* A.space + sndSrvVerifyKey <- C.strKeyP + pure QIK {rcvId, rcvSrvVerifyKey, rcvPublicDHKey, sndId, sndSrvVerifyKey} + nIdsResp = Cmd SBroker . NID <$> base64P + keyCmd = Cmd SRecipient . KEY <$> C.strKeyP + nKeyCmd = Cmd SRecipient . NKEY <$> C.strKeyP + sendCmd = do + size <- A.decimal <* A.space + Cmd SSender . SEND <$> A.take size <* A.space + message = do + msgId <- base64P <* A.space + ts <- tsISO8601P <* A.space + size <- A.decimal <* A.space + Cmd SBroker . MSG msgId ts <$> A.take size <* A.space + serverError = Cmd SBroker . ERR <$> errorTypeP --- TODO ignore the end of block, no need to parse it +instance CommandI ClientCmd where + serializeCommand (ClientCmd _ cmd) = serializeCommand cmd + commandP = clientCmd <$?> commandP + where + clientCmd :: Cmd -> Either String ClientCmd + clientCmd (Cmd p cmd) = case clientParty p of + Just Dict -> Right (ClientCmd p cmd) + _ -> Left "not a client command" -- | Parse SMP command. parseCommand :: ByteString -> Either ErrorType Cmd -parseCommand = parse (commandP <* " " <* A.takeByteString) $ CMD SYNTAX +parseCommand = parse commandP $ CMD SYNTAX --- | Serialize SMP command. -serializeCommand :: Cmd -> ByteString -serializeCommand = \case - Cmd SRecipient (NEW rKey dhKey) -> B.unwords ["NEW", C.serializeKey rKey, C.serializeKey dhKey] - Cmd SRecipient (KEY sKey) -> "KEY " <> C.serializeKey sKey - Cmd SRecipient (NKEY nKey) -> "NKEY " <> C.serializeKey nKey - Cmd SRecipient SUB -> "SUB" - Cmd SRecipient ACK -> "ACK" - Cmd SRecipient OFF -> "OFF" - Cmd SRecipient DEL -> "DEL" - Cmd SSender (SEND msgBody) -> "SEND " <> serializeMsg msgBody - Cmd SSender PING -> "PING" - Cmd SNotifier NSUB -> "NSUB" - Cmd SBroker (MSG msgId ts msgBody) -> - B.unwords ["MSG", encode msgId, B.pack $ formatISO8601Millis ts, serializeMsg msgBody] - Cmd SBroker (IDS QIK {rcvId, rcvSrvVerifyKey = rsKey, rcvPublicDHKey = dhKey, sndId, sndSrvVerifyKey = ssKey}) -> - B.unwords ["IDS", encode rcvId, C.serializeKey rsKey, C.serializeKey dhKey, encode sndId, C.serializeKey ssKey] - Cmd SBroker (NID nId) -> "NID " <> encode nId - Cmd SBroker (ERR err) -> "ERR " <> serializeErrorType err - Cmd SBroker NMSG -> "NMSG" - Cmd SBroker END -> "END" - Cmd SBroker OK -> "OK" - Cmd SBroker PONG -> "PONG" - where - serializeMsg msgBody = bshow (B.length msgBody) <> " " <> msgBody <> " " +instance PartyI p => CommandI (Command p) where + commandP = command' <$?> commandP + where + command' :: Cmd -> Either String (Command p) + command' (Cmd p cmd) = case testEquality p $ sParty @p of + Just Refl -> Right cmd + _ -> Left "bad command party" + serializeCommand = \case + NEW rKey dhKey -> B.unwords ["NEW", C.serializeKey rKey, C.serializeKey dhKey] + KEY sKey -> "KEY " <> C.serializeKey sKey + NKEY nKey -> "NKEY " <> C.serializeKey nKey + SUB -> "SUB" + ACK -> "ACK" + OFF -> "OFF" + DEL -> "DEL" + SEND msgBody -> "SEND " <> serializeBody msgBody + PING -> "PING" + NSUB -> "NSUB" + MSG msgId ts msgBody -> + B.unwords ["MSG", encode msgId, B.pack $ formatISO8601Millis ts, serializeBody msgBody] + IDS QIK {rcvId, rcvSrvVerifyKey = rsKey, rcvPublicDHKey = dhKey, sndId, sndSrvVerifyKey = ssKey} -> + B.unwords ["IDS", encode rcvId, C.serializeKey rsKey, C.serializeKey dhKey, encode sndId, C.serializeKey ssKey] + NID nId -> "NID " <> encode nId + ERR err -> "ERR " <> serializeErrorType err + NMSG -> "NMSG" + END -> "END" + OK -> "OK" + PONG -> "PONG" + +serializeBody :: ByteString -> ByteString +serializeBody s = bshow (B.length s) <> " " <> s <> " " -- | SMP error parser. errorTypeP :: Parser ErrorType @@ -362,56 +428,56 @@ serializeErrorType = bshow -- | Send signed SMP transmission to TCP transport. tPut :: Transport c => THandle c -> SentRawTransmission -> IO (Either TransportError ()) -tPut th (sig, t) = - tPutEncrypted th $ C.serializeSignature sig <> " " <> t <> " " +tPut th (sig, t) = tPutEncrypted th $ C.serializeSignature sig <> " " <> serializeBody t --- | Serialize SMP transmission. -serializeTransmission :: Transmission -> ByteString -serializeTransmission (CorrId corrId, queueId, command) = - B.intercalate " " [corrId, encode queueId, serializeCommand command] +serializeTransmission :: CommandI c => Transmission c -> ByteString +serializeTransmission (SessionId sessId, CorrId corrId, queueId, command) = + B.unwords [sessId, corrId, encode queueId, serializeCommand command] -- | Validate that it is an SMP client command, used with 'tGet' by 'Simplex.Messaging.Server'. -fromClient :: Cmd -> Either ErrorType Cmd -fromClient = \case - Cmd SBroker _ -> Left $ CMD PROHIBITED - cmd -> Right cmd +fromClient :: Cmd -> Either ErrorType ClientCmd +fromClient (Cmd p cmd) = case clientParty p of + Just Dict -> Right $ ClientCmd p cmd + Nothing -> Left $ CMD PROHIBITED -- | Validate that it is an SMP server command, used with 'tGet' by 'Simplex.Messaging.Client'. -fromServer :: Cmd -> Either ErrorType Cmd +fromServer :: Cmd -> Either ErrorType (Command Broker) fromServer = \case - cmd@(Cmd SBroker _) -> Right cmd + Cmd SBroker cmd -> Right cmd _ -> Left $ CMD PROHIBITED --- | Receive and parse transmission from the TCP transport. +-- | Receive and parse transmission from the TCP transport (ignoring any trailing padding). tGetParse :: Transport c => THandle c -> IO (Either TransportError RawTransmission) -tGetParse th = (>>= parse transmissionP TEBadBlock) <$> tGetEncrypted th +tGetParse th = (first (const TEBadBlock) . A.parseOnly transmissionP =<<) <$> tGetEncrypted th -- | Receive client and server transmissions. -- -- The first argument is used to limit allowed senders. -- 'fromClient' or 'fromServer' should be used here. -tGet :: forall c m. (Transport c, MonadIO m) => (Cmd -> Either ErrorType Cmd) -> THandle c -> m SignedTransmissionOrError -tGet fromParty th = liftIO (tGetParse th) >>= decodeParseValidate +tGet :: forall c m cmd. (Transport c, MonadIO m) => (Cmd -> Either ErrorType cmd) -> THandle c -> m (SignedTransmission cmd) +tGet fromParty th@THandle {rcvSessionId, sndSessionId} = liftIO (tGetParse th) >>= decodeParseValidate where - decodeParseValidate :: Either TransportError RawTransmission -> m SignedTransmissionOrError + decodeParseValidate :: Either TransportError RawTransmission -> m (SignedTransmission cmd) decodeParseValidate = \case - Right (sig, corrId, queueId, command) -> - let decodedTransmission = liftM2 (,corrId,,command) (C.decodeSignature =<< decode sig) (decode queueId) - in either (const $ tError corrId) tParseValidate decodedTransmission + Right RawTransmission {signature, signed, sessId, corrId, queueId, command} + | SessionId sessId == rcvSessionId -> + let decodedTransmission = liftM2 (,sessId,corrId,,command) (C.decodeSignature =<< decode signature) (decode queueId) + in either (const $ tError corrId) (tParseValidate signed) decodedTransmission + | otherwise -> pure (Nothing, "", (sndSessionId, CorrId corrId, "", Left SESSION)) Left _ -> tError "" - tError :: ByteString -> m SignedTransmissionOrError - tError corrId = return (Nothing, (CorrId corrId, "", Left BLOCK)) + tError :: ByteString -> m (SignedTransmission cmd) + tError corrId = pure (Nothing, "", (sndSessionId, CorrId corrId, "", Left BLOCK)) - tParseValidate :: SignedRawTransmission -> m SignedTransmissionOrError - tParseValidate t@(sig, corrId, queueId, command) = do - let cmd = parseCommand command >>= fromParty >>= tCredentials t - return (sig, (CorrId corrId, queueId, cmd)) + tParseValidate :: ByteString -> SignedRawTransmission -> m (SignedTransmission cmd) + tParseValidate signed t@(sig, sessId, corrId, queueId, command) = do + let cmd = parseCommand command >>= tCredentials t >>= fromParty + return (sig, signed, (SessionId sessId, CorrId corrId, queueId, cmd)) tCredentials :: SignedRawTransmission -> Cmd -> Either ErrorType Cmd - tCredentials (sig, _, queueId, _) cmd = case cmd of + tCredentials (sig, _, _, queueId, _) cmd = case cmd of -- IDS response must not have queue ID - Cmd SBroker IDS {} -> Right cmd + Cmd SBroker (IDS _) -> Right cmd -- ERR response does not always have queue ID Cmd SBroker (ERR _) -> Right cmd -- PONG response must not have queue ID diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index a17a1f3cd..b506b184a 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -7,7 +7,6 @@ {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RankNTypes #-} -{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} @@ -107,7 +106,7 @@ runSMPServerBlocking started cfg@ServerConfig {transports} = do join <$> mapM (endPreviousSubscriptions qId) (M.lookup qId serverSubs) endPreviousSubscriptions :: QueueId -> Client -> STM (Maybe s) endPreviousSubscriptions qId c = do - writeTBQueue (rcvQ c) (CorrId B.empty, qId, Cmd SBroker END) + writeTBQueue (sndQ c) (SessionId "", CorrId "", qId, END) stateTVar (clientSubs c) $ \ss -> (M.lookup qId ss, M.delete qId ss) runClient :: (Transport c, MonadUnliftIO m, MonadReader Env m) => TProxy c -> c -> m () @@ -119,9 +118,9 @@ runClient _ h = do Left _ -> pure () runClientTransport :: (Transport c, MonadUnliftIO m, MonadReader Env m) => THandle c -> m () -runClientTransport th = do +runClientTransport th@THandle {sndSessionId} = do q <- asks $ tbqSize . config - c <- atomically $ newClient q + c <- atomically $ newClient q sndSessionId s <- asks server raceAny_ [send th c, client c s, receive th c] `finally` cancelSubscribers c @@ -136,58 +135,50 @@ cancelSub = \case _ -> return () receive :: (Transport c, MonadUnliftIO m, MonadReader Env m) => THandle c -> Client -> m () -receive h Client {rcvQ} = forever $ do - (signature, (corrId, queueId, cmdOrError)) <- tGet fromClient h - t <- case cmdOrError of - Left e -> return . mkResp corrId queueId $ ERR e - Right cmd -> verifyTransmission (signature, (corrId, queueId, cmd)) - atomically $ writeTBQueue rcvQ t +receive h@THandle {sndSessionId} Client {rcvQ, sndQ} = forever $ do + (sig, signed, (sessId, corrId, queueId, cmdOrError)) <- tGet fromClient h + case cmdOrError of + Left e -> write sndQ (sndSessionId, corrId, queueId, ERR e) + Right cmd -> do + verified <- verifyTransmission sig signed queueId cmd + if verified + then write rcvQ (sessId, corrId, queueId, cmd) + else write sndQ (sndSessionId, corrId, queueId, ERR AUTH) + where + write q t = atomically $ writeTBQueue q t send :: (Transport c, MonadUnliftIO m) => THandle c -> Client -> m () send h Client {sndQ} = forever $ do t <- atomically $ readTBQueue sndQ + -- TODO sign it here? liftIO $ tPut h (Nothing, serializeTransmission t) -mkResp :: CorrId -> QueueId -> Command 'Broker -> Transmission -mkResp corrId queueId command = (corrId, queueId, Cmd SBroker command) - -verifyTransmission :: forall m. (MonadUnliftIO m, MonadReader Env m) => SignedTransmission -> m Transmission -verifyTransmission (sig_, t@(corrId, queueId, cmd)) = do - (corrId,queueId,) <$> case cmd of - Cmd SBroker _ -> return $ smpErr INTERNAL -- it can only be client command, because `fromClient` was used - Cmd SRecipient (NEW k _) -> pure $ verifySignature k - Cmd SRecipient _ -> verifyCmd SRecipient $ verifySignature . recipientKey - Cmd SSender (SEND _) -> verifyCmd SSender $ verifyMaybe . senderKey - Cmd SSender PING -> return cmd - Cmd SNotifier NSUB -> verifyCmd SNotifier $ verifyMaybe . fmap snd . notifier +verifyTransmission :: + forall m. (MonadUnliftIO m, MonadReader Env m) => Maybe C.ASignature -> ByteString -> QueueId -> ClientCmd -> m Bool +verifyTransmission sig_ signed queueId cmd = do + case cmd of + ClientCmd SRecipient (NEW k _) -> pure $ verifySignature k + ClientCmd SRecipient _ -> verifyCmd SRecipient $ verifySignature . recipientKey + ClientCmd SSender (SEND _) -> verifyCmd SSender $ verifyMaybe . senderKey + ClientCmd SSender PING -> pure True + ClientCmd SNotifier NSUB -> verifyCmd SNotifier $ verifyMaybe . fmap snd . notifier where - verifyCmd :: SParty p -> (QueueRec -> Cmd) -> m Cmd + verifyCmd :: SParty p -> (QueueRec -> Bool) -> m Bool verifyCmd party f = do st <- asks queueStore q <- atomically $ getQueue st party queueId - pure $ either (const $ dummyVerify_ sig_ authErr) f q - verifyMaybe :: Maybe C.APublicVerifyKey -> Cmd - verifyMaybe (Just k) = verifySignature k - verifyMaybe _ = maybe cmd (const authErr) sig_ - verifySignature :: C.APublicVerifyKey -> Cmd - verifySignature key = case sig_ of - Just s -> if verify key s then cmd else authErr - _ -> authErr + pure $ either (const $ maybe False dummyVerify sig_ `seq` False) f q + verifyMaybe :: Maybe C.APublicVerifyKey -> Bool + verifyMaybe = maybe (isNothing sig_) verifySignature + verifySignature :: C.APublicVerifyKey -> Bool + verifySignature key = maybe False (verify key) sig_ verify :: C.APublicVerifyKey -> C.ASignature -> Bool verify (C.APublicVerifyKey a k) sig@(C.ASignature a' s) = case (testEquality a a', C.signatureSize k == C.signatureSize s) of - (Just Refl, True) -> cryptoVerify k s - _ -> dummyVerify sig False - cryptoVerify :: C.SignatureAlgorithm a => C.PublicKey a -> C.Signature a -> Bool - cryptoVerify k s = C.verify' k s (serializeTransmission t) - dummyVerify_ :: Maybe C.ASignature -> a -> a - dummyVerify_ = \case - Just s -> dummyVerify s - _ -> id - dummyVerify :: C.ASignature -> a -> a - dummyVerify (C.ASignature _ s) = seq $ cryptoVerify (dummyPublicKey s) s - smpErr = Cmd SBroker . ERR - authErr = smpErr AUTH + (Just Refl, True) -> C.verify' k s signed + _ -> dummyVerify sig `seq` False + dummyVerify :: C.ASignature -> Bool + dummyVerify (C.ASignature _ s) = C.verify' (dummyPublicKey s) s signed -- These dummy keys are used with `dummyVerify` function to mitigate timing attacks -- by having the same time of the response whether a queue exists or nor, for all valid key/signature sizes @@ -221,22 +212,21 @@ dummyKey512 :: C.PublicKey 'C.RSA dummyKey512 = "MIICoDANBgkqhkiG9w0BAQEFAAOCAo0AMIICiAKCAgEArkCY9DuverJ4mmzDektv9aZMFyeRV46WZK9NsOBKEc+1ncqMs+LhLti9asKNgUBRbNzmbOe0NYYftrUpwnATaenggkTFxxbJ4JGJuGYbsEdFWkXSvrbWGtM8YUmn5RkAGme12xQ89bSM4VoJAGnrYPHwmcQd+KYCPZvTUsxaxgrJTX65ejHN9BsAn8XtGViOtHTDJO9yUMD2WrJvd7wnNa+0ugEteDLzMU++xS98VC+uA1vfauUqi3yXVchdfrLdVUuM+JE0gUEXCgzjuHkaoHiaGNiGhdPYoAJJdOKQOIHAKdk7Th6OPhirPhc9XYNB4O8JDthKhNtfokvFIFlC4QBRzJhpLIENaEBDt08WmgpOnecZB/CuxkqqOrNa8j5K5jNrtXAI67W46VEC2jeQy/gZwb64Zit2A4D00xXzGbQTPGj4ehcEMhLx5LSCygViEf0w0tN3c3TEyUcgPzvECd2ZVpQLr9Z4a07Ebr+YSuxcHhjg4Rg1VyJyOTTvaCBGm5X2B3+tI4NUttmikIHOYpBnsLmHY2BgfH2KcrIsDyAhInXmTFr/L2+erFarUnlfATd2L8Ti43TNHDedO6k6jI5Gyi62yPwjqPLEIIK8l+pIeNfHJ3pPmjhHBfzFcQLMMMXffHWNK8kWklrQXK+4j4HiPcTBvlO1FEtG9nEIZhUCgYA4a6WtI2k5YNli1C89GY5rGUY7RP71T6RWri/D3Lz9T7GvU+FemAyYmsvCQwqijUOur0uLvwSP8VdxpSUcrjJJSWur2hrPWzWlu0XbNaeizxpFeKbQP+zSrWJ1z8RwfAeUjShxt8q1TuqGqY10wQyp3nyiTGvS+KwZVj5h5qx8NQ==" client :: forall m. (MonadUnliftIO m, MonadReader Env m) => Client -> Server -> m () -client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ = sndQ'} Server {subscribedQ, ntfSubscribedQ, notifiers} = +client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sndSessionId} Server {subscribedQ, ntfSubscribedQ, notifiers} = forever $ atomically (readTBQueue rcvQ) >>= processCommand - >>= atomically . writeTBQueue sndQ' + >>= atomically . writeTBQueue sndQ where - processCommand :: Transmission -> m Transmission - processCommand (corrId, queueId, cmd) = do + processCommand :: Transmission ClientCmd -> m BrokerTransmission + processCommand (_, corrId, queueId, cmd) = do st <- asks queueStore case cmd of - Cmd SBroker _ -> pure (corrId, queueId, cmd) - Cmd SSender command -> case command of + ClientCmd SSender command -> case command of SEND msgBody -> sendMessage st msgBody - PING -> return (corrId, queueId, Cmd SBroker PONG) - Cmd SNotifier NSUB -> subscribeNotifications - Cmd SRecipient command -> case command of + PING -> pure (sndSessionId, corrId, queueId, PONG) + ClientCmd SNotifier NSUB -> subscribeNotifications + ClientCmd SRecipient command -> case command of NEW rKey dhKey -> createQueue st rKey dhKey SUB -> subscribeQueue queueId ACK -> acknowledgeMsg @@ -245,7 +235,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ = sndQ'} Server OFF -> suspendQueue_ st DEL -> delQueueAndMsgs st where - createQueue :: QueueStore -> RcvPublicVerifyKey -> RcvPublicDhKey -> m Transmission + createQueue :: QueueStore -> RcvPublicVerifyKey -> RcvPublicDhKey -> m BrokerTransmission createQueue st recipientKey dhKey = checkKeySize recipientKey $ do C.SignAlg a <- asks $ trnSignAlg . config (rcvPublicDHKey, privDhKey) <- liftIO $ C.generateKeyPair' 0 @@ -291,12 +281,12 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ = sndQ'} Server n <- asks $ queueIdBytes . config liftM2 (,) (randomId n) (randomId n) - secureQueue_ :: QueueStore -> SndPublicVerifyKey -> m Transmission + secureQueue_ :: QueueStore -> SndPublicVerifyKey -> m BrokerTransmission secureQueue_ st sKey = do withLog $ \s -> logSecureQueue s queueId sKey atomically . checkKeySize sKey $ either ERR (const OK) <$> secureQueue st queueId sKey - addQueueNotifier_ :: QueueStore -> NtfPublicVerifyKey -> m Transmission + addQueueNotifier_ :: QueueStore -> NtfPublicVerifyKey -> m BrokerTransmission addQueueNotifier_ st nKey = checkKeySize nKey $ addNotifierRetry 3 where addNotifierRetry :: Int -> m (Command 'Broker) @@ -310,19 +300,19 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ = sndQ'} Server withLog $ \s -> logAddNotifier s queueId nId nKey pure $ NID nId - checkKeySize :: Monad m' => C.APublicVerifyKey -> m' (Command 'Broker) -> m' Transmission + checkKeySize :: Monad m' => C.APublicVerifyKey -> m' (Command 'Broker) -> m' BrokerTransmission checkKeySize key action = - mkResp corrId queueId + (sndSessionId,corrId,queueId,) <$> if C.validKeySize key then action else pure . ERR $ CMD KEY_SIZE - suspendQueue_ :: QueueStore -> m Transmission + suspendQueue_ :: QueueStore -> m BrokerTransmission suspendQueue_ st = do withLog (`logDeleteQueue` queueId) okResp <$> atomically (suspendQueue st queueId) - subscribeQueue :: RecipientId -> m Transmission + subscribeQueue :: RecipientId -> m BrokerTransmission subscribeQueue rId = atomically (getSubscription rId) >>= deliverMessage tryPeekMsg rId @@ -337,7 +327,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ = sndQ'} Server writeTVar subscriptions $ M.insert rId s subs return s - subscribeNotifications :: m Transmission + subscribeNotifications :: m BrokerTransmission subscribeNotifications = atomically $ do subs <- readTVar ntfSubscriptions when (isNothing $ M.lookup queueId subs) $ do @@ -345,7 +335,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ = sndQ'} Server writeTVar ntfSubscriptions $ M.insert queueId () subs pure ok - acknowledgeMsg :: m Transmission + acknowledgeMsg :: m BrokerTransmission acknowledgeMsg = atomically (withSub queueId $ \s -> const s <$$> tryTakeTMVar (delivered s)) >>= \case @@ -355,12 +345,12 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ = sndQ'} Server withSub :: RecipientId -> (Sub -> STM a) -> STM (Maybe a) withSub rId f = readTVar subscriptions >>= mapM f . M.lookup rId - sendMessage :: QueueStore -> MsgBody -> m Transmission + sendMessage :: QueueStore -> MsgBody -> m BrokerTransmission sendMessage st msgBody = do qr <- atomically $ getQueue st SSender queueId either (return . err) storeMessage qr where - storeMessage :: QueueRec -> m Transmission + storeMessage :: QueueRec -> m BrokerTransmission storeMessage qr = case status qr of QueueOff -> return $ err AUTH QueueActive -> do @@ -387,11 +377,11 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ = sndQ'} Server mapM_ (writeNtf nId) . M.lookup nId =<< readTVar notifiers writeNtf :: NotifierId -> Client -> STM () - writeNtf nId Client {sndQ} = + writeNtf nId Client {sndQ = q, sndSessionId = sessId} = unlessM (isFullTBQueue sndQ) $ - writeTBQueue sndQ $ mkResp (CorrId B.empty) nId NMSG + writeTBQueue q (sessId, CorrId "", nId, NMSG) - deliverMessage :: (MsgQueue -> STM (Maybe Message)) -> RecipientId -> Sub -> m Transmission + deliverMessage :: (MsgQueue -> STM (Maybe Message)) -> RecipientId -> Sub -> m BrokerTransmission deliverMessage tryPeek rId = \case Sub {subThread = NoSub} -> do ms <- asks msgStore @@ -399,7 +389,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ = sndQ'} Server q <- atomically $ getMsgQueue ms rId quota atomically (tryPeek q) >>= \case Nothing -> forkSub q $> ok - Just msg -> atomically setDelivered $> mkResp corrId rId (msgCmd msg) + Just msg -> atomically setDelivered $> (sndSessionId, corrId, rId, msgCmd msg) _ -> return ok where forkSub :: MsgQueue -> m () @@ -413,7 +403,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ = sndQ'} Server subscriber :: MsgQueue -> m () subscriber q = atomically $ do msg <- peekMsg q - writeTBQueue sndQ' $ mkResp (CorrId B.empty) rId (msgCmd msg) + writeTBQueue sndQ (sndSessionId, CorrId "", rId, msgCmd msg) setSub (\s -> s {subThread = NoSub}) void setDelivered @@ -426,7 +416,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ = sndQ'} Server msgCmd :: Message -> Command 'Broker msgCmd Message {msgId, ts, msgBody} = MSG msgId ts msgBody - delQueueAndMsgs :: QueueStore -> m Transmission + delQueueAndMsgs :: QueueStore -> m BrokerTransmission delQueueAndMsgs st = do withLog (`logDeleteQueue` queueId) ms <- asks msgStore @@ -435,13 +425,13 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ = sndQ'} Server Left e -> return $ err e Right _ -> delMsgQueue ms queueId $> ok - ok :: Transmission - ok = mkResp corrId queueId OK + ok :: BrokerTransmission + ok = (sndSessionId, corrId, queueId, OK) - err :: ErrorType -> Transmission - err = mkResp corrId queueId . ERR + err :: ErrorType -> BrokerTransmission + err e = (sndSessionId, corrId, queueId, ERR e) - okResp :: Either ErrorType () -> Transmission + okResp :: Either ErrorType () -> BrokerTransmission okResp = either err $ const ok withLog :: (MonadUnliftIO m, MonadReader Env m) => (StoreLog 'WriteMode -> IO a) -> m () diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index 455bdeefc..cb8b80d75 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -19,7 +19,7 @@ import Simplex.Messaging.Server.MsgStore.STM import Simplex.Messaging.Server.QueueStore (QueueRec (..)) import Simplex.Messaging.Server.QueueStore.STM import Simplex.Messaging.Server.StoreLog -import Simplex.Messaging.Transport (ATransport, loadServerCredential) +import Simplex.Messaging.Transport (ATransport, SessionId, loadServerCredential) import System.IO (IOMode (..)) import UnliftIO.STM @@ -58,8 +58,9 @@ data Server = Server data Client = Client { subscriptions :: TVar (Map RecipientId Sub), ntfSubscriptions :: TVar (Map NotifierId ()), - rcvQ :: TBQueue Transmission, - sndQ :: TBQueue Transmission + rcvQ :: TBQueue (Transmission ClientCmd), + sndQ :: TBQueue BrokerTransmission, + sndSessionId :: SessionId } data SubscriptionThread = NoSub | SubPending | SubThread ThreadId @@ -77,13 +78,13 @@ newServer qSize = do notifiers <- newTVar M.empty return Server {subscribedQ, subscribers, ntfSubscribedQ, notifiers} -newClient :: Natural -> STM Client -newClient qSize = do +newClient :: Natural -> SessionId -> STM Client +newClient qSize sndSessionId = do subscriptions <- newTVar M.empty ntfSubscriptions <- newTVar M.empty rcvQ <- newTBQueue qSize sndQ <- newTBQueue qSize - return Client {subscriptions, ntfSubscriptions, rcvQ, sndQ} + return Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sndSessionId} newSubscription :: STM Sub newSubscription = do diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index f67d726dc..c4a81c840 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -40,6 +40,7 @@ module Simplex.Messaging.Transport -- * SMP encrypted transport THandle (..), + SessionId (..), TransportError (..), serverHandshake, clientHandshake, @@ -325,9 +326,16 @@ smpVersionP = let ver = A.decimal <* A.char '.' in SMPVersion <$> ver <*> ver <*> ver <*> A.decimal +-- | Session identifier (base64 encoded here, to avoid encoding every time it is sent) +-- It should be set from TLS finished and passed in the initial handshake +newtype SessionId = SessionId {unSessionId :: ByteString} + deriving (Eq, Show) + -- | The handle for SMP encrypted transport connection over Transport . data THandle c = THandle { connection :: c, + sndSessionId :: SessionId, + rcvSessionId :: SessionId, sndKey :: SessionKey, rcvKey :: SessionKey, blockSize :: Int @@ -349,6 +357,8 @@ data ClientHandshake = ClientHandshake data TransportError = -- | error parsing transport block TEBadBlock + | -- | incorrect session ID + TEBadSession | -- | block encryption error TEEncrypt | -- | block decryption error @@ -387,6 +397,7 @@ instance Arbitrary HandshakeError where arbitrary = genericArbitraryU transportErrorP :: Parser TransportError transportErrorP = "BLOCK" $> TEBadBlock + <|> "SESSION" $> TEBadSession <|> "AES_ENCRYPT" $> TEEncrypt <|> "AES_DECRYPT" $> TEDecrypt <|> TEHandshake <$> parseRead1 @@ -397,6 +408,7 @@ serializeTransportError = \case TEEncrypt -> "AES_ENCRYPT" TEDecrypt -> "AES_DECRYPT" TEBadBlock -> "BLOCK" + TEBadSession -> "SESSION" TEHandshake e -> bshow e -- | Encrypt and send block to SMP encrypted transport. @@ -582,6 +594,8 @@ transportHandle c sk rk blockSize = do pure THandle { connection = c, + sndSessionId = SessionId "", + rcvSessionId = SessionId "", sndKey = sk {counter = sndCounter}, rcvKey = rk {counter = rcvCounter}, blockSize diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index 0dcb682dc..ec2f89f3b 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -171,11 +171,11 @@ smpTest4 _ test' = smpTestN 4 _test _test _ = error "expected 4 handles" tPutRaw :: Transport c => THandle c -> SignedRawTransmission -> IO () -tPutRaw h (sig, corrId, queueId, command) = do - let t = B.intercalate " " [corrId, queueId, command] +tPutRaw h (sig, sessId, corrId, queueId, command) = do + let t = B.unwords [sessId, corrId, queueId, command] void $ tPut h (sig, t) tGetRaw :: Transport c => THandle c -> IO SignedRawTransmission tGetRaw h = do - (Nothing, (CorrId corrId, qId, Right cmd)) <- tGet fromServer h - pure (Nothing, corrId, encode qId, serializeCommand cmd) + (Nothing, _, (SessionId sessId, CorrId corrId, qId, Right cmd)) <- tGet fromServer h + pure (Nothing, sessId, corrId, encode qId, serializeCommand cmd) diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index caf33b604..8c0a240ae 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -40,18 +41,19 @@ serverTests t = do describe "Timing of AUTH error" $ testTiming t describe "Message notifications" $ testMessageNotifications t -pattern Resp :: CorrId -> QueueId -> Command 'Broker -> SignedTransmissionOrError -pattern Resp corrId queueId command <- ("", (corrId, queueId, Right (Cmd SBroker command))) +pattern Resp :: CorrId -> QueueId -> Command 'Broker -> SignedTransmission (Command 'Broker) +pattern Resp corrId queueId command <- ("", _, (_, corrId, queueId, Right command)) pattern Ids :: RecipientId -> SenderId -> RcvPublicDhKey -> Command 'Broker pattern Ids rId sId srvDh <- IDS (QIK rId _ srvDh sId _) -sendRecv :: Transport c => THandle c -> (Maybe C.ASignature, ByteString, ByteString, ByteString) -> IO SignedTransmissionOrError -sendRecv h (sgn, corrId, qId, cmd) = tPutRaw h (sgn, corrId, encode qId, cmd) >> tGet fromServer h +sendRecv :: Transport c => THandle c -> (Maybe C.ASignature, ByteString, ByteString, ByteString) -> IO (SignedTransmission (Command 'Broker)) +sendRecv h@THandle {sndSessionId = SessionId sessId} (sgn, corrId, qId, cmd) = + tPutRaw h (sgn, sessId, corrId, encode qId, cmd) >> tGet fromServer h -signSendRecv :: Transport c => THandle c -> C.APrivateSignKey -> (ByteString, ByteString, ByteString) -> IO SignedTransmissionOrError -signSendRecv h pk (corrId, qId, cmd) = do - let t = B.intercalate " " [corrId, encode qId, cmd] +signSendRecv :: Transport c => THandle c -> C.APrivateSignKey -> (ByteString, ByteString, ByteString) -> IO (SignedTransmission (Command 'Broker)) +signSendRecv h@THandle {sndSessionId = SessionId sessId} pk (corrId, qId, cmd) = do + let t = B.intercalate " " [sessId, corrId, encode qId, cmd] Right sig <- runExceptT $ C.sign pk t _ <- tPut h (Just sig, t) tGet fromServer h @@ -482,40 +484,40 @@ sampleSig = "gM8qn2Vx3GkhIp2hgrji9uhfXKpgtKDmc0maxdP8GvbORUxMCTlLG8Q/gNcl3pQVOzm syntaxTests :: ATransport -> Spec syntaxTests (ATransport t) = do - it "unknown command" $ ("", "abcd", "1234", "HELLO") >#> ("", "abcd", "1234", "ERR CMD SYNTAX") + it "unknown command" $ ("", "", "abcd", "1234", "HELLO") >#> ("", "", "abcd", "1234", "ERR CMD SYNTAX") describe "NEW" $ do - it "no parameters" $ (sampleSig, "bcda", "", "NEW") >#> ("", "bcda", "", "ERR CMD SYNTAX") - it "many parameters" $ (sampleSig, "cdab", "", B.unwords ["NEW 1", samplePubKey, sampleDhPubKey]) >#> ("", "cdab", "", "ERR CMD SYNTAX") - it "no signature" $ ("", "dabc", "", B.unwords ["NEW", samplePubKey, sampleDhPubKey]) >#> ("", "dabc", "", "ERR CMD NO_AUTH") - it "queue ID" $ (sampleSig, "abcd", "12345678", B.unwords ["NEW", samplePubKey, sampleDhPubKey]) >#> ("", "abcd", "12345678", "ERR CMD HAS_AUTH") + it "no parameters" $ (sampleSig, "", "bcda", "", "NEW") >#> ("", "", "bcda", "", "ERR CMD SYNTAX") + it "many parameters" $ (sampleSig, "", "cdab", "", B.unwords ["NEW 1", samplePubKey, sampleDhPubKey]) >#> ("", "", "cdab", "", "ERR CMD SYNTAX") + it "no signature" $ ("", "", "dabc", "", B.unwords ["NEW", samplePubKey, sampleDhPubKey]) >#> ("", "", "dabc", "", "ERR CMD NO_AUTH") + it "queue ID" $ (sampleSig, "", "abcd", "12345678", B.unwords ["NEW", samplePubKey, sampleDhPubKey]) >#> ("", "", "abcd", "12345678", "ERR CMD HAS_AUTH") describe "KEY" $ do - it "valid syntax" $ (sampleSig, "bcda", "12345678", "KEY " <> samplePubKey) >#> ("", "bcda", "12345678", "ERR AUTH") - it "no parameters" $ (sampleSig, "cdab", "12345678", "KEY") >#> ("", "cdab", "12345678", "ERR CMD SYNTAX") - it "many parameters" $ (sampleSig, "dabc", "12345678", "KEY 1 " <> samplePubKey) >#> ("", "dabc", "12345678", "ERR CMD SYNTAX") - it "no signature" $ ("", "abcd", "12345678", "KEY " <> samplePubKey) >#> ("", "abcd", "12345678", "ERR CMD NO_AUTH") - it "no queue ID" $ (sampleSig, "bcda", "", "KEY " <> samplePubKey) >#> ("", "bcda", "", "ERR CMD NO_AUTH") + it "valid syntax" $ (sampleSig, "", "bcda", "12345678", "KEY " <> samplePubKey) >#> ("", "", "bcda", "12345678", "ERR AUTH") + it "no parameters" $ (sampleSig, "", "cdab", "12345678", "KEY") >#> ("", "", "cdab", "12345678", "ERR CMD SYNTAX") + it "many parameters" $ (sampleSig, "", "dabc", "12345678", "KEY 1 " <> samplePubKey) >#> ("", "", "dabc", "12345678", "ERR CMD SYNTAX") + it "no signature" $ ("", "", "abcd", "12345678", "KEY " <> samplePubKey) >#> ("", "", "abcd", "12345678", "ERR CMD NO_AUTH") + it "no queue ID" $ (sampleSig, "", "bcda", "", "KEY " <> samplePubKey) >#> ("", "", "bcda", "", "ERR CMD NO_AUTH") noParamsSyntaxTest "SUB" noParamsSyntaxTest "ACK" noParamsSyntaxTest "OFF" noParamsSyntaxTest "DEL" describe "SEND" $ do - it "valid syntax 1" $ (sampleSig, "cdab", "12345678", "SEND 5 hello ") >#> ("", "cdab", "12345678", "ERR AUTH") - it "valid syntax 2" $ (sampleSig, "dabc", "12345678", "SEND 11 hello there ") >#> ("", "dabc", "12345678", "ERR AUTH") - it "no parameters" $ (sampleSig, "abcd", "12345678", "SEND") >#> ("", "abcd", "12345678", "ERR CMD SYNTAX") - it "no queue ID" $ (sampleSig, "bcda", "", "SEND 5 hello ") >#> ("", "bcda", "", "ERR CMD NO_QUEUE") - it "bad message body 1" $ (sampleSig, "cdab", "12345678", "SEND 11 hello ") >#> ("", "cdab", "12345678", "ERR CMD SYNTAX") - it "bad message body 2" $ (sampleSig, "dabc", "12345678", "SEND hello ") >#> ("", "dabc", "12345678", "ERR CMD SYNTAX") - it "bigger body" $ (sampleSig, "abcd", "12345678", "SEND 4 hello ") >#> ("", "abcd", "12345678", "ERR CMD SYNTAX") + it "valid syntax 1" $ (sampleSig, "", "cdab", "12345678", "SEND 5 hello ") >#> ("", "", "cdab", "12345678", "ERR AUTH") + it "valid syntax 2" $ (sampleSig, "", "dabc", "12345678", "SEND 11 hello there ") >#> ("", "", "dabc", "12345678", "ERR AUTH") + it "no parameters" $ (sampleSig, "", "abcd", "12345678", "SEND") >#> ("", "", "abcd", "12345678", "ERR CMD SYNTAX") + it "no queue ID" $ (sampleSig, "", "bcda", "", "SEND 5 hello ") >#> ("", "", "bcda", "", "ERR CMD NO_QUEUE") + it "bad message body 1" $ (sampleSig, "", "cdab", "12345678", "SEND 11 hello ") >#> ("", "", "cdab", "12345678", "ERR CMD SYNTAX") + it "bad message body 2" $ (sampleSig, "", "dabc", "12345678", "SEND hello ") >#> ("", "", "dabc", "12345678", "ERR CMD SYNTAX") + it "bigger body" $ (sampleSig, "", "abcd", "12345678", "SEND 4 hello ") >#> ("", "", "abcd", "12345678", "ERR CMD SYNTAX") describe "PING" $ do - it "valid syntax" $ ("", "abcd", "", "PING") >#> ("", "abcd", "", "PONG") + it "valid syntax" $ ("", "", "abcd", "", "PING") >#> ("", "", "abcd", "", "PONG") describe "broker response not allowed" $ do - it "OK" $ (sampleSig, "bcda", "12345678", "OK") >#> ("", "bcda", "12345678", "ERR CMD PROHIBITED") + it "OK" $ (sampleSig, "", "bcda", "12345678", "OK") >#> ("", "", "bcda", "12345678", "ERR CMD PROHIBITED") where noParamsSyntaxTest :: ByteString -> Spec noParamsSyntaxTest cmd = describe (B.unpack cmd) $ do - it "valid syntax" $ (sampleSig, "abcd", "12345678", cmd) >#> ("", "abcd", "12345678", "ERR AUTH") - it "wrong terminator" $ (sampleSig, "bcda", "12345678", cmd <> "=") >#> ("", "bcda", "12345678", "ERR CMD SYNTAX") - it "no signature" $ ("", "cdab", "12345678", cmd) >#> ("", "cdab", "12345678", "ERR CMD NO_AUTH") - it "no queue ID" $ (sampleSig, "dabc", "", cmd) >#> ("", "dabc", "", "ERR CMD NO_AUTH") + it "valid syntax" $ (sampleSig, "", "abcd", "12345678", cmd) >#> ("", "", "abcd", "12345678", "ERR AUTH") + it "wrong terminator" $ (sampleSig, "", "bcda", "12345678", cmd <> "=") >#> ("", "", "bcda", "12345678", "ERR CMD SYNTAX") + it "no signature" $ ("", "", "cdab", "12345678", cmd) >#> ("", "", "cdab", "12345678", "ERR CMD NO_AUTH") + it "no queue ID" $ (sampleSig, "", "dabc", "", cmd) >#> ("", "", "dabc", "", "ERR CMD NO_AUTH") (>#>) :: SignedRawTransmission -> SignedRawTransmission -> Expectation command >#> response = smpServerTest t command `shouldReturn` response diff --git a/tests/Test.hs b/tests/Test.hs index 0ae4b523d..9f946592f 100644 --- a/tests/Test.hs +++ b/tests/Test.hs @@ -4,7 +4,7 @@ import AgentTests (agentTests) import ProtocolErrorTests import ServerTests import Simplex.Messaging.Transport (TLS, Transport (..)) -import Simplex.Messaging.Transport.WebSockets (WS) +-- import Simplex.Messaging.Transport.WebSockets (WS) import System.Directory (createDirectoryIfMissing, removeDirectoryRecursive) import Test.Hspec