diff --git a/simplexmq.cabal b/simplexmq.cabal index 535e8fd2e..9242f6168 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -160,6 +160,7 @@ library Simplex.Messaging.Transport.WebSockets Simplex.Messaging.Util Simplex.Messaging.Version + Simplex.Messaging.Version.Internal Simplex.RemoteControl.Client Simplex.RemoteControl.Discovery Simplex.RemoteControl.Discovery.Multicast diff --git a/src/Simplex/FileTransfer/Agent.hs b/src/Simplex/FileTransfer/Agent.hs index f45389462..5666b63ff 100644 --- a/src/Simplex/FileTransfer/Agent.hs +++ b/src/Simplex/FileTransfer/Agent.hs @@ -52,8 +52,8 @@ import Simplex.FileTransfer.Client.Main import Simplex.FileTransfer.Crypto import Simplex.FileTransfer.Description import Simplex.FileTransfer.Protocol (FileParty (..), SFileParty (..)) -import qualified Simplex.FileTransfer.Protocol as XFTP import Simplex.FileTransfer.Transport (XFTPRcvChunkSpec (..)) +import qualified Simplex.FileTransfer.Transport as XFTP import Simplex.FileTransfer.Types import Simplex.FileTransfer.Util (removePath, uniqueCombine) import Simplex.Messaging.Agent.Client diff --git a/src/Simplex/FileTransfer/Client.hs b/src/Simplex/FileTransfer/Client.hs index 84c99eb48..d9c4c058a 100644 --- a/src/Simplex/FileTransfer/Client.hs +++ b/src/Simplex/FileTransfer/Client.hs @@ -57,7 +57,7 @@ import UnliftIO.Directory data XFTPClient = XFTPClient { http2Client :: HTTP2Client, transportSession :: TransportSession FileResponse, - thParams :: THandleParams, + thParams :: THandleParams XFTPVersion, config :: XFTPClientConfig } diff --git a/src/Simplex/FileTransfer/Protocol.hs b/src/Simplex/FileTransfer/Protocol.hs index e9988b56a..2ba75f027 100644 --- a/src/Simplex/FileTransfer/Protocol.hs +++ b/src/Simplex/FileTransfer/Protocol.hs @@ -1,9 +1,11 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TemplateHaskell #-} @@ -14,9 +16,7 @@ module Simplex.FileTransfer.Protocol where -import Control.Applicative ((<|>)) import qualified Data.Aeson.TH as J -import qualified Data.Attoparsec.ByteString.Char8 as A import Data.Bifunctor (first) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B @@ -25,11 +25,11 @@ import Data.List.NonEmpty (NonEmpty (..)) import Data.Maybe (isNothing) import Data.Type.Equality import Data.Word (Word32) +import Simplex.FileTransfer.Transport (VersionXFTP, XFTPErrorType (..), XFTPVersion, pattern VersionXFTP, xftpClientHandshake) import Simplex.Messaging.Client (authTransmission) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Notifications.Transport (ntfClientHandshake) import Simplex.Messaging.Parsers import Simplex.Messaging.Protocol ( BasicAuth, @@ -56,11 +56,10 @@ import Simplex.Messaging.Protocol tParse, ) import Simplex.Messaging.Transport (THandleParams (..), TransportError (..)) -import Simplex.Messaging.Util (bshow, (<$?>)) -import Simplex.Messaging.Version +import Simplex.Messaging.Util ((<$?>)) -currentXFTPVersion :: Version -currentXFTPVersion = 1 +currentXFTPVersion :: VersionXFTP +currentXFTPVersion = VersionXFTP 1 xftpBlockSize :: Int xftpBlockSize = 16384 @@ -142,10 +141,10 @@ instance ProtocolMsgTag FileCmdTag where instance FilePartyI p => ProtocolMsgTag (FileCommandTag p) where decodeTag s = decodeTag s >>= (\(FCT _ t) -> checkParty' t) -instance Protocol XFTPErrorType FileResponse where +instance Protocol XFTPVersion XFTPErrorType FileResponse where type ProtoCommand FileResponse = FileCmd type ProtoType FileResponse = 'PXFTP - protocolClientHandshake = ntfClientHandshake + protocolClientHandshake = xftpClientHandshake protocolPing = FileCmd SFRecipient PING protocolError = \case FRErr e -> Just e @@ -175,7 +174,7 @@ data FileInfo = FileInfo type XFTPFileId = ByteString -instance FilePartyI p => ProtocolEncoding XFTPErrorType (FileCommand p) where +instance FilePartyI p => ProtocolEncoding XFTPVersion XFTPErrorType (FileCommand p) where type Tag (FileCommand p) = FileCommandTag p encodeProtocol _v = \case FNEW file rKeys auth_ -> e (FNEW_, ' ', file, rKeys, auth_) @@ -191,7 +190,7 @@ instance FilePartyI p => ProtocolEncoding XFTPErrorType (FileCommand p) where protocolP v tag = (\(FileCmd _ c) -> checkParty c) <$?> protocolP v (FCT (sFileParty @p) tag) - fromProtocolError = fromProtocolError @XFTPErrorType @FileResponse + fromProtocolError = fromProtocolError @XFTPVersion @XFTPErrorType @FileResponse {-# INLINE fromProtocolError #-} checkCredentials (auth, _, fileId, _) cmd = case cmd of @@ -208,7 +207,7 @@ instance FilePartyI p => ProtocolEncoding XFTPErrorType (FileCommand p) where | isNothing auth || B.null fileId -> Left $ CMD NO_AUTH | otherwise -> Right cmd -instance ProtocolEncoding XFTPErrorType FileCmd where +instance ProtocolEncoding XFTPVersion XFTPErrorType FileCmd where type Tag FileCmd = FileCmdTag encodeProtocol _v (FileCmd _ c) = encodeProtocol _v c @@ -225,7 +224,7 @@ instance ProtocolEncoding XFTPErrorType FileCmd where FACK_ -> pure FACK PING_ -> pure PING - fromProtocolError = fromProtocolError @XFTPErrorType @FileResponse + fromProtocolError = fromProtocolError @XFTPVersion @XFTPErrorType @FileResponse {-# INLINE fromProtocolError #-} checkCredentials t (FileCmd p c) = FileCmd p <$> checkCredentials t c @@ -276,7 +275,7 @@ data FileResponse | FRPong deriving (Show) -instance ProtocolEncoding XFTPErrorType FileResponse where +instance ProtocolEncoding XFTPVersion XFTPErrorType FileResponse where type Tag FileResponse = FileResponseTag encodeProtocol _v = \case FRSndIds fId rIds -> e (FRSndIds_, ' ', fId, rIds) @@ -319,82 +318,6 @@ instance ProtocolEncoding XFTPErrorType FileResponse where | B.null entId = Right cmd | otherwise = Left $ CMD HAS_AUTH -data XFTPErrorType - = -- | incorrect block format, encoding or signature size - BLOCK - | -- | incorrect SMP session ID (TLS Finished message / tls-unique binding RFC5929) - SESSION - | -- | SMP command is unknown or has invalid syntax - CMD {cmdErr :: CommandError} - | -- | command authorization error - bad signature or non-existing SMP queue - AUTH - | -- | incorrent file size - SIZE - | -- | storage quota exceeded - QUOTA - | -- | incorrent file digest - DIGEST - | -- | file encryption/decryption failed - CRYPTO - | -- | no expected file body in request/response or no file on the server - NO_FILE - | -- | unexpected file body - HAS_FILE - | -- | file IO error - FILE_IO - | -- | bad redirect data - REDIRECT {redirectError :: String} - | -- | internal server error - INTERNAL - | -- | used internally, never returned by the server (to be removed) - DUPLICATE_ -- not part of SMP protocol, used internally - deriving (Eq, Read, Show) - -instance StrEncoding XFTPErrorType where - strEncode = \case - CMD e -> "CMD " <> bshow e - REDIRECT e -> "REDIRECT " <> bshow e - e -> bshow e - strP = - "CMD " *> (CMD <$> parseRead1) - <|> "REDIRECT " *> (REDIRECT <$> parseRead A.takeByteString) - <|> parseRead1 - -instance Encoding XFTPErrorType where - smpEncode = \case - BLOCK -> "BLOCK" - SESSION -> "SESSION" - CMD err -> "CMD " <> smpEncode err - AUTH -> "AUTH" - SIZE -> "SIZE" - QUOTA -> "QUOTA" - DIGEST -> "DIGEST" - CRYPTO -> "CRYPTO" - NO_FILE -> "NO_FILE" - HAS_FILE -> "HAS_FILE" - FILE_IO -> "FILE_IO" - REDIRECT err -> "REDIRECT " <> smpEncode err - INTERNAL -> "INTERNAL" - DUPLICATE_ -> "DUPLICATE_" - - smpP = - A.takeTill (== ' ') >>= \case - "BLOCK" -> pure BLOCK - "SESSION" -> pure SESSION - "CMD" -> CMD <$> _smpP - "AUTH" -> pure AUTH - "SIZE" -> pure SIZE - "QUOTA" -> pure QUOTA - "DIGEST" -> pure DIGEST - "CRYPTO" -> pure CRYPTO - "NO_FILE" -> pure NO_FILE - "HAS_FILE" -> pure HAS_FILE - "FILE_IO" -> pure FILE_IO - "REDIRECT" -> REDIRECT <$> _smpP - "INTERNAL" -> pure INTERNAL - "DUPLICATE_" -> pure DUPLICATE_ - _ -> fail "bad error type" - checkParty :: forall t p p'. (FilePartyI p, FilePartyI p') => t p' -> Either String (t p) checkParty c = case testEquality (sFileParty @p) (sFileParty @p') of Just Refl -> Right c @@ -405,12 +328,12 @@ checkParty' c = case testEquality (sFileParty @p) (sFileParty @p') of Just Refl -> Just c _ -> Nothing -xftpEncodeAuthTransmission :: ProtocolEncoding e c => THandleParams -> C.APrivateAuthKey -> Transmission c -> Either TransportError ByteString +xftpEncodeAuthTransmission :: ProtocolEncoding XFTPVersion e c => THandleParams XFTPVersion -> C.APrivateAuthKey -> Transmission c -> Either TransportError ByteString xftpEncodeAuthTransmission thParams pKey (corrId, fId, msg) = do let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (corrId, fId, msg) xftpEncodeBatch1 . (,tToSend) =<< authTransmission Nothing (Just pKey) corrId tForAuth -xftpEncodeTransmission :: ProtocolEncoding e c => THandleParams -> Transmission c -> Either TransportError ByteString +xftpEncodeTransmission :: ProtocolEncoding XFTPVersion e c => THandleParams XFTPVersion -> Transmission c -> Either TransportError ByteString xftpEncodeTransmission thParams (corrId, fId, msg) = do let t = encodeTransmission thParams (corrId, fId, msg) xftpEncodeBatch1 (Nothing, t) @@ -419,7 +342,7 @@ xftpEncodeTransmission thParams (corrId, fId, msg) = do xftpEncodeBatch1 :: SentRawTransmission -> Either TransportError ByteString xftpEncodeBatch1 t = first (const TELargeMsg) $ C.pad (tEncodeBatch1 t) xftpBlockSize -xftpDecodeTransmission :: ProtocolEncoding e c => THandleParams -> ByteString -> Either XFTPErrorType (SignedTransmission e c) +xftpDecodeTransmission :: ProtocolEncoding XFTPVersion e c => THandleParams XFTPVersion -> ByteString -> Either XFTPErrorType (SignedTransmission e c) xftpDecodeTransmission thParams t = do t' <- first (const BLOCK) $ C.unPad t case tParse thParams t' of @@ -427,5 +350,3 @@ xftpDecodeTransmission thParams t = do _ -> Left BLOCK $(J.deriveJSON (enumJSON $ dropPrefix "F") ''FileParty) - -$(J.deriveJSON (sumTypeJSON id) ''XFTPErrorType) diff --git a/src/Simplex/FileTransfer/Server.hs b/src/Simplex/FileTransfer/Server.hs index 158429d79..ae202c2b0 100644 --- a/src/Simplex/FileTransfer/Server.hs +++ b/src/Simplex/FileTransfer/Server.hs @@ -69,7 +69,7 @@ type M a = ReaderT XFTPEnv IO a data XFTPTransportRequest = XFTPTransportRequest - { thParams :: THandleParams, + { thParams :: THandleParams XFTPVersion, reqBody :: HTTP2Body, request :: H.Request, sendResponse :: H.Response -> IO () diff --git a/src/Simplex/FileTransfer/Server/Store.hs b/src/Simplex/FileTransfer/Server/Store.hs index 8c198690e..a3681944e 100644 --- a/src/Simplex/FileTransfer/Server/Store.hs +++ b/src/Simplex/FileTransfer/Server/Store.hs @@ -28,7 +28,8 @@ import Data.Int (Int64) import Data.Set (Set) import qualified Data.Set as S import Data.Time.Clock.System (SystemTime (..)) -import Simplex.FileTransfer.Protocol (FileInfo (..), SFileParty (..), XFTPErrorType (..), XFTPFileId) +import Simplex.FileTransfer.Protocol (FileInfo (..), SFileParty (..), XFTPFileId) +import Simplex.FileTransfer.Transport (XFTPErrorType (..)) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol (RcvPublicAuthKey, RecipientId, SenderId) diff --git a/src/Simplex/FileTransfer/Transport.hs b/src/Simplex/FileTransfer/Transport.hs index 90b1a8a44..464a75ac8 100644 --- a/src/Simplex/FileTransfer/Transport.hs +++ b/src/Simplex/FileTransfer/Transport.hs @@ -1,10 +1,19 @@ {-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} module Simplex.FileTransfer.Transport ( supportedFileServerVRange, + xftpClientHandshake, -- stub + XFTPVersion, + VersionXFTP, + pattern VersionXFTP, + XFTPErrorType (..), XFTPRcvChunkSpec (..), ReceiveFileError (..), receiveFile, @@ -14,22 +23,31 @@ module Simplex.FileTransfer.Transport ) where +import Control.Applicative ((<|>)) import qualified Control.Exception as E import Control.Monad import Control.Monad.Except import Control.Monad.IO.Class +import qualified Data.Aeson.TH as J +import qualified Data.Attoparsec.ByteString.Char8 as A import Data.Bifunctor (first) import qualified Data.ByteArray as BA import Data.ByteString.Builder (Builder, byteString) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import qualified Data.ByteString.Lazy.Char8 as LB -import Data.Word (Word32) -import Simplex.FileTransfer.Protocol (XFTPErrorType (..)) +import Data.Word (Word16, Word32) import qualified Simplex.Messaging.Crypto as C import qualified Simplex.Messaging.Crypto.Lazy as LC +import Simplex.Messaging.Encoding +import Simplex.Messaging.Encoding.String +import Simplex.Messaging.Parsers +import Simplex.Messaging.Protocol (CommandError) +import Simplex.Messaging.Transport (HandshakeError (..), THandle, TransportError (..)) import Simplex.Messaging.Transport.HTTP2.File +import Simplex.Messaging.Util (bshow) import Simplex.Messaging.Version +import Simplex.Messaging.Version.Internal import System.IO (Handle, IOMode (..), withFile) data XFTPRcvChunkSpec = XFTPRcvChunkSpec @@ -39,8 +57,26 @@ data XFTPRcvChunkSpec = XFTPRcvChunkSpec } deriving (Show) -supportedFileServerVRange :: VersionRange -supportedFileServerVRange = mkVersionRange 1 1 +data XFTPVersion + +instance VersionScope XFTPVersion + +type VersionXFTP = Version XFTPVersion + +type VersionRangeXFTP = VersionRange XFTPVersion + +pattern VersionXFTP :: Word16 -> VersionXFTP +pattern VersionXFTP v = Version v + +initialXFTPVersion :: VersionXFTP +initialXFTPVersion = VersionXFTP 1 + +supportedFileServerVRange :: VersionRangeXFTP +supportedFileServerVRange = mkVersionRange initialXFTPVersion initialXFTPVersion + +-- XFTP protocol does not support handshake +xftpClientHandshake :: c -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeXFTP -> ExceptT TransportError IO (THandle XFTPVersion c) +xftpClientHandshake _c _ks _keyHash _xftpVRange = throwError $ TEHandshake VERSION sendEncFile :: Handle -> (Builder -> IO ()) -> LC.SbState -> Word32 -> IO () sendEncFile h send = go @@ -97,3 +133,81 @@ receiveFile_ receive XFTPRcvChunkSpec {filePath, chunkSize, chunkDigest} = do ExceptT $ withFile filePath WriteMode (`receive` chunkSize) digest' <- liftIO $ LC.sha256Hash <$> LB.readFile filePath when (digest' /= chunkDigest) $ throwError DIGEST + +data XFTPErrorType + = -- | incorrect block format, encoding or signature size + BLOCK + | -- | incorrect SMP session ID (TLS Finished message / tls-unique binding RFC5929) + SESSION + | -- | SMP command is unknown or has invalid syntax + CMD {cmdErr :: CommandError} + | -- | command authorization error - bad signature or non-existing SMP queue + AUTH + | -- | incorrent file size + SIZE + | -- | storage quota exceeded + QUOTA + | -- | incorrent file digest + DIGEST + | -- | file encryption/decryption failed + CRYPTO + | -- | no expected file body in request/response or no file on the server + NO_FILE + | -- | unexpected file body + HAS_FILE + | -- | file IO error + FILE_IO + | -- | bad redirect data + REDIRECT {redirectError :: String} + | -- | internal server error + INTERNAL + | -- | used internally, never returned by the server (to be removed) + DUPLICATE_ -- not part of SMP protocol, used internally + deriving (Eq, Read, Show) + +instance StrEncoding XFTPErrorType where + strEncode = \case + CMD e -> "CMD " <> bshow e + REDIRECT e -> "REDIRECT " <> bshow e + e -> bshow e + strP = + "CMD " *> (CMD <$> parseRead1) + <|> "REDIRECT " *> (REDIRECT <$> parseRead A.takeByteString) + <|> parseRead1 + +instance Encoding XFTPErrorType where + smpEncode = \case + BLOCK -> "BLOCK" + SESSION -> "SESSION" + CMD err -> "CMD " <> smpEncode err + AUTH -> "AUTH" + SIZE -> "SIZE" + QUOTA -> "QUOTA" + DIGEST -> "DIGEST" + CRYPTO -> "CRYPTO" + NO_FILE -> "NO_FILE" + HAS_FILE -> "HAS_FILE" + FILE_IO -> "FILE_IO" + REDIRECT err -> "REDIRECT " <> smpEncode err + INTERNAL -> "INTERNAL" + DUPLICATE_ -> "DUPLICATE_" + + smpP = + A.takeTill (== ' ') >>= \case + "BLOCK" -> pure BLOCK + "SESSION" -> pure SESSION + "CMD" -> CMD <$> _smpP + "AUTH" -> pure AUTH + "SIZE" -> pure SIZE + "QUOTA" -> pure QUOTA + "DIGEST" -> pure DIGEST + "CRYPTO" -> pure CRYPTO + "NO_FILE" -> pure NO_FILE + "HAS_FILE" -> pure HAS_FILE + "FILE_IO" -> pure FILE_IO + "REDIRECT" -> REDIRECT <$> _smpP + "INTERNAL" -> pure INTERNAL + "DUPLICATE_" -> pure DUPLICATE_ + _ -> fail "bad error type" + +$(J.deriveJSON (sumTypeJSON id) ''XFTPErrorType) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 3f0ca12b4..e8af60b96 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -165,11 +165,11 @@ import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfRegCode (NtfReg import Simplex.Messaging.Notifications.Server.Push.APNS (PNMessageData (..)) import Simplex.Messaging.Notifications.Types import Simplex.Messaging.Parsers (parse) -import Simplex.Messaging.Protocol (BrokerMsg, EntityId, ErrorType (AUTH), MsgBody, MsgFlags (..), NtfServer, ProtoServerWithAuth, ProtocolTypeI (..), SMPMsgMeta, SProtocolType (..), SndPublicAuthKey, SubscriptionMode (..), UserProtocol, XFTPServerWithAuth) +import Simplex.Messaging.Protocol (BrokerMsg, EntityId, ErrorType (AUTH), MsgBody, MsgFlags (..), NtfServer, ProtoServerWithAuth, ProtocolTypeI (..), SMPMsgMeta, SProtocolType (..), SndPublicAuthKey, SubscriptionMode (..), UserProtocol, VersionSMPC, XFTPServerWithAuth) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.ServiceScheme (ServiceScheme (..)) import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Transport (THandleParams (sessionId)) +import Simplex.Messaging.Transport (SMPVersion, THandleParams (sessionId)) import Simplex.Messaging.Util import Simplex.Messaging.Version import Simplex.RemoteControl.Client @@ -684,7 +684,7 @@ joinConn c userId connId enableNtfs cReq cInfo pqEnc subMode = do _ -> getSMPServer c userId joinConnSrv c userId connId enableNtfs cReq cInfo pqEnc subMode srv -startJoinInvitation :: AgentMonad m => UserId -> ConnId -> Bool -> ConnectionRequestUri 'CMInvitation -> CR.PQEncryption -> m (Compatible Version, ConnData, NewSndQueue, CR.Ratchet 'C.X448, CR.SndE2ERatchetParams 'C.X448) +startJoinInvitation :: AgentMonad m => UserId -> ConnId -> Bool -> ConnectionRequestUri 'CMInvitation -> CR.PQEncryption -> m (Compatible VersionSMPA, ConnData, NewSndQueue, CR.Ratchet 'C.X448, CR.SndE2ERatchetParams 'C.X448) startJoinInvitation userId connId enableNtfs (CRInvitationUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)} e2eRcvParamsUri) pqEncryption = do AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config case ( qUri `compatibleVersion` smpClientVRange, @@ -1114,7 +1114,7 @@ enqueueMessageB c pqEnc_ reqs = do let sqs' = filter isActiveSndQ sqs pure $ Right ((msgId, pqSecr), if null sqs' then Nothing else Just (cData, sqs', msgId)) where - storeSentMsg :: DB.Connection -> Version -> (ConnData, NonEmpty SndQueue, MsgFlags, AMessage) -> IO (Either AgentErrorType ((ConnData, NonEmpty SndQueue, MsgFlags, AMessage), InternalId, CR.PQEncryption)) + storeSentMsg :: DB.Connection -> VersionSMPA -> (ConnData, NonEmpty SndQueue, MsgFlags, AMessage) -> IO (Either AgentErrorType ((ConnData, NonEmpty SndQueue, MsgFlags, AMessage), InternalId, CR.PQEncryption)) storeSentMsg db agentVersion req@(ConnData {connId}, sq :| _, msgFlags, aMessage) = fmap (first storeError) $ runExceptT $ do internalTs <- liftIO getCurrentTime (internalId, internalSndId, prevMsgHash) <- liftIO $ updateSndIds db connId @@ -1922,7 +1922,7 @@ cleanupManager c@AgentClient {subQ} = do -- | make sure to ACK or throw in each message processing branch -- it cannot be finally, unfortunately, as sometimes it needs to be ACK+DEL -processSMPTransmission :: forall m. AgentMonad m => AgentClient -> ServerTransmission BrokerMsg -> m () +processSMPTransmission :: forall m. AgentMonad m => AgentClient -> ServerTransmission SMPVersion BrokerMsg -> m () processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, sessId, rId, cmd) = do (rq, SomeConn _ conn) <- withStore c (\db -> getRcvConn db srv rId) processSMP rq conn $ toConnData conn @@ -2063,7 +2063,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, _ -> pure Nothing _ -> prohibited >> ack _ -> prohibited >> ack - updateConnVersion :: Connection c -> ConnData -> Version -> m (Connection c) + updateConnVersion :: Connection c -> ConnData -> VersionSMPA -> m (Connection c) updateConnVersion conn' cData' msgAgentVersion = do aVRange <- asks $ smpAgentVRange . config let msgAVRange = fromMaybe (versionToRange msgAgentVersion) $ safeVersionRange (minVersion aVRange) msgAgentVersion @@ -2126,7 +2126,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, parseMessage :: Encoding a => ByteString -> m a parseMessage = liftEither . parse smpP (AGENT A_MESSAGE) - smpConfirmation :: SMP.MsgId -> Connection c -> C.APublicAuthKey -> C.PublicKeyX25519 -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> ByteString -> Version -> Version -> m () + smpConfirmation :: SMP.MsgId -> Connection c -> C.APublicAuthKey -> C.PublicKeyX25519 -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> ByteString -> VersionSMPC -> VersionSMPA -> m () smpConfirmation srvMsgId conn' senderKey e2ePubKey e2eEncryption encConnInfo smpClientVersion agentVersion = do logServer "<--" c srv rId $ "MSG :" <> logSecret srvMsgId AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config @@ -2380,7 +2380,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, createRatchet db connId rc -- compare public keys `k1` in AgentRatchetKey messages sent by self and other party -- to determine ratchet initilization ordering - initRatchet :: VersionRange -> (C.PrivateKeyX448, C.PrivateKeyX448, Maybe CR.RcvPrivRKEMParams) -> m () + initRatchet :: CR.VersionRangeE2E -> (C.PrivateKeyX448, C.PrivateKeyX448, Maybe CR.RcvPrivRKEMParams) -> m () initRatchet e2eEncryptVRange (pk1, pk2, pKem) | rkHash (C.publicKey pk1) (C.publicKey pk2) <= rkHashRcv = do rcParams <- liftError cryptoError $ CR.pqX3dhRcv pk1 pk2 pKem e2eOtherPartyParams @@ -2431,7 +2431,7 @@ confirmQueueAsync c cData sq srv connInfo e2eEncryption_ pqEnc subMode = do storeConfirmation c cData sq e2eEncryption_ (Just pqEnc) =<< mkAgentConfirmation c cData sq srv connInfo subMode submitPendingMsg c cData sq -confirmQueue :: forall m. AgentMonad m => Compatible Version -> AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> Maybe CR.PQEncryption -> SubscriptionMode -> m () +confirmQueue :: forall m. AgentMonad m => Compatible VersionSMPA -> AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> Maybe CR.PQEncryption -> SubscriptionMode -> m () confirmQueue (Compatible agentVersion) c cData@ConnData {connId} sq srv connInfo e2eEncryption_ pqEnc_ subMode = do msg <- mkConfirmation =<< mkAgentConfirmation c cData sq srv connInfo subMode sendConfirmation c sq msg @@ -2478,7 +2478,7 @@ enqueueRatchetKey c cData@ConnData {connId} sq e2eEncryption = do submitPendingMsg c cData sq pure $ unId msgId where - storeRatchetKey :: Version -> m InternalId + storeRatchetKey :: VersionSMPA -> m InternalId storeRatchetKey agentVersion = withStore c $ \db -> runExceptT $ do internalTs <- liftIO getCurrentTime (internalId, internalSndId, prevMsgHash) <- liftIO $ updateSndIds db connId diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 23caa2254..f60ddea26 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -167,8 +167,8 @@ import Network.Socket (HostName) import Simplex.FileTransfer.Client (XFTPChunkSpec (..), XFTPClient, XFTPClientConfig (..), XFTPClientError) import qualified Simplex.FileTransfer.Client as X import Simplex.FileTransfer.Description (ChunkReplicaId (..), FileDigest (..), kb) -import Simplex.FileTransfer.Protocol (FileInfo (..), FileResponse, XFTPErrorType (DIGEST)) -import Simplex.FileTransfer.Transport (XFTPRcvChunkSpec (..)) +import Simplex.FileTransfer.Protocol (FileInfo (..), FileResponse) +import Simplex.FileTransfer.Transport (XFTPRcvChunkSpec (..), XFTPErrorType (DIGEST), XFTPVersion) import Simplex.FileTransfer.Types (DeletedSndChunkReplica (..), NewSndChunkReplica (..), RcvFileChunkReplica (..), SndFileChunk (..), SndFileChunkReplica (..)) import Simplex.FileTransfer.Util (uniqueCombine) import Simplex.Messaging.Agent.Env.SQLite @@ -187,6 +187,7 @@ import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Notifications.Client import Simplex.Messaging.Notifications.Protocol +import Simplex.Messaging.Notifications.Transport (NTFVersion) import Simplex.Messaging.Notifications.Types import Simplex.Messaging.Parsers (defaultJSON, dropPrefix, enumJSON, parse) import Simplex.Messaging.Protocol @@ -215,9 +216,12 @@ import Simplex.Messaging.Protocol UserProtocol, XFTPServer, XFTPServerWithAuth, + VersionSMPC, + VersionRangeSMPC, sameSrvAddr', ) import qualified Simplex.Messaging.Protocol as SMP +import Simplex.Messaging.Transport (SMPVersion) import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport.Client (TransportHost) @@ -253,7 +257,7 @@ data AgentClient = AgentClient { active :: TVar Bool, rcvQ :: TBQueue (ATransmission 'Client), subQ :: TBQueue (ATransmission 'Agent), - msgQ :: TBQueue (ServerTransmission BrokerMsg), + msgQ :: TBQueue (ServerTransmission SMPVersion BrokerMsg), smpServers :: TMap UserId (NonEmpty SMPServerWithAuth), smpClients :: TMap SMPTransportSession SMPClientVar, ntfServers :: TVar [NtfServer], @@ -467,7 +471,7 @@ agentClientStore AgentClient {agentEnv = Env {store}} = store agentDRG :: AgentClient -> TVar ChaChaDRG agentDRG AgentClient {agentEnv = Env {random}} = random -class (Encoding err, Show err) => ProtocolServerClient err msg | msg -> err where +class (Encoding err, Show err) => ProtocolServerClient v err msg | msg -> v, msg -> err where type Client msg = c | c -> msg getProtocolServerClient :: AgentMonad m => AgentClient -> TransportSession msg -> m (Client msg) clientProtocolError :: err -> AgentErrorType @@ -476,8 +480,8 @@ class (Encoding err, Show err) => ProtocolServerClient err msg | msg -> err wher clientTransportHost :: Client msg -> TransportHost clientSessionTs :: Client msg -> UTCTime -instance ProtocolServerClient ErrorType BrokerMsg where - type Client BrokerMsg = ProtocolClient ErrorType BrokerMsg +instance ProtocolServerClient SMPVersion ErrorType BrokerMsg where + type Client BrokerMsg = ProtocolClient SMPVersion ErrorType BrokerMsg getProtocolServerClient = getSMPServerClient clientProtocolError = SMP closeProtocolServerClient = closeProtocolClient @@ -485,8 +489,8 @@ instance ProtocolServerClient ErrorType BrokerMsg where clientTransportHost = transportHost' clientSessionTs = sessionTs -instance ProtocolServerClient ErrorType NtfResponse where - type Client NtfResponse = ProtocolClient ErrorType NtfResponse +instance ProtocolServerClient NTFVersion ErrorType NtfResponse where + type Client NtfResponse = ProtocolClient NTFVersion ErrorType NtfResponse getProtocolServerClient = getNtfServerClient clientProtocolError = NTF closeProtocolServerClient = closeProtocolClient @@ -494,7 +498,7 @@ instance ProtocolServerClient ErrorType NtfResponse where clientTransportHost = transportHost' clientSessionTs = sessionTs -instance ProtocolServerClient XFTPErrorType FileResponse where +instance ProtocolServerClient XFTPVersion XFTPErrorType FileResponse where type Client FileResponse = XFTPClient getProtocolServerClient = getXFTPServerClient clientProtocolError = XFTP @@ -683,8 +687,8 @@ waitForProtocolClient c (_, srv, _) v = do -- clientConnected arg is only passed for SMP server newProtocolClient :: - forall err msg m. - (AgentMonad m, ProtocolTypeI (ProtoType msg), ProtocolServerClient err msg) => + forall v err msg m. + (AgentMonad m, ProtocolTypeI (ProtoType msg), ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> TMap (TransportSession msg) (ClientVar msg) -> @@ -706,10 +710,10 @@ newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient v = putTMVar (sessionVar v) (Left e) throwError e -- signal error to caller -hostEvent :: forall err msg. (ProtocolTypeI (ProtoType msg), ProtocolServerClient err msg) => (AProtocolType -> TransportHost -> ACommand 'Agent 'AENone) -> Client msg -> ACommand 'Agent 'AENone +hostEvent :: forall v err msg. (ProtocolTypeI (ProtoType msg), ProtocolServerClient v err msg) => (AProtocolType -> TransportHost -> ACommand 'Agent 'AENone) -> Client msg -> ACommand 'Agent 'AENone hostEvent event = event (AProtocolType $ protocolTypeI @(ProtoType msg)) . clientTransportHost -getClientConfig :: AgentMonad' m => AgentClient -> (AgentConfig -> ProtocolClientConfig) -> m ProtocolClientConfig +getClientConfig :: AgentMonad' m => AgentClient -> (AgentConfig -> ProtocolClientConfig v) -> m (ProtocolClientConfig v) getClientConfig AgentClient {useNetworkConfig} cfgSel = do cfg <- asks $ cfgSel . config networkConfig <- readTVarIO useNetworkConfig @@ -754,19 +758,19 @@ throwWhenNoDelivery c sq = unlessM (TM.member (qAddress sq) $ smpDeliveryWorkers c) $ throwSTM ThreadKilled -closeProtocolServerClients :: ProtocolServerClient err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO () +closeProtocolServerClients :: ProtocolServerClient v err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO () closeProtocolServerClients c clientsSel = atomically (clientsSel c `swapTVar` M.empty) >>= mapM_ (forkIO . closeClient_ c) -reconnectServerClients :: ProtocolServerClient err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO () +reconnectServerClients :: ProtocolServerClient v err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO () reconnectServerClients c clientsSel = readTVarIO (clientsSel c) >>= mapM_ (forkIO . closeClient_ c) -closeClient :: ProtocolServerClient err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> TransportSession msg -> IO () +closeClient :: ProtocolServerClient v err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> TransportSession msg -> IO () closeClient c clientSel tSess = atomically (TM.lookupDelete tSess $ clientSel c) >>= mapM_ (closeClient_ c) -closeClient_ :: ProtocolServerClient err msg => AgentClient -> ClientVar msg -> IO () +closeClient_ :: ProtocolServerClient v err msg => AgentClient -> ClientVar msg -> IO () closeClient_ c v = do NetworkConfig {tcpConnectTimeout} <- readTVarIO $ useNetworkConfig c tcpConnectTimeout `timeout` atomically (readTMVar $ sessionVar v) >>= \case @@ -798,7 +802,7 @@ getMapLock locks key = TM.lookup key locks >>= maybe newLock pure where newLock = createLock >>= \l -> TM.insert key l locks $> l -withClient_ :: forall a m err msg. (AgentMonad m, ProtocolServerClient err msg) => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> m a) -> m a +withClient_ :: forall a m v err msg. (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> m a) -> m a withClient_ c tSess@(userId, srv, _) statCmd action = do cl <- getProtocolServerClient c tSess (action cl <* stat cl "OK") `catchAgentError` logServerError cl @@ -810,18 +814,18 @@ withClient_ c tSess@(userId, srv, _) statCmd action = do stat cl $ strEncode e throwError e -withLogClient_ :: (AgentMonad m, ProtocolServerClient err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> m a) -> m a +withLogClient_ :: (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> m a) -> m a withLogClient_ c tSess@(_, srv, _) entId cmdStr action = do logServer "-->" c srv entId cmdStr res <- withClient_ c tSess cmdStr action logServer "<--" c srv entId "OK" return res -withClient :: forall m err msg a. (AgentMonad m, ProtocolServerClient err msg) => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> m a -withClient c tSess statKey action = withClient_ c tSess statKey $ \client -> liftClient (clientProtocolError @err @msg) (clientServer client) $ action client +withClient :: forall m v err msg a. (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> m a +withClient c tSess statKey action = withClient_ c tSess statKey $ \client -> liftClient (clientProtocolError @v @err @msg) (clientServer client) $ action client -withLogClient :: forall m err msg a. (AgentMonad m, ProtocolServerClient err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> m a -withLogClient c tSess entId cmdStr action = withLogClient_ c tSess entId cmdStr $ \client -> liftClient (clientProtocolError @err @msg) (clientServer client) $ action client +withLogClient :: forall m v err msg a. (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> m a +withLogClient c tSess entId cmdStr action = withLogClient_ c tSess entId cmdStr $ \client -> liftClient (clientProtocolError @v @err @msg) (clientServer client) $ action client withSMPClient :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> m a withSMPClient c q cmdStr action = do @@ -837,7 +841,7 @@ withNtfClient :: forall m a. AgentMonad m => AgentClient -> NtfServer -> EntityI withNtfClient c srv = withLogClient c (0, srv, Nothing) withXFTPClient :: - (AgentMonad m, ProtocolServerClient err msg) => + (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> (UserId, ProtoServer msg, EntityId) -> ByteString -> @@ -1001,7 +1005,7 @@ mkSMPTSession q = mkTSession (qUserId q) (qServer q) (qConnId q) getSessionMode :: AgentMonad' m => AgentClient -> m TransportSessionMode getSessionMode = fmap sessionMode . readTVarIO . useNetworkConfig -newRcvQueue :: AgentMonad m => AgentClient -> UserId -> ConnId -> SMPServerWithAuth -> VersionRange -> SubscriptionMode -> m (NewRcvQueue, SMPQueueUri) +newRcvQueue :: AgentMonad m => AgentClient -> UserId -> ConnId -> SMPServerWithAuth -> VersionRangeSMPC -> SubscriptionMode -> m (NewRcvQueue, SMPQueueUri) newRcvQueue c userId connId (ProtoServerWithAuth srv auth) vRange subMode = do C.AuthAlg a <- asks (rcvAuthAlg . config) g <- asks random @@ -1151,7 +1155,7 @@ sendConfirmation c sq@SndQueue {sndId, sndPublicKey = Just sndPublicKey, e2ePubK liftClient SMP (clientServer smp) $ sendSMPMessage smp Nothing sndId (SMP.MsgFlags {notification = True}) msg sendConfirmation _ _ _ = throwError $ INTERNAL "sendConfirmation called without snd_queue public key(s) in the database" -sendInvitation :: forall m. AgentMonad m => AgentClient -> UserId -> Compatible SMPQueueInfo -> Compatible Version -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m () +sendInvitation :: forall m. AgentMonad m => AgentClient -> UserId -> Compatible SMPQueueInfo -> Compatible VersionSMPA -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m () sendInvitation c userId (Compatible (SMPQueueInfo v SMPQueueAddress {smpServer, senderId, dhPublicKey})) (Compatible agentVersion) connReq connInfo = do tSess <- mkTransportSession c userId smpServer senderId withLogClient_ c tSess senderId "SEND " $ \smp -> do @@ -1334,7 +1338,7 @@ agentCbEncrypt SndQueue {e2eDhSecret, smpClientVersion} e2ePubKey msg = do pure $ smpEncode SMP.ClientMsgEnvelope {cmHeader, cmNonce, cmEncBody} -- add encoding as AgentInvitation'? -agentCbEncryptOnce :: AgentMonad m => Version -> C.PublicKeyX25519 -> ByteString -> m ByteString +agentCbEncryptOnce :: AgentMonad m => VersionSMPC -> C.PublicKeyX25519 -> ByteString -> m ByteString agentCbEncryptOnce clientVersion dhRcvPubKey msg = do g <- asks random (dhSndPubKey, dhSndPrivKey) <- atomically $ C.generateKeyPair g @@ -1518,7 +1522,7 @@ incStat AgentClient {agentStats} n k = do Just v -> modifyTVar' v (+ n) _ -> newTVar n >>= \v -> TM.insert k v agentStats -incClientStat :: ProtocolServerClient err msg => AgentClient -> UserId -> Client msg -> ByteString -> ByteString -> IO () +incClientStat :: ProtocolServerClient v err msg => AgentClient -> UserId -> Client msg -> ByteString -> ByteString -> IO () incClientStat c userId pc = incClientStatN c userId pc 1 incServerStat :: AgentClient -> UserId -> ProtocolServer p -> ByteString -> ByteString -> IO () @@ -1528,7 +1532,7 @@ incServerStat c userId ProtocolServer {host} cmd res = do where statsKey = AgentStatsKey {userId, host = strEncode $ L.head host, clientTs = "", cmd, res} -incClientStatN :: ProtocolServerClient err msg => AgentClient -> UserId -> Client msg -> Int -> ByteString -> ByteString -> IO () +incClientStatN :: ProtocolServerClient v err msg => AgentClient -> UserId -> Client msg -> Int -> ByteString -> ByteString -> IO () incClientStatN c userId pc n cmd res = do atomically $ incStat c n statsKey where diff --git a/src/Simplex/Messaging/Agent/Env/SQLite.hs b/src/Simplex/Messaging/Agent/Env/SQLite.hs index e603e50b8..7a879bb22 100644 --- a/src/Simplex/Messaging/Agent/Env/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Env/SQLite.hs @@ -56,16 +56,17 @@ import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations import Simplex.Messaging.Client import Simplex.Messaging.Client.Agent () import qualified Simplex.Messaging.Crypto as C -import Simplex.Messaging.Crypto.Ratchet (supportedE2EEncryptVRange) +import Simplex.Messaging.Crypto.Ratchet (VersionRangeE2E, supportedE2EEncryptVRange) import Simplex.Messaging.Notifications.Client (defaultNTFClientConfig) +import Simplex.Messaging.Notifications.Transport (NTFVersion) import Simplex.Messaging.Notifications.Types -import Simplex.Messaging.Protocol (NtfServer, XFTPServer, XFTPServerWithAuth, supportedSMPClientVRange) +import Simplex.Messaging.Protocol (NtfServer, VersionRangeSMPC, XFTPServer, XFTPServerWithAuth, supportedSMPClientVRange) +import Simplex.Messaging.Transport (SMPVersion) import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (TLS, Transport (..)) import Simplex.Messaging.Transport.Client (defaultSMPPort) import Simplex.Messaging.Util (allFinally, catchAllErrors, tryAllErrors) -import Simplex.Messaging.Version import System.Random (StdGen, newStdGen) import UnliftIO (Async, SomeException) import UnliftIO.STM @@ -87,8 +88,8 @@ data AgentConfig = AgentConfig sndAuthAlg :: C.AuthAlg, connIdBytes :: Int, tbqSize :: Natural, - smpCfg :: ProtocolClientConfig, - ntfCfg :: ProtocolClientConfig, + smpCfg :: ProtocolClientConfig SMPVersion, + ntfCfg :: ProtocolClientConfig NTFVersion, xftpCfg :: XFTPClientConfig, reconnectInterval :: RetryInterval, messageRetryInterval :: RetryInterval2, @@ -116,9 +117,9 @@ data AgentConfig = AgentConfig caCertificateFile :: FilePath, privateKeyFile :: FilePath, certificateFile :: FilePath, - e2eEncryptVRange :: VersionRange, - smpAgentVRange :: VersionRange, - smpClientVRange :: VersionRange + e2eEncryptVRange :: VersionRangeE2E, + smpAgentVRange :: VersionRangeSMPA, + smpClientVRange :: VersionRangeSMPC } defaultReconnectInterval :: RetryInterval diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index c17ecf4cf..b4383cf9b 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -4,6 +4,7 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} @@ -33,6 +34,9 @@ -- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/agent-protocol.md module Simplex.Messaging.Agent.Protocol ( -- * Protocol parameters + VersionSMPA, + VersionRangeSMPA, + pattern VersionSMPA, ratchetSyncSMPAgentVersion, deliveryRcptsSMPAgentVersion, supportedSMPAgentVRange, @@ -175,11 +179,12 @@ import Data.Time.Clock.System (SystemTime) import Data.Time.ISO8601 import Data.Type.Equality import Data.Typeable () -import Data.Word (Word32) +import Data.Word (Word16, Word32) import Database.SQLite.Simple.FromField import Database.SQLite.Simple.ToField import Simplex.FileTransfer.Description -import Simplex.FileTransfer.Protocol (FileParty (..), XFTPErrorType) +import Simplex.FileTransfer.Protocol (FileParty (..)) +import Simplex.FileTransfer.Transport (XFTPErrorType) import Simplex.Messaging.Agent.QueryString import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), PQEncryption (..), pattern PQEncOff, RcvE2ERatchetParams, RcvE2ERatchetParamsUri, SndE2ERatchetParams) @@ -200,6 +205,10 @@ import Simplex.Messaging.Protocol SMPServerWithAuth, SndPublicAuthKey, SubscriptionMode, + SMPClientVersion, + VersionSMPC, + VersionRangeSMPC, + initialSMPClientVersion, legacyEncodeServer, legacyServerP, legacyStrEncodeServer, @@ -215,6 +224,7 @@ import Simplex.Messaging.Transport (Transport (..), TransportError, serializeTra import Simplex.Messaging.Transport.Client (TransportHost, TransportHosts_ (..)) import Simplex.Messaging.Util import Simplex.Messaging.Version +import Simplex.Messaging.Version.Internal import Simplex.RemoteControl.Types import Text.Read import UnliftIO.Exception (Exception) @@ -225,19 +235,30 @@ import UnliftIO.Exception (Exception) -- 3 - support ratchet renegotiation (6/30/2023) -- 4 - delivery receipts (7/13/2023) -duplexHandshakeSMPAgentVersion :: Version -duplexHandshakeSMPAgentVersion = 2 +data SMPAgentVersion -ratchetSyncSMPAgentVersion :: Version -ratchetSyncSMPAgentVersion = 3 +instance VersionScope SMPAgentVersion -deliveryRcptsSMPAgentVersion :: Version -deliveryRcptsSMPAgentVersion = 4 +type VersionSMPA = Version SMPAgentVersion -currentSMPAgentVersion :: Version -currentSMPAgentVersion = 4 +type VersionRangeSMPA = VersionRange SMPAgentVersion -supportedSMPAgentVRange :: VersionRange +pattern VersionSMPA :: Word16 -> VersionSMPA +pattern VersionSMPA v = Version v + +duplexHandshakeSMPAgentVersion :: VersionSMPA +duplexHandshakeSMPAgentVersion = VersionSMPA 2 + +ratchetSyncSMPAgentVersion :: VersionSMPA +ratchetSyncSMPAgentVersion = VersionSMPA 3 + +deliveryRcptsSMPAgentVersion :: VersionSMPA +deliveryRcptsSMPAgentVersion = VersionSMPA 4 + +currentSMPAgentVersion :: VersionSMPA +currentSMPAgentVersion = VersionSMPA 4 + +supportedSMPAgentVRange :: VersionRangeSMPA supportedSMPAgentVRange = mkVersionRange duplexHandshakeSMPAgentVersion currentSMPAgentVersion -- it is shorter to allow all handshake headers, @@ -651,7 +672,7 @@ instance StrEncoding SndQueueInfo where pure SndQueueInfo {sndServer, sndSwitchStatus} data ConnectionStats = ConnectionStats - { connAgentVersion :: Version, + { connAgentVersion :: VersionSMPA, rcvQueuesInfo :: [RcvQueueInfo], sndQueuesInfo :: [SndQueueInfo], ratchetSyncState :: RatchetSyncState, @@ -786,27 +807,27 @@ data SMPConfirmation = SMPConfirmation -- | optional reply queues included in confirmation (added in agent protocol v2) smpReplyQueues :: [SMPQueueInfo], -- | SMP client version - smpClientVersion :: Version + smpClientVersion :: VersionSMPC } deriving (Show) data AgentMsgEnvelope = AgentConfirmation - { agentVersion :: Version, + { agentVersion :: VersionSMPA, e2eEncryption_ :: Maybe (SndE2ERatchetParams 'C.X448), encConnInfo :: ByteString } | AgentMsgEnvelope - { agentVersion :: Version, + { agentVersion :: VersionSMPA, encAgentMessage :: ByteString } | AgentInvitation -- the connInfo in contactInvite is only encrypted with per-queue E2E, not with double ratchet, - { agentVersion :: Version, + { agentVersion :: VersionSMPA, connReq :: ConnectionRequestUri 'CMInvitation, connInfo :: ByteString -- this message is only encrypted with per-queue E2E, not with double ratchet, } | AgentRatchetKey - { agentVersion :: Version, + { agentVersion :: VersionSMPA, e2eEncryption :: RcvE2ERatchetParams 'C.X448, info :: ByteString } @@ -1212,16 +1233,16 @@ sameQueue :: SMPQueue q => (SMPServer, SMP.QueueId) -> q -> Bool sameQueue addr q = sameQAddress addr (qAddress q) {-# INLINE sameQueue #-} -data SMPQueueInfo = SMPQueueInfo {clientVersion :: Version, queueAddress :: SMPQueueAddress} +data SMPQueueInfo = SMPQueueInfo {clientVersion :: VersionSMPC, queueAddress :: SMPQueueAddress} deriving (Eq, Show) instance Encoding SMPQueueInfo where smpEncode (SMPQueueInfo clientVersion SMPQueueAddress {smpServer, senderId, dhPublicKey}) - | clientVersion > 1 = smpEncode (clientVersion, smpServer, senderId, dhPublicKey) + | clientVersion > initialSMPClientVersion = smpEncode (clientVersion, smpServer, senderId, dhPublicKey) | otherwise = smpEncode clientVersion <> legacyEncodeServer smpServer <> smpEncode (senderId, dhPublicKey) smpP = do clientVersion <- smpP - smpServer <- if clientVersion > 1 then smpP else updateSMPServerHosts <$> legacyServerP + smpServer <- if clientVersion > initialSMPClientVersion then smpP else updateSMPServerHosts <$> legacyServerP (senderId, dhPublicKey) <- smpP pure $ SMPQueueInfo clientVersion SMPQueueAddress {smpServer, senderId, dhPublicKey} @@ -1229,20 +1250,20 @@ instance Encoding SMPQueueInfo where -- But this is created to allow backward and forward compatibility where SMPQueueUri -- could have more fields to convert to different versions of SMPQueueInfo in a different way, -- and this instance would become non-trivial. -instance VersionI SMPQueueInfo where - type VersionRangeT SMPQueueInfo = SMPQueueUri +instance VersionI SMPClientVersion SMPQueueInfo where + type VersionRangeT SMPClientVersion SMPQueueInfo = SMPQueueUri version = clientVersion toVersionRangeT (SMPQueueInfo _v addr) vr = SMPQueueUri vr addr -instance VersionRangeI SMPQueueUri where - type VersionT SMPQueueUri = SMPQueueInfo +instance VersionRangeI SMPClientVersion SMPQueueUri where + type VersionT SMPClientVersion SMPQueueUri = SMPQueueInfo versionRange = clientVRange toVersionT (SMPQueueUri _vr addr) v = SMPQueueInfo v addr -- | SMP queue information sent out-of-band. -- -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#out-of-band-messages -data SMPQueueUri = SMPQueueUri {clientVRange :: VersionRange, queueAddress :: SMPQueueAddress} +data SMPQueueUri = SMPQueueUri {clientVRange :: VersionRangeSMPC, queueAddress :: SMPQueueAddress} deriving (Eq, Show) data SMPQueueAddress = SMPQueueAddress @@ -1291,7 +1312,7 @@ instance StrEncoding SMPQueueUri where smpServer = if maxVersion vr < srvHostnamesSMPClientVersion then updateSMPServerHosts srv' else srv' pure $ SMPQueueUri vr SMPQueueAddress {smpServer, senderId, dhPublicKey} where - unversioned = (versionToRange 1,[],) <$> strP <* A.endOfInput + unversioned = (versionToRange initialSMPClientVersion,[],) <$> strP <* A.endOfInput versioned = do dhKey_ <- optional strP query <- optional (A.char '/') *> A.char '?' *> strP @@ -1321,7 +1342,7 @@ deriving instance Show AConnectionRequestUri data ConnReqUriData = ConnReqUriData { crScheme :: ServiceScheme, - crAgentVRange :: VersionRange, + crAgentVRange :: VersionRangeSMPA, crSmpQueues :: NonEmpty SMPQueueUri, crClientData :: Maybe CRClientData } diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index 07112a836..971b38905 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -44,10 +44,10 @@ import Simplex.Messaging.Protocol RcvPrivateAuthKey, SndPrivateAuthKey, SndPublicAuthKey, + VersionSMPC, ) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Util ((<$?>)) -import Simplex.Messaging.Version -- * Queue types @@ -94,7 +94,7 @@ data StoredRcvQueue (q :: QueueStored) = RcvQueue dbReplaceQueueId :: Maybe Int64, rcvSwchStatus :: Maybe RcvSwitchStatus, -- | SMP client version - smpClientVersion :: Version, + smpClientVersion :: VersionSMPC, -- | credentials used in context of notifications clientNtfCreds :: Maybe ClientNtfCreds, deleteErrors :: Int @@ -157,7 +157,7 @@ data StoredSndQueue (q :: QueueStored) = SndQueue dbReplaceQueueId :: Maybe Int64, sndSwchStatus :: Maybe SndSwitchStatus, -- | SMP client version - smpClientVersion :: Version + smpClientVersion :: VersionSMPC } deriving (Show) @@ -304,7 +304,7 @@ deriving instance Show SomeConn data ConnData = ConnData { connId :: ConnId, userId :: UserId, - connAgentVersion :: Version, + connAgentVersion :: VersionSMPA, enableNtfs :: Bool, lastExternalSndId :: PrevExternalSndId, deleted :: Bool, diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 63ac4a280..a2d01c201 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -279,7 +279,7 @@ import Simplex.Messaging.Protocol import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Transport.Client (TransportHost) import Simplex.Messaging.Util (bshow, catchAllErrors, eitherToMaybe, ifM, safeDecodeUtf8, ($>>=), (<$$>)) -import Simplex.Messaging.Version +import Simplex.Messaging.Version.Internal import System.Directory (copyFile, createDirectoryIfMissing, doesFileExist) import System.Exit (exitFailure) import System.FilePath (takeDirectory) @@ -702,7 +702,7 @@ setRcvQueueDeleted db RcvQueue {rcvId, server = ProtocolServer {host, port}} = d |] (host, port, rcvId) -setRcvQueueConfirmedE2E :: DB.Connection -> RcvQueue -> C.DhSecretX25519 -> Version -> IO () +setRcvQueueConfirmedE2E :: DB.Connection -> RcvQueue -> C.DhSecretX25519 -> VersionSMPC -> IO () setRcvQueueConfirmedE2E db RcvQueue {rcvId, server = ProtocolServer {host, port}} e2eDhSecret smpClientVersion = DB.executeNamed db @@ -803,7 +803,7 @@ setRcvQueueNtfCreds db connId clientNtfCreds = Just ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret} -> (Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret) Nothing -> (Nothing, Nothing, Nothing, Nothing) -type SMPConfirmationRow = (SndPublicAuthKey, C.PublicKeyX25519, ConnInfo, Maybe [SMPQueueInfo], Maybe Version) +type SMPConfirmationRow = (SndPublicAuthKey, C.PublicKeyX25519, ConnInfo, Maybe [SMPQueueInfo], Maybe VersionSMPC) smpConfirmation :: SMPConfirmationRow -> SMPConfirmation smpConfirmation (senderKey, e2ePubKey, connInfo, smpReplyQueues_, smpClientVersion_) = @@ -812,7 +812,7 @@ smpConfirmation (senderKey, e2ePubKey, connInfo, smpReplyQueues_, smpClientVersi e2ePubKey, connInfo, smpReplyQueues = fromMaybe [] smpReplyQueues_, - smpClientVersion = fromMaybe 1 smpClientVersion_ + smpClientVersion = fromMaybe initialSMPClientVersion smpClientVersion_ } createConfirmation :: DB.Connection -> TVar ChaChaDRG -> NewConfirmation -> IO (Either StoreError ConfirmationId) @@ -889,7 +889,7 @@ removeConfirmations db connId = |] [":conn_id" := connId] -setConnectionVersion :: DB.Connection -> ConnId -> Version -> IO () +setConnectionVersion :: DB.Connection -> ConnId -> VersionSMPA -> IO () setConnectionVersion db connId aVersion = DB.execute db "UPDATE connections SET smp_agent_version = ? WHERE conn_id = ?" (aVersion, connId) @@ -1776,6 +1776,10 @@ instance ToField MsgReceiptStatus where toField = toField . decodeLatin1 . strEn instance FromField MsgReceiptStatus where fromField = fromTextField_ $ eitherToMaybe . strDecode . encodeUtf8 +instance ToField (Version v) where toField (Version v) = toField v + +instance FromField (Version v) where fromField f = Version <$> fromField f + instance ToField CR.PQEncryption where toField (CR.PQEncryption pqEnc) = toField pqEnc instance FromField CR.PQEncryption where fromField f = CR.PQEncryption <$> fromField f @@ -1948,7 +1952,7 @@ setConnDeleted db waitDelivery connId | otherwise = DB.execute db "UPDATE connections SET deleted = ? WHERE conn_id = ?" (True, connId) -setConnAgentVersion :: DB.Connection -> ConnId -> Version -> IO () +setConnAgentVersion :: DB.Connection -> ConnId -> VersionSMPA -> IO () setConnAgentVersion db connId aVersion = DB.execute db "UPDATE connections SET smp_agent_version = ? WHERE conn_id = ?" (aVersion, connId) @@ -2007,12 +2011,12 @@ rcvQueueQuery = toRcvQueue :: (UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SMP.RecipientId, SMP.RcvPrivateAuthKey, SMP.RcvDhSecret, C.PrivateKeyX25519, Maybe C.DhSecretX25519, SMP.SenderId, QueueStatus) - :. (DBQueueId 'QSStored, Bool, Maybe Int64, Maybe RcvSwitchStatus, Maybe Version, Int) + :. (DBQueueId 'QSStored, Bool, Maybe Int64, Maybe RcvSwitchStatus, Maybe VersionSMPC, Int) :. (Maybe SMP.NtfPublicAuthKey, Maybe SMP.NtfPrivateAuthKey, Maybe SMP.NotifierId, Maybe RcvNtfDhSecret) -> RcvQueue toRcvQueue ((userId, keyHash, connId, host, port, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status) :. (dbQueueId, primary, dbReplaceQueueId, rcvSwchStatus, smpClientVersion_, deleteErrors) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_)) = let server = SMPServer host port keyHash - smpClientVersion = fromMaybe 1 smpClientVersion_ + smpClientVersion = fromMaybe initialSMPClientVersion smpClientVersion_ clientNtfCreds = case (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) of (Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret) -> Just $ ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret} _ -> Nothing @@ -2048,7 +2052,7 @@ sndQueueQuery = toSndQueue :: (UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SenderId) :. (Maybe SndPublicAuthKey, SndPrivateAuthKey, Maybe C.PublicKeyX25519, C.DhSecretX25519, QueueStatus) - :. (DBQueueId 'QSStored, Bool, Maybe Int64, Maybe SndSwitchStatus, Version) -> + :. (DBQueueId 'QSStored, Bool, Maybe Int64, Maybe SndSwitchStatus, VersionSMPC) -> SndQueue toSndQueue ( (userId, keyHash, connId, host, port, sndId) diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index 2cbdced35..d8b202761 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -76,7 +76,7 @@ module Simplex.Messaging.Client PCTransmission, mkTransmission, authTransmission, - clientStub, + smpClientStub, ) where @@ -117,14 +117,14 @@ import System.Timeout (timeout) -- | 'SMPClient' is a handle used to send commands to a specific SMP server. -- -- Use 'getSMPClient' to connect to an SMP server and create a client handle. -data ProtocolClient err msg = ProtocolClient +data ProtocolClient v err msg = ProtocolClient { action :: Maybe (Async ()), - thParams :: THandleParams, + thParams :: THandleParams v, sessionTs :: UTCTime, - client_ :: PClient err msg + client_ :: PClient v err msg } -data PClient err msg = PClient +data PClient v err msg = PClient { connected :: TVar Bool, transportSession :: TransportSession msg, transportHost :: TransportHost, @@ -135,11 +135,11 @@ data PClient err msg = PClient sentCommands :: TMap CorrId (Request err msg), sndQ :: TBQueue ByteString, rcvQ :: TBQueue (NonEmpty (SignedTransmission err msg)), - msgQ :: Maybe (TBQueue (ServerTransmission msg)) + msgQ :: Maybe (TBQueue (ServerTransmission v msg)) } -clientStub :: TVar ChaChaDRG -> ByteString -> Version -> Maybe THandleAuth -> STM (ProtocolClient err msg) -clientStub g sessionId thVersion thAuth = do +smpClientStub :: TVar ChaChaDRG -> ByteString -> VersionSMP -> Maybe THandleAuth -> STM (ProtocolClient SMPVersion err msg) +smpClientStub g sessionId thVersion thAuth = do connected <- newTVar False clientCorrId <- C.newRandomDRG g sentCommands <- TM.empty @@ -174,13 +174,13 @@ clientStub g sessionId thVersion thAuth = do } } -type SMPClient = ProtocolClient ErrorType BrokerMsg +type SMPClient = ProtocolClient SMPVersion ErrorType BrokerMsg -- | Type for client command data type ClientCommand msg = (Maybe C.APrivateAuthKey, EntityId, ProtoCommand msg) -- | Type synonym for transmission from some SPM server queue. -type ServerTransmission msg = (TransportSession msg, Version, SessionId, EntityId, msg) +type ServerTransmission v msg = (TransportSession msg, Version v, SessionId, EntityId, msg) data HostMode = -- | prefer (or require) onion hosts when connecting via SOCKS proxy @@ -241,7 +241,7 @@ transportClientConfig NetworkConfig {socksProxy, tcpKeepAlive, logTLSErrors} = TransportClientConfig {socksProxy, tcpKeepAlive, logTLSErrors, clientCredentials = Nothing} -- | protocol client configuration. -data ProtocolClientConfig = ProtocolClientConfig +data ProtocolClientConfig v = ProtocolClientConfig { -- | size of TBQueue to use for server commands and responses qSize :: Natural, -- | default server port if port is not specified in ProtocolServer @@ -249,13 +249,13 @@ data ProtocolClientConfig = ProtocolClientConfig -- | network configuration networkConfig :: NetworkConfig, -- | client-server protocol version range - serverVRange :: VersionRange, + serverVRange :: VersionRange v, -- | delay between sending batches of commands (microseconds) batchDelay :: Maybe Int } -- | Default protocol client configuration. -defaultClientConfig :: VersionRange -> ProtocolClientConfig +defaultClientConfig :: VersionRange v -> ProtocolClientConfig v defaultClientConfig serverVRange = ProtocolClientConfig { qSize = 64, @@ -265,7 +265,7 @@ defaultClientConfig serverVRange = batchDelay = Nothing } -defaultSMPClientConfig :: ProtocolClientConfig +defaultSMPClientConfig :: ProtocolClientConfig SMPVersion defaultSMPClientConfig = defaultClientConfig supportedClientSMPRelayVRange data Request err msg = Request @@ -292,15 +292,15 @@ chooseTransportHost NetworkConfig {socksProxy, hostMode, requiredHostMode} hosts onionHost = find isOnionHost hosts publicHost = find (not . isOnionHost) hosts -protocolClientServer :: ProtocolTypeI (ProtoType msg) => ProtocolClient err msg -> String +protocolClientServer :: ProtocolTypeI (ProtoType msg) => ProtocolClient v err msg -> String protocolClientServer = B.unpack . strEncode . snd3 . transportSession . client_ where snd3 (_, s, _) = s -transportHost' :: ProtocolClient err msg -> TransportHost +transportHost' :: ProtocolClient v err msg -> TransportHost transportHost' = transportHost . client_ -transportSession' :: ProtocolClient err msg -> TransportSession msg +transportSession' :: ProtocolClient v err msg -> TransportSession msg transportSession' = transportSession . client_ type UserId = Int64 @@ -313,7 +313,7 @@ type TransportSession msg = (UserId, ProtoServer msg, Maybe EntityId) -- -- A single queue can be used for multiple 'SMPClient' instances, -- as 'SMPServerTransmission' includes server information. -getProtocolClient :: forall err msg. Protocol err msg => TVar ChaChaDRG -> TransportSession msg -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> (ProtocolClient err msg -> IO ()) -> IO (Either (ProtocolClientError err) (ProtocolClient err msg)) +getProtocolClient :: forall v err msg. Protocol v err msg => TVar ChaChaDRG -> TransportSession msg -> ProtocolClientConfig v -> Maybe (TBQueue (ServerTransmission v msg)) -> (ProtocolClient v err msg -> IO ()) -> IO (Either (ProtocolClientError err) (ProtocolClient v err msg)) getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize, networkConfig, serverVRange, batchDelay} msgQ disconnected = do case chooseTransportHost networkConfig (host srv) of Right useHost -> @@ -322,7 +322,7 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize Left e -> pure $ Left e where NetworkConfig {tcpConnectTimeout, tcpTimeout, smpPingInterval} = networkConfig - mkProtocolClient :: TransportHost -> STM (PClient err msg) + mkProtocolClient :: TransportHost -> STM (PClient v err msg) mkProtocolClient transportHost = do connected <- newTVar False pingErrorCount <- newTVar 0 @@ -345,7 +345,7 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize msgQ } - runClient :: (ServiceName, ATransport) -> TransportHost -> PClient err msg -> IO (Either (ProtocolClientError err) (ProtocolClient err msg)) + runClient :: (ServiceName, ATransport) -> TransportHost -> PClient v err msg -> IO (Either (ProtocolClientError err) (ProtocolClient v err msg)) runClient (port', ATransport t) useHost c = do cVar <- newEmptyTMVarIO let tcConfig = transportClientConfig networkConfig @@ -366,10 +366,10 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize "80" -> ("80", transport @WS) p -> (p, transport @TLS) - client :: forall c. Transport c => TProxy c -> PClient err msg -> TMVar (Either (ProtocolClientError err) (ProtocolClient err msg)) -> c -> IO () + client :: forall c. Transport c => TProxy c -> PClient v err msg -> TMVar (Either (ProtocolClientError err) (ProtocolClient v err msg)) -> c -> IO () client _ c cVar h = do ks <- atomically $ C.generateKeyPair g - runExceptT (protocolClientHandshake @err @msg h ks (keyHash srv) serverVRange) >>= \case + runExceptT (protocolClientHandshake @v @err @msg h ks (keyHash srv) serverVRange) >>= \case Left e -> atomically . putTMVar cVar . Left $ PCETransportError e Right th@THandle {params} -> do sessionTs <- getCurrentTime @@ -380,16 +380,16 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize raceAny_ ([send c' th, process c', receive c' th] <> [ping c' | smpPingInterval > 0]) `finally` disconnected c' - send :: Transport c => ProtocolClient err msg -> THandle c -> IO () + send :: Transport c => ProtocolClient v err msg -> THandle v c -> IO () send ProtocolClient {client_ = PClient {sndQ}} h = forever $ atomically (readTBQueue sndQ) >>= tPutLog h - receive :: Transport c => ProtocolClient err msg -> THandle c -> IO () + receive :: Transport c => ProtocolClient v err msg -> THandle v c -> IO () receive ProtocolClient {client_ = PClient {rcvQ}} h = forever $ tGet h >>= atomically . writeTBQueue rcvQ - ping :: ProtocolClient err msg -> IO () + ping :: ProtocolClient v err msg -> IO () ping c@ProtocolClient {client_ = PClient {pingErrorCount}} = do threadDelay' smpPingInterval - runExceptT (sendProtocolCommand c Nothing "" $ protocolPing @err @msg) >>= \case + runExceptT (sendProtocolCommand c Nothing "" $ protocolPing @v @err @msg) >>= \case Left PCEResponseTimeout -> do cnt <- atomically $ stateTVar pingErrorCount $ \cnt -> (cnt + 1, cnt + 1) when (maxCnt == 0 || cnt < maxCnt) $ ping c @@ -397,10 +397,10 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize where maxCnt = smpPingCount networkConfig - process :: ProtocolClient err msg -> IO () + process :: ProtocolClient v err msg -> IO () process c = forever $ atomically (readTBQueue $ rcvQ $ client_ c) >>= mapM_ (processMsg c) - processMsg :: ProtocolClient err msg -> SignedTransmission err msg -> IO () + processMsg :: ProtocolClient v err msg -> SignedTransmission err msg -> IO () processMsg c@ProtocolClient {client_ = PClient {sentCommands}} (_, _, (corrId, entId, respOrErr)) = if B.null $ bs corrId then sendMsg respOrErr @@ -428,7 +428,7 @@ proxyUsername :: TransportSession msg -> ByteString proxyUsername (userId, _, entityId_) = C.sha256Hash $ bshow userId <> maybe "" (":" <>) entityId_ -- | Disconnects client from the server and terminates client threads. -closeProtocolClient :: ProtocolClient err msg -> IO () +closeProtocolClient :: ProtocolClient v err msg -> IO () closeProtocolClient = mapM_ uninterruptibleCancel . action -- | SMP client error type. @@ -517,7 +517,7 @@ processSUBResponse c (Response rId r) = case r of writeSMPMessage :: SMPClient -> RecipientId -> BrokerMsg -> IO () writeSMPMessage c rId msg = atomically $ mapM_ (`writeTBQueue` serverTransmission c rId msg) (msgQ $ client_ c) -serverTransmission :: ProtocolClient err msg -> RecipientId -> msg -> ServerTransmission msg +serverTransmission :: ProtocolClient v err msg -> RecipientId -> msg -> ServerTransmission v msg serverTransmission ProtocolClient {thParams = THandleParams {thVersion, sessionId}, client_ = PClient {transportSession}} entityId message = (transportSession, thVersion, sessionId, entityId, message) @@ -635,7 +635,7 @@ sendSMPCommand c pKey qId cmd = sendProtocolCommand c pKey qId (Cmd sParty cmd) type PCTransmission err msg = (Either TransportError SentRawTransmission, Request err msg) -- | Send multiple commands with batching and collect responses -sendProtocolCommands :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> NonEmpty (ClientCommand msg) -> IO (NonEmpty (Response err msg)) +sendProtocolCommands :: forall v err msg. ProtocolEncoding v err (ProtoCommand msg) => ProtocolClient v err msg -> NonEmpty (ClientCommand msg) -> IO (NonEmpty (Response err msg)) sendProtocolCommands c@ProtocolClient {thParams = THandleParams {batch, blockSize}} cs = do bs <- batchTransmissions' batch blockSize <$> mapM (mkTransmission c) cs validate . concat =<< mapM (sendBatch c) bs @@ -652,12 +652,12 @@ sendProtocolCommands c@ProtocolClient {thParams = THandleParams {batch, blockSiz where diff = L.length cs - length rs -streamProtocolCommands :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> NonEmpty (ClientCommand msg) -> ([Response err msg] -> IO ()) -> IO () +streamProtocolCommands :: forall v err msg. ProtocolEncoding v err (ProtoCommand msg) => ProtocolClient v err msg -> NonEmpty (ClientCommand msg) -> ([Response err msg] -> IO ()) -> IO () streamProtocolCommands c@ProtocolClient {thParams = THandleParams {batch, blockSize}} cs cb = do bs <- batchTransmissions' batch blockSize <$> mapM (mkTransmission c) cs mapM_ (cb <=< sendBatch c) bs -sendBatch :: ProtocolClient err msg -> TransportBatch (Request err msg) -> IO [Response err msg] +sendBatch :: ProtocolClient v err msg -> TransportBatch (Request err msg) -> IO [Response err msg] sendBatch c@ProtocolClient {client_ = PClient {sndQ}} b = do case b of TBError e Request {entityId} -> do @@ -673,7 +673,7 @@ sendBatch c@ProtocolClient {client_ = PClient {sndQ}} b = do (: []) <$> getResponse c r -- | Send Protocol command -sendProtocolCommand :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> Maybe C.APrivateAuthKey -> EntityId -> ProtoCommand msg -> ExceptT (ProtocolClientError err) IO msg +sendProtocolCommand :: forall v err msg. ProtocolEncoding v err (ProtoCommand msg) => ProtocolClient v err msg -> Maybe C.APrivateAuthKey -> EntityId -> ProtoCommand msg -> ExceptT (ProtocolClientError err) IO msg sendProtocolCommand c@ProtocolClient {client_ = PClient {sndQ}, thParams = THandleParams {batch, blockSize}} pKey entId cmd = ExceptT $ uncurry sendRecv =<< mkTransmission c (pKey, entId, cmd) where @@ -690,7 +690,7 @@ sendProtocolCommand c@ProtocolClient {client_ = PClient {sndQ}, thParams = THand | otherwise = tEncode t -- TODO switch to timeout or TimeManager that supports Int64 -getResponse :: ProtocolClient err msg -> Request err msg -> IO (Response err msg) +getResponse :: ProtocolClient v err msg -> Request err msg -> IO (Response err msg) getResponse ProtocolClient {client_ = PClient {tcpTimeout, pingErrorCount}} Request {entityId, responseVar} = do response <- timeout tcpTimeout (atomically (takeTMVar responseVar)) >>= \case @@ -698,7 +698,7 @@ getResponse ProtocolClient {client_ = PClient {tcpTimeout, pingErrorCount}} Requ Nothing -> pure $ Left PCEResponseTimeout pure Response {entityId, response} -mkTransmission :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> ClientCommand msg -> IO (PCTransmission err msg) +mkTransmission :: forall v err msg. ProtocolEncoding v err (ProtoCommand msg) => ProtocolClient v err msg -> ClientCommand msg -> IO (PCTransmission err msg) mkTransmission ProtocolClient {thParams, client_ = PClient {clientCorrId, sentCommands}} (pKey_, entId, cmd) = do corrId <- atomically getNextCorrId let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (corrId, entId, cmd) diff --git a/src/Simplex/Messaging/Client/Agent.hs b/src/Simplex/Messaging/Client/Agent.hs index 068a52782..73f47648b 100644 --- a/src/Simplex/Messaging/Client/Agent.hs +++ b/src/Simplex/Messaging/Client/Agent.hs @@ -65,7 +65,7 @@ type SMPSub = (SMPSubParty, QueueId) -- type SMPServerSub = (SMPServer, SMPSub) data SMPClientAgentConfig = SMPClientAgentConfig - { smpCfg :: ProtocolClientConfig, + { smpCfg :: ProtocolClientConfig SMPVersion, reconnectInterval :: RetryInterval, msgQSize :: Natural, agentQSize :: Natural, @@ -91,7 +91,7 @@ defaultSMPClientAgentConfig = data SMPClientAgent = SMPClientAgent { agentCfg :: SMPClientAgentConfig, - msgQ :: TBQueue (ServerTransmission BrokerMsg), + msgQ :: TBQueue (ServerTransmission SMPVersion BrokerMsg), agentQ :: TBQueue SMPClientAgentEvent, randomDrg :: TVar ChaChaDRG, smpClients :: TMap SMPServer SMPClientVar, diff --git a/src/Simplex/Messaging/Crypto/Ratchet.hs b/src/Simplex/Messaging/Crypto/Ratchet.hs index 345119fca..a6faf49c7 100644 --- a/src/Simplex/Messaging/Crypto/Ratchet.hs +++ b/src/Simplex/Messaging/Crypto/Ratchet.hs @@ -3,6 +3,7 @@ {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} @@ -14,10 +15,64 @@ {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} --- {-# LANGUAGE TypeOperators #-} {-# OPTIONS_GHC -fno-warn-redundant-constraints #-} -module Simplex.Messaging.Crypto.Ratchet where +module Simplex.Messaging.Crypto.Ratchet + ( Ratchet (..), + RatchetX448, + SkippedMsgDiff (..), + SkippedMsgKeys, + InitialKeys (..), + PQEncryption (..), + pattern PQEncOn, + pattern PQEncOff, + AUseKEM (..), + RatchetKEMState (..), + SRatchetKEMState (..), + RcvPrivRKEMParams, + APrivRKEMParams (..), + RcvE2ERatchetParamsUri, + RcvE2ERatchetParams, + SndE2ERatchetParams, + AE2ERatchetParams (..), + E2ERatchetParamsUri (..), + E2ERatchetParams (..), + VersionE2E, + VersionRangeE2E, + pattern VersionE2E, + kdfX3DHE2EEncryptVersion, + pqRatchetVersion, + currentE2EEncryptVersion, + supportedE2EEncryptVRange, + generateRcvE2EParams, + generateSndE2EParams, + initialPQEncryption, + connPQEncryption, + joinContactInitialKeys, + replyKEM_, + pqX3dhSnd, + pqX3dhRcv, + initSndRatchet, + initRcvRatchet, + rcEncrypt, + rcDecrypt, + -- used in tests + MsgHeader (..), + RatchetVersions (..), + RatchetInitParams (..), + UseKEM (..), + RKEMParams (..), + ARKEMParams (..), + SndRatchet (..), + RcvRatchet (..), + RatchetKEM (..), + RatchetKEMAccepted (..), + RatchetKey (..), + ratchetVersions, + fullHeaderLen, + applySMDiff, + ) +where import Control.Applicative ((<|>)) import Control.Monad.Except @@ -44,7 +99,7 @@ import qualified Data.Map.Strict as M import Data.Maybe (fromMaybe, isJust) import Data.Type.Equality import Data.Typeable (Typeable) -import Data.Word (Word32) +import Data.Word (Word16, Word32) import Database.SQLite.Simple.FromField (FromField (..)) import Database.SQLite.Simple.ToField (ToField (..)) import Simplex.Messaging.Agent.QueryString @@ -55,22 +110,34 @@ import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (blobFieldDecoder, defaultJSON, parseE, parseE') import Simplex.Messaging.Util ((<$?>), ($>>=)) import Simplex.Messaging.Version +import Simplex.Messaging.Version.Internal import UnliftIO.STM -- e2e encryption headers version history: -- 1 - binary protocol encoding (1/1/2022) -- 2 - use KDF in x3dh (10/20/2022) -kdfX3DHE2EEncryptVersion :: Version -kdfX3DHE2EEncryptVersion = 2 +data E2EVersion -pqRatchetVersion :: Version -pqRatchetVersion = 3 +instance VersionScope E2EVersion -currentE2EEncryptVersion :: Version -currentE2EEncryptVersion = 3 +type VersionE2E = Version E2EVersion -supportedE2EEncryptVRange :: VersionRange +type VersionRangeE2E = VersionRange E2EVersion + +pattern VersionE2E :: Word16 -> VersionE2E +pattern VersionE2E v = Version v + +kdfX3DHE2EEncryptVersion :: VersionE2E +kdfX3DHE2EEncryptVersion = VersionE2E 2 + +pqRatchetVersion :: VersionE2E +pqRatchetVersion = VersionE2E 3 + +currentE2EEncryptVersion :: VersionE2E +currentE2EEncryptVersion = VersionE2E 3 + +supportedE2EEncryptVRange :: VersionRangeE2E supportedE2EEncryptVRange = mkVersionRange kdfX3DHE2EEncryptVersion currentE2EEncryptVersion data RatchetKEMState @@ -120,7 +187,7 @@ instance RatchetKEMStateI s => Encoding (RKEMParams s) where RKParamsAccepted ct k -> smpEncode ('A', ct, k) smpP = (\(ARKP _ ps) -> checkRatchetKEMState' ps) <$?> smpP -instance Encoding (ARKEMParams) where +instance Encoding ARKEMParams where smpEncode (ARKP _ ps) = smpEncode ps smpP = smpP >>= \case @@ -129,7 +196,7 @@ instance Encoding (ARKEMParams) where _ -> fail "bad ratchet KEM params" data E2ERatchetParams (s :: RatchetKEMState) (a :: Algorithm) - = E2ERatchetParams Version (PublicKey a) (PublicKey a) (Maybe (RKEMParams s)) + = E2ERatchetParams VersionE2E (PublicKey a) (PublicKey a) (Maybe (RKEMParams s)) deriving (Show) data AE2ERatchetParams (a :: Algorithm) @@ -159,12 +226,12 @@ instance (RatchetKEMStateI s, AlgorithmI a) => Encoding (E2ERatchetParams s a) w instance AlgorithmI a => Encoding (AE2ERatchetParams a) where smpEncode (AE2ERatchetParams _ ps) = smpEncode ps - smpP = (\(AnyE2ERatchetParams s _ ps) -> (AE2ERatchetParams s) <$> checkAlgorithm ps) <$?> smpP + smpP = (\(AnyE2ERatchetParams s _ ps) -> AE2ERatchetParams s <$> checkAlgorithm ps) <$?> smpP instance Encoding AnyE2ERatchetParams where smpEncode (AnyE2ERatchetParams _ _ ps) = smpEncode ps smpP = do - v :: Version <- smpP + v :: VersionE2E <- smpP APublicDhKey a k1 <- smpP APublicDhKey a' k2 <- smpP case testEquality a a' of @@ -174,25 +241,25 @@ instance Encoding AnyE2ERatchetParams where Just (ARKP s kem) -> pure $ AnyE2ERatchetParams s a $ E2ERatchetParams v k1 k2 (Just kem) Nothing -> pure $ AnyE2ERatchetParams SRKSProposed a $ E2ERatchetParams v k1 k2 Nothing where - kemP :: Version -> Parser (Maybe (ARKEMParams)) + kemP :: VersionE2E -> Parser (Maybe ARKEMParams) kemP v | v >= pqRatchetVersion = smpP | otherwise = pure Nothing -instance VersionI (E2ERatchetParams s a) where - type VersionRangeT (E2ERatchetParams s a) = E2ERatchetParamsUri s a +instance VersionI E2EVersion (E2ERatchetParams s a) where + type VersionRangeT E2EVersion (E2ERatchetParams s a) = E2ERatchetParamsUri s a version (E2ERatchetParams v _ _ _) = v toVersionRangeT (E2ERatchetParams _ k1 k2 kem_) vr = E2ERatchetParamsUri vr k1 k2 kem_ -instance VersionRangeI (E2ERatchetParamsUri s a) where - type VersionT (E2ERatchetParamsUri s a) = (E2ERatchetParams s a) +instance VersionRangeI E2EVersion (E2ERatchetParamsUri s a) where + type VersionT E2EVersion (E2ERatchetParamsUri s a) = (E2ERatchetParams s a) versionRange (E2ERatchetParamsUri vr _ _ _) = vr toVersionT (E2ERatchetParamsUri _ k1 k2 kem_) v = E2ERatchetParams v k1 k2 kem_ type RcvE2ERatchetParamsUri a = E2ERatchetParamsUri 'RKSProposed a data E2ERatchetParamsUri (s :: RatchetKEMState) (a :: Algorithm) - = E2ERatchetParamsUri VersionRange (PublicKey a) (PublicKey a) (Maybe (RKEMParams s)) + = E2ERatchetParamsUri VersionRangeE2E (PublicKey a) (PublicKey a) (Maybe (RKEMParams s)) deriving (Show) data AE2ERatchetParamsUri (a :: Algorithm) @@ -228,13 +295,13 @@ instance (RatchetKEMStateI s, AlgorithmI a) => StrEncoding (E2ERatchetParamsUri instance AlgorithmI a => StrEncoding (AE2ERatchetParamsUri a) where strEncode (AE2ERatchetParamsUri _ ps) = strEncode ps - strP = (\(AnyE2ERatchetParamsUri s _ ps) -> (AE2ERatchetParamsUri s) <$> checkAlgorithm ps) <$?> strP + strP = (\(AnyE2ERatchetParamsUri s _ ps) -> AE2ERatchetParamsUri s <$> checkAlgorithm ps) <$?> strP instance StrEncoding AnyE2ERatchetParamsUri where strEncode (AnyE2ERatchetParamsUri _ _ ps) = strEncode ps strP = do query <- strP - vr :: VersionRange <- queryParam "v" query + vr :: VersionRangeE2E <- queryParam "v" query keys <- L.toList <$> queryParam "x3dh" query case keys of [APublicDhKey a k1, APublicDhKey a' k2] -> case testEquality a a' of @@ -248,7 +315,7 @@ instance StrEncoding AnyE2ERatchetParamsUri where kemP vr query | maxVersion vr >= pqRatchetVersion = queryParam_ "kem_key" query - $>>= \k -> (Just . kemParams k <$> queryParam_ "kem_ct" query) + $>>= \k -> Just . kemParams k <$> queryParam_ "kem_ct" query | otherwise = pure Nothing kemParams k = \case Nothing -> ARKP SRKSProposed $ RKParamsProposed k @@ -272,7 +339,7 @@ instance RatchetKEMStateI s => Encoding (PrivRKEMParams s) where PrivateRKParamsAccepted ct shared k -> smpEncode ('A', ct, shared, k) smpP = (\(APRKP _ ps) -> checkRatchetKEMState' ps) <$?> smpP -instance Encoding (APrivRKEMParams) where +instance Encoding APrivRKEMParams where smpEncode (APRKP _ ps) = smpEncode ps smpP = smpP >>= \case @@ -290,7 +357,7 @@ data UseKEM (s :: RatchetKEMState) where data AUseKEM = forall s. RatchetKEMStateI s => AUseKEM (SRatchetKEMState s) (UseKEM s) -generateE2EParams :: forall s a. (AlgorithmI a, DhAlgorithm a) => TVar ChaChaDRG -> Version -> Maybe (UseKEM s) -> IO (PrivateKey a, PrivateKey a, Maybe (PrivRKEMParams s), E2ERatchetParams s a) +generateE2EParams :: forall s a. (AlgorithmI a, DhAlgorithm a) => TVar ChaChaDRG -> VersionE2E -> Maybe (UseKEM s) -> IO (PrivateKey a, PrivateKey a, Maybe (PrivRKEMParams s), E2ERatchetParams s a) generateE2EParams g v useKEM_ = do (k1, pk1) <- atomically $ generateKeyPair g (k2, pk2) <- atomically $ generateKeyPair g @@ -309,7 +376,7 @@ generateE2EParams g v useKEM_ = do _ -> pure Nothing -- used by party initiating connection, Bob in double-ratchet spec -generateRcvE2EParams :: (AlgorithmI a, DhAlgorithm a) => TVar ChaChaDRG -> Version -> PQEncryption -> IO (PrivateKey a, PrivateKey a, Maybe (PrivRKEMParams 'RKSProposed), E2ERatchetParams 'RKSProposed a) +generateRcvE2EParams :: (AlgorithmI a, DhAlgorithm a) => TVar ChaChaDRG -> VersionE2E -> PQEncryption -> IO (PrivateKey a, PrivateKey a, Maybe (PrivRKEMParams 'RKSProposed), E2ERatchetParams 'RKSProposed a) generateRcvE2EParams g v = generateE2EParams g v . proposeKEM_ where proposeKEM_ :: PQEncryption -> Maybe (UseKEM 'RKSProposed) @@ -319,7 +386,7 @@ generateRcvE2EParams g v = generateE2EParams g v . proposeKEM_ -- used by party accepting connection, Alice in double-ratchet spec -generateSndE2EParams :: forall a. (AlgorithmI a, DhAlgorithm a) => TVar ChaChaDRG -> Version -> Maybe AUseKEM -> IO (PrivateKey a, PrivateKey a, Maybe APrivRKEMParams, AE2ERatchetParams a) +generateSndE2EParams :: forall a. (AlgorithmI a, DhAlgorithm a) => TVar ChaChaDRG -> VersionE2E -> Maybe AUseKEM -> IO (PrivateKey a, PrivateKey a, Maybe APrivRKEMParams, AE2ERatchetParams a) generateSndE2EParams g v = \case Nothing -> do (pk1, pk2, _, e2eParams) <- generateE2EParams g v Nothing @@ -385,7 +452,7 @@ type RatchetX448 = Ratchet 'X448 data Ratchet a = Ratchet { -- ratchet version range sent in messages (current .. max supported ratchet version) - rcVersion :: VersionRange, + rcVersion :: RatchetVersions, -- associated data - must be the same in both parties ratchets rcAD :: Str, rcDHRs :: PrivateKey a, @@ -405,6 +472,29 @@ data Ratchet a = Ratchet } deriving (Show) +data RatchetVersions = RVersions + { current :: VersionE2E, + maxSupported :: VersionE2E + } + deriving (Eq, Show) + +instance ToJSON RatchetVersions where + -- TODO v5.7 or v5.8 change to the default record encoding + toJSON (RVersions v1 v2) = toJSON (v1, v2) + toEncoding (RVersions v1 v2) = toEncoding (v1, v2) + +instance FromJSON RatchetVersions where + -- TODO v6.0 replace with the default record parser + -- this parser supports JSON record encoding for forward compatibility + parseJSON v = (tupleP <|> recordP v) >>= toRV + where + tupleP = parseJSON v + recordP = J.withObject "RatchetVersions" $ \o -> (,) <$> o J..: "current" <*> o J..: "maxSupported" + toRV (v1, v2) = maybe (fail "bad version range") (pure . ratchetVersions) $ safeVersionRange v1 v2 + +ratchetVersions :: VersionRangeE2E -> RatchetVersions +ratchetVersions (VersionRange v1 v2) = RVersions {current = v1, maxSupported = v2} + data SndRatchet a = SndRatchet { rcDHRr :: PublicKey a, rcCKs :: RatchetKey, @@ -494,12 +584,12 @@ instance FromField MessageKey where fromField = blobFieldDecoder smpDecode -- // above added for KEM -- @ initSndRatchet :: - forall a. (AlgorithmI a, DhAlgorithm a) => VersionRange -> PublicKey a -> PrivateKey a -> (RatchetInitParams, Maybe KEMKeyPair) -> Ratchet a -initSndRatchet rcVersion rcDHRr rcDHRs (RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted}, rcPQRs_) = do + forall a. (AlgorithmI a, DhAlgorithm a) => VersionRangeE2E -> PublicKey a -> PrivateKey a -> (RatchetInitParams, Maybe KEMKeyPair) -> Ratchet a +initSndRatchet v rcDHRr rcDHRs (RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted}, rcPQRs_) = do -- state.RK, state.CKs, state.NHKs = KDF_RK_HE(SK, DH(state.DHRs, state.DHRr) || state.PQRss) let (rcRK, rcCKs, rcNHKs) = rootKdf ratchetKey rcDHRr rcDHRs (rcPQRss <$> kemAccepted) in Ratchet - { rcVersion, + { rcVersion = ratchetVersions v, rcAD = assocData, rcDHRs, rcKEM = (`RatchetKEM` kemAccepted) <$> rcPQRs_, @@ -523,10 +613,10 @@ initSndRatchet rcVersion rcDHRr rcDHRs (RatchetInitParams {assocData, ratchetKey -- Please note that the public part of rcDHRs was sent to the sender -- as part of the connection request and random salt was received from the sender. initRcvRatchet :: - forall a. (AlgorithmI a, DhAlgorithm a) => VersionRange -> PrivateKey a -> (RatchetInitParams, Maybe KEMKeyPair) -> PQEncryption -> Ratchet a -initRcvRatchet rcVersion rcDHRs (RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted}, rcPQRs_) rcEnableKEM = + forall a. (AlgorithmI a, DhAlgorithm a) => VersionRangeE2E -> PrivateKey a -> (RatchetInitParams, Maybe KEMKeyPair) -> PQEncryption -> Ratchet a +initRcvRatchet v rcDHRs (RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted}, rcPQRs_) rcEnableKEM = Ratchet - { rcVersion, + { rcVersion = ratchetVersions v, rcAD = assocData, rcDHRs, -- rcKEM: @@ -552,7 +642,7 @@ initRcvRatchet rcVersion rcDHRs (RatchetInitParams {assocData, ratchetKey, sndHK -- ct = state.PQRct // added for KEM #1 data MsgHeader a = MsgHeader { -- | max supported ratchet version - msgMaxVersion :: Version, + msgMaxVersion :: VersionE2E, msgDHRs :: PublicKey a, msgKEM :: Maybe ARKEMParams, msgPN :: Word32, @@ -560,11 +650,6 @@ data MsgHeader a = MsgHeader } deriving (Show) -data AMsgHeader - = forall a. - (AlgorithmI a, DhAlgorithm a) => - AMsgHeader (SAlgorithm a) (MsgHeader a) - -- to allow extension without increasing the size, the actual header length is: -- 69 = 2 (original size) + 2 + 1+56 (Curve448) + 4 + 4 -- TODO PQ this must be version-dependent @@ -590,7 +675,7 @@ instance AlgorithmI a => Encoding (MsgHeader a) where pure MsgHeader {msgMaxVersion, msgDHRs, msgKEM, msgPN, msgNs} data EncMessageHeader = EncMessageHeader - { ehVersion :: Version, + { ehVersion :: VersionE2E, ehIV :: IV, ehAuthTag :: AuthTag, ehBody :: ByteString @@ -611,12 +696,12 @@ data EncRatchetMessage = EncRatchetMessage emBody :: ByteString } -encodeEncRatchetMessage :: Version -> EncRatchetMessage -> ByteString +encodeEncRatchetMessage :: VersionE2E -> EncRatchetMessage -> ByteString encodeEncRatchetMessage v EncRatchetMessage {emHeader, emBody, emAuthTag} | v >= pqRatchetVersion = smpEncode (Large emHeader, emAuthTag, Tail emBody) | otherwise = smpEncode (emHeader, emAuthTag, Tail emBody) -encRatchetMessageP :: Version -> Parser EncRatchetMessage +encRatchetMessageP :: VersionE2E -> Parser EncRatchetMessage encRatchetMessageP v = do emHeader <- if v >= pqRatchetVersion then unLarge <$> smpP else smpP (emAuthTag, Tail emBody) <- smpP @@ -694,10 +779,10 @@ rcEncrypt rc@Ratchet {rcSnd = Just sr@SndRatchet {rcCKs, rcHKs}, rcDHRs, rcKEM, -- enc_header = HENCRYPT(state.HKs, header) (ehAuthTag, ehBody) <- encryptAEAD rcHKs ehIV paddedHeaderLen rcAD msgHeader -- return enc_header, ENCRYPT(mk, plaintext, CONCAT(AD, enc_header)) - -- TODO PQ versioning in Ratchet should change somehow - let emHeader = smpEncode EncMessageHeader {ehVersion = maxVersion rcVersion, ehBody, ehAuthTag, ehIV} + -- TODO PQ versioning in Ratchet should change: we should use "current" version here + let emHeader = smpEncode EncMessageHeader {ehVersion = maxSupported rcVersion, ehBody, ehAuthTag, ehIV} (emAuthTag, emBody) <- encryptAEAD mk iv paddedMsgLen (rcAD <> emHeader) msg - let msg' = encodeEncRatchetMessage (maxVersion rcVersion) EncRatchetMessage {emHeader, emBody, emAuthTag} + let msg' = encodeEncRatchetMessage (maxSupported rcVersion) EncRatchetMessage {emHeader, emBody, emAuthTag} -- state.Ns += 1 rc' = rc {rcSnd = Just sr {rcCKs = ck'}, rcNs = rcNs + 1} rc'' = case pqMode_ of @@ -719,7 +804,7 @@ rcEncrypt rc@Ratchet {rcSnd = Just sr@SndRatchet {rcCKs, rcHKs}, rcDHRs, rcKEM, msgHeader = smpEncode MsgHeader - { msgMaxVersion = maxVersion rcVersion, + { msgMaxVersion = maxSupported rcVersion, msgDHRs = publicKey rcDHRs, msgKEM = msgKEMParams <$> rcKEM, msgPN = rcPN, @@ -752,7 +837,7 @@ rcDecrypt :: ExceptT CryptoError IO (DecryptResult a) rcDecrypt g rc@Ratchet {rcRcv, rcAD = Str rcAD, rcVersion} rcMKSkipped msg' = do -- TODO PQ versioning should change - encMsg@EncRatchetMessage {emHeader} <- parseE CryptoHeaderError (encRatchetMessageP $ maxVersion rcVersion) msg' + encMsg@EncRatchetMessage {emHeader} <- parseE CryptoHeaderError (encRatchetMessageP $ maxSupported rcVersion) msg' encHdr <- parseE CryptoHeaderError smpP emHeader -- plaintext = TrySkippedMessageKeysHE(state, enc_header, cipher-text, AD) decryptSkipped encHdr encMsg >>= \case diff --git a/src/Simplex/Messaging/Notifications/Client.hs b/src/Simplex/Messaging/Notifications/Client.hs index d69114b68..72a92c278 100644 --- a/src/Simplex/Messaging/Notifications/Client.hs +++ b/src/Simplex/Messaging/Notifications/Client.hs @@ -10,15 +10,15 @@ import Data.Word (Word16) import Simplex.Messaging.Client import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Notifications.Protocol -import Simplex.Messaging.Notifications.Transport (supportedClientNTFVRange) +import Simplex.Messaging.Notifications.Transport (NTFVersion, supportedClientNTFVRange) import Simplex.Messaging.Protocol (ErrorType) import Simplex.Messaging.Util (bshow) -type NtfClient = ProtocolClient ErrorType NtfResponse +type NtfClient = ProtocolClient NTFVersion ErrorType NtfResponse type NtfClientError = ProtocolClientError ErrorType -defaultNTFClientConfig :: ProtocolClientConfig +defaultNTFClientConfig :: ProtocolClientConfig NTFVersion defaultNTFClientConfig = defaultClientConfig supportedClientNTFVRange ntfRegisterToken :: NtfClient -> C.APrivateAuthKey -> NewNtfEntity 'Token -> ExceptT NtfClientError IO (NtfTokenId, C.PublicKeyX25519) diff --git a/src/Simplex/Messaging/Notifications/Protocol.hs b/src/Simplex/Messaging/Notifications/Protocol.hs index 73c2dada6..943c30c5a 100644 --- a/src/Simplex/Messaging/Notifications/Protocol.hs +++ b/src/Simplex/Messaging/Notifications/Protocol.hs @@ -28,7 +28,7 @@ import Simplex.Messaging.Agent.Protocol (updateSMPServerHosts) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Notifications.Transport (ntfClientHandshake) +import Simplex.Messaging.Notifications.Transport (NTFVersion, ntfClientHandshake) import Simplex.Messaging.Parsers (fromTextField_) import Simplex.Messaging.Protocol hiding (Command (..), CommandTag (..)) import Simplex.Messaging.Util (eitherToMaybe, (<$?>)) @@ -147,7 +147,7 @@ instance Encoding ANewNtfEntity where 'S' -> ANE SSubscription <$> (NewNtfSub <$> smpP <*> smpP <*> smpP) _ -> fail "bad ANewNtfEntity" -instance Protocol ErrorType NtfResponse where +instance Protocol NTFVersion ErrorType NtfResponse where type ProtoCommand NtfResponse = NtfCmd type ProtoType NtfResponse = 'PNTF protocolClientHandshake = ntfClientHandshake @@ -184,7 +184,7 @@ data NtfCmd = forall e. NtfEntityI e => NtfCmd (SNtfEntity e) (NtfCommand e) deriving instance Show NtfCmd -instance NtfEntityI e => ProtocolEncoding ErrorType (NtfCommand e) where +instance NtfEntityI e => ProtocolEncoding NTFVersion ErrorType (NtfCommand e) where type Tag (NtfCommand e) = NtfCommandTag e encodeProtocol _v = \case TNEW newTkn -> e (TNEW_, ' ', newTkn) @@ -203,7 +203,7 @@ instance NtfEntityI e => ProtocolEncoding ErrorType (NtfCommand e) where protocolP _v tag = (\(NtfCmd _ c) -> checkEntity c) <$?> protocolP _v (NCT (sNtfEntity @e) tag) - fromProtocolError = fromProtocolError @ErrorType @NtfResponse + fromProtocolError = fromProtocolError @NTFVersion @ErrorType @NtfResponse {-# INLINE fromProtocolError #-} checkCredentials (auth, _, entityId, _) cmd = case cmd of @@ -223,7 +223,7 @@ instance NtfEntityI e => ProtocolEncoding ErrorType (NtfCommand e) where | not (B.null entityId) = Left $ CMD HAS_AUTH | otherwise = Right cmd -instance ProtocolEncoding ErrorType NtfCmd where +instance ProtocolEncoding NTFVersion ErrorType NtfCmd where type Tag NtfCmd = NtfCmdTag encodeProtocol _v (NtfCmd _ c) = encodeProtocol _v c @@ -243,7 +243,7 @@ instance ProtocolEncoding ErrorType NtfCmd where SDEL_ -> pure SDEL PING_ -> pure PING - fromProtocolError = fromProtocolError @ErrorType @NtfResponse + fromProtocolError = fromProtocolError @NTFVersion @ErrorType @NtfResponse {-# INLINE fromProtocolError #-} checkCredentials t (NtfCmd e c) = NtfCmd e <$> checkCredentials t c @@ -290,7 +290,7 @@ data NtfResponse | NRPong deriving (Show) -instance ProtocolEncoding ErrorType NtfResponse where +instance ProtocolEncoding NTFVersion ErrorType NtfResponse where type Tag NtfResponse = NtfResponseTag encodeProtocol _v = \case NRTknId entId dhKey -> e (NRTknId_, ' ', entId, dhKey) diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index 754aa6d62..7ae657fd1 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -338,7 +338,7 @@ updateTknStatus NtfTknData {ntfTknId, tknStatus} status = do old <- atomically $ stateTVar tknStatus (,status) when (old /= status) $ withNtfLog $ \sl -> logTokenStatus sl ntfTknId status -runNtfClientTransport :: Transport c => THandle c -> M () +runNtfClientTransport :: Transport c => THandleNTF c -> M () runNtfClientTransport th@THandle {params} = do qSize <- asks $ clientQSize . config ts <- liftIO getSystemTime @@ -355,7 +355,7 @@ runNtfClientTransport th@THandle {params} = do clientDisconnected :: NtfServerClient -> IO () clientDisconnected NtfServerClient {connected} = atomically $ writeTVar connected False -receive :: Transport c => THandle c -> NtfServerClient -> M () +receive :: Transport c => THandleNTF c -> NtfServerClient -> M () receive th@THandle {params = THandleParams {thAuth}} NtfServerClient {rcvQ, sndQ, rcvActiveAt} = forever $ do ts <- liftIO $ tGet th forM_ ts $ \t@(_, _, (corrId, entId, cmdOrError)) -> do @@ -370,7 +370,7 @@ receive th@THandle {params = THandleParams {thAuth}} NtfServerClient {rcvQ, sndQ where write q t = atomically $ writeTBQueue q t -send :: Transport c => THandle c -> NtfServerClient -> IO () +send :: Transport c => THandleNTF c -> NtfServerClient -> IO () send h@THandle {params} NtfServerClient {sndQ, sndActiveAt} = forever $ do t <- atomically $ readTBQueue sndQ void . liftIO $ tPut h [Right (Nothing, encodeTransmission params t)] diff --git a/src/Simplex/Messaging/Notifications/Server/Env.hs b/src/Simplex/Messaging/Notifications/Server/Env.hs index 9e3013a8d..0d722dcc3 100644 --- a/src/Simplex/Messaging/Notifications/Server/Env.hs +++ b/src/Simplex/Messaging/Notifications/Server/Env.hs @@ -24,6 +24,7 @@ import Numeric.Natural import Simplex.Messaging.Client.Agent import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Notifications.Protocol +import Simplex.Messaging.Notifications.Transport (NTFVersion, VersionRangeNTF) import Simplex.Messaging.Notifications.Server.Push.APNS import Simplex.Messaging.Notifications.Server.Stats import Simplex.Messaging.Notifications.Server.Store @@ -34,7 +35,6 @@ import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (ATransport, THandleParams) import Simplex.Messaging.Transport.Server (TransportServerConfig, loadFingerprint, loadTLSServerParams) -import Simplex.Messaging.Version (VersionRange) import System.IO (IOMode (..)) import System.Mem.Weak (Weak) import UnliftIO.STM @@ -60,7 +60,7 @@ data NtfServerConfig = NtfServerConfig logStatsStartTime :: Int64, serverStatsLogFile :: FilePath, serverStatsBackupFile :: Maybe FilePath, - ntfServerVRange :: VersionRange, + ntfServerVRange :: VersionRangeNTF, transportConfig :: TransportServerConfig } @@ -161,13 +161,13 @@ data NtfRequest data NtfServerClient = NtfServerClient { rcvQ :: TBQueue NtfRequest, sndQ :: TBQueue (Transmission NtfResponse), - ntfThParams :: THandleParams, + ntfThParams :: THandleParams NTFVersion, connected :: TVar Bool, rcvActiveAt :: TVar SystemTime, sndActiveAt :: TVar SystemTime } -newNtfServerClient :: Natural -> THandleParams -> SystemTime -> STM NtfServerClient +newNtfServerClient :: Natural -> THandleParams NTFVersion -> SystemTime -> STM NtfServerClient newNtfServerClient qSize ntfThParams ts = do rcvQ <- newTBQueue qSize sndQ <- newTBQueue qSize diff --git a/src/Simplex/Messaging/Notifications/Transport.hs b/src/Simplex/Messaging/Notifications/Transport.hs index 00fd811a2..bc68fab03 100644 --- a/src/Simplex/Messaging/Notifications/Transport.hs +++ b/src/Simplex/Messaging/Notifications/Transport.hs @@ -3,6 +3,7 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} module Simplex.Messaging.Notifications.Transport where @@ -12,33 +13,51 @@ import Control.Monad.Except import Data.Attoparsec.ByteString.Char8 (Parser) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B +import Data.Word (Word16) import qualified Data.X509 as X import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Transport import Simplex.Messaging.Version +import Simplex.Messaging.Version.Internal import Simplex.Messaging.Util (liftEitherWith) ntfBlockSize :: Int ntfBlockSize = 512 -authBatchCmdsNTFVersion :: Version -authBatchCmdsNTFVersion = 2 +data NTFVersion -currentClientNTFVersion :: Version -currentClientNTFVersion = 1 +instance VersionScope NTFVersion -currentServerNTFVersion :: Version -currentServerNTFVersion = 1 +type VersionNTF = Version NTFVersion -supportedClientNTFVRange :: VersionRange -supportedClientNTFVRange = mkVersionRange 1 currentClientNTFVersion +type VersionRangeNTF = VersionRange NTFVersion -supportedServerNTFVRange :: VersionRange -supportedServerNTFVRange = mkVersionRange 1 currentServerNTFVersion +pattern VersionNTF :: Word16 -> VersionNTF +pattern VersionNTF v = Version v + +initialNTFVersion :: VersionNTF +initialNTFVersion = VersionNTF 1 + +authBatchCmdsNTFVersion :: VersionNTF +authBatchCmdsNTFVersion = VersionNTF 2 + +currentClientNTFVersion :: VersionNTF +currentClientNTFVersion = VersionNTF 1 + +currentServerNTFVersion :: VersionNTF +currentServerNTFVersion = VersionNTF 1 + +supportedClientNTFVRange :: VersionRangeNTF +supportedClientNTFVRange = mkVersionRange initialNTFVersion currentClientNTFVersion + +supportedServerNTFVRange :: VersionRangeNTF +supportedServerNTFVRange = mkVersionRange initialNTFVersion currentServerNTFVersion + +type THandleNTF c = THandle NTFVersion c data NtfServerHandshake = NtfServerHandshake - { ntfVersionRange :: VersionRange, + { ntfVersionRange :: VersionRangeNTF, sessionId :: SessionId, -- pub key to agree shared secrets for command authorization and entity ID encryption. authPubKey :: Maybe (X.SignedExact X.PubKey) @@ -46,7 +65,7 @@ data NtfServerHandshake = NtfServerHandshake data NtfClientHandshake = NtfClientHandshake { -- | agreed SMP notifications server protocol version - ntfVersion :: Version, + ntfVersion :: VersionNTF, -- | server identity - CA certificate fingerprint keyHash :: C.KeyHash, -- pub key to agree shared secret for entity ID encryption, shared secret for command authorization is agreed using per-queue keys. @@ -66,12 +85,12 @@ instance Encoding NtfServerHandshake where authPubKey <- authEncryptCmdsP (maxVersion ntfVersionRange) $ C.getSignedExact <$> smpP pure NtfServerHandshake {ntfVersionRange, sessionId, authPubKey} -encodeAuthEncryptCmds :: Encoding a => Version -> Maybe a -> ByteString +encodeAuthEncryptCmds :: Encoding a => VersionNTF -> Maybe a -> ByteString encodeAuthEncryptCmds v k | v >= authBatchCmdsNTFVersion = maybe "" smpEncode k | otherwise = "" -authEncryptCmdsP :: Version -> Parser a -> Parser (Maybe a) +authEncryptCmdsP :: VersionNTF -> Parser a -> Parser (Maybe a) authEncryptCmdsP v p = if v >= authBatchCmdsNTFVersion then Just <$> p else pure Nothing instance Encoding NtfClientHandshake where @@ -83,16 +102,16 @@ instance Encoding NtfClientHandshake where authPubKey <- ntfAuthPubKeyP ntfVersion pure NtfClientHandshake {ntfVersion, keyHash, authPubKey} -ntfAuthPubKeyP :: Version -> Parser (Maybe C.PublicKeyX25519) +ntfAuthPubKeyP :: VersionNTF -> Parser (Maybe C.PublicKeyX25519) ntfAuthPubKeyP v = if v >= authBatchCmdsNTFVersion then Just <$> smpP else pure Nothing -encodeNtfAuthPubKey :: Version -> Maybe C.PublicKeyX25519 -> ByteString +encodeNtfAuthPubKey :: VersionNTF -> Maybe C.PublicKeyX25519 -> ByteString encodeNtfAuthPubKey v k | v >= authBatchCmdsNTFVersion = maybe "" smpEncode k | otherwise = "" -- | Notifcations server transport handshake. -ntfServerHandshake :: forall c. Transport c => C.APrivateSignKey -> c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c) +ntfServerHandshake :: forall c. Transport c => C.APrivateSignKey -> c -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeNTF -> ExceptT TransportError IO (THandleNTF c) ntfServerHandshake serverSignKey c (k, pk) kh ntfVRange = do let th@THandle {params = THandleParams {sessionId}} = ntfTHandle c let sk = C.signX509 serverSignKey $ C.publicToX509 k @@ -106,7 +125,7 @@ ntfServerHandshake serverSignKey c (k, pk) kh ntfVRange = do | otherwise -> throwError $ TEHandshake VERSION -- | Notifcations server client transport handshake. -ntfClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c) +ntfClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeNTF -> ExceptT TransportError IO (THandleNTF c) ntfClientHandshake c (k, pk) keyHash ntfVRange = do let th@THandle {params = THandleParams {sessionId}} = ntfTHandle c NtfServerHandshake {sessionId = sessId, ntfVersionRange, authPubKey = sk'} <- getHandshake th @@ -122,15 +141,15 @@ ntfClientHandshake c (k, pk) keyHash ntfVRange = do pure $ ntfThHandle th v pk sk_ Nothing -> throwError $ TEHandshake VERSION -ntfThHandle :: forall c. THandle c -> Version -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> THandle c +ntfThHandle :: forall c. THandleNTF c -> VersionNTF -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> THandleNTF c ntfThHandle th@THandle {params} v privKey k_ = -- TODO drop SMP v6: make thAuth non-optional let thAuth = (\k -> THandleAuth {peerPubKey = k, privKey}) <$> k_ v3 = v >= authBatchCmdsNTFVersion params' = params {thVersion = v, thAuth, implySessId = v3, batch = v3} - in (th :: THandle c) {params = params'} + in (th :: THandleNTF c) {params = params'} -ntfTHandle :: Transport c => c -> THandle c +ntfTHandle :: Transport c => c -> THandleNTF c ntfTHandle c = THandle {connection = c, params} where - params = THandleParams {sessionId = tlsUnique c, blockSize = ntfBlockSize, thVersion = 0, thAuth = Nothing, implySessId = False, batch = False} + params = THandleParams {sessionId = tlsUnique c, blockSize = ntfBlockSize, thVersion = VersionNTF 0, thAuth = Nothing, implySessId = False, batch = False} diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 2c7685ab6..8f36d8e7a 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -46,6 +46,10 @@ module Simplex.Messaging.Protocol e2eEncMessageLength, -- * SMP protocol types + SMPClientVersion, + VersionSMPC, + VersionRangeSMPC, + pattern VersionSMPC, ProtocolEncoding (..), Command (..), SubscriptionMode (..), @@ -117,6 +121,7 @@ module Simplex.Messaging.Protocol SMPMsgMeta (..), NMsgMeta (..), MsgFlags (..), + initialSMPClientVersion, userProtocol, rcvMessageMeta, noMsgFlags, @@ -179,6 +184,7 @@ import Data.Maybe (isNothing) import Data.String import Data.Time.Clock.System (SystemTime (..)) import Data.Type.Equality +import Data.Word (Word16) import GHC.TypeLits (ErrorMessage (..), TypeError, type (+)) import Network.Socket (ServiceName) import qualified Simplex.Messaging.Crypto as C @@ -190,19 +196,34 @@ import Simplex.Messaging.Transport import Simplex.Messaging.Transport.Client (TransportHost, TransportHosts (..)) import Simplex.Messaging.Util (bshow, eitherToMaybe, (<$?>)) import Simplex.Messaging.Version +import Simplex.Messaging.Version.Internal -- SMP client protocol version history: -- 1 - binary protocol encoding (1/1/2022) -- 2 - multiple server hostnames and versioned queue addresses (8/12/2022) -srvHostnamesSMPClientVersion :: Version -srvHostnamesSMPClientVersion = 2 +data SMPClientVersion -currentSMPClientVersion :: Version -currentSMPClientVersion = 2 +instance VersionScope SMPClientVersion -supportedSMPClientVRange :: VersionRange -supportedSMPClientVRange = mkVersionRange 1 currentSMPClientVersion +type VersionSMPC = Version SMPClientVersion + +type VersionRangeSMPC = VersionRange SMPClientVersion + +pattern VersionSMPC :: Word16 -> VersionSMPC +pattern VersionSMPC v = Version v + +initialSMPClientVersion :: VersionSMPC +initialSMPClientVersion = VersionSMPC 1 + +srvHostnamesSMPClientVersion :: VersionSMPC +srvHostnamesSMPClientVersion = VersionSMPC 2 + +currentSMPClientVersion :: VersionSMPC +currentSMPClientVersion = VersionSMPC 2 + +supportedSMPClientVRange :: VersionRangeSMPC +supportedSMPClientVRange = mkVersionRange initialSMPClientVersion currentSMPClientVersion maxMessageLength :: Int maxMessageLength = 16088 @@ -642,7 +663,7 @@ data ClientMsgEnvelope = ClientMsgEnvelope deriving (Show) data PubHeader = PubHeader - { phVersion :: Version, + { phVersion :: VersionSMPC, phE2ePubDhKey :: Maybe C.PublicKeyX25519 } deriving (Show) @@ -1048,7 +1069,7 @@ data CommandError deriving (Eq, Read, Show) -- | SMP transmission parser. -transmissionP :: THandleParams -> Parser RawTransmission +transmissionP :: THandleParams v -> Parser RawTransmission transmissionP THandleParams {sessionId, implySessId} = do authenticator <- smpP authorized <- A.takeByteString @@ -1062,16 +1083,16 @@ transmissionP THandleParams {sessionId, implySessId} = do command <- A.takeByteString pure RawTransmission {authenticator, authorized = authorized', sessId, corrId, entityId, command} -class (ProtocolEncoding err msg, ProtocolEncoding err (ProtoCommand msg), Show err, Show msg) => Protocol err msg | msg -> err where +class (ProtocolEncoding v err msg, ProtocolEncoding v err (ProtoCommand msg), Show err, Show msg) => Protocol v err msg | msg -> v, msg -> err where type ProtoCommand msg = cmd | cmd -> msg type ProtoType msg = (sch :: ProtocolType) | sch -> msg - protocolClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c) + protocolClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange v -> ExceptT TransportError IO (THandle v c) protocolPing :: ProtoCommand msg protocolError :: msg -> Maybe err type ProtoServer msg = ProtocolServer (ProtoType msg) -instance Protocol ErrorType BrokerMsg where +instance Protocol SMPVersion ErrorType BrokerMsg where type ProtoCommand BrokerMsg = Cmd type ProtoType BrokerMsg = 'PSMP protocolClientHandshake = smpClientHandshake @@ -1080,14 +1101,14 @@ instance Protocol ErrorType BrokerMsg where ERR e -> Just e _ -> Nothing -class ProtocolMsgTag (Tag msg) => ProtocolEncoding err msg | msg -> err where +class ProtocolMsgTag (Tag msg) => ProtocolEncoding v err msg | msg -> err, msg -> v where type Tag msg - encodeProtocol :: Version -> msg -> ByteString - protocolP :: Version -> Tag msg -> Parser msg + encodeProtocol :: Version v -> msg -> ByteString + protocolP :: Version v -> Tag msg -> Parser msg fromProtocolError :: ProtocolErrorType -> err checkCredentials :: SignedRawTransmission -> msg -> Either err msg -instance PartyI p => ProtocolEncoding ErrorType (Command p) where +instance PartyI p => ProtocolEncoding SMPVersion ErrorType (Command p) where type Tag (Command p) = CommandTag p encodeProtocol v = \case NEW rKey dhKey auth_ subMode @@ -1114,7 +1135,7 @@ instance PartyI p => ProtocolEncoding ErrorType (Command p) where protocolP v tag = (\(Cmd _ c) -> checkParty c) <$?> protocolP v (CT (sParty @p) tag) - fromProtocolError = fromProtocolError @ErrorType @BrokerMsg + fromProtocolError = fromProtocolError @SMPVersion @ErrorType @BrokerMsg {-# INLINE fromProtocolError #-} checkCredentials (auth, _, queueId, _) cmd = case cmd of @@ -1136,7 +1157,7 @@ instance PartyI p => ProtocolEncoding ErrorType (Command p) where | isNothing auth || B.null queueId -> Left $ CMD NO_AUTH | otherwise -> Right cmd -instance ProtocolEncoding ErrorType Cmd where +instance ProtocolEncoding SMPVersion ErrorType Cmd where type Tag Cmd = CmdTag encodeProtocol v (Cmd _ c) = encodeProtocol v c @@ -1164,12 +1185,12 @@ instance ProtocolEncoding ErrorType Cmd where PING_ -> pure PING CT SNotifier NSUB_ -> pure $ Cmd SNotifier NSUB - fromProtocolError = fromProtocolError @ErrorType @BrokerMsg + fromProtocolError = fromProtocolError @SMPVersion @ErrorType @BrokerMsg {-# INLINE fromProtocolError #-} checkCredentials t (Cmd p c) = Cmd p <$> checkCredentials t c -instance ProtocolEncoding ErrorType BrokerMsg where +instance ProtocolEncoding SMPVersion ErrorType BrokerMsg where type Tag BrokerMsg = BrokerMsgTag encodeProtocol _v = \case IDS (QIK rcvId sndId srvDh) -> e (IDS_, ' ', rcvId, sndId, srvDh) @@ -1221,12 +1242,12 @@ instance ProtocolEncoding ErrorType BrokerMsg where | otherwise -> Right cmd -- | Parse SMP protocol commands and broker messages -parseProtocol :: forall err msg. ProtocolEncoding err msg => Version -> ByteString -> Either err msg +parseProtocol :: forall v err msg. ProtocolEncoding v err msg => Version v -> ByteString -> Either err msg parseProtocol v s = let (tag, params) = B.break (== ' ') s in case decodeTag tag of - Just cmd -> parse (protocolP v cmd) (fromProtocolError @err @msg $ PECmdSyntax) params - Nothing -> Left $ fromProtocolError @err @msg $ PECmdUnknown + Just cmd -> parse (protocolP v cmd) (fromProtocolError @v @err @msg $ PECmdSyntax) params + Nothing -> Left $ fromProtocolError @v @err @msg $ PECmdUnknown checkParty :: forall t p p'. (PartyI p, PartyI p') => t p' -> Either String (t p) checkParty c = case testEquality (sParty @p) (sParty @p') of @@ -1281,7 +1302,7 @@ instance Encoding CommandError where _ -> fail "bad command error type" -- | Send signed SMP transmission to TCP transport. -tPut :: Transport c => THandle c -> NonEmpty (Either TransportError SentRawTransmission) -> IO [Either TransportError ()] +tPut :: Transport c => THandle v c -> NonEmpty (Either TransportError SentRawTransmission) -> IO [Either TransportError ()] tPut th@THandle {params} = fmap concat . mapM tPutBatch . batchTransmissions (batch params) (blockSize params) where tPutBatch :: TransportBatch () -> IO [Either TransportError ()] @@ -1290,7 +1311,7 @@ tPut th@THandle {params} = fmap concat . mapM tPutBatch . batchTransmissions (ba TBTransmissions s n _ -> replicate n <$> tPutLog th s TBTransmission s _ -> (: []) <$> tPutLog th s -tPutLog :: Transport c => THandle c -> ByteString -> IO (Either TransportError ()) +tPutLog :: Transport c => THandle v c -> ByteString -> IO (Either TransportError ()) tPutLog th s = do r <- tPutBlock th s case r of @@ -1352,7 +1373,7 @@ tEncodeBatch1 t = lenEncode 1 `B.cons` tEncodeForBatch t -- tForAuth is lazy to avoid computing it when there is no key to sign data TransmissionForAuth = TransmissionForAuth {tForAuth :: ~ByteString, tToSend :: ByteString} -encodeTransmissionForAuth :: ProtocolEncoding e c => THandleParams -> Transmission c -> TransmissionForAuth +encodeTransmissionForAuth :: ProtocolEncoding v e c => THandleParams v -> Transmission c -> TransmissionForAuth encodeTransmissionForAuth THandleParams {thVersion = v, sessionId, implySessId} t = TransmissionForAuth {tForAuth, tToSend = if implySessId then t' else tForAuth} where @@ -1360,24 +1381,24 @@ encodeTransmissionForAuth THandleParams {thVersion = v, sessionId, implySessId} t' = encodeTransmission_ v t {-# INLINE encodeTransmissionForAuth #-} -encodeTransmission :: ProtocolEncoding e c => THandleParams -> Transmission c -> ByteString +encodeTransmission :: ProtocolEncoding v e c => THandleParams v -> Transmission c -> ByteString encodeTransmission THandleParams {thVersion = v, sessionId, implySessId} t = if implySessId then t' else smpEncode sessionId <> t' where t' = encodeTransmission_ v t {-# INLINE encodeTransmission #-} -encodeTransmission_ :: ProtocolEncoding e c => Version -> Transmission c -> ByteString +encodeTransmission_ :: ProtocolEncoding v e c => Version v -> Transmission c -> ByteString encodeTransmission_ v (CorrId corrId, queueId, command) = smpEncode (corrId, queueId) <> encodeProtocol v command {-# INLINE encodeTransmission_ #-} -- | Receive and parse transmission from the TCP transport (ignoring any trailing padding). -tGetParse :: Transport c => THandle c -> IO (NonEmpty (Either TransportError RawTransmission)) +tGetParse :: Transport c => THandle v c -> IO (NonEmpty (Either TransportError RawTransmission)) tGetParse th@THandle {params} = eitherList (tParse params) <$> tGetBlock th {-# INLINE tGetParse #-} -tParse :: THandleParams -> ByteString -> NonEmpty (Either TransportError RawTransmission) +tParse :: THandleParams v -> ByteString -> NonEmpty (Either TransportError RawTransmission) tParse thParams@THandleParams {batch} s | batch = eitherList (L.map (\(Large t) -> tParse1 t)) ts | otherwise = [tParse1 s] @@ -1389,24 +1410,24 @@ eitherList :: (a -> NonEmpty (Either e b)) -> Either e a -> NonEmpty (Either e b eitherList = either (\e -> [Left e]) -- | Receive client and server transmissions (determined by `cmd` type). -tGet :: forall err cmd c. (ProtocolEncoding err cmd, Transport c) => THandle c -> IO (NonEmpty (SignedTransmission err cmd)) +tGet :: forall v err cmd c. (ProtocolEncoding v err cmd, Transport c) => THandle v c -> IO (NonEmpty (SignedTransmission err cmd)) tGet th@THandle {params} = L.map (tDecodeParseValidate params) <$> tGetParse th -tDecodeParseValidate :: forall err cmd. ProtocolEncoding err cmd => THandleParams -> Either TransportError RawTransmission -> SignedTransmission err cmd +tDecodeParseValidate :: forall v err cmd. ProtocolEncoding v err cmd => THandleParams v -> Either TransportError RawTransmission -> SignedTransmission err cmd tDecodeParseValidate THandleParams {sessionId, thVersion = v, implySessId} = \case Right RawTransmission {authenticator, authorized, sessId, corrId, entityId, command} | implySessId || sessId == sessionId -> let decodedTransmission = (,corrId,entityId,command) <$> decodeTAuthBytes authenticator in either (const $ tError corrId) (tParseValidate authorized) decodedTransmission - | otherwise -> (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @err @cmd PESession)) + | otherwise -> (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @v @err @cmd PESession)) Left _ -> tError "" where tError :: ByteString -> SignedTransmission err cmd - tError corrId = (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @err @cmd PEBlock)) + tError corrId = (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @v @err @cmd PEBlock)) tParseValidate :: ByteString -> SignedRawTransmission -> SignedTransmission err cmd tParseValidate signed t@(sig, corrId, entityId, command) = - let cmd = parseProtocol @err @cmd v command >>= checkCredentials t + let cmd = parseProtocol @v @err @cmd v command >>= checkCredentials t in (sig, signed, (CorrId corrId, entityId, cmd)) $(J.deriveJSON defaultJSON ''MsgFlags) diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index aaa42d91b..0dcde0350 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -380,7 +380,7 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do CPQuit -> pure () CPSkip -> pure () -runClientTransport :: Transport c => THandle c -> M () +runClientTransport :: Transport c => THandleSMP c -> M () runClientTransport th@THandle {params = THandleParams {thVersion, sessionId}} = do q <- asks $ tbqSize . config ts <- liftIO getSystemTime @@ -428,7 +428,7 @@ cancelSub sub = Sub {subThread = SubThread t} -> liftIO $ deRefWeak t >>= mapM_ killThread _ -> return () -receive :: Transport c => THandle c -> Client -> M () +receive :: Transport c => THandleSMP c -> Client -> M () receive th@THandle {params = THandleParams {thAuth}} Client {rcvQ, sndQ, rcvActiveAt, sessionId} = do labelMyThread . B.unpack $ "client $" <> encode sessionId <> " receive" forever $ do @@ -449,7 +449,7 @@ receive th@THandle {params = THandleParams {thAuth}} Client {rcvQ, sndQ, rcvActi VRFailed -> Left (corrId, queueId, ERR AUTH) write q = mapM_ (atomically . writeTBQueue q) . L.nonEmpty -send :: Transport c => THandle c -> Client -> IO () +send :: Transport c => THandleSMP c -> Client -> IO () send h@THandle {params} Client {sndQ, sessionId, sndActiveAt} = do labelMyThread . B.unpack $ "client $" <> encode sessionId <> " send" forever $ do @@ -464,7 +464,7 @@ send h@THandle {params} Client {sndQ, sessionId, sndActiveAt} = do NMSG {} -> 0 _ -> 1 -disconnectTransport :: Transport c => THandle c -> TVar SystemTime -> TVar SystemTime -> ExpirationConfig -> IO Bool -> IO () +disconnectTransport :: Transport c => THandle v c -> TVar SystemTime -> TVar SystemTime -> ExpirationConfig -> IO Bool -> IO () disconnectTransport THandle {connection, params = THandleParams {sessionId}} rcvActiveAt sndActiveAt expCfg noSubscriptions = do labelMyThread . B.unpack $ "client $" <> encode sessionId <> " disconnectTransport" loop diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index 82666a0fc..7b9fef0b3 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -33,9 +33,8 @@ import Simplex.Messaging.Server.Stats import Simplex.Messaging.Server.StoreLog import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Transport (ATransport) +import Simplex.Messaging.Transport (ATransport, VersionSMP, VersionRangeSMP) import Simplex.Messaging.Transport.Server (SocketState, TransportServerConfig, loadFingerprint, loadTLSServerParams, newSocketState) -import Simplex.Messaging.Version import System.IO (IOMode (..)) import System.Mem.Weak (Weak) import UnliftIO.STM @@ -73,7 +72,7 @@ data ServerConfig = ServerConfig privateKeyFile :: FilePath, certificateFile :: FilePath, -- | SMP client-server protocol version range - smpServerVRange :: VersionRange, + smpServerVRange :: VersionRangeSMP, -- | TCP transport config transportConfig :: TransportServerConfig, -- | run listener on control port @@ -128,7 +127,7 @@ data Client = Client sndQ :: TBQueue (NonEmpty (Transmission BrokerMsg)), endThreads :: TVar (IntMap (Weak ThreadId)), endThreadSeq :: TVar Int, - thVersion :: Version, + thVersion :: VersionSMP, sessionId :: ByteString, connected :: TVar Bool, createdAt :: SystemTime, @@ -152,7 +151,7 @@ newServer = do savingLock <- createLock return Server {subscribedQ, subscribers, ntfSubscribedQ, notifiers, savingLock} -newClient :: TVar Int -> Natural -> Version -> ByteString -> SystemTime -> STM Client +newClient :: TVar Int -> Natural -> VersionSMP -> ByteString -> SystemTime -> STM Client newClient nextClientId qSize thVersion sessionId createdAt = do clientId <- stateTVar nextClientId $ \next -> (next, next + 1) subscriptions <- TM.empty diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index 0d9552f9b..775400260 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -9,6 +9,7 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} @@ -27,10 +28,15 @@ -- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a module Simplex.Messaging.Transport ( -- * SMP transport parameters + SMPVersion, + VersionSMP, + VersionRangeSMP, + THandleSMP, supportedClientSMPRelayVRange, supportedServerSMPRelayVRange, currentClientSMPRelayVersion, currentServerSMPRelayVersion, + batchCmdsSMPVersion, basicAuthSMPVersion, subModeSMPVersion, authCmdsSMPVersion, @@ -85,6 +91,7 @@ import qualified Data.ByteString.Lazy.Char8 as LB import Data.Default (def) import Data.Functor (($>)) import Data.Version (showVersion) +import Data.Word (Word16) import qualified Data.X509 as X import qualified Data.X509.Validation as XV import GHC.IO.Handle.Internals (ioe_EOF) @@ -98,6 +105,7 @@ import Simplex.Messaging.Parsers (dropPrefix, parseRead1, sumTypeJSON) import Simplex.Messaging.Transport.Buffer import Simplex.Messaging.Util (bshow, catchAll, catchAll_, liftEitherWith) import Simplex.Messaging.Version +import Simplex.Messaging.Version.Internal import UnliftIO.Exception (Exception) import qualified UnliftIO.Exception as E import UnliftIO.STM @@ -116,30 +124,41 @@ smpBlockSize = 16384 -- 6 - allow creating queues without subscribing (9/10/2023) -- 7 - support authenticated encryption to verify senders' commands, imply but do NOT send session ID in signed part (2/3/2024) -batchCmdsSMPVersion :: Version -batchCmdsSMPVersion = 4 +data SMPVersion -basicAuthSMPVersion :: Version -basicAuthSMPVersion = 5 +instance VersionScope SMPVersion -subModeSMPVersion :: Version -subModeSMPVersion = 6 +type VersionSMP = Version SMPVersion -authCmdsSMPVersion :: Version -authCmdsSMPVersion = 7 +type VersionRangeSMP = VersionRange SMPVersion -currentClientSMPRelayVersion :: Version -currentClientSMPRelayVersion = 6 +pattern VersionSMP :: Word16 -> VersionSMP +pattern VersionSMP v = Version v -currentServerSMPRelayVersion :: Version -currentServerSMPRelayVersion = 6 +batchCmdsSMPVersion :: VersionSMP +batchCmdsSMPVersion = VersionSMP 4 + +basicAuthSMPVersion :: VersionSMP +basicAuthSMPVersion = VersionSMP 5 + +subModeSMPVersion :: VersionSMP +subModeSMPVersion = VersionSMP 6 + +authCmdsSMPVersion :: VersionSMP +authCmdsSMPVersion = VersionSMP 7 + +currentClientSMPRelayVersion :: VersionSMP +currentClientSMPRelayVersion = VersionSMP 6 + +currentServerSMPRelayVersion :: VersionSMP +currentServerSMPRelayVersion = VersionSMP 6 -- minimal supported protocol version is 4 -- TODO remove code that supports sending commands without batching -supportedClientSMPRelayVRange :: VersionRange +supportedClientSMPRelayVRange :: VersionRangeSMP supportedClientSMPRelayVRange = mkVersionRange batchCmdsSMPVersion currentClientSMPRelayVersion -supportedServerSMPRelayVRange :: VersionRange +supportedServerSMPRelayVRange :: VersionRangeSMP supportedServerSMPRelayVRange = mkVersionRange batchCmdsSMPVersion currentServerSMPRelayVersion simplexMQVersion :: String @@ -287,16 +306,18 @@ instance Transport TLS where -- * SMP transport -- | The handle for SMP encrypted transport connection over Transport. -data THandle c = THandle +data THandle v c = THandle { connection :: c, - params :: THandleParams + params :: THandleParams v } -data THandleParams = THandleParams +type THandleSMP c = THandle SMPVersion c + +data THandleParams v = THandleParams { sessionId :: SessionId, blockSize :: Int, -- | agreed server protocol version - thVersion :: Version, + thVersion :: Version v, -- | peer public key for command authorization and shared secrets for entity ID encryption thAuth :: Maybe THandleAuth, -- | do NOT send session ID in transmission, but include it into signed message @@ -316,7 +337,7 @@ data THandleAuth = THandleAuth type SessionId = ByteString data ServerHandshake = ServerHandshake - { smpVersionRange :: VersionRange, + { smpVersionRange :: VersionRangeSMP, sessionId :: SessionId, -- pub key to agree shared secrets for command authorization and entity ID encryption. authPubKey :: Maybe (X.CertificateChain, X.SignedExact X.PubKey) @@ -324,7 +345,7 @@ data ServerHandshake = ServerHandshake data ClientHandshake = ClientHandshake { -- | agreed SMP server protocol version - smpVersion :: Version, + smpVersion :: VersionSMP, -- | server identity - CA certificate fingerprint keyHash :: C.KeyHash, -- pub key to agree shared secret for entity ID encryption, shared secret for command authorization is agreed using per-queue keys. @@ -358,12 +379,12 @@ instance Encoding ServerHandshake where C.SignedObject key <- smpP pure (cert, key) -encodeAuthEncryptCmds :: Encoding a => Version -> Maybe a -> ByteString +encodeAuthEncryptCmds :: Encoding a => VersionSMP -> Maybe a -> ByteString encodeAuthEncryptCmds v k | v >= authCmdsSMPVersion = maybe "" smpEncode k | otherwise = "" -authEncryptCmdsP :: Version -> Parser a -> Parser (Maybe a) +authEncryptCmdsP :: VersionSMP -> Parser a -> Parser (Maybe a) authEncryptCmdsP v p = if v >= authCmdsSMPVersion then Just <$> p else pure Nothing -- | Error of SMP encrypted transport over TCP. @@ -412,13 +433,13 @@ serializeTransportError = \case TEHandshake e -> "HANDSHAKE " <> bshow e -- | Pad and send block to SMP transport. -tPutBlock :: Transport c => THandle c -> ByteString -> IO (Either TransportError ()) +tPutBlock :: Transport c => THandle v c -> ByteString -> IO (Either TransportError ()) tPutBlock THandle {connection = c, params = THandleParams {blockSize}} block = bimapM (const $ pure TELargeMsg) (cPut c) $ C.pad block blockSize -- | Receive block from SMP transport. -tGetBlock :: Transport c => THandle c -> IO (Either TransportError ByteString) +tGetBlock :: Transport c => THandle v c -> IO (Either TransportError ByteString) tGetBlock THandle {connection = c, params = THandleParams {blockSize}} = do msg <- cGet c blockSize if B.length msg == blockSize @@ -428,7 +449,7 @@ tGetBlock THandle {connection = c, params = THandleParams {blockSize}} = do -- | Server SMP transport handshake. -- -- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a -smpServerHandshake :: forall c. Transport c => C.APrivateSignKey -> c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c) +smpServerHandshake :: forall c. Transport c => C.APrivateSignKey -> c -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeSMP -> ExceptT TransportError IO (THandleSMP c) smpServerHandshake serverSignKey c (k, pk) kh smpVRange = do let th@THandle {params = THandleParams {sessionId}} = smpTHandle c sk = C.signX509 serverSignKey $ C.publicToX509 k @@ -445,7 +466,7 @@ smpServerHandshake serverSignKey c (k, pk) kh smpVRange = do -- | Client SMP transport handshake. -- -- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a -smpClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c) +smpClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeSMP -> ExceptT TransportError IO (THandleSMP c) smpClientHandshake c (k, pk) keyHash@(C.KeyHash kh) smpVRange = do let th@THandle {params = THandleParams {sessionId}} = smpTHandle c ServerHandshake {sessionId = sessId, smpVersionRange, authPubKey} <- getHandshake th @@ -465,24 +486,24 @@ smpClientHandshake c (k, pk) keyHash@(C.KeyHash kh) smpVRange = do pure $ smpThHandle th v pk sk_ Nothing -> throwE $ TEHandshake VERSION -smpThHandle :: forall c. THandle c -> Version -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> THandle c +smpThHandle :: forall c. THandleSMP c -> VersionSMP -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> THandleSMP c smpThHandle th@THandle {params} v privKey k_ = -- TODO drop SMP v6: make thAuth non-optional let thAuth = (\k -> THandleAuth {peerPubKey = k, privKey}) <$> k_ params' = params {thVersion = v, thAuth, implySessId = v >= authCmdsSMPVersion} - in (th :: THandle c) {params = params'} + in (th :: THandleSMP c) {params = params'} -sendHandshake :: (Transport c, Encoding smp) => THandle c -> smp -> ExceptT TransportError IO () +sendHandshake :: (Transport c, Encoding smp) => THandle v c -> smp -> ExceptT TransportError IO () sendHandshake th = ExceptT . tPutBlock th . smpEncode -- ignores tail bytes to allow future extensions -getHandshake :: (Transport c, Encoding smp) => THandle c -> ExceptT TransportError IO smp +getHandshake :: (Transport c, Encoding smp) => THandle v c -> ExceptT TransportError IO smp getHandshake th = ExceptT $ (first (\_ -> TEHandshake PARSE) . A.parseOnly smpP =<<) <$> tGetBlock th -smpTHandle :: Transport c => c -> THandle c +smpTHandle :: Transport c => c -> THandleSMP c smpTHandle c = THandle {connection = c, params} where - params = THandleParams {sessionId = tlsUnique c, blockSize = smpBlockSize, thVersion = 0, thAuth = Nothing, implySessId = False, batch = True} + params = THandleParams {sessionId = tlsUnique c, blockSize = smpBlockSize, thVersion = VersionSMP 0, thAuth = Nothing, implySessId = False, batch = True} $(J.deriveJSON (sumTypeJSON id) ''HandshakeError) diff --git a/src/Simplex/Messaging/Version.hs b/src/Simplex/Messaging/Version.hs index dc8cfff68..78d290687 100644 --- a/src/Simplex/Messaging/Version.hs +++ b/src/Simplex/Messaging/Version.hs @@ -1,13 +1,15 @@ {-# LANGUAGE ConstrainedClassMethods #-} {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeSynonymInstances #-} module Simplex.Messaging.Version ( Version, VersionRange (minVersion, maxVersion), + VersionScope, pattern VersionRange, VersionI (..), VersionRangeI (..), @@ -24,47 +26,45 @@ module Simplex.Messaging.Version where import Control.Applicative (optional) -import Data.Aeson (FromJSON (..), ToJSON (..)) import qualified Data.Attoparsec.ByteString.Char8 as A -import Data.Word (Word16) import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Util ((<$?>)) +import Simplex.Messaging.Version.Internal (Version (..)) -pattern VersionRange :: Word16 -> Word16 -> VersionRange +pattern VersionRange :: Version v -> Version v -> VersionRange v pattern VersionRange v1 v2 <- VRange v1 v2 {-# COMPLETE VersionRange #-} -type Version = Word16 - -data VersionRange = VRange - { minVersion :: Version, - maxVersion :: Version +data VersionRange v = VRange + { minVersion :: Version v, + maxVersion :: Version v } deriving (Eq, Show) +class VersionScope v + -- | construct valid version range, to be used in constants -mkVersionRange :: Version -> Version -> VersionRange +mkVersionRange :: Version v -> Version v -> VersionRange v mkVersionRange v1 v2 | v1 <= v2 = VRange v1 v2 | otherwise = error "invalid version range" -safeVersionRange :: Version -> Version -> Maybe VersionRange +safeVersionRange :: Version v -> Version v -> Maybe (VersionRange v) safeVersionRange v1 v2 | v1 <= v2 = Just $ VRange v1 v2 | otherwise = Nothing -versionToRange :: Version -> VersionRange +versionToRange :: Version v -> VersionRange v versionToRange v = VRange v v -instance Encoding VersionRange where +instance VersionScope v => Encoding (VersionRange v) where smpEncode (VRange v1 v2) = smpEncode (v1, v2) smpP = maybe (fail "invalid version range") pure =<< safeVersionRange <$> smpP <*> smpP -instance StrEncoding VersionRange where +instance VersionScope v => StrEncoding (VersionRange v) where strEncode (VRange v1 v2) | v1 == v2 = strEncode v1 | otherwise = strEncode v1 <> "-" <> strEncode v2 @@ -73,32 +73,23 @@ instance StrEncoding VersionRange where v2 <- maybe (pure v1) (const strP) =<< optional (A.char '-') maybe (fail "invalid version range") pure $ safeVersionRange v1 v2 -instance ToJSON VersionRange where - toJSON (VRange v1 v2) = toJSON (v1, v2) - toEncoding (VRange v1 v2) = toEncoding (v1, v2) +class VersionScope v => VersionI v a | a -> v where + type VersionRangeT v a + version :: a -> Version v + toVersionRangeT :: a -> VersionRange v -> VersionRangeT v a -instance FromJSON VersionRange where - parseJSON v = - (\(v1, v2) -> maybe (Left "bad VersionRange") Right $ safeVersionRange v1 v2) - <$?> parseJSON v +class VersionScope v => VersionRangeI v a | a -> v where + type VersionT v a + versionRange :: a -> VersionRange v + toVersionT :: a -> Version v -> VersionT v a -class VersionI a where - type VersionRangeT a - version :: a -> Version - toVersionRangeT :: a -> VersionRange -> VersionRangeT a - -class VersionRangeI a where - type VersionT a - versionRange :: a -> VersionRange - toVersionT :: a -> Version -> VersionT a - -instance VersionI Version where - type VersionRangeT Version = VersionRange +instance VersionScope v => VersionI v (Version v) where + type VersionRangeT v (Version v) = VersionRange v version = id toVersionRangeT _ vr = vr -instance VersionRangeI VersionRange where - type VersionT VersionRange = Version +instance VersionScope v => VersionRangeI v (VersionRange v) where + type VersionT v (VersionRange v) = Version v versionRange = id toVersionT _ v = v @@ -109,18 +100,18 @@ pattern Compatible a <- Compatible_ a {-# COMPLETE Compatible #-} -isCompatible :: VersionI a => a -> VersionRange -> Bool +isCompatible :: VersionI v a => a -> VersionRange v -> Bool isCompatible x (VRange v1 v2) = let v = version x in v1 <= v && v <= v2 -isCompatibleRange :: VersionRangeI a => a -> VersionRange -> Bool +isCompatibleRange :: VersionRangeI v a => a -> VersionRange v -> Bool isCompatibleRange x (VRange min2 max2) = min1 <= max2 && min2 <= max1 where VRange min1 max1 = versionRange x -proveCompatible :: VersionI a => a -> VersionRange -> Maybe (Compatible a) +proveCompatible :: VersionI v a => a -> VersionRange v -> Maybe (Compatible a) proveCompatible x vr = x `mkCompatibleIf` (x `isCompatible` vr) -compatibleVersion :: VersionRangeI a => a -> VersionRange -> Maybe (Compatible (VersionT a)) +compatibleVersion :: VersionRangeI v a => a -> VersionRange v -> Maybe (Compatible (VersionT v a)) compatibleVersion x vr = toVersionT x (min max1 max2) `mkCompatibleIf` isCompatibleRange x vr where diff --git a/src/Simplex/Messaging/Version/Internal.hs b/src/Simplex/Messaging/Version/Internal.hs new file mode 100644 index 000000000..23cab1d1f --- /dev/null +++ b/src/Simplex/Messaging/Version/Internal.hs @@ -0,0 +1,25 @@ +module Simplex.Messaging.Version.Internal where + +import Data.Aeson (FromJSON (..), ToJSON (..)) +import Data.Word (Word16) +import Simplex.Messaging.Encoding +import Simplex.Messaging.Encoding.String + +-- Do not use constructor of this type directry +newtype Version v = Version Word16 + deriving (Eq, Ord, Show) + +instance Encoding (Version v) where + smpEncode (Version v) = smpEncode v + smpP = Version <$> smpP + +instance StrEncoding (Version v) where + strEncode (Version v) = strEncode v + strP = Version <$> strP + +instance ToJSON (Version v) where + toEncoding (Version v) = toEncoding v + toJSON (Version v) = toJSON v + +instance FromJSON (Version v) where + parseJSON v = Version <$> parseJSON v diff --git a/src/Simplex/RemoteControl/Client.hs b/src/Simplex/RemoteControl/Client.hs index c73679439..3cf1050fa 100644 --- a/src/Simplex/RemoteControl/Client.hs +++ b/src/Simplex/RemoteControl/Client.hs @@ -68,12 +68,6 @@ import Simplex.RemoteControl.Types import UnliftIO import UnliftIO.Concurrent -currentRCVersion :: Version -currentRCVersion = 1 - -supportedRCVRange :: VersionRange -supportedRCVRange = mkVersionRange 1 currentRCVersion - xrcpBlockSize :: Int xrcpBlockSize = 16384 @@ -181,7 +175,7 @@ connectRCHost drg pairing@RCHostPairing {caKey, caCert, idPrivKey, knownHost} ct { ca = certFingerprint caCert, host, port = fromIntegral portNum, - v = supportedRCVRange, + v = supportedRCPVRange, app = ctrlAppInfo, ts, skey = fst sessKeys, @@ -220,7 +214,7 @@ prepareHostSession unless (ca == tlsHostFingerprint) $ throwError RCEIdentity (kemCiphertext, kemSharedKey) <- liftIO $ sntrup761Enc drg kemPubKey let hybridKey = kemHybridSecret dhPubKey dhPrivKey kemSharedKey - unless (isCompatible v supportedRCVRange) $ throwError RCEVersion + unless (isCompatible v supportedRCPVRange) $ throwError RCEVersion let keys = HostSessKeys {hybridKey, idPrivKey, sessPrivKey} knownHost' <- updateKnownHost ca dhPubKey let ctrlHello = RCCtrlHello {} @@ -334,7 +328,7 @@ prepareHostHello RCInvitation {v, dh = dhPubKey} hostAppInfo = do logDebug "Preparing session" - case compatibleVersion v supportedRCVRange of + case compatibleVersion v supportedRCPVRange of Nothing -> throwError RCEVersion Just (Compatible v') -> do nonce <- liftIO . atomically $ C.randomCbNonce drg diff --git a/src/Simplex/RemoteControl/Invitation.hs b/src/Simplex/RemoteControl/Invitation.hs index f5deac9a8..712c41a9d 100644 --- a/src/Simplex/RemoteControl/Invitation.hs +++ b/src/Simplex/RemoteControl/Invitation.hs @@ -27,7 +27,7 @@ import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (parseAll) import Simplex.Messaging.Transport.Client (TransportHost) -import Simplex.Messaging.Version (VersionRange) +import Simplex.RemoteControl.Types (VersionRangeRCP) data RCInvitation = RCInvitation { -- | CA TLS certificate fingerprint of the controller. @@ -37,7 +37,7 @@ data RCInvitation = RCInvitation host :: TransportHost, port :: Word16, -- | Supported version range for remote control protocol - v :: VersionRange, + v :: VersionRangeRCP, -- | Application information app :: J.Value, -- | Session start time in seconds since epoch diff --git a/src/Simplex/RemoteControl/Types.hs b/src/Simplex/RemoteControl/Types.hs index e1598f25c..b8a7c1141 100644 --- a/src/Simplex/RemoteControl/Types.hs +++ b/src/Simplex/RemoteControl/Types.hs @@ -5,6 +5,7 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TemplateHaskell #-} @@ -17,6 +18,7 @@ import Data.ByteString (ByteString) import Data.Text (Text) import qualified Data.Text as T import Data.Text.Encoding (encodeUtf8) +import Data.Word (Word16) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.SNTRUP761 import Simplex.Messaging.Crypto.SNTRUP761.Bindings @@ -26,7 +28,8 @@ import Simplex.Messaging.Parsers (defaultJSON, dropPrefix, sumTypeJSON) import Simplex.Messaging.Transport (TLS) import Simplex.Messaging.Transport.Client (TransportHost) import Simplex.Messaging.Util (safeDecodeUtf8) -import Simplex.Messaging.Version (Version, VersionRange, mkVersionRange) +import Simplex.Messaging.Version (VersionRange, VersionScope, mkVersionRange) +import Simplex.Messaging.Version.Internal import UnliftIO data RCErrorType @@ -92,24 +95,37 @@ instance StrEncoding RCErrorType where -- * Discovery -ipProbeVersionRange :: VersionRange -ipProbeVersionRange = mkVersionRange 1 1 +data RCPVersion + +instance VersionScope RCPVersion + +type VersionRCP = Version RCPVersion + +type VersionRangeRCP = VersionRange RCPVersion + +pattern VersionRCP :: Word16 -> VersionRCP +pattern VersionRCP v = Version v + +currentRCPVersion :: VersionRCP +currentRCPVersion = VersionRCP 1 + +supportedRCPVRange :: VersionRangeRCP +supportedRCPVRange = mkVersionRange (VersionRCP 1) currentRCPVersion data IpProbe = IpProbe - { versionRange :: VersionRange, + { versionRange :: VersionRangeRCP, randomNonce :: ByteString } deriving (Show) instance Encoding IpProbe where smpEncode IpProbe {versionRange, randomNonce} = smpEncode (versionRange, 'I', randomNonce) - smpP = IpProbe <$> (smpP <* "I") *> smpP -- * Session data RCHostHello = RCHostHello - { v :: Version, + { v :: VersionRCP, ca :: C.KeyHash, app :: J.Value, kem :: KEMPublicKey diff --git a/tests/AgentTests/ConnectionRequestTests.hs b/tests/AgentTests/ConnectionRequestTests.hs index eae87651e..758c2d66d 100644 --- a/tests/AgentTests/ConnectionRequestTests.hs +++ b/tests/AgentTests/ConnectionRequestTests.hs @@ -4,6 +4,7 @@ {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE PatternSynonyms #-} {-# OPTIONS_GHC -Wno-orphans #-} {-# OPTIONS_GHC -fno-warn-ambiguous-fields #-} @@ -16,7 +17,7 @@ import Simplex.Messaging.Agent.Protocol import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.Ratchet import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Protocol (ProtocolServer (..), supportedSMPClientVRange) +import Simplex.Messaging.Protocol (ProtocolServer (..), pattern VersionSMPC, supportedSMPClientVRange) import Simplex.Messaging.ServiceScheme (ServiceScheme (..)) import Simplex.Messaging.Version import Test.Hspec @@ -53,7 +54,7 @@ queue :: SMPQueueUri queue = SMPQueueUri supportedSMPClientVRange queueAddr queueV1 :: SMPQueueUri -queueV1 = SMPQueueUri (mkVersionRange 1 1) queueAddr +queueV1 = SMPQueueUri (mkVersionRange (VersionSMPC 1) (VersionSMPC 1)) queueAddr testDhKey :: C.PublicKeyX25519 testDhKey = "MCowBQYDK2VuAyEAjiswwI3O/NlS8Fk3HJUW870EY2bAwmttMBsvRB9eV3o=" @@ -68,7 +69,7 @@ connReqData :: ConnReqUriData connReqData = ConnReqUriData { crScheme = SSSimplex, - crAgentVRange = mkVersionRange 2 2, + crAgentVRange = mkVersionRange (VersionSMPA 2) (VersionSMPA 2), crSmpQueues = [queueV1], crClientData = Nothing } @@ -77,7 +78,7 @@ testDhPubKey :: C.PublicKeyX448 testDhPubKey = "MEIwBQYDK2VvAzkAmKuSYeQ/m0SixPDS8Wq8VBaTS1cW+Lp0n0h4Diu+kUpR+qXx4SDJ32YGEFoGFGSbGPry5Ychr6U=" testE2ERatchetParams :: RcvE2ERatchetParamsUri 'C.X448 -testE2ERatchetParams = E2ERatchetParamsUri (mkVersionRange 1 1) testDhPubKey testDhPubKey Nothing +testE2ERatchetParams = E2ERatchetParamsUri (mkVersionRange (VersionE2E 1) (VersionE2E 1)) testDhPubKey testDhPubKey Nothing testE2ERatchetParams12 :: RcvE2ERatchetParamsUri 'C.X448 testE2ERatchetParams12 = E2ERatchetParamsUri supportedE2EEncryptVRange testDhPubKey testDhPubKey Nothing @@ -113,7 +114,7 @@ connectionRequestTests = it "should serialize SMP queue URIs" $ do strEncode (queue :: SMPQueueUri) {queueAddress = queueAddrNoPort} `shouldBe` "smp://1234-w==@smp.simplex.im/3456-w==#/?v=1-2&dh=" <> testDhKeyStrUri - strEncode queue {clientVRange = mkVersionRange 1 2} + strEncode queue {clientVRange = mkVersionRange (VersionSMPC 1) (VersionSMPC 2)} `shouldBe` "smp://1234-w==@smp.simplex.im:5223/3456-w==#/?v=1-2&dh=" <> testDhKeyStrUri it "should parse SMP queue URIs" $ do strDecode ("smp://1234-w==@smp.simplex.im/3456-w==#/?v=1-2&dh=" <> testDhKeyStr) diff --git a/tests/AgentTests/DoubleRatchetTests.hs b/tests/AgentTests/DoubleRatchetTests.hs index a0d2deb5f..87c9801ee 100644 --- a/tests/AgentTests/DoubleRatchetTests.hs +++ b/tests/AgentTests/DoubleRatchetTests.hs @@ -18,7 +18,7 @@ import Control.Monad (when) import Control.Monad.Except import Control.Monad.IO.Class import Crypto.Random (ChaChaDRG) -import Data.Aeson (FromJSON, ToJSON) +import Data.Aeson (FromJSON, ToJSON, (.=)) import qualified Data.Aeson as J import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B @@ -44,6 +44,7 @@ doubleRatchetTests = do it "should encode/decode ratchet as JSON" $ do testAlgs testKeyJSON testAlgs testRatchetJSON + testVersionJSON it "should agree the same ratchet parameters" $ testAlgs testX3dh it "should agree the same ratchet parameters with version 1" $ testAlgs testX3dhV1 describe "post-quantum hybrid KEM double-ratchet algorithm" $ do @@ -87,12 +88,12 @@ testAlgs test = test C.SX25519 >> test C.SX448 paddedMsgLen :: Int paddedMsgLen = 100 -fullMsgLen :: Version -> Int +fullMsgLen :: VersionE2E -> Int fullMsgLen v = headerLenLength + fullHeaderLen + C.authTagSize + paddedMsgLen where headerLenLength = if v < pqRatchetVersion then 1 else 3 -- two bytes are added because of two Large used in new encoding -testMessageHeader :: forall a. AlgorithmI a => Version -> C.SAlgorithm a -> Expectation +testMessageHeader :: forall a. AlgorithmI a => VersionE2E -> C.SAlgorithm a -> Expectation testMessageHeader v _ = do (k, _) <- atomically . C.generateKeyPair @a =<< C.newRandom let hdr = MsgHeader {msgMaxVersion = v, msgDHRs = k, msgKEM = Nothing, msgPN = 0, msgNs = 0} @@ -335,6 +336,20 @@ testRatchetJSON _ = do testEncodeDecode alice testEncodeDecode bob +testVersionJSON :: IO () +testVersionJSON = do + testEncodeDecode $ rv 1 1 + testEncodeDecode $ rv 1 2 + -- let bad = RVersions 2 1 + -- Left err <- pure $ J.eitherDecode' @RatchetVersions (J.encode bad) + -- err `shouldContain` "bad version range" + testDecodeRV $ (1 :: Int, 2 :: Int) + testDecodeRV $ J.object ["current" .= (1 :: Int), "maxSupported" .= (2 :: Int)] + where + rv v1 v2 = ratchetVersions $ mkVersionRange (VersionE2E v1) (VersionE2E v2) + testDecodeRV :: ToJSON a => a -> Expectation + testDecodeRV a = J.eitherDecode' (J.encode a) `shouldBe` Right (rv 1 2) + testEncodeDecode :: (Eq a, Show a, ToJSON a, FromJSON a) => a -> Expectation testEncodeDecode x = do let j = J.encode x @@ -354,8 +369,8 @@ testX3dh _ = do testX3dhV1 :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO () testX3dhV1 _ = do g <- C.newRandom - (pkBob1, pkBob2, Nothing, AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g 1 Nothing - (pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams @a g 1 PQEncOff + (pkBob1, pkBob2, Nothing, AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g (VersionE2E 1) Nothing + (pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams @a g (VersionE2E 1) PQEncOff let paramsBob = pqX3dhSnd pkBob1 pkBob2 Nothing e2eAlice paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 Nothing e2eBob paramsAlice `shouldBe` paramsBob @@ -544,7 +559,7 @@ encrypt_ enableKem (_, rc, _) msg = >>= either (pure . Left) checkLength where checkLength (msg', rc') = do - B.length msg' `shouldBe` fullMsgLen (maxVersion $ rcVersion rc) + B.length msg' `shouldBe` fullMsgLen (maxSupported $ rcVersion rc) pure $ Right (msg', rc', SMDNoChange) decrypt_ :: (AlgorithmI a, DhAlgorithm a) => (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError (Either CryptoError ByteString, Ratchet a, SkippedMsgDiff)) diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index a9a261711..11adc2d0a 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -58,6 +58,7 @@ import qualified Data.Set as S import Data.Time.Clock (diffUTCTime, getCurrentTime) import Data.Time.Clock.System (SystemTime (..), getSystemTime) import Data.Type.Equality +import Data.Word (Word16) import qualified Database.SQLite.Simple as SQL import SMPAgentClient import SMPClient (cfg, testPort, testPort2, testStoreLogFile2, withSmpServer, withSmpServerConfigOn, withSmpServerOn, withSmpServerStoreLogOn, withSmpServerStoreMsgLogOn, withSmpServerV7) @@ -74,13 +75,15 @@ import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), PQEncryption (..), pattern PQEncOn, pattern PQEncOff) import qualified Simplex.Messaging.Crypto.Ratchet as CR import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Notifications.Transport (authBatchCmdsNTFVersion) +import Simplex.Messaging.Notifications.Transport (NTFVersion, pattern VersionNTF, authBatchCmdsNTFVersion) import Simplex.Messaging.Protocol (AProtocolType (..), BasicAuth, ErrorType (..), MsgBody, ProtocolServer (..), SubscriptionMode (..), supportedSMPClientVRange) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Server.Env.STM (ServerConfig (..)) import Simplex.Messaging.Server.Expiration -import Simplex.Messaging.Transport (ATransport (..), authCmdsSMPVersion, basicAuthSMPVersion, currentServerSMPRelayVersion) -import Simplex.Messaging.Version +import Simplex.Messaging.Transport (ATransport (..), SMPVersion, VersionSMP, authCmdsSMPVersion, batchCmdsSMPVersion, basicAuthSMPVersion, currentServerSMPRelayVersion) +import Simplex.Messaging.Version (VersionRange (..)) +import qualified Simplex.Messaging.Version as V +import Simplex.Messaging.Version.Internal (Version (..)) import System.Directory (copyFile, renameFile) import Test.Hspec import UnliftIO @@ -156,14 +159,14 @@ pattern MsgErr msgId err msgBody <- MSG MsgMeta {recipient = (msgId, _), integri pattern Rcvd :: AgentMsgId -> ACommand 'Agent 'AEConn pattern Rcvd agentMsgId <- RCVD MsgMeta {integrity = MsgOk} [MsgReceipt {agentMsgId, msgRcptStatus = MROk}] -smpCfgVPrev :: ProtocolClientConfig +smpCfgVPrev :: ProtocolClientConfig SMPVersion smpCfgVPrev = (smpCfg agentCfg) {serverVRange = prevRange $ serverVRange $ smpCfg agentCfg} -smpCfgV7 :: ProtocolClientConfig -smpCfgV7 = (smpCfg agentCfg) {serverVRange = mkVersionRange 4 authCmdsSMPVersion} +smpCfgV7 :: ProtocolClientConfig SMPVersion +smpCfgV7 = (smpCfg agentCfg) {serverVRange = V.mkVersionRange batchCmdsSMPVersion authCmdsSMPVersion} -ntfCfgV2 :: ProtocolClientConfig -ntfCfgV2 = (smpCfg agentCfg) {serverVRange = mkVersionRange 1 authBatchCmdsNTFVersion} +ntfCfgV2 :: ProtocolClientConfig NTFVersion +ntfCfgV2 = (smpCfg agentCfg) {serverVRange = V.mkVersionRange (VersionNTF 1) authBatchCmdsNTFVersion} agentCfgVPrev :: AgentConfig agentCfgVPrev = @@ -186,8 +189,14 @@ agentCfgV7 = agentCfgRatchetVPrev :: AgentConfig agentCfgRatchetVPrev = agentCfg {e2eEncryptVRange = prevRange $ e2eEncryptVRange agentCfg} -prevRange :: VersionRange -> VersionRange -prevRange vr = vr {maxVersion = max (minVersion vr) (maxVersion vr - 1)} +prevRange :: VersionRange v -> VersionRange v +prevRange vr = vr {maxVersion = max (minVersion vr) (prevVersion $ maxVersion vr)} + +prevVersion :: Version v -> Version v +prevVersion (Version v) = Version (v - 1) + +mkVersionRange :: Word16 -> Word16 -> VersionRange v +mkVersionRange v1 v2 = V.mkVersionRange (Version v1) (Version v2) runRight_ :: (Eq e, Show e, HasCallStack) => ExceptT e IO () -> Expectation runRight_ action = runExceptT action `shouldReturn` Right () @@ -352,8 +361,8 @@ functionalAPITests t = do describe "should switch two connections simultaneously, abort one" $ testServerMatrix2 t testSwitch2ConnectionsAbort1 describe "SMP basic auth" $ do - let v4 = basicAuthSMPVersion - 1 - forM_ (nub [authCmdsSMPVersion - 1, authCmdsSMPVersion, currentServerSMPRelayVersion]) $ \v -> do + let v4 = prevVersion basicAuthSMPVersion + forM_ (nub [prevVersion authCmdsSMPVersion, authCmdsSMPVersion, currentServerSMPRelayVersion]) $ \v -> do describe ("v" <> show v <> ": with server auth") $ do -- allow NEW | server auth, v | clnt1 auth, v | clnt2 auth, v | 2 - success, 1 - JOIN fail, 0 - NEW fail it "success " $ testBasicAuth t True (Just "abcd", v) (Just "abcd", v) (Just "abcd", v) `shouldReturn` 2 @@ -397,9 +406,9 @@ functionalAPITests t = do it "should send delivery receipt only in connection v3+" $ testDeliveryReceiptsVersion t it "send delivery receipts concurrently with messages" $ testDeliveryReceiptsConcurrent t -testBasicAuth :: ATransport -> Bool -> (Maybe BasicAuth, Version) -> (Maybe BasicAuth, Version) -> (Maybe BasicAuth, Version) -> IO Int +testBasicAuth :: ATransport -> Bool -> (Maybe BasicAuth, VersionSMP) -> (Maybe BasicAuth, VersionSMP) -> (Maybe BasicAuth, VersionSMP) -> IO Int testBasicAuth t allowNewQueues srv@(srvAuth, srvVersion) clnt1 clnt2 = do - let testCfg = cfg {allowNewQueues, newQueueBasicAuth = srvAuth, smpServerVRange = mkVersionRange 4 srvVersion} + let testCfg = cfg {allowNewQueues, newQueueBasicAuth = srvAuth, smpServerVRange = V.mkVersionRange batchCmdsSMPVersion srvVersion} canCreate1 = canCreateQueue allowNewQueues srv clnt1 canCreate2 = canCreateQueue allowNewQueues srv clnt2 expected @@ -410,7 +419,7 @@ testBasicAuth t allowNewQueues srv@(srvAuth, srvVersion) clnt1 clnt2 = do created `shouldBe` expected pure created -canCreateQueue :: Bool -> (Maybe BasicAuth, Version) -> (Maybe BasicAuth, Version) -> Bool +canCreateQueue :: Bool -> (Maybe BasicAuth, VersionSMP) -> (Maybe BasicAuth, VersionSMP) -> Bool canCreateQueue allowNew (srvAuth, srvVersion) (clntAuth, clntVersion) = let v = basicAuthSMPVersion in allowNew && (isNothing srvAuth || (srvVersion >= v && clntVersion >= v && srvAuth == clntAuth)) @@ -733,10 +742,10 @@ testIncreaseConnAgentVersion t = do disconnectAgentClient alice3 disconnectAgentClient bob3 -checkVersion :: AgentClient -> ConnId -> Version -> ExceptT AgentErrorType IO () +checkVersion :: AgentClient -> ConnId -> Word16 -> ExceptT AgentErrorType IO () checkVersion c connId v = do ConnectionStats {connAgentVersion} <- getConnectionServers c connId - liftIO $ connAgentVersion `shouldBe` v + liftIO $ connAgentVersion `shouldBe` VersionSMPA v testIncreaseConnAgentVersionMaxCompatible :: HasCallStack => ATransport -> IO () testIncreaseConnAgentVersionMaxCompatible t = do @@ -2276,7 +2285,7 @@ testSwitch2ConnectionsAbort1 servers = do withB :: (AgentClient -> IO a) -> IO a withB = withAgent 2 agentCfg servers testDB2 -testCreateQueueAuth :: HasCallStack => Version -> (Maybe BasicAuth, Version) -> (Maybe BasicAuth, Version) -> IO Int +testCreateQueueAuth :: HasCallStack => VersionSMP -> (Maybe BasicAuth, VersionSMP) -> (Maybe BasicAuth, VersionSMP) -> IO Int testCreateQueueAuth srvVersion clnt1 clnt2 = do a <- getClient 1 clnt1 testDB b <- getClient 2 clnt2 testDB2 @@ -2302,7 +2311,7 @@ testCreateQueueAuth srvVersion clnt1 clnt2 = do where getClient clientId (clntAuth, clntVersion) db = let servers = initAgentServers {smp = userServers [ProtoServerWithAuth testSMPServer clntAuth]} - smpCfg = (defaultSMPClientConfig :: ProtocolClientConfig) {serverVRange = mkVersionRange (basicAuthSMPVersion - 1) clntVersion} + smpCfg = (defaultSMPClientConfig :: ProtocolClientConfig SMPVersion) {serverVRange = V.mkVersionRange (prevVersion basicAuthSMPVersion) clntVersion} sndAuthAlg = if srvVersion >= authCmdsSMPVersion && clntVersion >= authCmdsSMPVersion then C.AuthAlg C.SX25519 else C.AuthAlg C.SEd25519 in getSMPAgentClient' clientId agentCfg {smpCfg, sndAuthAlg} servers db diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index 9665e3833..af91dac42 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -49,7 +49,7 @@ import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), pattern PQEncOn) import qualified Simplex.Messaging.Crypto.Ratchet as CR import Simplex.Messaging.Crypto.File (CryptoFile (..)) import Simplex.Messaging.Encoding.String (StrEncoding (..)) -import Simplex.Messaging.Protocol (SubscriptionMode (..)) +import Simplex.Messaging.Protocol (SubscriptionMode (..), pattern VersionSMPC) import qualified Simplex.Messaging.Protocol as SMP import System.Random import Test.Hspec @@ -185,7 +185,7 @@ cData1 = ConnData { userId = 1, connId = "conn1", - connAgentVersion = 1, + connAgentVersion = VersionSMPA 1, enableNtfs = True, lastExternalSndId = 0, deleted = False, @@ -222,7 +222,7 @@ rcvQueue1 = primary = True, dbReplaceQueueId = Nothing, rcvSwchStatus = Nothing, - smpClientVersion = 1, + smpClientVersion = VersionSMPC 1, clientNtfCreds = Nothing, deleteErrors = 0 } @@ -243,7 +243,7 @@ sndQueue1 = primary = True, dbReplaceQueueId = Nothing, sndSwchStatus = Nothing, - smpClientVersion = 1 + smpClientVersion = VersionSMPC 1 } createRcvConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> NewRcvQueue -> SConnectionMode c -> IO (Either StoreError (ConnId, RcvQueue)) @@ -387,7 +387,7 @@ testUpgradeRcvConnToDuplex = sndSwchStatus = Nothing, primary = True, dbReplaceQueueId = Nothing, - smpClientVersion = 1 + smpClientVersion = VersionSMPC 1 } upgradeRcvConnToDuplex db "conn1" anotherSndQueue `shouldReturn` Left (SEBadConnType CSnd) @@ -416,7 +416,7 @@ testUpgradeSndConnToDuplex = rcvSwchStatus = Nothing, primary = True, dbReplaceQueueId = Nothing, - smpClientVersion = 1, + smpClientVersion = VersionSMPC 1, clientNtfCreds = Nothing, deleteErrors = 0 } diff --git a/tests/CoreTests/BatchingTests.hs b/tests/CoreTests/BatchingTests.hs index eb9b62d3d..996d2fed1 100644 --- a/tests/CoreTests/BatchingTests.hs +++ b/tests/CoreTests/BatchingTests.hs @@ -15,7 +15,6 @@ import Simplex.Messaging.Client import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol import Simplex.Messaging.Transport -import Simplex.Messaging.Version (Version) import Test.Hspec batchingTests :: Spec @@ -253,27 +252,27 @@ testClientBatchWithLargeMessageV7 = do (length rs1', length rs2') `shouldBe` (74, 136) all lenOk [s1', s2'] `shouldBe` True -testClientStub :: IO (ProtocolClient ErrorType BrokerMsg) +testClientStub :: IO (ProtocolClient SMPVersion ErrorType BrokerMsg) testClientStub = do g <- C.newRandom sessId <- atomically $ C.randomBytes 32 g - atomically $ clientStub g sessId (authCmdsSMPVersion - 1) Nothing + atomically $ smpClientStub g sessId subModeSMPVersion Nothing -clientStubV7 :: IO (ProtocolClient ErrorType BrokerMsg) +clientStubV7 :: IO (ProtocolClient SMPVersion ErrorType BrokerMsg) clientStubV7 = do g <- C.newRandom sessId <- atomically $ C.randomBytes 32 g (rKey, _) <- atomically $ C.generateAuthKeyPair C.SX25519 g thAuth_ <- testTHandleAuth authCmdsSMPVersion g rKey - atomically $ clientStub g sessId authCmdsSMPVersion thAuth_ + atomically $ smpClientStub g sessId authCmdsSMPVersion thAuth_ randomSUB :: ByteString -> IO (Either TransportError (Maybe TransmissionAuth, ByteString)) -randomSUB = randomSUB_ C.SEd25519 (authCmdsSMPVersion - 1) +randomSUB = randomSUB_ C.SEd25519 subModeSMPVersion randomSUBv7 :: ByteString -> IO (Either TransportError (Maybe TransmissionAuth, ByteString)) randomSUBv7 = randomSUB_ C.SEd25519 authCmdsSMPVersion -randomSUB_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> Version -> ByteString -> IO (Either TransportError (Maybe TransmissionAuth, ByteString)) +randomSUB_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> VersionSMP -> ByteString -> IO (Either TransportError (Maybe TransmissionAuth, ByteString)) randomSUB_ a v sessId = do g <- C.newRandom rId <- atomically $ C.randomBytes 24 g @@ -284,13 +283,13 @@ randomSUB_ a v sessId = do TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (corrId, rId, Cmd SRecipient SUB) pure $ (,tToSend) <$> authTransmission thAuth_ (Just rpKey) corrId tForAuth -randomSUBCmd :: ProtocolClient ErrorType BrokerMsg -> IO (PCTransmission ErrorType BrokerMsg) +randomSUBCmd :: ProtocolClient SMPVersion ErrorType BrokerMsg -> IO (PCTransmission ErrorType BrokerMsg) randomSUBCmd = randomSUBCmd_ C.SEd25519 -randomSUBCmdV7 :: ProtocolClient ErrorType BrokerMsg -> IO (PCTransmission ErrorType BrokerMsg) +randomSUBCmdV7 :: ProtocolClient SMPVersion ErrorType BrokerMsg -> IO (PCTransmission ErrorType BrokerMsg) randomSUBCmdV7 = randomSUBCmd_ C.SEd25519 -- same as v6 -randomSUBCmd_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> ProtocolClient ErrorType BrokerMsg -> IO (PCTransmission ErrorType BrokerMsg) +randomSUBCmd_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> ProtocolClient SMPVersion ErrorType BrokerMsg -> IO (PCTransmission ErrorType BrokerMsg) randomSUBCmd_ a c = do g <- C.newRandom rId <- atomically $ C.randomBytes 24 g @@ -298,12 +297,12 @@ randomSUBCmd_ a c = do mkTransmission c (Just rpKey, rId, Cmd SRecipient SUB) randomSEND :: ByteString -> Int -> IO (Either TransportError (Maybe TransmissionAuth, ByteString)) -randomSEND = randomSEND_ C.SEd25519 (authCmdsSMPVersion - 1) +randomSEND = randomSEND_ C.SEd25519 subModeSMPVersion randomSENDv7 :: ByteString -> Int -> IO (Either TransportError (Maybe TransmissionAuth, ByteString)) randomSENDv7 = randomSEND_ C.SX25519 authCmdsSMPVersion -randomSEND_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> Version -> ByteString -> Int -> IO (Either TransportError (Maybe TransmissionAuth, ByteString)) +randomSEND_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> VersionSMP -> ByteString -> Int -> IO (Either TransportError (Maybe TransmissionAuth, ByteString)) randomSEND_ a v sessId len = do g <- C.newRandom sId <- atomically $ C.randomBytes 24 g @@ -315,7 +314,7 @@ randomSEND_ a v sessId len = do TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (corrId, sId, Cmd SSender $ SEND noMsgFlags msg) pure $ (,tToSend) <$> authTransmission thAuth_ (Just spKey) corrId tForAuth -testTHandleParams :: Version -> ByteString -> THandleParams +testTHandleParams :: VersionSMP -> ByteString -> THandleParams SMPVersion testTHandleParams v sessionId = THandleParams { sessionId, @@ -326,20 +325,20 @@ testTHandleParams v sessionId = batch = True } -testTHandleAuth :: Version -> TVar ChaChaDRG -> C.APublicAuthKey -> IO (Maybe THandleAuth) +testTHandleAuth :: VersionSMP -> TVar ChaChaDRG -> C.APublicAuthKey -> IO (Maybe THandleAuth) testTHandleAuth v g (C.APublicAuthKey a k) = case a of C.SX25519 | v >= authCmdsSMPVersion -> do (_, privKey) <- atomically $ C.generateKeyPair g pure $ Just THandleAuth {peerPubKey = k, privKey} _ -> pure Nothing -randomSENDCmd :: ProtocolClient ErrorType BrokerMsg -> Int -> IO (PCTransmission ErrorType BrokerMsg) +randomSENDCmd :: ProtocolClient SMPVersion ErrorType BrokerMsg -> Int -> IO (PCTransmission ErrorType BrokerMsg) randomSENDCmd = randomSENDCmd_ C.SEd25519 -randomSENDCmdV7 :: ProtocolClient ErrorType BrokerMsg -> Int -> IO (PCTransmission ErrorType BrokerMsg) +randomSENDCmdV7 :: ProtocolClient SMPVersion ErrorType BrokerMsg -> Int -> IO (PCTransmission ErrorType BrokerMsg) randomSENDCmdV7 = randomSENDCmd_ C.SX25519 -randomSENDCmd_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> ProtocolClient ErrorType BrokerMsg -> Int -> IO (PCTransmission ErrorType BrokerMsg) +randomSENDCmd_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> ProtocolClient SMPVersion ErrorType BrokerMsg -> Int -> IO (PCTransmission ErrorType BrokerMsg) randomSENDCmd_ a c len = do g <- C.newRandom sId <- atomically $ C.randomBytes 24 g diff --git a/tests/CoreTests/ProtocolErrorTests.hs b/tests/CoreTests/ProtocolErrorTests.hs index 6dc6f2c02..7b1a7b813 100644 --- a/tests/CoreTests/ProtocolErrorTests.hs +++ b/tests/CoreTests/ProtocolErrorTests.hs @@ -12,7 +12,7 @@ import qualified Data.Text as T import Data.Text.Encoding (encodeUtf8) import GHC.Generics (Generic) import Generic.Random (genericArbitraryU) -import Simplex.FileTransfer.Protocol (XFTPErrorType (..)) +import Simplex.FileTransfer.Transport (XFTPErrorType (..)) import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String diff --git a/tests/CoreTests/TRcvQueuesTests.hs b/tests/CoreTests/TRcvQueuesTests.hs index 64181a179..91722228b 100644 --- a/tests/CoreTests/TRcvQueuesTests.hs +++ b/tests/CoreTests/TRcvQueuesTests.hs @@ -1,5 +1,6 @@ {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE TypeApplications #-} module CoreTests.TRcvQueuesTests where @@ -12,7 +13,7 @@ import Simplex.Messaging.Agent.Protocol (ConnId, QueueStatus (..), UserId) import Simplex.Messaging.Agent.Store (DBQueueId (..), RcvQueue, StoredRcvQueue (..)) import qualified Simplex.Messaging.Agent.TRcvQueues as RQ import qualified Simplex.Messaging.Crypto as C -import Simplex.Messaging.Protocol (SMPServer) +import Simplex.Messaging.Protocol (SMPServer, pattern VersionSMPC) import Test.Hspec import UnliftIO @@ -137,7 +138,7 @@ dummyRQ userId server connId = primary = True, dbReplaceQueueId = Nothing, rcvSwchStatus = Nothing, - smpClientVersion = 123, + smpClientVersion = VersionSMPC 123, clientNtfCreds = Nothing, deleteErrors = 0 } diff --git a/tests/CoreTests/VersionRangeTests.hs b/tests/CoreTests/VersionRangeTests.hs index be02e38b7..cef556376 100644 --- a/tests/CoreTests/VersionRangeTests.hs +++ b/tests/CoreTests/VersionRangeTests.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} module CoreTests.VersionRangeTests where @@ -8,6 +9,7 @@ module CoreTests.VersionRangeTests where import GHC.Generics (Generic) import Generic.Random (genericArbitraryU) import Simplex.Messaging.Version +import Simplex.Messaging.Version.Internal import Test.Hspec import Test.Hspec.QuickCheck (modifyMaxSuccess) import Test.QuickCheck @@ -16,6 +18,10 @@ data V = V1 | V2 | V3 | V4 | V5 deriving (Eq, Enum, Ord, Generic, Show) instance Arbitrary V where arbitrary = genericArbitraryU +data T + +instance VersionScope T + versionRangeTests :: Spec versionRangeTests = modifyMaxSuccess (const 1000) $ do describe "VersionRange construction" $ do @@ -25,31 +31,31 @@ versionRangeTests = modifyMaxSuccess (const 1000) $ do (pure $! vr 2 1) `shouldThrow` anyErrorCall describe "compatible version" $ do it "should choose mutually compatible max version" $ do - (vr 1 1, vr 1 1) `compatible` Just 1 - (vr 1 1, vr 1 2) `compatible` Just 1 - (vr 1 2, vr 1 2) `compatible` Just 2 - (vr 1 2, vr 2 3) `compatible` Just 2 - (vr 1 3, vr 2 3) `compatible` Just 3 - (vr 1 3, vr 2 4) `compatible` Just 3 + (vr 1 1, vr 1 1) `compatible` Just (Version 1) + (vr 1 1, vr 1 2) `compatible` Just (Version 1) + (vr 1 2, vr 1 2) `compatible` Just (Version 2) + (vr 1 2, vr 2 3) `compatible` Just (Version 2) + (vr 1 3, vr 2 3) `compatible` Just (Version 3) + (vr 1 3, vr 2 4) `compatible` Just (Version 3) (vr 1 2, vr 3 4) `compatible` Nothing it "should check if version is compatible" $ do - isCompatible (1 :: Version) (vr 1 2) `shouldBe` True - isCompatible (2 :: Version) (vr 1 2) `shouldBe` True - isCompatible (2 :: Version) (vr 1 1) `shouldBe` False - isCompatible (1 :: Version) (vr 2 2) `shouldBe` False + isCompatible @T (Version 1) (vr 1 2) `shouldBe` True + isCompatible @T (Version 2) (vr 1 2) `shouldBe` True + isCompatible @T (Version 2) (vr 1 1) `shouldBe` False + isCompatible @T (Version 1) (vr 2 2) `shouldBe` False it "compatibleVersion should pass isCompatible check" . property $ \((min1, max1) :: (V, V)) ((min2, max2) :: (V, V)) -> min1 > max1 || min2 > max2 -- one of ranges is invalid, skip testing it - || let w = fromIntegral . fromEnum - vr1 = mkVersionRange (w min1) (w max1) :: VersionRange - vr2 = mkVersionRange (w min2) (w max2) :: VersionRange + || let w = Version . fromIntegral . fromEnum + vr1 = mkVersionRange (w min1) (w max1) :: VersionRange T + vr2 = mkVersionRange (w min2) (w max2) :: VersionRange T in case compatibleVersion vr1 vr2 of Just (Compatible v) -> v `isCompatible` vr1 && v `isCompatible` vr2 _ -> True where - vr = mkVersionRange - compatible :: (VersionRange, VersionRange) -> Maybe Version -> Expectation + vr v1 v2 = mkVersionRange (Version v1) (Version v2) + compatible :: (VersionRange T, VersionRange T) -> Maybe (Version T) -> Expectation (vr1, vr2) `compatible` v = do (vr1, vr2) `checkCompatible` v (vr2, vr1) `checkCompatible` v diff --git a/tests/NtfClient.hs b/tests/NtfClient.hs index 43558a86c..5a2dbb8de 100644 --- a/tests/NtfClient.hs +++ b/tests/NtfClient.hs @@ -34,6 +34,7 @@ import Simplex.Messaging.Client (ProtocolClientConfig (..), chooseTransportHost, import Simplex.Messaging.Client.Agent (SMPClientAgentConfig (..), defaultSMPClientAgentConfig) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding +import Simplex.Messaging.Notifications.Protocol (NtfResponse) import Simplex.Messaging.Notifications.Server (runNtfServerBlocking) import Simplex.Messaging.Notifications.Server.Env import Simplex.Messaging.Notifications.Server.Push.APNS @@ -70,7 +71,7 @@ testKeyHash = "LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=" ntfTestStoreLogFile :: FilePath ntfTestStoreLogFile = "tests/tmp/ntf-server-store.log" -testNtfClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandle c -> m a) -> m a +testNtfClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandleNTF c -> m a) -> m a testNtfClient client = do Right host <- pure $ chooseTransportHost defaultNetworkConfig testHost runTransportClient defaultTransportClientConfig Nothing host ntfTestPort (Just testKeyHash) $ \h -> do @@ -114,8 +115,8 @@ ntfServerCfg = ntfServerCfgV2 :: NtfServerConfig ntfServerCfgV2 = ntfServerCfg - { ntfServerVRange = mkVersionRange 1 authBatchCmdsNTFVersion, - smpAgentCfg = defaultSMPClientAgentConfig {smpCfg = (smpCfg defaultSMPClientAgentConfig) {serverVRange = mkVersionRange 4 authCmdsSMPVersion}} + { ntfServerVRange = mkVersionRange initialNTFVersion authBatchCmdsNTFVersion, + smpAgentCfg = defaultSMPClientAgentConfig {smpCfg = (smpCfg defaultSMPClientAgentConfig) {serverVRange = mkVersionRange batchCmdsSMPVersion authCmdsSMPVersion}} } withNtfServerStoreLog :: ATransport -> (ThreadId -> IO a) -> IO a @@ -139,7 +140,7 @@ withNtfServerOn t port' = withNtfServerThreadOn t port' . const withNtfServer :: ATransport -> IO a -> IO a withNtfServer t = withNtfServerOn t ntfTestPort -runNtfTest :: forall c a. Transport c => (THandle c -> IO a) -> IO a +runNtfTest :: forall c a. Transport c => (THandleNTF c -> IO a) -> IO a runNtfTest test = withNtfServer (transport @c) $ testNtfClient test ntfServerTest :: @@ -147,10 +148,10 @@ ntfServerTest :: (Transport c, Encoding smp) => TProxy c -> (Maybe TransmissionAuth, ByteString, ByteString, smp) -> - IO (Maybe TransmissionAuth, ByteString, ByteString, BrokerMsg) + IO (Maybe TransmissionAuth, ByteString, ByteString, NtfResponse) ntfServerTest _ t = runNtfTest $ \h -> tPut' h t >> tGet' h where - tPut' :: THandle c -> (Maybe TransmissionAuth, ByteString, ByteString, smp) -> IO () + tPut' :: THandleNTF c -> (Maybe TransmissionAuth, ByteString, ByteString, smp) -> IO () tPut' h@THandle {params = THandleParams {sessionId, implySessId}} (sig, corrId, queueId, smp) = do let t' = if implySessId then smpEncode (corrId, queueId, smp) else smpEncode (sessionId, corrId, queueId, smp) [Right ()] <- tPut h [Right (sig, t')] @@ -159,7 +160,7 @@ ntfServerTest _ t = runNtfTest $ \h -> tPut' h t >> tGet' h [(Nothing, _, (CorrId corrId, qId, Right cmd))] <- tGet h pure (Nothing, corrId, qId, cmd) -ntfTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> Expectation +ntfTest :: Transport c => TProxy c -> (THandleNTF c -> IO ()) -> Expectation ntfTest _ test' = runNtfTest test' `shouldReturn` () data APNSMockRequest = APNSMockRequest diff --git a/tests/NtfServerTests.hs b/tests/NtfServerTests.hs index e29a292ee..e7e2018c2 100644 --- a/tests/NtfServerTests.hs +++ b/tests/NtfServerTests.hs @@ -5,7 +5,9 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} {-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-} +{-# OPTIONS_GHC -Wno-orphans #-} module NtfServerTests where @@ -37,6 +39,7 @@ import Simplex.Messaging.Encoding.String import Simplex.Messaging.Notifications.Protocol import Simplex.Messaging.Notifications.Server.Push.APNS import qualified Simplex.Messaging.Notifications.Server.Push.APNS as APNS +import Simplex.Messaging.Notifications.Transport (THandleNTF) import Simplex.Messaging.Parsers (parse, parseAll) import Simplex.Messaging.Protocol hiding (notification) import Simplex.Messaging.Transport @@ -50,30 +53,32 @@ ntfServerTests t = do ntfSyntaxTests :: ATransport -> Spec ntfSyntaxTests (ATransport t) = do - it "unknown command" $ ("", "abcd", "1234", ('H', 'E', 'L', 'L', 'O')) >#> ("", "abcd", "1234", ERR $ CMD UNKNOWN) + it "unknown command" $ ("", "abcd", "1234", ('H', 'E', 'L', 'L', 'O')) >#> ("", "abcd", "1234", NRErr $ CMD UNKNOWN) describe "NEW" $ do - it "no parameters" $ (sampleSig, "bcda", "", TNEW_) >#> ("", "bcda", "", ERR $ CMD SYNTAX) - it "many parameters" $ (sampleSig, "cdab", "", (TNEW_, (' ', '\x01', 'A'), ('T', 'A', 'T', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "cdab", "", ERR $ CMD SYNTAX) - it "no signature" $ ("", "dabc", "", (TNEW_, ' ', ('T', 'A', 'T', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "dabc", "", ERR $ CMD NO_AUTH) - it "token ID" $ (sampleSig, "abcd", "12345678", (TNEW_, ' ', ('T', 'A', 'T', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "abcd", "12345678", ERR $ CMD HAS_AUTH) + it "no parameters" $ (sampleSig, "bcda", "", TNEW_) >#> ("", "bcda", "", NRErr $ CMD SYNTAX) + it "many parameters" $ (sampleSig, "cdab", "", (TNEW_, (' ', '\x01', 'A'), ('T', 'A', 'T', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "cdab", "", NRErr $ CMD SYNTAX) + it "no signature" $ ("", "dabc", "", (TNEW_, ' ', ('T', 'A', 'T', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "dabc", "", NRErr $ CMD NO_AUTH) + it "token ID" $ (sampleSig, "abcd", "12345678", (TNEW_, ' ', ('T', 'A', 'T', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "abcd", "12345678", NRErr $ CMD HAS_AUTH) where (>#>) :: Encoding smp => (Maybe TransmissionAuth, ByteString, ByteString, smp) -> - (Maybe TransmissionAuth, ByteString, ByteString, BrokerMsg) -> + (Maybe TransmissionAuth, ByteString, ByteString, NtfResponse) -> Expectation command >#> response = withAPNSMockServer $ \_ -> ntfServerTest t command `shouldReturn` response pattern RespNtf :: CorrId -> QueueId -> NtfResponse -> SignedTransmission ErrorType NtfResponse pattern RespNtf corrId queueId command <- (_, _, (corrId, queueId, Right command)) -sendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandle c -> (Maybe TransmissionAuth, ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission ErrorType NtfResponse) +deriving instance Eq NtfResponse + +sendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandleNTF c -> (Maybe TransmissionAuth, ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission ErrorType NtfResponse) sendRecvNtf h@THandle {params} (sgn, corrId, qId, cmd) = do let TransmissionForAuth {tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd) Right () <- tPut1 h (sgn, tToSend) tGet1 h -signSendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandle c -> C.APrivateAuthKey -> (ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission ErrorType NtfResponse) +signSendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandleNTF c -> C.APrivateAuthKey -> (ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission ErrorType NtfResponse) signSendRecvNtf h@THandle {params} (C.APrivateAuthKey a pk) (corrId, qId, cmd) = do let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd) Right () <- tPut1 h (authorize tForAuth, tToSend) diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index 88de2c8b8..871652f53 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -26,7 +26,7 @@ import Simplex.Messaging.Server.Env.STM import Simplex.Messaging.Transport import Simplex.Messaging.Transport.Client import Simplex.Messaging.Transport.Server -import Simplex.Messaging.Version (VersionRange, mkVersionRange) +import Simplex.Messaging.Version (mkVersionRange) import System.Environment (lookupEnv) import System.Info (os) import Test.Hspec @@ -68,10 +68,10 @@ xit'' d t = do ci <- runIO $ lookupEnv "CI" (if ci == Just "true" then skip "skipped on CI" . it d else it d) t -testSMPClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandle c -> m a) -> m a +testSMPClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandleSMP c -> m a) -> m a testSMPClient = testSMPClientVR supportedClientSMPRelayVRange -testSMPClientVR :: (Transport c, MonadUnliftIO m, MonadFail m) => VersionRange -> (THandle c -> m a) -> m a +testSMPClientVR :: (Transport c, MonadUnliftIO m, MonadFail m) => VersionRangeSMP -> (THandleSMP c -> m a) -> m a testSMPClientVR vr client = do Right useHost <- pure $ chooseTransportHost defaultNetworkConfig testHost runTransportClient defaultTransportClientConfig Nothing useHost testPort (Just testKeyHash) $ \h -> do @@ -110,7 +110,7 @@ cfg = } cfgV7 :: ServerConfig -cfgV7 = cfg {smpServerVRange = mkVersionRange 4 authCmdsSMPVersion} +cfgV7 = cfg {smpServerVRange = mkVersionRange batchCmdsSMPVersion authCmdsSMPVersion} withSmpServerStoreMsgLogOn :: HasCallStack => ATransport -> ServiceName -> (HasCallStack => ThreadId -> IO a) -> IO a withSmpServerStoreMsgLogOn t = withSmpServerConfigOn t cfg {storeLogFile = Just testStoreLogFile, storeMsgsFile = Just testStoreMsgsFile, serverStatsBackupFile = Just testServerStatsBackupFile} @@ -149,16 +149,16 @@ withSmpServer t = withSmpServerOn t testPort withSmpServerV7 :: HasCallStack => ATransport -> IO a -> IO a withSmpServerV7 t = withSmpServerConfigOn t cfgV7 testPort . const -runSmpTest :: forall c a. (HasCallStack, Transport c) => (HasCallStack => THandle c -> IO a) -> IO a +runSmpTest :: forall c a. (HasCallStack, Transport c) => (HasCallStack => THandleSMP c -> IO a) -> IO a runSmpTest test = withSmpServer (transport @c) $ testSMPClient test -runSmpTestN :: forall c a. (HasCallStack, Transport c) => Int -> (HasCallStack => [THandle c] -> IO a) -> IO a +runSmpTestN :: forall c a. (HasCallStack, Transport c) => Int -> (HasCallStack => [THandleSMP c] -> IO a) -> IO a runSmpTestN = runSmpTestNCfg cfg supportedClientSMPRelayVRange -runSmpTestNCfg :: forall c a. (HasCallStack, Transport c) => ServerConfig -> VersionRange -> Int -> (HasCallStack => [THandle c] -> IO a) -> IO a +runSmpTestNCfg :: forall c a. (HasCallStack, Transport c) => ServerConfig -> VersionRangeSMP -> Int -> (HasCallStack => [THandleSMP c] -> IO a) -> IO a runSmpTestNCfg srvCfg clntVR nClients test = withSmpServerConfigOn (transport @c) srvCfg testPort $ \_ -> run nClients [] where - run :: Int -> [THandle c] -> IO a + run :: Int -> [THandleSMP c] -> IO a run 0 hs = test hs run n hs = testSMPClientVR clntVR $ \h -> run (n - 1) (h : hs) @@ -170,7 +170,7 @@ smpServerTest :: IO (Maybe TransmissionAuth, ByteString, ByteString, BrokerMsg) smpServerTest _ t = runSmpTest $ \h -> tPut' h t >> tGet' h where - tPut' :: THandle c -> (Maybe TransmissionAuth, ByteString, ByteString, smp) -> IO () + tPut' :: THandleSMP c -> (Maybe TransmissionAuth, ByteString, ByteString, smp) -> IO () tPut' h@THandle {params = THandleParams {sessionId, implySessId}} (sig, corrId, queueId, smp) = do let t' = if implySessId then smpEncode (corrId, queueId, smp) else smpEncode (sessionId, corrId, queueId, smp) [Right ()] <- tPut h [Right (sig, t')] @@ -179,33 +179,33 @@ smpServerTest _ t = runSmpTest $ \h -> tPut' h t >> tGet' h [(Nothing, _, (CorrId corrId, qId, Right cmd))] <- tGet h pure (Nothing, corrId, qId, cmd) -smpTest :: (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandle c -> IO ()) -> Expectation +smpTest :: (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandleSMP c -> IO ()) -> Expectation smpTest _ test' = runSmpTest test' `shouldReturn` () -smpTestN :: (HasCallStack, Transport c) => Int -> (HasCallStack => [THandle c] -> IO ()) -> Expectation +smpTestN :: (HasCallStack, Transport c) => Int -> (HasCallStack => [THandleSMP c] -> IO ()) -> Expectation smpTestN n test' = runSmpTestN n test' `shouldReturn` () -smpTest2 :: forall c. (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandle c -> THandle c -> IO ()) -> Expectation +smpTest2 :: forall c. (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandleSMP c -> THandleSMP c -> IO ()) -> Expectation smpTest2 = smpTest2Cfg cfg supportedClientSMPRelayVRange -smpTest2Cfg :: forall c. (HasCallStack, Transport c) => ServerConfig -> VersionRange -> TProxy c -> (HasCallStack => THandle c -> THandle c -> IO ()) -> Expectation +smpTest2Cfg :: forall c. (HasCallStack, Transport c) => ServerConfig -> VersionRangeSMP -> TProxy c -> (HasCallStack => THandleSMP c -> THandleSMP c -> IO ()) -> Expectation smpTest2Cfg srvCfg clntVR _ test' = runSmpTestNCfg srvCfg clntVR 2 _test `shouldReturn` () where - _test :: HasCallStack => [THandle c] -> IO () + _test :: HasCallStack => [THandleSMP c] -> IO () _test [h1, h2] = test' h1 h2 _test _ = error "expected 2 handles" -smpTest3 :: forall c. (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandle c -> THandle c -> THandle c -> IO ()) -> Expectation +smpTest3 :: forall c. (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandleSMP c -> THandleSMP c -> THandleSMP c -> IO ()) -> Expectation smpTest3 _ test' = smpTestN 3 _test where - _test :: HasCallStack => [THandle c] -> IO () + _test :: HasCallStack => [THandleSMP c] -> IO () _test [h1, h2, h3] = test' h1 h2 h3 _test _ = error "expected 3 handles" -smpTest4 :: forall c. (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandle c -> THandle c -> THandle c -> THandle c -> IO ()) -> Expectation +smpTest4 :: forall c. (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandleSMP c -> THandleSMP c -> THandleSMP c -> THandleSMP c -> IO ()) -> Expectation smpTest4 _ test' = smpTestN 4 _test where - _test :: HasCallStack => [THandle c] -> IO () + _test :: HasCallStack => [THandleSMP c] -> IO () _test [h1, h2, h3, h4] = test' h1 h2 h3 h4 _test _ = error "expected 4 handles" diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index 4065c7e19..03935fed5 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -1,6 +1,7 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} @@ -77,13 +78,13 @@ pattern Ids rId sId srvDh <- IDS (QIK rId sId srvDh) pattern Msg :: MsgId -> MsgBody -> BrokerMsg pattern Msg msgId body <- MSG RcvMessage {msgId, msgBody = EncRcvMsgBody body} -sendRecv :: forall c p. (Transport c, PartyI p) => THandle c -> (Maybe TransmissionAuth, ByteString, ByteString, Command p) -> IO (SignedTransmission ErrorType BrokerMsg) +sendRecv :: forall c p. (Transport c, PartyI p) => THandleSMP c -> (Maybe TransmissionAuth, ByteString, ByteString, Command p) -> IO (SignedTransmission ErrorType BrokerMsg) sendRecv h@THandle {params} (sgn, corrId, qId, cmd) = do let TransmissionForAuth {tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd) Right () <- tPut1 h (sgn, tToSend) tGet1 h -signSendRecv :: forall c p. (Transport c, PartyI p) => THandle c -> C.APrivateAuthKey -> (ByteString, ByteString, Command p) -> IO (SignedTransmission ErrorType BrokerMsg) +signSendRecv :: forall c p. (Transport c, PartyI p) => THandleSMP c -> C.APrivateAuthKey -> (ByteString, ByteString, Command p) -> IO (SignedTransmission ErrorType BrokerMsg) signSendRecv h@THandle {params} (C.APrivateAuthKey a pk) (corrId, qId, cmd) = do let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd) Right () <- tPut1 h (authorize tForAuth, tToSend) @@ -97,12 +98,12 @@ signSendRecv h@THandle {params} (C.APrivateAuthKey a pk) (corrId, qId, cmd) = do _sx448 -> undefined -- ghc8107 fails to the branch excluded by types #endif -tPut1 :: Transport c => THandle c -> SentRawTransmission -> IO (Either TransportError ()) +tPut1 :: Transport c => THandle v c -> SentRawTransmission -> IO (Either TransportError ()) tPut1 h t = do [r] <- tPut h [Right t] pure r -tGet1 :: (ProtocolEncoding err cmd, Transport c, MonadIO m, MonadFail m) => THandle c -> m (SignedTransmission err cmd) +tGet1 :: (ProtocolEncoding v err cmd, Transport c, MonadIO m, MonadFail m) => THandle v c -> m (SignedTransmission err cmd) tGet1 h = do [r] <- liftIO $ tGet h pure r @@ -383,7 +384,7 @@ testSwitchSub (ATransport t) = Resp "bcda" _ ok3 <- signSendRecv rh2 rKey ("bcda", rId, ACK mId3) (ok3, OK) #== "accepts ACK from the 2nd TCP connection" - 1000 `timeout` tGet @ErrorType @BrokerMsg rh1 >>= \case + 1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg rh1 >>= \case Nothing -> return () Just _ -> error "nothing else is delivered to the 1st TCP connection" @@ -554,12 +555,12 @@ testWithStoreLog at@(ATransport t) = logSize testStoreLogFile `shouldReturn` 1 removeFile testStoreLogFile where - runTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> ThreadId -> Expectation + runTest :: Transport c => TProxy c -> (THandleSMP c -> IO ()) -> ThreadId -> Expectation runTest _ test' server = do testSMPClient test' `shouldReturn` () killThread server - runClient :: Transport c => TProxy c -> (THandle c -> IO ()) -> Expectation + runClient :: Transport c => TProxy c -> (THandleSMP c -> IO ()) -> Expectation runClient _ test' = testSMPClient test' `shouldReturn` () logSize :: FilePath -> IO Int @@ -652,12 +653,12 @@ testRestoreMessages at@(ATransport t) = removeFile testStoreMsgsFile removeFile testServerStatsBackupFile where - runTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> ThreadId -> Expectation + runTest :: Transport c => TProxy c -> (THandleSMP c -> IO ()) -> ThreadId -> Expectation runTest _ test' server = do testSMPClient test' `shouldReturn` () killThread server - runClient :: Transport c => TProxy c -> (THandle c -> IO ()) -> Expectation + runClient :: Transport c => TProxy c -> (THandleSMP c -> IO ()) -> Expectation runClient _ test' = testSMPClient test' `shouldReturn` () checkStats :: ServerStatsData -> [RecipientId] -> Int -> Int -> Expectation @@ -726,15 +727,15 @@ testRestoreExpireMessages at@(ATransport t) = Right ServerStatsData {_msgExpired} <- strDecode <$> B.readFile testServerStatsBackupFile _msgExpired `shouldBe` 2 where - runTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> ThreadId -> Expectation + runTest :: Transport c => TProxy c -> (THandleSMP c -> IO ()) -> ThreadId -> Expectation runTest _ test' server = do testSMPClient test' `shouldReturn` () killThread server - runClient :: Transport c => TProxy c -> (THandle c -> IO ()) -> Expectation + runClient :: Transport c => TProxy c -> (THandleSMP c -> IO ()) -> Expectation runClient _ test' = testSMPClient test' `shouldReturn` () -createAndSecureQueue :: Transport c => THandle c -> SndPublicAuthKey -> IO (SenderId, RecipientId, RcvPrivateAuthKey, RcvDhSecret) +createAndSecureQueue :: Transport c => THandleSMP c -> SndPublicAuthKey -> IO (SenderId, RecipientId, RcvPrivateAuthKey, RcvDhSecret) createAndSecureQueue h sPub = do g <- C.newRandom (rPub, rKey) <- atomically $ C.generateAuthKeyPair C.SEd448 g @@ -750,7 +751,7 @@ testTiming (ATransport t) = describe "should have similar time for auth error, whether queue exists or not, for all key types" $ forM_ timingTests $ \tst -> it (testName tst) $ - smpTest2Cfg cfgV7 (mkVersionRange 4 authCmdsSMPVersion) t $ \rh sh -> + smpTest2Cfg cfgV7 (mkVersionRange batchCmdsSMPVersion authCmdsSMPVersion) t $ \rh sh -> testSameTiming rh sh tst where testName :: (C.AuthAlg, C.AuthAlg, Int) -> String @@ -769,7 +770,7 @@ testTiming (ATransport t) = ] timeRepeat n = fmap fst . timeItT . forM_ (replicate n ()) . const similarTime t1 t2 = abs (t2 / t1 - 1) < 0.15 -- normally the difference between "no queue" and "wrong key" is less than 5% - testSameTiming :: forall c. Transport c => THandle c -> THandle c -> (C.AuthAlg, C.AuthAlg, Int) -> Expectation + testSameTiming :: forall c. Transport c => THandleSMP c -> THandleSMP c -> (C.AuthAlg, C.AuthAlg, Int) -> Expectation testSameTiming rh sh (C.AuthAlg goodKeyAlg, C.AuthAlg badKeyAlg, n) = do g <- C.newRandom (rPub, rKey) <- atomically $ C.generateAuthKeyPair goodKeyAlg g @@ -790,7 +791,7 @@ testTiming (ATransport t) = runTimingTest sh badKey sId $ _SEND "hello" where - runTimingTest :: PartyI p => THandle c -> C.APrivateAuthKey -> ByteString -> Command p -> IO () + runTimingTest :: PartyI p => THandleSMP c -> C.APrivateAuthKey -> ByteString -> Command p -> IO () runTimingTest h badKey qId cmd = do threadDelay 100000 _ <- timeRepeat n $ do -- "warm up" the server @@ -840,14 +841,14 @@ testMessageNotifications (ATransport t) = Resp "5a" _ OK <- signSendRecv rh rKey ("5a", rId, ACK mId2) (dec mId2 msg2, Right "hello again") #== "delivered from queue again" Resp "" _ (NMSG _ _) <- tGet1 nh2 - 1000 `timeout` tGet @ErrorType @BrokerMsg nh1 >>= \case + 1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg nh1 >>= \case Nothing -> pure () Just _ -> error "nothing else should be delivered to the 1st notifier's TCP connection" Resp "6" _ OK <- signSendRecv rh rKey ("6", rId, NDEL) Resp "7" _ OK <- signSendRecv sh sKey ("7", sId, _SEND' "hello there") Resp "" _ (Msg mId3 msg3) <- tGet1 rh (dec mId3 msg3, Right "hello there") #== "delivered from queue again" - 1000 `timeout` tGet @ErrorType @BrokerMsg nh2 >>= \case + 1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg nh2 >>= \case Nothing -> pure () Just _ -> error "nothing else should be delivered to the 2nd notifier's TCP connection" @@ -867,7 +868,7 @@ testMsgExpireOnSend t = testSMPClient @c $ \rh -> do Resp "3" _ (Msg mId msg) <- signSendRecv rh rKey ("3", rId, SUB) (dec mId msg, Right "hello (should NOT expire)") #== "delivered" - 1000 `timeout` tGet @ErrorType @BrokerMsg rh >>= \case + 1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg rh >>= \case Nothing -> return () Just _ -> error "nothing else should be delivered" @@ -887,7 +888,7 @@ testMsgExpireOnInterval t = signSendRecv rh rKey ("2", rId, SUB) >>= \case Resp "2" _ OK -> pure () r -> unexpected r - 1000 `timeout` tGet @ErrorType @BrokerMsg rh >>= \case + 1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg rh >>= \case Nothing -> return () Just _ -> error "nothing should be delivered" @@ -906,7 +907,7 @@ testMsgNOTExpireOnInterval t = testSMPClient @c $ \rh -> do Resp "2" _ (Msg mId msg) <- signSendRecv rh rKey ("2", rId, SUB) (dec mId msg, Right "hello (should NOT expire)") #== "delivered" - 1000 `timeout` tGet @ErrorType @BrokerMsg rh >>= \case + 1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg rh >>= \case Nothing -> return () Just _ -> error "nothing else should be delivered" diff --git a/tests/XFTPAgent.hs b/tests/XFTPAgent.hs index 46d3d4dd8..999746858 100644 --- a/tests/XFTPAgent.hs +++ b/tests/XFTPAgent.hs @@ -21,7 +21,8 @@ import Data.List (find, isSuffixOf) import Data.Maybe (fromJust) import SMPAgentClient (agentCfg, initAgentServers, testDB, testDB2, testDB3) import Simplex.FileTransfer.Description (FileDescription (..), FileDescriptionURI (..), ValidFileDescription, fileDescriptionURI, mb, qrSizeLimit, pattern ValidFileDescription) -import Simplex.FileTransfer.Protocol (FileParty (..), XFTPErrorType (AUTH)) +import Simplex.FileTransfer.Protocol (FileParty (..)) +import Simplex.FileTransfer.Transport (XFTPErrorType (AUTH)) import Simplex.FileTransfer.Server.Env (XFTPServerConfig (..)) import Simplex.Messaging.Agent (AgentClient, disconnectAgentClient, testProtocolServer, xftpDeleteRcvFile, xftpDeleteSndFileInternal, xftpDeleteSndFileRemote, xftpReceiveFile, xftpSendDescription, xftpSendFile, xftpStartWorkers) import Simplex.Messaging.Agent.Client (ProtocolTestFailure (..), ProtocolTestStep (..)) diff --git a/tests/XFTPServerTests.hs b/tests/XFTPServerTests.hs index a11ba515a..451406275 100644 --- a/tests/XFTPServerTests.hs +++ b/tests/XFTPServerTests.hs @@ -20,9 +20,9 @@ import Data.List (isInfixOf) import ServerTests (logSize) import Simplex.FileTransfer.Client import Simplex.FileTransfer.Description (kb) -import Simplex.FileTransfer.Protocol (FileInfo (..), XFTPErrorType (..)) +import Simplex.FileTransfer.Protocol (FileInfo (..)) import Simplex.FileTransfer.Server.Env (XFTPServerConfig (..)) -import Simplex.FileTransfer.Transport (XFTPRcvChunkSpec (..)) +import Simplex.FileTransfer.Transport (XFTPRcvChunkSpec (..), XFTPErrorType (..)) import Simplex.Messaging.Client (ProtocolClientError (..)) import qualified Simplex.Messaging.Crypto as C import qualified Simplex.Messaging.Crypto.Lazy as LC