Merge branch 'master' into pq

This commit is contained in:
Evgeny Poberezkin
2024-03-04 20:13:01 +00:00
44 changed files with 851 additions and 574 deletions
+10 -10
View File
@@ -165,11 +165,11 @@ import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfRegCode (NtfReg
import Simplex.Messaging.Notifications.Server.Push.APNS (PNMessageData (..))
import Simplex.Messaging.Notifications.Types
import Simplex.Messaging.Parsers (parse)
import Simplex.Messaging.Protocol (BrokerMsg, EntityId, ErrorType (AUTH), MsgBody, MsgFlags (..), NtfServer, ProtoServerWithAuth, ProtocolTypeI (..), SMPMsgMeta, SProtocolType (..), SndPublicAuthKey, SubscriptionMode (..), UserProtocol, XFTPServerWithAuth)
import Simplex.Messaging.Protocol (BrokerMsg, EntityId, ErrorType (AUTH), MsgBody, MsgFlags (..), NtfServer, ProtoServerWithAuth, ProtocolTypeI (..), SMPMsgMeta, SProtocolType (..), SndPublicAuthKey, SubscriptionMode (..), UserProtocol, VersionSMPC, XFTPServerWithAuth)
import qualified Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.ServiceScheme (ServiceScheme (..))
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport (THandleParams (sessionId))
import Simplex.Messaging.Transport (SMPVersion, THandleParams (sessionId))
import Simplex.Messaging.Util
import Simplex.Messaging.Version
import Simplex.RemoteControl.Client
@@ -684,7 +684,7 @@ joinConn c userId connId enableNtfs cReq cInfo pqEnc subMode = do
_ -> getSMPServer c userId
joinConnSrv c userId connId enableNtfs cReq cInfo pqEnc subMode srv
startJoinInvitation :: AgentMonad m => UserId -> ConnId -> Bool -> ConnectionRequestUri 'CMInvitation -> CR.PQEncryption -> m (Compatible Version, ConnData, NewSndQueue, CR.Ratchet 'C.X448, CR.SndE2ERatchetParams 'C.X448)
startJoinInvitation :: AgentMonad m => UserId -> ConnId -> Bool -> ConnectionRequestUri 'CMInvitation -> CR.PQEncryption -> m (Compatible VersionSMPA, ConnData, NewSndQueue, CR.Ratchet 'C.X448, CR.SndE2ERatchetParams 'C.X448)
startJoinInvitation userId connId enableNtfs (CRInvitationUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)} e2eRcvParamsUri) pqEncryption = do
AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config
case ( qUri `compatibleVersion` smpClientVRange,
@@ -1114,7 +1114,7 @@ enqueueMessageB c pqEnc_ reqs = do
let sqs' = filter isActiveSndQ sqs
pure $ Right ((msgId, pqSecr), if null sqs' then Nothing else Just (cData, sqs', msgId))
where
storeSentMsg :: DB.Connection -> Version -> (ConnData, NonEmpty SndQueue, MsgFlags, AMessage) -> IO (Either AgentErrorType ((ConnData, NonEmpty SndQueue, MsgFlags, AMessage), InternalId, CR.PQEncryption))
storeSentMsg :: DB.Connection -> VersionSMPA -> (ConnData, NonEmpty SndQueue, MsgFlags, AMessage) -> IO (Either AgentErrorType ((ConnData, NonEmpty SndQueue, MsgFlags, AMessage), InternalId, CR.PQEncryption))
storeSentMsg db agentVersion req@(ConnData {connId}, sq :| _, msgFlags, aMessage) = fmap (first storeError) $ runExceptT $ do
internalTs <- liftIO getCurrentTime
(internalId, internalSndId, prevMsgHash) <- liftIO $ updateSndIds db connId
@@ -1922,7 +1922,7 @@ cleanupManager c@AgentClient {subQ} = do
-- | make sure to ACK or throw in each message processing branch
-- it cannot be finally, unfortunately, as sometimes it needs to be ACK+DEL
processSMPTransmission :: forall m. AgentMonad m => AgentClient -> ServerTransmission BrokerMsg -> m ()
processSMPTransmission :: forall m. AgentMonad m => AgentClient -> ServerTransmission SMPVersion BrokerMsg -> m ()
processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, sessId, rId, cmd) = do
(rq, SomeConn _ conn) <- withStore c (\db -> getRcvConn db srv rId)
processSMP rq conn $ toConnData conn
@@ -2063,7 +2063,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v,
_ -> pure Nothing
_ -> prohibited >> ack
_ -> prohibited >> ack
updateConnVersion :: Connection c -> ConnData -> Version -> m (Connection c)
updateConnVersion :: Connection c -> ConnData -> VersionSMPA -> m (Connection c)
updateConnVersion conn' cData' msgAgentVersion = do
aVRange <- asks $ smpAgentVRange . config
let msgAVRange = fromMaybe (versionToRange msgAgentVersion) $ safeVersionRange (minVersion aVRange) msgAgentVersion
@@ -2126,7 +2126,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v,
parseMessage :: Encoding a => ByteString -> m a
parseMessage = liftEither . parse smpP (AGENT A_MESSAGE)
smpConfirmation :: SMP.MsgId -> Connection c -> C.APublicAuthKey -> C.PublicKeyX25519 -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> ByteString -> Version -> Version -> m ()
smpConfirmation :: SMP.MsgId -> Connection c -> C.APublicAuthKey -> C.PublicKeyX25519 -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> ByteString -> VersionSMPC -> VersionSMPA -> m ()
smpConfirmation srvMsgId conn' senderKey e2ePubKey e2eEncryption encConnInfo smpClientVersion agentVersion = do
logServer "<--" c srv rId $ "MSG <CONF>:" <> logSecret srvMsgId
AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config
@@ -2380,7 +2380,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v,
createRatchet db connId rc
-- compare public keys `k1` in AgentRatchetKey messages sent by self and other party
-- to determine ratchet initilization ordering
initRatchet :: VersionRange -> (C.PrivateKeyX448, C.PrivateKeyX448, Maybe CR.RcvPrivRKEMParams) -> m ()
initRatchet :: CR.VersionRangeE2E -> (C.PrivateKeyX448, C.PrivateKeyX448, Maybe CR.RcvPrivRKEMParams) -> m ()
initRatchet e2eEncryptVRange (pk1, pk2, pKem)
| rkHash (C.publicKey pk1) (C.publicKey pk2) <= rkHashRcv = do
rcParams <- liftError cryptoError $ CR.pqX3dhRcv pk1 pk2 pKem e2eOtherPartyParams
@@ -2431,7 +2431,7 @@ confirmQueueAsync c cData sq srv connInfo e2eEncryption_ pqEnc subMode = do
storeConfirmation c cData sq e2eEncryption_ (Just pqEnc) =<< mkAgentConfirmation c cData sq srv connInfo subMode
submitPendingMsg c cData sq
confirmQueue :: forall m. AgentMonad m => Compatible Version -> AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> Maybe CR.PQEncryption -> SubscriptionMode -> m ()
confirmQueue :: forall m. AgentMonad m => Compatible VersionSMPA -> AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> Maybe CR.PQEncryption -> SubscriptionMode -> m ()
confirmQueue (Compatible agentVersion) c cData@ConnData {connId} sq srv connInfo e2eEncryption_ pqEnc_ subMode = do
msg <- mkConfirmation =<< mkAgentConfirmation c cData sq srv connInfo subMode
sendConfirmation c sq msg
@@ -2478,7 +2478,7 @@ enqueueRatchetKey c cData@ConnData {connId} sq e2eEncryption = do
submitPendingMsg c cData sq
pure $ unId msgId
where
storeRatchetKey :: Version -> m InternalId
storeRatchetKey :: VersionSMPA -> m InternalId
storeRatchetKey agentVersion = withStore c $ \db -> runExceptT $ do
internalTs <- liftIO getCurrentTime
(internalId, internalSndId, prevMsgHash) <- liftIO $ updateSndIds db connId
+33 -29
View File
@@ -167,8 +167,8 @@ import Network.Socket (HostName)
import Simplex.FileTransfer.Client (XFTPChunkSpec (..), XFTPClient, XFTPClientConfig (..), XFTPClientError)
import qualified Simplex.FileTransfer.Client as X
import Simplex.FileTransfer.Description (ChunkReplicaId (..), FileDigest (..), kb)
import Simplex.FileTransfer.Protocol (FileInfo (..), FileResponse, XFTPErrorType (DIGEST))
import Simplex.FileTransfer.Transport (XFTPRcvChunkSpec (..))
import Simplex.FileTransfer.Protocol (FileInfo (..), FileResponse)
import Simplex.FileTransfer.Transport (XFTPRcvChunkSpec (..), XFTPErrorType (DIGEST), XFTPVersion)
import Simplex.FileTransfer.Types (DeletedSndChunkReplica (..), NewSndChunkReplica (..), RcvFileChunkReplica (..), SndFileChunk (..), SndFileChunkReplica (..))
import Simplex.FileTransfer.Util (uniqueCombine)
import Simplex.Messaging.Agent.Env.SQLite
@@ -187,6 +187,7 @@ import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Notifications.Client
import Simplex.Messaging.Notifications.Protocol
import Simplex.Messaging.Notifications.Transport (NTFVersion)
import Simplex.Messaging.Notifications.Types
import Simplex.Messaging.Parsers (defaultJSON, dropPrefix, enumJSON, parse)
import Simplex.Messaging.Protocol
@@ -215,9 +216,12 @@ import Simplex.Messaging.Protocol
UserProtocol,
XFTPServer,
XFTPServerWithAuth,
VersionSMPC,
VersionRangeSMPC,
sameSrvAddr',
)
import qualified Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.Transport (SMPVersion)
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport.Client (TransportHost)
@@ -253,7 +257,7 @@ data AgentClient = AgentClient
{ active :: TVar Bool,
rcvQ :: TBQueue (ATransmission 'Client),
subQ :: TBQueue (ATransmission 'Agent),
msgQ :: TBQueue (ServerTransmission BrokerMsg),
msgQ :: TBQueue (ServerTransmission SMPVersion BrokerMsg),
smpServers :: TMap UserId (NonEmpty SMPServerWithAuth),
smpClients :: TMap SMPTransportSession SMPClientVar,
ntfServers :: TVar [NtfServer],
@@ -467,7 +471,7 @@ agentClientStore AgentClient {agentEnv = Env {store}} = store
agentDRG :: AgentClient -> TVar ChaChaDRG
agentDRG AgentClient {agentEnv = Env {random}} = random
class (Encoding err, Show err) => ProtocolServerClient err msg | msg -> err where
class (Encoding err, Show err) => ProtocolServerClient v err msg | msg -> v, msg -> err where
type Client msg = c | c -> msg
getProtocolServerClient :: AgentMonad m => AgentClient -> TransportSession msg -> m (Client msg)
clientProtocolError :: err -> AgentErrorType
@@ -476,8 +480,8 @@ class (Encoding err, Show err) => ProtocolServerClient err msg | msg -> err wher
clientTransportHost :: Client msg -> TransportHost
clientSessionTs :: Client msg -> UTCTime
instance ProtocolServerClient ErrorType BrokerMsg where
type Client BrokerMsg = ProtocolClient ErrorType BrokerMsg
instance ProtocolServerClient SMPVersion ErrorType BrokerMsg where
type Client BrokerMsg = ProtocolClient SMPVersion ErrorType BrokerMsg
getProtocolServerClient = getSMPServerClient
clientProtocolError = SMP
closeProtocolServerClient = closeProtocolClient
@@ -485,8 +489,8 @@ instance ProtocolServerClient ErrorType BrokerMsg where
clientTransportHost = transportHost'
clientSessionTs = sessionTs
instance ProtocolServerClient ErrorType NtfResponse where
type Client NtfResponse = ProtocolClient ErrorType NtfResponse
instance ProtocolServerClient NTFVersion ErrorType NtfResponse where
type Client NtfResponse = ProtocolClient NTFVersion ErrorType NtfResponse
getProtocolServerClient = getNtfServerClient
clientProtocolError = NTF
closeProtocolServerClient = closeProtocolClient
@@ -494,7 +498,7 @@ instance ProtocolServerClient ErrorType NtfResponse where
clientTransportHost = transportHost'
clientSessionTs = sessionTs
instance ProtocolServerClient XFTPErrorType FileResponse where
instance ProtocolServerClient XFTPVersion XFTPErrorType FileResponse where
type Client FileResponse = XFTPClient
getProtocolServerClient = getXFTPServerClient
clientProtocolError = XFTP
@@ -683,8 +687,8 @@ waitForProtocolClient c (_, srv, _) v = do
-- clientConnected arg is only passed for SMP server
newProtocolClient ::
forall err msg m.
(AgentMonad m, ProtocolTypeI (ProtoType msg), ProtocolServerClient err msg) =>
forall v err msg m.
(AgentMonad m, ProtocolTypeI (ProtoType msg), ProtocolServerClient v err msg) =>
AgentClient ->
TransportSession msg ->
TMap (TransportSession msg) (ClientVar msg) ->
@@ -706,10 +710,10 @@ newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient v =
putTMVar (sessionVar v) (Left e)
throwError e -- signal error to caller
hostEvent :: forall err msg. (ProtocolTypeI (ProtoType msg), ProtocolServerClient err msg) => (AProtocolType -> TransportHost -> ACommand 'Agent 'AENone) -> Client msg -> ACommand 'Agent 'AENone
hostEvent :: forall v err msg. (ProtocolTypeI (ProtoType msg), ProtocolServerClient v err msg) => (AProtocolType -> TransportHost -> ACommand 'Agent 'AENone) -> Client msg -> ACommand 'Agent 'AENone
hostEvent event = event (AProtocolType $ protocolTypeI @(ProtoType msg)) . clientTransportHost
getClientConfig :: AgentMonad' m => AgentClient -> (AgentConfig -> ProtocolClientConfig) -> m ProtocolClientConfig
getClientConfig :: AgentMonad' m => AgentClient -> (AgentConfig -> ProtocolClientConfig v) -> m (ProtocolClientConfig v)
getClientConfig AgentClient {useNetworkConfig} cfgSel = do
cfg <- asks $ cfgSel . config
networkConfig <- readTVarIO useNetworkConfig
@@ -754,19 +758,19 @@ throwWhenNoDelivery c sq =
unlessM (TM.member (qAddress sq) $ smpDeliveryWorkers c) $
throwSTM ThreadKilled
closeProtocolServerClients :: ProtocolServerClient err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO ()
closeProtocolServerClients :: ProtocolServerClient v err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO ()
closeProtocolServerClients c clientsSel =
atomically (clientsSel c `swapTVar` M.empty) >>= mapM_ (forkIO . closeClient_ c)
reconnectServerClients :: ProtocolServerClient err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO ()
reconnectServerClients :: ProtocolServerClient v err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO ()
reconnectServerClients c clientsSel =
readTVarIO (clientsSel c) >>= mapM_ (forkIO . closeClient_ c)
closeClient :: ProtocolServerClient err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> TransportSession msg -> IO ()
closeClient :: ProtocolServerClient v err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> TransportSession msg -> IO ()
closeClient c clientSel tSess =
atomically (TM.lookupDelete tSess $ clientSel c) >>= mapM_ (closeClient_ c)
closeClient_ :: ProtocolServerClient err msg => AgentClient -> ClientVar msg -> IO ()
closeClient_ :: ProtocolServerClient v err msg => AgentClient -> ClientVar msg -> IO ()
closeClient_ c v = do
NetworkConfig {tcpConnectTimeout} <- readTVarIO $ useNetworkConfig c
tcpConnectTimeout `timeout` atomically (readTMVar $ sessionVar v) >>= \case
@@ -798,7 +802,7 @@ getMapLock locks key = TM.lookup key locks >>= maybe newLock pure
where
newLock = createLock >>= \l -> TM.insert key l locks $> l
withClient_ :: forall a m err msg. (AgentMonad m, ProtocolServerClient err msg) => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> m a) -> m a
withClient_ :: forall a m v err msg. (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> m a) -> m a
withClient_ c tSess@(userId, srv, _) statCmd action = do
cl <- getProtocolServerClient c tSess
(action cl <* stat cl "OK") `catchAgentError` logServerError cl
@@ -810,18 +814,18 @@ withClient_ c tSess@(userId, srv, _) statCmd action = do
stat cl $ strEncode e
throwError e
withLogClient_ :: (AgentMonad m, ProtocolServerClient err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> m a) -> m a
withLogClient_ :: (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> m a) -> m a
withLogClient_ c tSess@(_, srv, _) entId cmdStr action = do
logServer "-->" c srv entId cmdStr
res <- withClient_ c tSess cmdStr action
logServer "<--" c srv entId "OK"
return res
withClient :: forall m err msg a. (AgentMonad m, ProtocolServerClient err msg) => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> m a
withClient c tSess statKey action = withClient_ c tSess statKey $ \client -> liftClient (clientProtocolError @err @msg) (clientServer client) $ action client
withClient :: forall m v err msg a. (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> m a
withClient c tSess statKey action = withClient_ c tSess statKey $ \client -> liftClient (clientProtocolError @v @err @msg) (clientServer client) $ action client
withLogClient :: forall m err msg a. (AgentMonad m, ProtocolServerClient err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> m a
withLogClient c tSess entId cmdStr action = withLogClient_ c tSess entId cmdStr $ \client -> liftClient (clientProtocolError @err @msg) (clientServer client) $ action client
withLogClient :: forall m v err msg a. (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> m a
withLogClient c tSess entId cmdStr action = withLogClient_ c tSess entId cmdStr $ \client -> liftClient (clientProtocolError @v @err @msg) (clientServer client) $ action client
withSMPClient :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> m a
withSMPClient c q cmdStr action = do
@@ -837,7 +841,7 @@ withNtfClient :: forall m a. AgentMonad m => AgentClient -> NtfServer -> EntityI
withNtfClient c srv = withLogClient c (0, srv, Nothing)
withXFTPClient ::
(AgentMonad m, ProtocolServerClient err msg) =>
(AgentMonad m, ProtocolServerClient v err msg) =>
AgentClient ->
(UserId, ProtoServer msg, EntityId) ->
ByteString ->
@@ -1001,7 +1005,7 @@ mkSMPTSession q = mkTSession (qUserId q) (qServer q) (qConnId q)
getSessionMode :: AgentMonad' m => AgentClient -> m TransportSessionMode
getSessionMode = fmap sessionMode . readTVarIO . useNetworkConfig
newRcvQueue :: AgentMonad m => AgentClient -> UserId -> ConnId -> SMPServerWithAuth -> VersionRange -> SubscriptionMode -> m (NewRcvQueue, SMPQueueUri)
newRcvQueue :: AgentMonad m => AgentClient -> UserId -> ConnId -> SMPServerWithAuth -> VersionRangeSMPC -> SubscriptionMode -> m (NewRcvQueue, SMPQueueUri)
newRcvQueue c userId connId (ProtoServerWithAuth srv auth) vRange subMode = do
C.AuthAlg a <- asks (rcvAuthAlg . config)
g <- asks random
@@ -1151,7 +1155,7 @@ sendConfirmation c sq@SndQueue {sndId, sndPublicKey = Just sndPublicKey, e2ePubK
liftClient SMP (clientServer smp) $ sendSMPMessage smp Nothing sndId (SMP.MsgFlags {notification = True}) msg
sendConfirmation _ _ _ = throwError $ INTERNAL "sendConfirmation called without snd_queue public key(s) in the database"
sendInvitation :: forall m. AgentMonad m => AgentClient -> UserId -> Compatible SMPQueueInfo -> Compatible Version -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m ()
sendInvitation :: forall m. AgentMonad m => AgentClient -> UserId -> Compatible SMPQueueInfo -> Compatible VersionSMPA -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m ()
sendInvitation c userId (Compatible (SMPQueueInfo v SMPQueueAddress {smpServer, senderId, dhPublicKey})) (Compatible agentVersion) connReq connInfo = do
tSess <- mkTransportSession c userId smpServer senderId
withLogClient_ c tSess senderId "SEND <INV>" $ \smp -> do
@@ -1334,7 +1338,7 @@ agentCbEncrypt SndQueue {e2eDhSecret, smpClientVersion} e2ePubKey msg = do
pure $ smpEncode SMP.ClientMsgEnvelope {cmHeader, cmNonce, cmEncBody}
-- add encoding as AgentInvitation'?
agentCbEncryptOnce :: AgentMonad m => Version -> C.PublicKeyX25519 -> ByteString -> m ByteString
agentCbEncryptOnce :: AgentMonad m => VersionSMPC -> C.PublicKeyX25519 -> ByteString -> m ByteString
agentCbEncryptOnce clientVersion dhRcvPubKey msg = do
g <- asks random
(dhSndPubKey, dhSndPrivKey) <- atomically $ C.generateKeyPair g
@@ -1518,7 +1522,7 @@ incStat AgentClient {agentStats} n k = do
Just v -> modifyTVar' v (+ n)
_ -> newTVar n >>= \v -> TM.insert k v agentStats
incClientStat :: ProtocolServerClient err msg => AgentClient -> UserId -> Client msg -> ByteString -> ByteString -> IO ()
incClientStat :: ProtocolServerClient v err msg => AgentClient -> UserId -> Client msg -> ByteString -> ByteString -> IO ()
incClientStat c userId pc = incClientStatN c userId pc 1
incServerStat :: AgentClient -> UserId -> ProtocolServer p -> ByteString -> ByteString -> IO ()
@@ -1528,7 +1532,7 @@ incServerStat c userId ProtocolServer {host} cmd res = do
where
statsKey = AgentStatsKey {userId, host = strEncode $ L.head host, clientTs = "", cmd, res}
incClientStatN :: ProtocolServerClient err msg => AgentClient -> UserId -> Client msg -> Int -> ByteString -> ByteString -> IO ()
incClientStatN :: ProtocolServerClient v err msg => AgentClient -> UserId -> Client msg -> Int -> ByteString -> ByteString -> IO ()
incClientStatN c userId pc n cmd res = do
atomically $ incStat c n statsKey
where
+9 -8
View File
@@ -56,16 +56,17 @@ import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations
import Simplex.Messaging.Client
import Simplex.Messaging.Client.Agent ()
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Crypto.Ratchet (supportedE2EEncryptVRange)
import Simplex.Messaging.Crypto.Ratchet (VersionRangeE2E, supportedE2EEncryptVRange)
import Simplex.Messaging.Notifications.Client (defaultNTFClientConfig)
import Simplex.Messaging.Notifications.Transport (NTFVersion)
import Simplex.Messaging.Notifications.Types
import Simplex.Messaging.Protocol (NtfServer, XFTPServer, XFTPServerWithAuth, supportedSMPClientVRange)
import Simplex.Messaging.Protocol (NtfServer, VersionRangeSMPC, XFTPServer, XFTPServerWithAuth, supportedSMPClientVRange)
import Simplex.Messaging.Transport (SMPVersion)
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport (TLS, Transport (..))
import Simplex.Messaging.Transport.Client (defaultSMPPort)
import Simplex.Messaging.Util (allFinally, catchAllErrors, tryAllErrors)
import Simplex.Messaging.Version
import System.Random (StdGen, newStdGen)
import UnliftIO (Async, SomeException)
import UnliftIO.STM
@@ -87,8 +88,8 @@ data AgentConfig = AgentConfig
sndAuthAlg :: C.AuthAlg,
connIdBytes :: Int,
tbqSize :: Natural,
smpCfg :: ProtocolClientConfig,
ntfCfg :: ProtocolClientConfig,
smpCfg :: ProtocolClientConfig SMPVersion,
ntfCfg :: ProtocolClientConfig NTFVersion,
xftpCfg :: XFTPClientConfig,
reconnectInterval :: RetryInterval,
messageRetryInterval :: RetryInterval2,
@@ -116,9 +117,9 @@ data AgentConfig = AgentConfig
caCertificateFile :: FilePath,
privateKeyFile :: FilePath,
certificateFile :: FilePath,
e2eEncryptVRange :: VersionRange,
smpAgentVRange :: VersionRange,
smpClientVRange :: VersionRange
e2eEncryptVRange :: VersionRangeE2E,
smpAgentVRange :: VersionRangeSMPA,
smpClientVRange :: VersionRangeSMPC
}
defaultReconnectInterval :: RetryInterval
+48 -27
View File
@@ -4,6 +4,7 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
@@ -33,6 +34,9 @@
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/agent-protocol.md
module Simplex.Messaging.Agent.Protocol
( -- * Protocol parameters
VersionSMPA,
VersionRangeSMPA,
pattern VersionSMPA,
ratchetSyncSMPAgentVersion,
deliveryRcptsSMPAgentVersion,
supportedSMPAgentVRange,
@@ -175,11 +179,12 @@ import Data.Time.Clock.System (SystemTime)
import Data.Time.ISO8601
import Data.Type.Equality
import Data.Typeable ()
import Data.Word (Word32)
import Data.Word (Word16, Word32)
import Database.SQLite.Simple.FromField
import Database.SQLite.Simple.ToField
import Simplex.FileTransfer.Description
import Simplex.FileTransfer.Protocol (FileParty (..), XFTPErrorType)
import Simplex.FileTransfer.Protocol (FileParty (..))
import Simplex.FileTransfer.Transport (XFTPErrorType)
import Simplex.Messaging.Agent.QueryString
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), PQEncryption (..), pattern PQEncOff, RcvE2ERatchetParams, RcvE2ERatchetParamsUri, SndE2ERatchetParams)
@@ -200,6 +205,10 @@ import Simplex.Messaging.Protocol
SMPServerWithAuth,
SndPublicAuthKey,
SubscriptionMode,
SMPClientVersion,
VersionSMPC,
VersionRangeSMPC,
initialSMPClientVersion,
legacyEncodeServer,
legacyServerP,
legacyStrEncodeServer,
@@ -215,6 +224,7 @@ import Simplex.Messaging.Transport (Transport (..), TransportError, serializeTra
import Simplex.Messaging.Transport.Client (TransportHost, TransportHosts_ (..))
import Simplex.Messaging.Util
import Simplex.Messaging.Version
import Simplex.Messaging.Version.Internal
import Simplex.RemoteControl.Types
import Text.Read
import UnliftIO.Exception (Exception)
@@ -225,19 +235,30 @@ import UnliftIO.Exception (Exception)
-- 3 - support ratchet renegotiation (6/30/2023)
-- 4 - delivery receipts (7/13/2023)
duplexHandshakeSMPAgentVersion :: Version
duplexHandshakeSMPAgentVersion = 2
data SMPAgentVersion
ratchetSyncSMPAgentVersion :: Version
ratchetSyncSMPAgentVersion = 3
instance VersionScope SMPAgentVersion
deliveryRcptsSMPAgentVersion :: Version
deliveryRcptsSMPAgentVersion = 4
type VersionSMPA = Version SMPAgentVersion
currentSMPAgentVersion :: Version
currentSMPAgentVersion = 4
type VersionRangeSMPA = VersionRange SMPAgentVersion
supportedSMPAgentVRange :: VersionRange
pattern VersionSMPA :: Word16 -> VersionSMPA
pattern VersionSMPA v = Version v
duplexHandshakeSMPAgentVersion :: VersionSMPA
duplexHandshakeSMPAgentVersion = VersionSMPA 2
ratchetSyncSMPAgentVersion :: VersionSMPA
ratchetSyncSMPAgentVersion = VersionSMPA 3
deliveryRcptsSMPAgentVersion :: VersionSMPA
deliveryRcptsSMPAgentVersion = VersionSMPA 4
currentSMPAgentVersion :: VersionSMPA
currentSMPAgentVersion = VersionSMPA 4
supportedSMPAgentVRange :: VersionRangeSMPA
supportedSMPAgentVRange = mkVersionRange duplexHandshakeSMPAgentVersion currentSMPAgentVersion
-- it is shorter to allow all handshake headers,
@@ -651,7 +672,7 @@ instance StrEncoding SndQueueInfo where
pure SndQueueInfo {sndServer, sndSwitchStatus}
data ConnectionStats = ConnectionStats
{ connAgentVersion :: Version,
{ connAgentVersion :: VersionSMPA,
rcvQueuesInfo :: [RcvQueueInfo],
sndQueuesInfo :: [SndQueueInfo],
ratchetSyncState :: RatchetSyncState,
@@ -786,27 +807,27 @@ data SMPConfirmation = SMPConfirmation
-- | optional reply queues included in confirmation (added in agent protocol v2)
smpReplyQueues :: [SMPQueueInfo],
-- | SMP client version
smpClientVersion :: Version
smpClientVersion :: VersionSMPC
}
deriving (Show)
data AgentMsgEnvelope
= AgentConfirmation
{ agentVersion :: Version,
{ agentVersion :: VersionSMPA,
e2eEncryption_ :: Maybe (SndE2ERatchetParams 'C.X448),
encConnInfo :: ByteString
}
| AgentMsgEnvelope
{ agentVersion :: Version,
{ agentVersion :: VersionSMPA,
encAgentMessage :: ByteString
}
| AgentInvitation -- the connInfo in contactInvite is only encrypted with per-queue E2E, not with double ratchet,
{ agentVersion :: Version,
{ agentVersion :: VersionSMPA,
connReq :: ConnectionRequestUri 'CMInvitation,
connInfo :: ByteString -- this message is only encrypted with per-queue E2E, not with double ratchet,
}
| AgentRatchetKey
{ agentVersion :: Version,
{ agentVersion :: VersionSMPA,
e2eEncryption :: RcvE2ERatchetParams 'C.X448,
info :: ByteString
}
@@ -1212,16 +1233,16 @@ sameQueue :: SMPQueue q => (SMPServer, SMP.QueueId) -> q -> Bool
sameQueue addr q = sameQAddress addr (qAddress q)
{-# INLINE sameQueue #-}
data SMPQueueInfo = SMPQueueInfo {clientVersion :: Version, queueAddress :: SMPQueueAddress}
data SMPQueueInfo = SMPQueueInfo {clientVersion :: VersionSMPC, queueAddress :: SMPQueueAddress}
deriving (Eq, Show)
instance Encoding SMPQueueInfo where
smpEncode (SMPQueueInfo clientVersion SMPQueueAddress {smpServer, senderId, dhPublicKey})
| clientVersion > 1 = smpEncode (clientVersion, smpServer, senderId, dhPublicKey)
| clientVersion > initialSMPClientVersion = smpEncode (clientVersion, smpServer, senderId, dhPublicKey)
| otherwise = smpEncode clientVersion <> legacyEncodeServer smpServer <> smpEncode (senderId, dhPublicKey)
smpP = do
clientVersion <- smpP
smpServer <- if clientVersion > 1 then smpP else updateSMPServerHosts <$> legacyServerP
smpServer <- if clientVersion > initialSMPClientVersion then smpP else updateSMPServerHosts <$> legacyServerP
(senderId, dhPublicKey) <- smpP
pure $ SMPQueueInfo clientVersion SMPQueueAddress {smpServer, senderId, dhPublicKey}
@@ -1229,20 +1250,20 @@ instance Encoding SMPQueueInfo where
-- But this is created to allow backward and forward compatibility where SMPQueueUri
-- could have more fields to convert to different versions of SMPQueueInfo in a different way,
-- and this instance would become non-trivial.
instance VersionI SMPQueueInfo where
type VersionRangeT SMPQueueInfo = SMPQueueUri
instance VersionI SMPClientVersion SMPQueueInfo where
type VersionRangeT SMPClientVersion SMPQueueInfo = SMPQueueUri
version = clientVersion
toVersionRangeT (SMPQueueInfo _v addr) vr = SMPQueueUri vr addr
instance VersionRangeI SMPQueueUri where
type VersionT SMPQueueUri = SMPQueueInfo
instance VersionRangeI SMPClientVersion SMPQueueUri where
type VersionT SMPClientVersion SMPQueueUri = SMPQueueInfo
versionRange = clientVRange
toVersionT (SMPQueueUri _vr addr) v = SMPQueueInfo v addr
-- | SMP queue information sent out-of-band.
--
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#out-of-band-messages
data SMPQueueUri = SMPQueueUri {clientVRange :: VersionRange, queueAddress :: SMPQueueAddress}
data SMPQueueUri = SMPQueueUri {clientVRange :: VersionRangeSMPC, queueAddress :: SMPQueueAddress}
deriving (Eq, Show)
data SMPQueueAddress = SMPQueueAddress
@@ -1291,7 +1312,7 @@ instance StrEncoding SMPQueueUri where
smpServer = if maxVersion vr < srvHostnamesSMPClientVersion then updateSMPServerHosts srv' else srv'
pure $ SMPQueueUri vr SMPQueueAddress {smpServer, senderId, dhPublicKey}
where
unversioned = (versionToRange 1,[],) <$> strP <* A.endOfInput
unversioned = (versionToRange initialSMPClientVersion,[],) <$> strP <* A.endOfInput
versioned = do
dhKey_ <- optional strP
query <- optional (A.char '/') *> A.char '?' *> strP
@@ -1321,7 +1342,7 @@ deriving instance Show AConnectionRequestUri
data ConnReqUriData = ConnReqUriData
{ crScheme :: ServiceScheme,
crAgentVRange :: VersionRange,
crAgentVRange :: VersionRangeSMPA,
crSmpQueues :: NonEmpty SMPQueueUri,
crClientData :: Maybe CRClientData
}
+4 -4
View File
@@ -44,10 +44,10 @@ import Simplex.Messaging.Protocol
RcvPrivateAuthKey,
SndPrivateAuthKey,
SndPublicAuthKey,
VersionSMPC,
)
import qualified Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.Util ((<$?>))
import Simplex.Messaging.Version
-- * Queue types
@@ -94,7 +94,7 @@ data StoredRcvQueue (q :: QueueStored) = RcvQueue
dbReplaceQueueId :: Maybe Int64,
rcvSwchStatus :: Maybe RcvSwitchStatus,
-- | SMP client version
smpClientVersion :: Version,
smpClientVersion :: VersionSMPC,
-- | credentials used in context of notifications
clientNtfCreds :: Maybe ClientNtfCreds,
deleteErrors :: Int
@@ -157,7 +157,7 @@ data StoredSndQueue (q :: QueueStored) = SndQueue
dbReplaceQueueId :: Maybe Int64,
sndSwchStatus :: Maybe SndSwitchStatus,
-- | SMP client version
smpClientVersion :: Version
smpClientVersion :: VersionSMPC
}
deriving (Show)
@@ -304,7 +304,7 @@ deriving instance Show SomeConn
data ConnData = ConnData
{ connId :: ConnId,
userId :: UserId,
connAgentVersion :: Version,
connAgentVersion :: VersionSMPA,
enableNtfs :: Bool,
lastExternalSndId :: PrevExternalSndId,
deleted :: Bool,
+13 -9
View File
@@ -279,7 +279,7 @@ import Simplex.Messaging.Protocol
import qualified Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.Transport.Client (TransportHost)
import Simplex.Messaging.Util (bshow, catchAllErrors, eitherToMaybe, ifM, safeDecodeUtf8, ($>>=), (<$$>))
import Simplex.Messaging.Version
import Simplex.Messaging.Version.Internal
import System.Directory (copyFile, createDirectoryIfMissing, doesFileExist)
import System.Exit (exitFailure)
import System.FilePath (takeDirectory)
@@ -702,7 +702,7 @@ setRcvQueueDeleted db RcvQueue {rcvId, server = ProtocolServer {host, port}} = d
|]
(host, port, rcvId)
setRcvQueueConfirmedE2E :: DB.Connection -> RcvQueue -> C.DhSecretX25519 -> Version -> IO ()
setRcvQueueConfirmedE2E :: DB.Connection -> RcvQueue -> C.DhSecretX25519 -> VersionSMPC -> IO ()
setRcvQueueConfirmedE2E db RcvQueue {rcvId, server = ProtocolServer {host, port}} e2eDhSecret smpClientVersion =
DB.executeNamed
db
@@ -803,7 +803,7 @@ setRcvQueueNtfCreds db connId clientNtfCreds =
Just ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret} -> (Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret)
Nothing -> (Nothing, Nothing, Nothing, Nothing)
type SMPConfirmationRow = (SndPublicAuthKey, C.PublicKeyX25519, ConnInfo, Maybe [SMPQueueInfo], Maybe Version)
type SMPConfirmationRow = (SndPublicAuthKey, C.PublicKeyX25519, ConnInfo, Maybe [SMPQueueInfo], Maybe VersionSMPC)
smpConfirmation :: SMPConfirmationRow -> SMPConfirmation
smpConfirmation (senderKey, e2ePubKey, connInfo, smpReplyQueues_, smpClientVersion_) =
@@ -812,7 +812,7 @@ smpConfirmation (senderKey, e2ePubKey, connInfo, smpReplyQueues_, smpClientVersi
e2ePubKey,
connInfo,
smpReplyQueues = fromMaybe [] smpReplyQueues_,
smpClientVersion = fromMaybe 1 smpClientVersion_
smpClientVersion = fromMaybe initialSMPClientVersion smpClientVersion_
}
createConfirmation :: DB.Connection -> TVar ChaChaDRG -> NewConfirmation -> IO (Either StoreError ConfirmationId)
@@ -889,7 +889,7 @@ removeConfirmations db connId =
|]
[":conn_id" := connId]
setConnectionVersion :: DB.Connection -> ConnId -> Version -> IO ()
setConnectionVersion :: DB.Connection -> ConnId -> VersionSMPA -> IO ()
setConnectionVersion db connId aVersion =
DB.execute db "UPDATE connections SET smp_agent_version = ? WHERE conn_id = ?" (aVersion, connId)
@@ -1776,6 +1776,10 @@ instance ToField MsgReceiptStatus where toField = toField . decodeLatin1 . strEn
instance FromField MsgReceiptStatus where fromField = fromTextField_ $ eitherToMaybe . strDecode . encodeUtf8
instance ToField (Version v) where toField (Version v) = toField v
instance FromField (Version v) where fromField f = Version <$> fromField f
instance ToField CR.PQEncryption where toField (CR.PQEncryption pqEnc) = toField pqEnc
instance FromField CR.PQEncryption where fromField f = CR.PQEncryption <$> fromField f
@@ -1948,7 +1952,7 @@ setConnDeleted db waitDelivery connId
| otherwise =
DB.execute db "UPDATE connections SET deleted = ? WHERE conn_id = ?" (True, connId)
setConnAgentVersion :: DB.Connection -> ConnId -> Version -> IO ()
setConnAgentVersion :: DB.Connection -> ConnId -> VersionSMPA -> IO ()
setConnAgentVersion db connId aVersion =
DB.execute db "UPDATE connections SET smp_agent_version = ? WHERE conn_id = ?" (aVersion, connId)
@@ -2007,12 +2011,12 @@ rcvQueueQuery =
toRcvQueue ::
(UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SMP.RecipientId, SMP.RcvPrivateAuthKey, SMP.RcvDhSecret, C.PrivateKeyX25519, Maybe C.DhSecretX25519, SMP.SenderId, QueueStatus)
:. (DBQueueId 'QSStored, Bool, Maybe Int64, Maybe RcvSwitchStatus, Maybe Version, Int)
:. (DBQueueId 'QSStored, Bool, Maybe Int64, Maybe RcvSwitchStatus, Maybe VersionSMPC, Int)
:. (Maybe SMP.NtfPublicAuthKey, Maybe SMP.NtfPrivateAuthKey, Maybe SMP.NotifierId, Maybe RcvNtfDhSecret) ->
RcvQueue
toRcvQueue ((userId, keyHash, connId, host, port, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status) :. (dbQueueId, primary, dbReplaceQueueId, rcvSwchStatus, smpClientVersion_, deleteErrors) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_)) =
let server = SMPServer host port keyHash
smpClientVersion = fromMaybe 1 smpClientVersion_
smpClientVersion = fromMaybe initialSMPClientVersion smpClientVersion_
clientNtfCreds = case (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) of
(Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret) -> Just $ ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret}
_ -> Nothing
@@ -2048,7 +2052,7 @@ sndQueueQuery =
toSndQueue ::
(UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SenderId)
:. (Maybe SndPublicAuthKey, SndPrivateAuthKey, Maybe C.PublicKeyX25519, C.DhSecretX25519, QueueStatus)
:. (DBQueueId 'QSStored, Bool, Maybe Int64, Maybe SndSwitchStatus, Version) ->
:. (DBQueueId 'QSStored, Bool, Maybe Int64, Maybe SndSwitchStatus, VersionSMPC) ->
SndQueue
toSndQueue
( (userId, keyHash, connId, host, port, sndId)
+36 -36
View File
@@ -76,7 +76,7 @@ module Simplex.Messaging.Client
PCTransmission,
mkTransmission,
authTransmission,
clientStub,
smpClientStub,
)
where
@@ -117,14 +117,14 @@ import System.Timeout (timeout)
-- | 'SMPClient' is a handle used to send commands to a specific SMP server.
--
-- Use 'getSMPClient' to connect to an SMP server and create a client handle.
data ProtocolClient err msg = ProtocolClient
data ProtocolClient v err msg = ProtocolClient
{ action :: Maybe (Async ()),
thParams :: THandleParams,
thParams :: THandleParams v,
sessionTs :: UTCTime,
client_ :: PClient err msg
client_ :: PClient v err msg
}
data PClient err msg = PClient
data PClient v err msg = PClient
{ connected :: TVar Bool,
transportSession :: TransportSession msg,
transportHost :: TransportHost,
@@ -135,11 +135,11 @@ data PClient err msg = PClient
sentCommands :: TMap CorrId (Request err msg),
sndQ :: TBQueue ByteString,
rcvQ :: TBQueue (NonEmpty (SignedTransmission err msg)),
msgQ :: Maybe (TBQueue (ServerTransmission msg))
msgQ :: Maybe (TBQueue (ServerTransmission v msg))
}
clientStub :: TVar ChaChaDRG -> ByteString -> Version -> Maybe THandleAuth -> STM (ProtocolClient err msg)
clientStub g sessionId thVersion thAuth = do
smpClientStub :: TVar ChaChaDRG -> ByteString -> VersionSMP -> Maybe THandleAuth -> STM (ProtocolClient SMPVersion err msg)
smpClientStub g sessionId thVersion thAuth = do
connected <- newTVar False
clientCorrId <- C.newRandomDRG g
sentCommands <- TM.empty
@@ -174,13 +174,13 @@ clientStub g sessionId thVersion thAuth = do
}
}
type SMPClient = ProtocolClient ErrorType BrokerMsg
type SMPClient = ProtocolClient SMPVersion ErrorType BrokerMsg
-- | Type for client command data
type ClientCommand msg = (Maybe C.APrivateAuthKey, EntityId, ProtoCommand msg)
-- | Type synonym for transmission from some SPM server queue.
type ServerTransmission msg = (TransportSession msg, Version, SessionId, EntityId, msg)
type ServerTransmission v msg = (TransportSession msg, Version v, SessionId, EntityId, msg)
data HostMode
= -- | prefer (or require) onion hosts when connecting via SOCKS proxy
@@ -241,7 +241,7 @@ transportClientConfig NetworkConfig {socksProxy, tcpKeepAlive, logTLSErrors} =
TransportClientConfig {socksProxy, tcpKeepAlive, logTLSErrors, clientCredentials = Nothing}
-- | protocol client configuration.
data ProtocolClientConfig = ProtocolClientConfig
data ProtocolClientConfig v = ProtocolClientConfig
{ -- | size of TBQueue to use for server commands and responses
qSize :: Natural,
-- | default server port if port is not specified in ProtocolServer
@@ -249,13 +249,13 @@ data ProtocolClientConfig = ProtocolClientConfig
-- | network configuration
networkConfig :: NetworkConfig,
-- | client-server protocol version range
serverVRange :: VersionRange,
serverVRange :: VersionRange v,
-- | delay between sending batches of commands (microseconds)
batchDelay :: Maybe Int
}
-- | Default protocol client configuration.
defaultClientConfig :: VersionRange -> ProtocolClientConfig
defaultClientConfig :: VersionRange v -> ProtocolClientConfig v
defaultClientConfig serverVRange =
ProtocolClientConfig
{ qSize = 64,
@@ -265,7 +265,7 @@ defaultClientConfig serverVRange =
batchDelay = Nothing
}
defaultSMPClientConfig :: ProtocolClientConfig
defaultSMPClientConfig :: ProtocolClientConfig SMPVersion
defaultSMPClientConfig = defaultClientConfig supportedClientSMPRelayVRange
data Request err msg = Request
@@ -292,15 +292,15 @@ chooseTransportHost NetworkConfig {socksProxy, hostMode, requiredHostMode} hosts
onionHost = find isOnionHost hosts
publicHost = find (not . isOnionHost) hosts
protocolClientServer :: ProtocolTypeI (ProtoType msg) => ProtocolClient err msg -> String
protocolClientServer :: ProtocolTypeI (ProtoType msg) => ProtocolClient v err msg -> String
protocolClientServer = B.unpack . strEncode . snd3 . transportSession . client_
where
snd3 (_, s, _) = s
transportHost' :: ProtocolClient err msg -> TransportHost
transportHost' :: ProtocolClient v err msg -> TransportHost
transportHost' = transportHost . client_
transportSession' :: ProtocolClient err msg -> TransportSession msg
transportSession' :: ProtocolClient v err msg -> TransportSession msg
transportSession' = transportSession . client_
type UserId = Int64
@@ -313,7 +313,7 @@ type TransportSession msg = (UserId, ProtoServer msg, Maybe EntityId)
--
-- A single queue can be used for multiple 'SMPClient' instances,
-- as 'SMPServerTransmission' includes server information.
getProtocolClient :: forall err msg. Protocol err msg => TVar ChaChaDRG -> TransportSession msg -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> (ProtocolClient err msg -> IO ()) -> IO (Either (ProtocolClientError err) (ProtocolClient err msg))
getProtocolClient :: forall v err msg. Protocol v err msg => TVar ChaChaDRG -> TransportSession msg -> ProtocolClientConfig v -> Maybe (TBQueue (ServerTransmission v msg)) -> (ProtocolClient v err msg -> IO ()) -> IO (Either (ProtocolClientError err) (ProtocolClient v err msg))
getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize, networkConfig, serverVRange, batchDelay} msgQ disconnected = do
case chooseTransportHost networkConfig (host srv) of
Right useHost ->
@@ -322,7 +322,7 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize
Left e -> pure $ Left e
where
NetworkConfig {tcpConnectTimeout, tcpTimeout, smpPingInterval} = networkConfig
mkProtocolClient :: TransportHost -> STM (PClient err msg)
mkProtocolClient :: TransportHost -> STM (PClient v err msg)
mkProtocolClient transportHost = do
connected <- newTVar False
pingErrorCount <- newTVar 0
@@ -345,7 +345,7 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize
msgQ
}
runClient :: (ServiceName, ATransport) -> TransportHost -> PClient err msg -> IO (Either (ProtocolClientError err) (ProtocolClient err msg))
runClient :: (ServiceName, ATransport) -> TransportHost -> PClient v err msg -> IO (Either (ProtocolClientError err) (ProtocolClient v err msg))
runClient (port', ATransport t) useHost c = do
cVar <- newEmptyTMVarIO
let tcConfig = transportClientConfig networkConfig
@@ -366,10 +366,10 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize
"80" -> ("80", transport @WS)
p -> (p, transport @TLS)
client :: forall c. Transport c => TProxy c -> PClient err msg -> TMVar (Either (ProtocolClientError err) (ProtocolClient err msg)) -> c -> IO ()
client :: forall c. Transport c => TProxy c -> PClient v err msg -> TMVar (Either (ProtocolClientError err) (ProtocolClient v err msg)) -> c -> IO ()
client _ c cVar h = do
ks <- atomically $ C.generateKeyPair g
runExceptT (protocolClientHandshake @err @msg h ks (keyHash srv) serverVRange) >>= \case
runExceptT (protocolClientHandshake @v @err @msg h ks (keyHash srv) serverVRange) >>= \case
Left e -> atomically . putTMVar cVar . Left $ PCETransportError e
Right th@THandle {params} -> do
sessionTs <- getCurrentTime
@@ -380,16 +380,16 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize
raceAny_ ([send c' th, process c', receive c' th] <> [ping c' | smpPingInterval > 0])
`finally` disconnected c'
send :: Transport c => ProtocolClient err msg -> THandle c -> IO ()
send :: Transport c => ProtocolClient v err msg -> THandle v c -> IO ()
send ProtocolClient {client_ = PClient {sndQ}} h = forever $ atomically (readTBQueue sndQ) >>= tPutLog h
receive :: Transport c => ProtocolClient err msg -> THandle c -> IO ()
receive :: Transport c => ProtocolClient v err msg -> THandle v c -> IO ()
receive ProtocolClient {client_ = PClient {rcvQ}} h = forever $ tGet h >>= atomically . writeTBQueue rcvQ
ping :: ProtocolClient err msg -> IO ()
ping :: ProtocolClient v err msg -> IO ()
ping c@ProtocolClient {client_ = PClient {pingErrorCount}} = do
threadDelay' smpPingInterval
runExceptT (sendProtocolCommand c Nothing "" $ protocolPing @err @msg) >>= \case
runExceptT (sendProtocolCommand c Nothing "" $ protocolPing @v @err @msg) >>= \case
Left PCEResponseTimeout -> do
cnt <- atomically $ stateTVar pingErrorCount $ \cnt -> (cnt + 1, cnt + 1)
when (maxCnt == 0 || cnt < maxCnt) $ ping c
@@ -397,10 +397,10 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize
where
maxCnt = smpPingCount networkConfig
process :: ProtocolClient err msg -> IO ()
process :: ProtocolClient v err msg -> IO ()
process c = forever $ atomically (readTBQueue $ rcvQ $ client_ c) >>= mapM_ (processMsg c)
processMsg :: ProtocolClient err msg -> SignedTransmission err msg -> IO ()
processMsg :: ProtocolClient v err msg -> SignedTransmission err msg -> IO ()
processMsg c@ProtocolClient {client_ = PClient {sentCommands}} (_, _, (corrId, entId, respOrErr)) =
if B.null $ bs corrId
then sendMsg respOrErr
@@ -428,7 +428,7 @@ proxyUsername :: TransportSession msg -> ByteString
proxyUsername (userId, _, entityId_) = C.sha256Hash $ bshow userId <> maybe "" (":" <>) entityId_
-- | Disconnects client from the server and terminates client threads.
closeProtocolClient :: ProtocolClient err msg -> IO ()
closeProtocolClient :: ProtocolClient v err msg -> IO ()
closeProtocolClient = mapM_ uninterruptibleCancel . action
-- | SMP client error type.
@@ -517,7 +517,7 @@ processSUBResponse c (Response rId r) = case r of
writeSMPMessage :: SMPClient -> RecipientId -> BrokerMsg -> IO ()
writeSMPMessage c rId msg = atomically $ mapM_ (`writeTBQueue` serverTransmission c rId msg) (msgQ $ client_ c)
serverTransmission :: ProtocolClient err msg -> RecipientId -> msg -> ServerTransmission msg
serverTransmission :: ProtocolClient v err msg -> RecipientId -> msg -> ServerTransmission v msg
serverTransmission ProtocolClient {thParams = THandleParams {thVersion, sessionId}, client_ = PClient {transportSession}} entityId message =
(transportSession, thVersion, sessionId, entityId, message)
@@ -635,7 +635,7 @@ sendSMPCommand c pKey qId cmd = sendProtocolCommand c pKey qId (Cmd sParty cmd)
type PCTransmission err msg = (Either TransportError SentRawTransmission, Request err msg)
-- | Send multiple commands with batching and collect responses
sendProtocolCommands :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> NonEmpty (ClientCommand msg) -> IO (NonEmpty (Response err msg))
sendProtocolCommands :: forall v err msg. ProtocolEncoding v err (ProtoCommand msg) => ProtocolClient v err msg -> NonEmpty (ClientCommand msg) -> IO (NonEmpty (Response err msg))
sendProtocolCommands c@ProtocolClient {thParams = THandleParams {batch, blockSize}} cs = do
bs <- batchTransmissions' batch blockSize <$> mapM (mkTransmission c) cs
validate . concat =<< mapM (sendBatch c) bs
@@ -652,12 +652,12 @@ sendProtocolCommands c@ProtocolClient {thParams = THandleParams {batch, blockSiz
where
diff = L.length cs - length rs
streamProtocolCommands :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> NonEmpty (ClientCommand msg) -> ([Response err msg] -> IO ()) -> IO ()
streamProtocolCommands :: forall v err msg. ProtocolEncoding v err (ProtoCommand msg) => ProtocolClient v err msg -> NonEmpty (ClientCommand msg) -> ([Response err msg] -> IO ()) -> IO ()
streamProtocolCommands c@ProtocolClient {thParams = THandleParams {batch, blockSize}} cs cb = do
bs <- batchTransmissions' batch blockSize <$> mapM (mkTransmission c) cs
mapM_ (cb <=< sendBatch c) bs
sendBatch :: ProtocolClient err msg -> TransportBatch (Request err msg) -> IO [Response err msg]
sendBatch :: ProtocolClient v err msg -> TransportBatch (Request err msg) -> IO [Response err msg]
sendBatch c@ProtocolClient {client_ = PClient {sndQ}} b = do
case b of
TBError e Request {entityId} -> do
@@ -673,7 +673,7 @@ sendBatch c@ProtocolClient {client_ = PClient {sndQ}} b = do
(: []) <$> getResponse c r
-- | Send Protocol command
sendProtocolCommand :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> Maybe C.APrivateAuthKey -> EntityId -> ProtoCommand msg -> ExceptT (ProtocolClientError err) IO msg
sendProtocolCommand :: forall v err msg. ProtocolEncoding v err (ProtoCommand msg) => ProtocolClient v err msg -> Maybe C.APrivateAuthKey -> EntityId -> ProtoCommand msg -> ExceptT (ProtocolClientError err) IO msg
sendProtocolCommand c@ProtocolClient {client_ = PClient {sndQ}, thParams = THandleParams {batch, blockSize}} pKey entId cmd =
ExceptT $ uncurry sendRecv =<< mkTransmission c (pKey, entId, cmd)
where
@@ -690,7 +690,7 @@ sendProtocolCommand c@ProtocolClient {client_ = PClient {sndQ}, thParams = THand
| otherwise = tEncode t
-- TODO switch to timeout or TimeManager that supports Int64
getResponse :: ProtocolClient err msg -> Request err msg -> IO (Response err msg)
getResponse :: ProtocolClient v err msg -> Request err msg -> IO (Response err msg)
getResponse ProtocolClient {client_ = PClient {tcpTimeout, pingErrorCount}} Request {entityId, responseVar} = do
response <-
timeout tcpTimeout (atomically (takeTMVar responseVar)) >>= \case
@@ -698,7 +698,7 @@ getResponse ProtocolClient {client_ = PClient {tcpTimeout, pingErrorCount}} Requ
Nothing -> pure $ Left PCEResponseTimeout
pure Response {entityId, response}
mkTransmission :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> ClientCommand msg -> IO (PCTransmission err msg)
mkTransmission :: forall v err msg. ProtocolEncoding v err (ProtoCommand msg) => ProtocolClient v err msg -> ClientCommand msg -> IO (PCTransmission err msg)
mkTransmission ProtocolClient {thParams, client_ = PClient {clientCorrId, sentCommands}} (pKey_, entId, cmd) = do
corrId <- atomically getNextCorrId
let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (corrId, entId, cmd)
+2 -2
View File
@@ -65,7 +65,7 @@ type SMPSub = (SMPSubParty, QueueId)
-- type SMPServerSub = (SMPServer, SMPSub)
data SMPClientAgentConfig = SMPClientAgentConfig
{ smpCfg :: ProtocolClientConfig,
{ smpCfg :: ProtocolClientConfig SMPVersion,
reconnectInterval :: RetryInterval,
msgQSize :: Natural,
agentQSize :: Natural,
@@ -91,7 +91,7 @@ defaultSMPClientAgentConfig =
data SMPClientAgent = SMPClientAgent
{ agentCfg :: SMPClientAgentConfig,
msgQ :: TBQueue (ServerTransmission BrokerMsg),
msgQ :: TBQueue (ServerTransmission SMPVersion BrokerMsg),
agentQ :: TBQueue SMPClientAgentEvent,
randomDrg :: TVar ChaChaDRG,
smpClients :: TMap SMPServer SMPClientVar,
+133 -48
View File
@@ -3,6 +3,7 @@
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
@@ -14,10 +15,64 @@
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
-- {-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fno-warn-redundant-constraints #-}
module Simplex.Messaging.Crypto.Ratchet where
module Simplex.Messaging.Crypto.Ratchet
( Ratchet (..),
RatchetX448,
SkippedMsgDiff (..),
SkippedMsgKeys,
InitialKeys (..),
PQEncryption (..),
pattern PQEncOn,
pattern PQEncOff,
AUseKEM (..),
RatchetKEMState (..),
SRatchetKEMState (..),
RcvPrivRKEMParams,
APrivRKEMParams (..),
RcvE2ERatchetParamsUri,
RcvE2ERatchetParams,
SndE2ERatchetParams,
AE2ERatchetParams (..),
E2ERatchetParamsUri (..),
E2ERatchetParams (..),
VersionE2E,
VersionRangeE2E,
pattern VersionE2E,
kdfX3DHE2EEncryptVersion,
pqRatchetVersion,
currentE2EEncryptVersion,
supportedE2EEncryptVRange,
generateRcvE2EParams,
generateSndE2EParams,
initialPQEncryption,
connPQEncryption,
joinContactInitialKeys,
replyKEM_,
pqX3dhSnd,
pqX3dhRcv,
initSndRatchet,
initRcvRatchet,
rcEncrypt,
rcDecrypt,
-- used in tests
MsgHeader (..),
RatchetVersions (..),
RatchetInitParams (..),
UseKEM (..),
RKEMParams (..),
ARKEMParams (..),
SndRatchet (..),
RcvRatchet (..),
RatchetKEM (..),
RatchetKEMAccepted (..),
RatchetKey (..),
ratchetVersions,
fullHeaderLen,
applySMDiff,
)
where
import Control.Applicative ((<|>))
import Control.Monad.Except
@@ -44,7 +99,7 @@ import qualified Data.Map.Strict as M
import Data.Maybe (fromMaybe, isJust)
import Data.Type.Equality
import Data.Typeable (Typeable)
import Data.Word (Word32)
import Data.Word (Word16, Word32)
import Database.SQLite.Simple.FromField (FromField (..))
import Database.SQLite.Simple.ToField (ToField (..))
import Simplex.Messaging.Agent.QueryString
@@ -55,22 +110,34 @@ import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Parsers (blobFieldDecoder, defaultJSON, parseE, parseE')
import Simplex.Messaging.Util ((<$?>), ($>>=))
import Simplex.Messaging.Version
import Simplex.Messaging.Version.Internal
import UnliftIO.STM
-- e2e encryption headers version history:
-- 1 - binary protocol encoding (1/1/2022)
-- 2 - use KDF in x3dh (10/20/2022)
kdfX3DHE2EEncryptVersion :: Version
kdfX3DHE2EEncryptVersion = 2
data E2EVersion
pqRatchetVersion :: Version
pqRatchetVersion = 3
instance VersionScope E2EVersion
currentE2EEncryptVersion :: Version
currentE2EEncryptVersion = 3
type VersionE2E = Version E2EVersion
supportedE2EEncryptVRange :: VersionRange
type VersionRangeE2E = VersionRange E2EVersion
pattern VersionE2E :: Word16 -> VersionE2E
pattern VersionE2E v = Version v
kdfX3DHE2EEncryptVersion :: VersionE2E
kdfX3DHE2EEncryptVersion = VersionE2E 2
pqRatchetVersion :: VersionE2E
pqRatchetVersion = VersionE2E 3
currentE2EEncryptVersion :: VersionE2E
currentE2EEncryptVersion = VersionE2E 3
supportedE2EEncryptVRange :: VersionRangeE2E
supportedE2EEncryptVRange = mkVersionRange kdfX3DHE2EEncryptVersion currentE2EEncryptVersion
data RatchetKEMState
@@ -120,7 +187,7 @@ instance RatchetKEMStateI s => Encoding (RKEMParams s) where
RKParamsAccepted ct k -> smpEncode ('A', ct, k)
smpP = (\(ARKP _ ps) -> checkRatchetKEMState' ps) <$?> smpP
instance Encoding (ARKEMParams) where
instance Encoding ARKEMParams where
smpEncode (ARKP _ ps) = smpEncode ps
smpP =
smpP >>= \case
@@ -129,7 +196,7 @@ instance Encoding (ARKEMParams) where
_ -> fail "bad ratchet KEM params"
data E2ERatchetParams (s :: RatchetKEMState) (a :: Algorithm)
= E2ERatchetParams Version (PublicKey a) (PublicKey a) (Maybe (RKEMParams s))
= E2ERatchetParams VersionE2E (PublicKey a) (PublicKey a) (Maybe (RKEMParams s))
deriving (Show)
data AE2ERatchetParams (a :: Algorithm)
@@ -159,12 +226,12 @@ instance (RatchetKEMStateI s, AlgorithmI a) => Encoding (E2ERatchetParams s a) w
instance AlgorithmI a => Encoding (AE2ERatchetParams a) where
smpEncode (AE2ERatchetParams _ ps) = smpEncode ps
smpP = (\(AnyE2ERatchetParams s _ ps) -> (AE2ERatchetParams s) <$> checkAlgorithm ps) <$?> smpP
smpP = (\(AnyE2ERatchetParams s _ ps) -> AE2ERatchetParams s <$> checkAlgorithm ps) <$?> smpP
instance Encoding AnyE2ERatchetParams where
smpEncode (AnyE2ERatchetParams _ _ ps) = smpEncode ps
smpP = do
v :: Version <- smpP
v :: VersionE2E <- smpP
APublicDhKey a k1 <- smpP
APublicDhKey a' k2 <- smpP
case testEquality a a' of
@@ -174,25 +241,25 @@ instance Encoding AnyE2ERatchetParams where
Just (ARKP s kem) -> pure $ AnyE2ERatchetParams s a $ E2ERatchetParams v k1 k2 (Just kem)
Nothing -> pure $ AnyE2ERatchetParams SRKSProposed a $ E2ERatchetParams v k1 k2 Nothing
where
kemP :: Version -> Parser (Maybe (ARKEMParams))
kemP :: VersionE2E -> Parser (Maybe ARKEMParams)
kemP v
| v >= pqRatchetVersion = smpP
| otherwise = pure Nothing
instance VersionI (E2ERatchetParams s a) where
type VersionRangeT (E2ERatchetParams s a) = E2ERatchetParamsUri s a
instance VersionI E2EVersion (E2ERatchetParams s a) where
type VersionRangeT E2EVersion (E2ERatchetParams s a) = E2ERatchetParamsUri s a
version (E2ERatchetParams v _ _ _) = v
toVersionRangeT (E2ERatchetParams _ k1 k2 kem_) vr = E2ERatchetParamsUri vr k1 k2 kem_
instance VersionRangeI (E2ERatchetParamsUri s a) where
type VersionT (E2ERatchetParamsUri s a) = (E2ERatchetParams s a)
instance VersionRangeI E2EVersion (E2ERatchetParamsUri s a) where
type VersionT E2EVersion (E2ERatchetParamsUri s a) = (E2ERatchetParams s a)
versionRange (E2ERatchetParamsUri vr _ _ _) = vr
toVersionT (E2ERatchetParamsUri _ k1 k2 kem_) v = E2ERatchetParams v k1 k2 kem_
type RcvE2ERatchetParamsUri a = E2ERatchetParamsUri 'RKSProposed a
data E2ERatchetParamsUri (s :: RatchetKEMState) (a :: Algorithm)
= E2ERatchetParamsUri VersionRange (PublicKey a) (PublicKey a) (Maybe (RKEMParams s))
= E2ERatchetParamsUri VersionRangeE2E (PublicKey a) (PublicKey a) (Maybe (RKEMParams s))
deriving (Show)
data AE2ERatchetParamsUri (a :: Algorithm)
@@ -228,13 +295,13 @@ instance (RatchetKEMStateI s, AlgorithmI a) => StrEncoding (E2ERatchetParamsUri
instance AlgorithmI a => StrEncoding (AE2ERatchetParamsUri a) where
strEncode (AE2ERatchetParamsUri _ ps) = strEncode ps
strP = (\(AnyE2ERatchetParamsUri s _ ps) -> (AE2ERatchetParamsUri s) <$> checkAlgorithm ps) <$?> strP
strP = (\(AnyE2ERatchetParamsUri s _ ps) -> AE2ERatchetParamsUri s <$> checkAlgorithm ps) <$?> strP
instance StrEncoding AnyE2ERatchetParamsUri where
strEncode (AnyE2ERatchetParamsUri _ _ ps) = strEncode ps
strP = do
query <- strP
vr :: VersionRange <- queryParam "v" query
vr :: VersionRangeE2E <- queryParam "v" query
keys <- L.toList <$> queryParam "x3dh" query
case keys of
[APublicDhKey a k1, APublicDhKey a' k2] -> case testEquality a a' of
@@ -248,7 +315,7 @@ instance StrEncoding AnyE2ERatchetParamsUri where
kemP vr query
| maxVersion vr >= pqRatchetVersion =
queryParam_ "kem_key" query
$>>= \k -> (Just . kemParams k <$> queryParam_ "kem_ct" query)
$>>= \k -> Just . kemParams k <$> queryParam_ "kem_ct" query
| otherwise = pure Nothing
kemParams k = \case
Nothing -> ARKP SRKSProposed $ RKParamsProposed k
@@ -272,7 +339,7 @@ instance RatchetKEMStateI s => Encoding (PrivRKEMParams s) where
PrivateRKParamsAccepted ct shared k -> smpEncode ('A', ct, shared, k)
smpP = (\(APRKP _ ps) -> checkRatchetKEMState' ps) <$?> smpP
instance Encoding (APrivRKEMParams) where
instance Encoding APrivRKEMParams where
smpEncode (APRKP _ ps) = smpEncode ps
smpP =
smpP >>= \case
@@ -290,7 +357,7 @@ data UseKEM (s :: RatchetKEMState) where
data AUseKEM = forall s. RatchetKEMStateI s => AUseKEM (SRatchetKEMState s) (UseKEM s)
generateE2EParams :: forall s a. (AlgorithmI a, DhAlgorithm a) => TVar ChaChaDRG -> Version -> Maybe (UseKEM s) -> IO (PrivateKey a, PrivateKey a, Maybe (PrivRKEMParams s), E2ERatchetParams s a)
generateE2EParams :: forall s a. (AlgorithmI a, DhAlgorithm a) => TVar ChaChaDRG -> VersionE2E -> Maybe (UseKEM s) -> IO (PrivateKey a, PrivateKey a, Maybe (PrivRKEMParams s), E2ERatchetParams s a)
generateE2EParams g v useKEM_ = do
(k1, pk1) <- atomically $ generateKeyPair g
(k2, pk2) <- atomically $ generateKeyPair g
@@ -309,7 +376,7 @@ generateE2EParams g v useKEM_ = do
_ -> pure Nothing
-- used by party initiating connection, Bob in double-ratchet spec
generateRcvE2EParams :: (AlgorithmI a, DhAlgorithm a) => TVar ChaChaDRG -> Version -> PQEncryption -> IO (PrivateKey a, PrivateKey a, Maybe (PrivRKEMParams 'RKSProposed), E2ERatchetParams 'RKSProposed a)
generateRcvE2EParams :: (AlgorithmI a, DhAlgorithm a) => TVar ChaChaDRG -> VersionE2E -> PQEncryption -> IO (PrivateKey a, PrivateKey a, Maybe (PrivRKEMParams 'RKSProposed), E2ERatchetParams 'RKSProposed a)
generateRcvE2EParams g v = generateE2EParams g v . proposeKEM_
where
proposeKEM_ :: PQEncryption -> Maybe (UseKEM 'RKSProposed)
@@ -319,7 +386,7 @@ generateRcvE2EParams g v = generateE2EParams g v . proposeKEM_
-- used by party accepting connection, Alice in double-ratchet spec
generateSndE2EParams :: forall a. (AlgorithmI a, DhAlgorithm a) => TVar ChaChaDRG -> Version -> Maybe AUseKEM -> IO (PrivateKey a, PrivateKey a, Maybe APrivRKEMParams, AE2ERatchetParams a)
generateSndE2EParams :: forall a. (AlgorithmI a, DhAlgorithm a) => TVar ChaChaDRG -> VersionE2E -> Maybe AUseKEM -> IO (PrivateKey a, PrivateKey a, Maybe APrivRKEMParams, AE2ERatchetParams a)
generateSndE2EParams g v = \case
Nothing -> do
(pk1, pk2, _, e2eParams) <- generateE2EParams g v Nothing
@@ -385,7 +452,7 @@ type RatchetX448 = Ratchet 'X448
data Ratchet a = Ratchet
{ -- ratchet version range sent in messages (current .. max supported ratchet version)
rcVersion :: VersionRange,
rcVersion :: RatchetVersions,
-- associated data - must be the same in both parties ratchets
rcAD :: Str,
rcDHRs :: PrivateKey a,
@@ -405,6 +472,29 @@ data Ratchet a = Ratchet
}
deriving (Show)
data RatchetVersions = RVersions
{ current :: VersionE2E,
maxSupported :: VersionE2E
}
deriving (Eq, Show)
instance ToJSON RatchetVersions where
-- TODO v5.7 or v5.8 change to the default record encoding
toJSON (RVersions v1 v2) = toJSON (v1, v2)
toEncoding (RVersions v1 v2) = toEncoding (v1, v2)
instance FromJSON RatchetVersions where
-- TODO v6.0 replace with the default record parser
-- this parser supports JSON record encoding for forward compatibility
parseJSON v = (tupleP <|> recordP v) >>= toRV
where
tupleP = parseJSON v
recordP = J.withObject "RatchetVersions" $ \o -> (,) <$> o J..: "current" <*> o J..: "maxSupported"
toRV (v1, v2) = maybe (fail "bad version range") (pure . ratchetVersions) $ safeVersionRange v1 v2
ratchetVersions :: VersionRangeE2E -> RatchetVersions
ratchetVersions (VersionRange v1 v2) = RVersions {current = v1, maxSupported = v2}
data SndRatchet a = SndRatchet
{ rcDHRr :: PublicKey a,
rcCKs :: RatchetKey,
@@ -494,12 +584,12 @@ instance FromField MessageKey where fromField = blobFieldDecoder smpDecode
-- // above added for KEM
-- @
initSndRatchet ::
forall a. (AlgorithmI a, DhAlgorithm a) => VersionRange -> PublicKey a -> PrivateKey a -> (RatchetInitParams, Maybe KEMKeyPair) -> Ratchet a
initSndRatchet rcVersion rcDHRr rcDHRs (RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted}, rcPQRs_) = do
forall a. (AlgorithmI a, DhAlgorithm a) => VersionRangeE2E -> PublicKey a -> PrivateKey a -> (RatchetInitParams, Maybe KEMKeyPair) -> Ratchet a
initSndRatchet v rcDHRr rcDHRs (RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted}, rcPQRs_) = do
-- state.RK, state.CKs, state.NHKs = KDF_RK_HE(SK, DH(state.DHRs, state.DHRr) || state.PQRss)
let (rcRK, rcCKs, rcNHKs) = rootKdf ratchetKey rcDHRr rcDHRs (rcPQRss <$> kemAccepted)
in Ratchet
{ rcVersion,
{ rcVersion = ratchetVersions v,
rcAD = assocData,
rcDHRs,
rcKEM = (`RatchetKEM` kemAccepted) <$> rcPQRs_,
@@ -523,10 +613,10 @@ initSndRatchet rcVersion rcDHRr rcDHRs (RatchetInitParams {assocData, ratchetKey
-- Please note that the public part of rcDHRs was sent to the sender
-- as part of the connection request and random salt was received from the sender.
initRcvRatchet ::
forall a. (AlgorithmI a, DhAlgorithm a) => VersionRange -> PrivateKey a -> (RatchetInitParams, Maybe KEMKeyPair) -> PQEncryption -> Ratchet a
initRcvRatchet rcVersion rcDHRs (RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted}, rcPQRs_) rcEnableKEM =
forall a. (AlgorithmI a, DhAlgorithm a) => VersionRangeE2E -> PrivateKey a -> (RatchetInitParams, Maybe KEMKeyPair) -> PQEncryption -> Ratchet a
initRcvRatchet v rcDHRs (RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted}, rcPQRs_) rcEnableKEM =
Ratchet
{ rcVersion,
{ rcVersion = ratchetVersions v,
rcAD = assocData,
rcDHRs,
-- rcKEM:
@@ -552,7 +642,7 @@ initRcvRatchet rcVersion rcDHRs (RatchetInitParams {assocData, ratchetKey, sndHK
-- ct = state.PQRct // added for KEM #1
data MsgHeader a = MsgHeader
{ -- | max supported ratchet version
msgMaxVersion :: Version,
msgMaxVersion :: VersionE2E,
msgDHRs :: PublicKey a,
msgKEM :: Maybe ARKEMParams,
msgPN :: Word32,
@@ -560,11 +650,6 @@ data MsgHeader a = MsgHeader
}
deriving (Show)
data AMsgHeader
= forall a.
(AlgorithmI a, DhAlgorithm a) =>
AMsgHeader (SAlgorithm a) (MsgHeader a)
-- to allow extension without increasing the size, the actual header length is:
-- 69 = 2 (original size) + 2 + 1+56 (Curve448) + 4 + 4
-- TODO PQ this must be version-dependent
@@ -590,7 +675,7 @@ instance AlgorithmI a => Encoding (MsgHeader a) where
pure MsgHeader {msgMaxVersion, msgDHRs, msgKEM, msgPN, msgNs}
data EncMessageHeader = EncMessageHeader
{ ehVersion :: Version,
{ ehVersion :: VersionE2E,
ehIV :: IV,
ehAuthTag :: AuthTag,
ehBody :: ByteString
@@ -611,12 +696,12 @@ data EncRatchetMessage = EncRatchetMessage
emBody :: ByteString
}
encodeEncRatchetMessage :: Version -> EncRatchetMessage -> ByteString
encodeEncRatchetMessage :: VersionE2E -> EncRatchetMessage -> ByteString
encodeEncRatchetMessage v EncRatchetMessage {emHeader, emBody, emAuthTag}
| v >= pqRatchetVersion = smpEncode (Large emHeader, emAuthTag, Tail emBody)
| otherwise = smpEncode (emHeader, emAuthTag, Tail emBody)
encRatchetMessageP :: Version -> Parser EncRatchetMessage
encRatchetMessageP :: VersionE2E -> Parser EncRatchetMessage
encRatchetMessageP v = do
emHeader <- if v >= pqRatchetVersion then unLarge <$> smpP else smpP
(emAuthTag, Tail emBody) <- smpP
@@ -694,10 +779,10 @@ rcEncrypt rc@Ratchet {rcSnd = Just sr@SndRatchet {rcCKs, rcHKs}, rcDHRs, rcKEM,
-- enc_header = HENCRYPT(state.HKs, header)
(ehAuthTag, ehBody) <- encryptAEAD rcHKs ehIV paddedHeaderLen rcAD msgHeader
-- return enc_header, ENCRYPT(mk, plaintext, CONCAT(AD, enc_header))
-- TODO PQ versioning in Ratchet should change somehow
let emHeader = smpEncode EncMessageHeader {ehVersion = maxVersion rcVersion, ehBody, ehAuthTag, ehIV}
-- TODO PQ versioning in Ratchet should change: we should use "current" version here
let emHeader = smpEncode EncMessageHeader {ehVersion = maxSupported rcVersion, ehBody, ehAuthTag, ehIV}
(emAuthTag, emBody) <- encryptAEAD mk iv paddedMsgLen (rcAD <> emHeader) msg
let msg' = encodeEncRatchetMessage (maxVersion rcVersion) EncRatchetMessage {emHeader, emBody, emAuthTag}
let msg' = encodeEncRatchetMessage (maxSupported rcVersion) EncRatchetMessage {emHeader, emBody, emAuthTag}
-- state.Ns += 1
rc' = rc {rcSnd = Just sr {rcCKs = ck'}, rcNs = rcNs + 1}
rc'' = case pqMode_ of
@@ -719,7 +804,7 @@ rcEncrypt rc@Ratchet {rcSnd = Just sr@SndRatchet {rcCKs, rcHKs}, rcDHRs, rcKEM,
msgHeader =
smpEncode
MsgHeader
{ msgMaxVersion = maxVersion rcVersion,
{ msgMaxVersion = maxSupported rcVersion,
msgDHRs = publicKey rcDHRs,
msgKEM = msgKEMParams <$> rcKEM,
msgPN = rcPN,
@@ -752,7 +837,7 @@ rcDecrypt ::
ExceptT CryptoError IO (DecryptResult a)
rcDecrypt g rc@Ratchet {rcRcv, rcAD = Str rcAD, rcVersion} rcMKSkipped msg' = do
-- TODO PQ versioning should change
encMsg@EncRatchetMessage {emHeader} <- parseE CryptoHeaderError (encRatchetMessageP $ maxVersion rcVersion) msg'
encMsg@EncRatchetMessage {emHeader} <- parseE CryptoHeaderError (encRatchetMessageP $ maxSupported rcVersion) msg'
encHdr <- parseE CryptoHeaderError smpP emHeader
-- plaintext = TrySkippedMessageKeysHE(state, enc_header, cipher-text, AD)
decryptSkipped encHdr encMsg >>= \case
@@ -10,15 +10,15 @@ import Data.Word (Word16)
import Simplex.Messaging.Client
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Notifications.Protocol
import Simplex.Messaging.Notifications.Transport (supportedClientNTFVRange)
import Simplex.Messaging.Notifications.Transport (NTFVersion, supportedClientNTFVRange)
import Simplex.Messaging.Protocol (ErrorType)
import Simplex.Messaging.Util (bshow)
type NtfClient = ProtocolClient ErrorType NtfResponse
type NtfClient = ProtocolClient NTFVersion ErrorType NtfResponse
type NtfClientError = ProtocolClientError ErrorType
defaultNTFClientConfig :: ProtocolClientConfig
defaultNTFClientConfig :: ProtocolClientConfig NTFVersion
defaultNTFClientConfig = defaultClientConfig supportedClientNTFVRange
ntfRegisterToken :: NtfClient -> C.APrivateAuthKey -> NewNtfEntity 'Token -> ExceptT NtfClientError IO (NtfTokenId, C.PublicKeyX25519)
@@ -28,7 +28,7 @@ import Simplex.Messaging.Agent.Protocol (updateSMPServerHosts)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Notifications.Transport (ntfClientHandshake)
import Simplex.Messaging.Notifications.Transport (NTFVersion, ntfClientHandshake)
import Simplex.Messaging.Parsers (fromTextField_)
import Simplex.Messaging.Protocol hiding (Command (..), CommandTag (..))
import Simplex.Messaging.Util (eitherToMaybe, (<$?>))
@@ -147,7 +147,7 @@ instance Encoding ANewNtfEntity where
'S' -> ANE SSubscription <$> (NewNtfSub <$> smpP <*> smpP <*> smpP)
_ -> fail "bad ANewNtfEntity"
instance Protocol ErrorType NtfResponse where
instance Protocol NTFVersion ErrorType NtfResponse where
type ProtoCommand NtfResponse = NtfCmd
type ProtoType NtfResponse = 'PNTF
protocolClientHandshake = ntfClientHandshake
@@ -184,7 +184,7 @@ data NtfCmd = forall e. NtfEntityI e => NtfCmd (SNtfEntity e) (NtfCommand e)
deriving instance Show NtfCmd
instance NtfEntityI e => ProtocolEncoding ErrorType (NtfCommand e) where
instance NtfEntityI e => ProtocolEncoding NTFVersion ErrorType (NtfCommand e) where
type Tag (NtfCommand e) = NtfCommandTag e
encodeProtocol _v = \case
TNEW newTkn -> e (TNEW_, ' ', newTkn)
@@ -203,7 +203,7 @@ instance NtfEntityI e => ProtocolEncoding ErrorType (NtfCommand e) where
protocolP _v tag = (\(NtfCmd _ c) -> checkEntity c) <$?> protocolP _v (NCT (sNtfEntity @e) tag)
fromProtocolError = fromProtocolError @ErrorType @NtfResponse
fromProtocolError = fromProtocolError @NTFVersion @ErrorType @NtfResponse
{-# INLINE fromProtocolError #-}
checkCredentials (auth, _, entityId, _) cmd = case cmd of
@@ -223,7 +223,7 @@ instance NtfEntityI e => ProtocolEncoding ErrorType (NtfCommand e) where
| not (B.null entityId) = Left $ CMD HAS_AUTH
| otherwise = Right cmd
instance ProtocolEncoding ErrorType NtfCmd where
instance ProtocolEncoding NTFVersion ErrorType NtfCmd where
type Tag NtfCmd = NtfCmdTag
encodeProtocol _v (NtfCmd _ c) = encodeProtocol _v c
@@ -243,7 +243,7 @@ instance ProtocolEncoding ErrorType NtfCmd where
SDEL_ -> pure SDEL
PING_ -> pure PING
fromProtocolError = fromProtocolError @ErrorType @NtfResponse
fromProtocolError = fromProtocolError @NTFVersion @ErrorType @NtfResponse
{-# INLINE fromProtocolError #-}
checkCredentials t (NtfCmd e c) = NtfCmd e <$> checkCredentials t c
@@ -290,7 +290,7 @@ data NtfResponse
| NRPong
deriving (Show)
instance ProtocolEncoding ErrorType NtfResponse where
instance ProtocolEncoding NTFVersion ErrorType NtfResponse where
type Tag NtfResponse = NtfResponseTag
encodeProtocol _v = \case
NRTknId entId dhKey -> e (NRTknId_, ' ', entId, dhKey)
@@ -338,7 +338,7 @@ updateTknStatus NtfTknData {ntfTknId, tknStatus} status = do
old <- atomically $ stateTVar tknStatus (,status)
when (old /= status) $ withNtfLog $ \sl -> logTokenStatus sl ntfTknId status
runNtfClientTransport :: Transport c => THandle c -> M ()
runNtfClientTransport :: Transport c => THandleNTF c -> M ()
runNtfClientTransport th@THandle {params} = do
qSize <- asks $ clientQSize . config
ts <- liftIO getSystemTime
@@ -355,7 +355,7 @@ runNtfClientTransport th@THandle {params} = do
clientDisconnected :: NtfServerClient -> IO ()
clientDisconnected NtfServerClient {connected} = atomically $ writeTVar connected False
receive :: Transport c => THandle c -> NtfServerClient -> M ()
receive :: Transport c => THandleNTF c -> NtfServerClient -> M ()
receive th@THandle {params = THandleParams {thAuth}} NtfServerClient {rcvQ, sndQ, rcvActiveAt} = forever $ do
ts <- liftIO $ tGet th
forM_ ts $ \t@(_, _, (corrId, entId, cmdOrError)) -> do
@@ -370,7 +370,7 @@ receive th@THandle {params = THandleParams {thAuth}} NtfServerClient {rcvQ, sndQ
where
write q t = atomically $ writeTBQueue q t
send :: Transport c => THandle c -> NtfServerClient -> IO ()
send :: Transport c => THandleNTF c -> NtfServerClient -> IO ()
send h@THandle {params} NtfServerClient {sndQ, sndActiveAt} = forever $ do
t <- atomically $ readTBQueue sndQ
void . liftIO $ tPut h [Right (Nothing, encodeTransmission params t)]
@@ -24,6 +24,7 @@ import Numeric.Natural
import Simplex.Messaging.Client.Agent
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Notifications.Protocol
import Simplex.Messaging.Notifications.Transport (NTFVersion, VersionRangeNTF)
import Simplex.Messaging.Notifications.Server.Push.APNS
import Simplex.Messaging.Notifications.Server.Stats
import Simplex.Messaging.Notifications.Server.Store
@@ -34,7 +35,6 @@ import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport (ATransport, THandleParams)
import Simplex.Messaging.Transport.Server (TransportServerConfig, loadFingerprint, loadTLSServerParams)
import Simplex.Messaging.Version (VersionRange)
import System.IO (IOMode (..))
import System.Mem.Weak (Weak)
import UnliftIO.STM
@@ -60,7 +60,7 @@ data NtfServerConfig = NtfServerConfig
logStatsStartTime :: Int64,
serverStatsLogFile :: FilePath,
serverStatsBackupFile :: Maybe FilePath,
ntfServerVRange :: VersionRange,
ntfServerVRange :: VersionRangeNTF,
transportConfig :: TransportServerConfig
}
@@ -161,13 +161,13 @@ data NtfRequest
data NtfServerClient = NtfServerClient
{ rcvQ :: TBQueue NtfRequest,
sndQ :: TBQueue (Transmission NtfResponse),
ntfThParams :: THandleParams,
ntfThParams :: THandleParams NTFVersion,
connected :: TVar Bool,
rcvActiveAt :: TVar SystemTime,
sndActiveAt :: TVar SystemTime
}
newNtfServerClient :: Natural -> THandleParams -> SystemTime -> STM NtfServerClient
newNtfServerClient :: Natural -> THandleParams NTFVersion -> SystemTime -> STM NtfServerClient
newNtfServerClient qSize ntfThParams ts = do
rcvQ <- newTBQueue qSize
sndQ <- newTBQueue qSize
@@ -3,6 +3,7 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Simplex.Messaging.Notifications.Transport where
@@ -12,33 +13,51 @@ import Control.Monad.Except
import Data.Attoparsec.ByteString.Char8 (Parser)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Word (Word16)
import qualified Data.X509 as X
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding
import Simplex.Messaging.Transport
import Simplex.Messaging.Version
import Simplex.Messaging.Version.Internal
import Simplex.Messaging.Util (liftEitherWith)
ntfBlockSize :: Int
ntfBlockSize = 512
authBatchCmdsNTFVersion :: Version
authBatchCmdsNTFVersion = 2
data NTFVersion
currentClientNTFVersion :: Version
currentClientNTFVersion = 1
instance VersionScope NTFVersion
currentServerNTFVersion :: Version
currentServerNTFVersion = 1
type VersionNTF = Version NTFVersion
supportedClientNTFVRange :: VersionRange
supportedClientNTFVRange = mkVersionRange 1 currentClientNTFVersion
type VersionRangeNTF = VersionRange NTFVersion
supportedServerNTFVRange :: VersionRange
supportedServerNTFVRange = mkVersionRange 1 currentServerNTFVersion
pattern VersionNTF :: Word16 -> VersionNTF
pattern VersionNTF v = Version v
initialNTFVersion :: VersionNTF
initialNTFVersion = VersionNTF 1
authBatchCmdsNTFVersion :: VersionNTF
authBatchCmdsNTFVersion = VersionNTF 2
currentClientNTFVersion :: VersionNTF
currentClientNTFVersion = VersionNTF 1
currentServerNTFVersion :: VersionNTF
currentServerNTFVersion = VersionNTF 1
supportedClientNTFVRange :: VersionRangeNTF
supportedClientNTFVRange = mkVersionRange initialNTFVersion currentClientNTFVersion
supportedServerNTFVRange :: VersionRangeNTF
supportedServerNTFVRange = mkVersionRange initialNTFVersion currentServerNTFVersion
type THandleNTF c = THandle NTFVersion c
data NtfServerHandshake = NtfServerHandshake
{ ntfVersionRange :: VersionRange,
{ ntfVersionRange :: VersionRangeNTF,
sessionId :: SessionId,
-- pub key to agree shared secrets for command authorization and entity ID encryption.
authPubKey :: Maybe (X.SignedExact X.PubKey)
@@ -46,7 +65,7 @@ data NtfServerHandshake = NtfServerHandshake
data NtfClientHandshake = NtfClientHandshake
{ -- | agreed SMP notifications server protocol version
ntfVersion :: Version,
ntfVersion :: VersionNTF,
-- | server identity - CA certificate fingerprint
keyHash :: C.KeyHash,
-- pub key to agree shared secret for entity ID encryption, shared secret for command authorization is agreed using per-queue keys.
@@ -66,12 +85,12 @@ instance Encoding NtfServerHandshake where
authPubKey <- authEncryptCmdsP (maxVersion ntfVersionRange) $ C.getSignedExact <$> smpP
pure NtfServerHandshake {ntfVersionRange, sessionId, authPubKey}
encodeAuthEncryptCmds :: Encoding a => Version -> Maybe a -> ByteString
encodeAuthEncryptCmds :: Encoding a => VersionNTF -> Maybe a -> ByteString
encodeAuthEncryptCmds v k
| v >= authBatchCmdsNTFVersion = maybe "" smpEncode k
| otherwise = ""
authEncryptCmdsP :: Version -> Parser a -> Parser (Maybe a)
authEncryptCmdsP :: VersionNTF -> Parser a -> Parser (Maybe a)
authEncryptCmdsP v p = if v >= authBatchCmdsNTFVersion then Just <$> p else pure Nothing
instance Encoding NtfClientHandshake where
@@ -83,16 +102,16 @@ instance Encoding NtfClientHandshake where
authPubKey <- ntfAuthPubKeyP ntfVersion
pure NtfClientHandshake {ntfVersion, keyHash, authPubKey}
ntfAuthPubKeyP :: Version -> Parser (Maybe C.PublicKeyX25519)
ntfAuthPubKeyP :: VersionNTF -> Parser (Maybe C.PublicKeyX25519)
ntfAuthPubKeyP v = if v >= authBatchCmdsNTFVersion then Just <$> smpP else pure Nothing
encodeNtfAuthPubKey :: Version -> Maybe C.PublicKeyX25519 -> ByteString
encodeNtfAuthPubKey :: VersionNTF -> Maybe C.PublicKeyX25519 -> ByteString
encodeNtfAuthPubKey v k
| v >= authBatchCmdsNTFVersion = maybe "" smpEncode k
| otherwise = ""
-- | Notifcations server transport handshake.
ntfServerHandshake :: forall c. Transport c => C.APrivateSignKey -> c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
ntfServerHandshake :: forall c. Transport c => C.APrivateSignKey -> c -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeNTF -> ExceptT TransportError IO (THandleNTF c)
ntfServerHandshake serverSignKey c (k, pk) kh ntfVRange = do
let th@THandle {params = THandleParams {sessionId}} = ntfTHandle c
let sk = C.signX509 serverSignKey $ C.publicToX509 k
@@ -106,7 +125,7 @@ ntfServerHandshake serverSignKey c (k, pk) kh ntfVRange = do
| otherwise -> throwError $ TEHandshake VERSION
-- | Notifcations server client transport handshake.
ntfClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
ntfClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeNTF -> ExceptT TransportError IO (THandleNTF c)
ntfClientHandshake c (k, pk) keyHash ntfVRange = do
let th@THandle {params = THandleParams {sessionId}} = ntfTHandle c
NtfServerHandshake {sessionId = sessId, ntfVersionRange, authPubKey = sk'} <- getHandshake th
@@ -122,15 +141,15 @@ ntfClientHandshake c (k, pk) keyHash ntfVRange = do
pure $ ntfThHandle th v pk sk_
Nothing -> throwError $ TEHandshake VERSION
ntfThHandle :: forall c. THandle c -> Version -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> THandle c
ntfThHandle :: forall c. THandleNTF c -> VersionNTF -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> THandleNTF c
ntfThHandle th@THandle {params} v privKey k_ =
-- TODO drop SMP v6: make thAuth non-optional
let thAuth = (\k -> THandleAuth {peerPubKey = k, privKey}) <$> k_
v3 = v >= authBatchCmdsNTFVersion
params' = params {thVersion = v, thAuth, implySessId = v3, batch = v3}
in (th :: THandle c) {params = params'}
in (th :: THandleNTF c) {params = params'}
ntfTHandle :: Transport c => c -> THandle c
ntfTHandle :: Transport c => c -> THandleNTF c
ntfTHandle c = THandle {connection = c, params}
where
params = THandleParams {sessionId = tlsUnique c, blockSize = ntfBlockSize, thVersion = 0, thAuth = Nothing, implySessId = False, batch = False}
params = THandleParams {sessionId = tlsUnique c, blockSize = ntfBlockSize, thVersion = VersionNTF 0, thAuth = Nothing, implySessId = False, batch = False}
+55 -34
View File
@@ -46,6 +46,10 @@ module Simplex.Messaging.Protocol
e2eEncMessageLength,
-- * SMP protocol types
SMPClientVersion,
VersionSMPC,
VersionRangeSMPC,
pattern VersionSMPC,
ProtocolEncoding (..),
Command (..),
SubscriptionMode (..),
@@ -117,6 +121,7 @@ module Simplex.Messaging.Protocol
SMPMsgMeta (..),
NMsgMeta (..),
MsgFlags (..),
initialSMPClientVersion,
userProtocol,
rcvMessageMeta,
noMsgFlags,
@@ -179,6 +184,7 @@ import Data.Maybe (isNothing)
import Data.String
import Data.Time.Clock.System (SystemTime (..))
import Data.Type.Equality
import Data.Word (Word16)
import GHC.TypeLits (ErrorMessage (..), TypeError, type (+))
import Network.Socket (ServiceName)
import qualified Simplex.Messaging.Crypto as C
@@ -190,19 +196,34 @@ import Simplex.Messaging.Transport
import Simplex.Messaging.Transport.Client (TransportHost, TransportHosts (..))
import Simplex.Messaging.Util (bshow, eitherToMaybe, (<$?>))
import Simplex.Messaging.Version
import Simplex.Messaging.Version.Internal
-- SMP client protocol version history:
-- 1 - binary protocol encoding (1/1/2022)
-- 2 - multiple server hostnames and versioned queue addresses (8/12/2022)
srvHostnamesSMPClientVersion :: Version
srvHostnamesSMPClientVersion = 2
data SMPClientVersion
currentSMPClientVersion :: Version
currentSMPClientVersion = 2
instance VersionScope SMPClientVersion
supportedSMPClientVRange :: VersionRange
supportedSMPClientVRange = mkVersionRange 1 currentSMPClientVersion
type VersionSMPC = Version SMPClientVersion
type VersionRangeSMPC = VersionRange SMPClientVersion
pattern VersionSMPC :: Word16 -> VersionSMPC
pattern VersionSMPC v = Version v
initialSMPClientVersion :: VersionSMPC
initialSMPClientVersion = VersionSMPC 1
srvHostnamesSMPClientVersion :: VersionSMPC
srvHostnamesSMPClientVersion = VersionSMPC 2
currentSMPClientVersion :: VersionSMPC
currentSMPClientVersion = VersionSMPC 2
supportedSMPClientVRange :: VersionRangeSMPC
supportedSMPClientVRange = mkVersionRange initialSMPClientVersion currentSMPClientVersion
maxMessageLength :: Int
maxMessageLength = 16088
@@ -642,7 +663,7 @@ data ClientMsgEnvelope = ClientMsgEnvelope
deriving (Show)
data PubHeader = PubHeader
{ phVersion :: Version,
{ phVersion :: VersionSMPC,
phE2ePubDhKey :: Maybe C.PublicKeyX25519
}
deriving (Show)
@@ -1048,7 +1069,7 @@ data CommandError
deriving (Eq, Read, Show)
-- | SMP transmission parser.
transmissionP :: THandleParams -> Parser RawTransmission
transmissionP :: THandleParams v -> Parser RawTransmission
transmissionP THandleParams {sessionId, implySessId} = do
authenticator <- smpP
authorized <- A.takeByteString
@@ -1062,16 +1083,16 @@ transmissionP THandleParams {sessionId, implySessId} = do
command <- A.takeByteString
pure RawTransmission {authenticator, authorized = authorized', sessId, corrId, entityId, command}
class (ProtocolEncoding err msg, ProtocolEncoding err (ProtoCommand msg), Show err, Show msg) => Protocol err msg | msg -> err where
class (ProtocolEncoding v err msg, ProtocolEncoding v err (ProtoCommand msg), Show err, Show msg) => Protocol v err msg | msg -> v, msg -> err where
type ProtoCommand msg = cmd | cmd -> msg
type ProtoType msg = (sch :: ProtocolType) | sch -> msg
protocolClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
protocolClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange v -> ExceptT TransportError IO (THandle v c)
protocolPing :: ProtoCommand msg
protocolError :: msg -> Maybe err
type ProtoServer msg = ProtocolServer (ProtoType msg)
instance Protocol ErrorType BrokerMsg where
instance Protocol SMPVersion ErrorType BrokerMsg where
type ProtoCommand BrokerMsg = Cmd
type ProtoType BrokerMsg = 'PSMP
protocolClientHandshake = smpClientHandshake
@@ -1080,14 +1101,14 @@ instance Protocol ErrorType BrokerMsg where
ERR e -> Just e
_ -> Nothing
class ProtocolMsgTag (Tag msg) => ProtocolEncoding err msg | msg -> err where
class ProtocolMsgTag (Tag msg) => ProtocolEncoding v err msg | msg -> err, msg -> v where
type Tag msg
encodeProtocol :: Version -> msg -> ByteString
protocolP :: Version -> Tag msg -> Parser msg
encodeProtocol :: Version v -> msg -> ByteString
protocolP :: Version v -> Tag msg -> Parser msg
fromProtocolError :: ProtocolErrorType -> err
checkCredentials :: SignedRawTransmission -> msg -> Either err msg
instance PartyI p => ProtocolEncoding ErrorType (Command p) where
instance PartyI p => ProtocolEncoding SMPVersion ErrorType (Command p) where
type Tag (Command p) = CommandTag p
encodeProtocol v = \case
NEW rKey dhKey auth_ subMode
@@ -1114,7 +1135,7 @@ instance PartyI p => ProtocolEncoding ErrorType (Command p) where
protocolP v tag = (\(Cmd _ c) -> checkParty c) <$?> protocolP v (CT (sParty @p) tag)
fromProtocolError = fromProtocolError @ErrorType @BrokerMsg
fromProtocolError = fromProtocolError @SMPVersion @ErrorType @BrokerMsg
{-# INLINE fromProtocolError #-}
checkCredentials (auth, _, queueId, _) cmd = case cmd of
@@ -1136,7 +1157,7 @@ instance PartyI p => ProtocolEncoding ErrorType (Command p) where
| isNothing auth || B.null queueId -> Left $ CMD NO_AUTH
| otherwise -> Right cmd
instance ProtocolEncoding ErrorType Cmd where
instance ProtocolEncoding SMPVersion ErrorType Cmd where
type Tag Cmd = CmdTag
encodeProtocol v (Cmd _ c) = encodeProtocol v c
@@ -1164,12 +1185,12 @@ instance ProtocolEncoding ErrorType Cmd where
PING_ -> pure PING
CT SNotifier NSUB_ -> pure $ Cmd SNotifier NSUB
fromProtocolError = fromProtocolError @ErrorType @BrokerMsg
fromProtocolError = fromProtocolError @SMPVersion @ErrorType @BrokerMsg
{-# INLINE fromProtocolError #-}
checkCredentials t (Cmd p c) = Cmd p <$> checkCredentials t c
instance ProtocolEncoding ErrorType BrokerMsg where
instance ProtocolEncoding SMPVersion ErrorType BrokerMsg where
type Tag BrokerMsg = BrokerMsgTag
encodeProtocol _v = \case
IDS (QIK rcvId sndId srvDh) -> e (IDS_, ' ', rcvId, sndId, srvDh)
@@ -1221,12 +1242,12 @@ instance ProtocolEncoding ErrorType BrokerMsg where
| otherwise -> Right cmd
-- | Parse SMP protocol commands and broker messages
parseProtocol :: forall err msg. ProtocolEncoding err msg => Version -> ByteString -> Either err msg
parseProtocol :: forall v err msg. ProtocolEncoding v err msg => Version v -> ByteString -> Either err msg
parseProtocol v s =
let (tag, params) = B.break (== ' ') s
in case decodeTag tag of
Just cmd -> parse (protocolP v cmd) (fromProtocolError @err @msg $ PECmdSyntax) params
Nothing -> Left $ fromProtocolError @err @msg $ PECmdUnknown
Just cmd -> parse (protocolP v cmd) (fromProtocolError @v @err @msg $ PECmdSyntax) params
Nothing -> Left $ fromProtocolError @v @err @msg $ PECmdUnknown
checkParty :: forall t p p'. (PartyI p, PartyI p') => t p' -> Either String (t p)
checkParty c = case testEquality (sParty @p) (sParty @p') of
@@ -1281,7 +1302,7 @@ instance Encoding CommandError where
_ -> fail "bad command error type"
-- | Send signed SMP transmission to TCP transport.
tPut :: Transport c => THandle c -> NonEmpty (Either TransportError SentRawTransmission) -> IO [Either TransportError ()]
tPut :: Transport c => THandle v c -> NonEmpty (Either TransportError SentRawTransmission) -> IO [Either TransportError ()]
tPut th@THandle {params} = fmap concat . mapM tPutBatch . batchTransmissions (batch params) (blockSize params)
where
tPutBatch :: TransportBatch () -> IO [Either TransportError ()]
@@ -1290,7 +1311,7 @@ tPut th@THandle {params} = fmap concat . mapM tPutBatch . batchTransmissions (ba
TBTransmissions s n _ -> replicate n <$> tPutLog th s
TBTransmission s _ -> (: []) <$> tPutLog th s
tPutLog :: Transport c => THandle c -> ByteString -> IO (Either TransportError ())
tPutLog :: Transport c => THandle v c -> ByteString -> IO (Either TransportError ())
tPutLog th s = do
r <- tPutBlock th s
case r of
@@ -1352,7 +1373,7 @@ tEncodeBatch1 t = lenEncode 1 `B.cons` tEncodeForBatch t
-- tForAuth is lazy to avoid computing it when there is no key to sign
data TransmissionForAuth = TransmissionForAuth {tForAuth :: ~ByteString, tToSend :: ByteString}
encodeTransmissionForAuth :: ProtocolEncoding e c => THandleParams -> Transmission c -> TransmissionForAuth
encodeTransmissionForAuth :: ProtocolEncoding v e c => THandleParams v -> Transmission c -> TransmissionForAuth
encodeTransmissionForAuth THandleParams {thVersion = v, sessionId, implySessId} t =
TransmissionForAuth {tForAuth, tToSend = if implySessId then t' else tForAuth}
where
@@ -1360,24 +1381,24 @@ encodeTransmissionForAuth THandleParams {thVersion = v, sessionId, implySessId}
t' = encodeTransmission_ v t
{-# INLINE encodeTransmissionForAuth #-}
encodeTransmission :: ProtocolEncoding e c => THandleParams -> Transmission c -> ByteString
encodeTransmission :: ProtocolEncoding v e c => THandleParams v -> Transmission c -> ByteString
encodeTransmission THandleParams {thVersion = v, sessionId, implySessId} t =
if implySessId then t' else smpEncode sessionId <> t'
where
t' = encodeTransmission_ v t
{-# INLINE encodeTransmission #-}
encodeTransmission_ :: ProtocolEncoding e c => Version -> Transmission c -> ByteString
encodeTransmission_ :: ProtocolEncoding v e c => Version v -> Transmission c -> ByteString
encodeTransmission_ v (CorrId corrId, queueId, command) =
smpEncode (corrId, queueId) <> encodeProtocol v command
{-# INLINE encodeTransmission_ #-}
-- | Receive and parse transmission from the TCP transport (ignoring any trailing padding).
tGetParse :: Transport c => THandle c -> IO (NonEmpty (Either TransportError RawTransmission))
tGetParse :: Transport c => THandle v c -> IO (NonEmpty (Either TransportError RawTransmission))
tGetParse th@THandle {params} = eitherList (tParse params) <$> tGetBlock th
{-# INLINE tGetParse #-}
tParse :: THandleParams -> ByteString -> NonEmpty (Either TransportError RawTransmission)
tParse :: THandleParams v -> ByteString -> NonEmpty (Either TransportError RawTransmission)
tParse thParams@THandleParams {batch} s
| batch = eitherList (L.map (\(Large t) -> tParse1 t)) ts
| otherwise = [tParse1 s]
@@ -1389,24 +1410,24 @@ eitherList :: (a -> NonEmpty (Either e b)) -> Either e a -> NonEmpty (Either e b
eitherList = either (\e -> [Left e])
-- | Receive client and server transmissions (determined by `cmd` type).
tGet :: forall err cmd c. (ProtocolEncoding err cmd, Transport c) => THandle c -> IO (NonEmpty (SignedTransmission err cmd))
tGet :: forall v err cmd c. (ProtocolEncoding v err cmd, Transport c) => THandle v c -> IO (NonEmpty (SignedTransmission err cmd))
tGet th@THandle {params} = L.map (tDecodeParseValidate params) <$> tGetParse th
tDecodeParseValidate :: forall err cmd. ProtocolEncoding err cmd => THandleParams -> Either TransportError RawTransmission -> SignedTransmission err cmd
tDecodeParseValidate :: forall v err cmd. ProtocolEncoding v err cmd => THandleParams v -> Either TransportError RawTransmission -> SignedTransmission err cmd
tDecodeParseValidate THandleParams {sessionId, thVersion = v, implySessId} = \case
Right RawTransmission {authenticator, authorized, sessId, corrId, entityId, command}
| implySessId || sessId == sessionId ->
let decodedTransmission = (,corrId,entityId,command) <$> decodeTAuthBytes authenticator
in either (const $ tError corrId) (tParseValidate authorized) decodedTransmission
| otherwise -> (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @err @cmd PESession))
| otherwise -> (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @v @err @cmd PESession))
Left _ -> tError ""
where
tError :: ByteString -> SignedTransmission err cmd
tError corrId = (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @err @cmd PEBlock))
tError corrId = (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @v @err @cmd PEBlock))
tParseValidate :: ByteString -> SignedRawTransmission -> SignedTransmission err cmd
tParseValidate signed t@(sig, corrId, entityId, command) =
let cmd = parseProtocol @err @cmd v command >>= checkCredentials t
let cmd = parseProtocol @v @err @cmd v command >>= checkCredentials t
in (sig, signed, (CorrId corrId, entityId, cmd))
$(J.deriveJSON defaultJSON ''MsgFlags)
+4 -4
View File
@@ -380,7 +380,7 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do
CPQuit -> pure ()
CPSkip -> pure ()
runClientTransport :: Transport c => THandle c -> M ()
runClientTransport :: Transport c => THandleSMP c -> M ()
runClientTransport th@THandle {params = THandleParams {thVersion, sessionId}} = do
q <- asks $ tbqSize . config
ts <- liftIO getSystemTime
@@ -428,7 +428,7 @@ cancelSub sub =
Sub {subThread = SubThread t} -> liftIO $ deRefWeak t >>= mapM_ killThread
_ -> return ()
receive :: Transport c => THandle c -> Client -> M ()
receive :: Transport c => THandleSMP c -> Client -> M ()
receive th@THandle {params = THandleParams {thAuth}} Client {rcvQ, sndQ, rcvActiveAt, sessionId} = do
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " receive"
forever $ do
@@ -449,7 +449,7 @@ receive th@THandle {params = THandleParams {thAuth}} Client {rcvQ, sndQ, rcvActi
VRFailed -> Left (corrId, queueId, ERR AUTH)
write q = mapM_ (atomically . writeTBQueue q) . L.nonEmpty
send :: Transport c => THandle c -> Client -> IO ()
send :: Transport c => THandleSMP c -> Client -> IO ()
send h@THandle {params} Client {sndQ, sessionId, sndActiveAt} = do
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " send"
forever $ do
@@ -464,7 +464,7 @@ send h@THandle {params} Client {sndQ, sessionId, sndActiveAt} = do
NMSG {} -> 0
_ -> 1
disconnectTransport :: Transport c => THandle c -> TVar SystemTime -> TVar SystemTime -> ExpirationConfig -> IO Bool -> IO ()
disconnectTransport :: Transport c => THandle v c -> TVar SystemTime -> TVar SystemTime -> ExpirationConfig -> IO Bool -> IO ()
disconnectTransport THandle {connection, params = THandleParams {sessionId}} rcvActiveAt sndActiveAt expCfg noSubscriptions = do
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " disconnectTransport"
loop
+4 -5
View File
@@ -33,9 +33,8 @@ import Simplex.Messaging.Server.Stats
import Simplex.Messaging.Server.StoreLog
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport (ATransport)
import Simplex.Messaging.Transport (ATransport, VersionSMP, VersionRangeSMP)
import Simplex.Messaging.Transport.Server (SocketState, TransportServerConfig, loadFingerprint, loadTLSServerParams, newSocketState)
import Simplex.Messaging.Version
import System.IO (IOMode (..))
import System.Mem.Weak (Weak)
import UnliftIO.STM
@@ -73,7 +72,7 @@ data ServerConfig = ServerConfig
privateKeyFile :: FilePath,
certificateFile :: FilePath,
-- | SMP client-server protocol version range
smpServerVRange :: VersionRange,
smpServerVRange :: VersionRangeSMP,
-- | TCP transport config
transportConfig :: TransportServerConfig,
-- | run listener on control port
@@ -128,7 +127,7 @@ data Client = Client
sndQ :: TBQueue (NonEmpty (Transmission BrokerMsg)),
endThreads :: TVar (IntMap (Weak ThreadId)),
endThreadSeq :: TVar Int,
thVersion :: Version,
thVersion :: VersionSMP,
sessionId :: ByteString,
connected :: TVar Bool,
createdAt :: SystemTime,
@@ -152,7 +151,7 @@ newServer = do
savingLock <- createLock
return Server {subscribedQ, subscribers, ntfSubscribedQ, notifiers, savingLock}
newClient :: TVar Int -> Natural -> Version -> ByteString -> SystemTime -> STM Client
newClient :: TVar Int -> Natural -> VersionSMP -> ByteString -> SystemTime -> STM Client
newClient nextClientId qSize thVersion sessionId createdAt = do
clientId <- stateTVar nextClientId $ \next -> (next, next + 1)
subscriptions <- TM.empty
+53 -32
View File
@@ -9,6 +9,7 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
@@ -27,10 +28,15 @@
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a
module Simplex.Messaging.Transport
( -- * SMP transport parameters
SMPVersion,
VersionSMP,
VersionRangeSMP,
THandleSMP,
supportedClientSMPRelayVRange,
supportedServerSMPRelayVRange,
currentClientSMPRelayVersion,
currentServerSMPRelayVersion,
batchCmdsSMPVersion,
basicAuthSMPVersion,
subModeSMPVersion,
authCmdsSMPVersion,
@@ -85,6 +91,7 @@ import qualified Data.ByteString.Lazy.Char8 as LB
import Data.Default (def)
import Data.Functor (($>))
import Data.Version (showVersion)
import Data.Word (Word16)
import qualified Data.X509 as X
import qualified Data.X509.Validation as XV
import GHC.IO.Handle.Internals (ioe_EOF)
@@ -98,6 +105,7 @@ import Simplex.Messaging.Parsers (dropPrefix, parseRead1, sumTypeJSON)
import Simplex.Messaging.Transport.Buffer
import Simplex.Messaging.Util (bshow, catchAll, catchAll_, liftEitherWith)
import Simplex.Messaging.Version
import Simplex.Messaging.Version.Internal
import UnliftIO.Exception (Exception)
import qualified UnliftIO.Exception as E
import UnliftIO.STM
@@ -116,30 +124,41 @@ smpBlockSize = 16384
-- 6 - allow creating queues without subscribing (9/10/2023)
-- 7 - support authenticated encryption to verify senders' commands, imply but do NOT send session ID in signed part (2/3/2024)
batchCmdsSMPVersion :: Version
batchCmdsSMPVersion = 4
data SMPVersion
basicAuthSMPVersion :: Version
basicAuthSMPVersion = 5
instance VersionScope SMPVersion
subModeSMPVersion :: Version
subModeSMPVersion = 6
type VersionSMP = Version SMPVersion
authCmdsSMPVersion :: Version
authCmdsSMPVersion = 7
type VersionRangeSMP = VersionRange SMPVersion
currentClientSMPRelayVersion :: Version
currentClientSMPRelayVersion = 6
pattern VersionSMP :: Word16 -> VersionSMP
pattern VersionSMP v = Version v
currentServerSMPRelayVersion :: Version
currentServerSMPRelayVersion = 6
batchCmdsSMPVersion :: VersionSMP
batchCmdsSMPVersion = VersionSMP 4
basicAuthSMPVersion :: VersionSMP
basicAuthSMPVersion = VersionSMP 5
subModeSMPVersion :: VersionSMP
subModeSMPVersion = VersionSMP 6
authCmdsSMPVersion :: VersionSMP
authCmdsSMPVersion = VersionSMP 7
currentClientSMPRelayVersion :: VersionSMP
currentClientSMPRelayVersion = VersionSMP 6
currentServerSMPRelayVersion :: VersionSMP
currentServerSMPRelayVersion = VersionSMP 6
-- minimal supported protocol version is 4
-- TODO remove code that supports sending commands without batching
supportedClientSMPRelayVRange :: VersionRange
supportedClientSMPRelayVRange :: VersionRangeSMP
supportedClientSMPRelayVRange = mkVersionRange batchCmdsSMPVersion currentClientSMPRelayVersion
supportedServerSMPRelayVRange :: VersionRange
supportedServerSMPRelayVRange :: VersionRangeSMP
supportedServerSMPRelayVRange = mkVersionRange batchCmdsSMPVersion currentServerSMPRelayVersion
simplexMQVersion :: String
@@ -287,16 +306,18 @@ instance Transport TLS where
-- * SMP transport
-- | The handle for SMP encrypted transport connection over Transport.
data THandle c = THandle
data THandle v c = THandle
{ connection :: c,
params :: THandleParams
params :: THandleParams v
}
data THandleParams = THandleParams
type THandleSMP c = THandle SMPVersion c
data THandleParams v = THandleParams
{ sessionId :: SessionId,
blockSize :: Int,
-- | agreed server protocol version
thVersion :: Version,
thVersion :: Version v,
-- | peer public key for command authorization and shared secrets for entity ID encryption
thAuth :: Maybe THandleAuth,
-- | do NOT send session ID in transmission, but include it into signed message
@@ -316,7 +337,7 @@ data THandleAuth = THandleAuth
type SessionId = ByteString
data ServerHandshake = ServerHandshake
{ smpVersionRange :: VersionRange,
{ smpVersionRange :: VersionRangeSMP,
sessionId :: SessionId,
-- pub key to agree shared secrets for command authorization and entity ID encryption.
authPubKey :: Maybe (X.CertificateChain, X.SignedExact X.PubKey)
@@ -324,7 +345,7 @@ data ServerHandshake = ServerHandshake
data ClientHandshake = ClientHandshake
{ -- | agreed SMP server protocol version
smpVersion :: Version,
smpVersion :: VersionSMP,
-- | server identity - CA certificate fingerprint
keyHash :: C.KeyHash,
-- pub key to agree shared secret for entity ID encryption, shared secret for command authorization is agreed using per-queue keys.
@@ -358,12 +379,12 @@ instance Encoding ServerHandshake where
C.SignedObject key <- smpP
pure (cert, key)
encodeAuthEncryptCmds :: Encoding a => Version -> Maybe a -> ByteString
encodeAuthEncryptCmds :: Encoding a => VersionSMP -> Maybe a -> ByteString
encodeAuthEncryptCmds v k
| v >= authCmdsSMPVersion = maybe "" smpEncode k
| otherwise = ""
authEncryptCmdsP :: Version -> Parser a -> Parser (Maybe a)
authEncryptCmdsP :: VersionSMP -> Parser a -> Parser (Maybe a)
authEncryptCmdsP v p = if v >= authCmdsSMPVersion then Just <$> p else pure Nothing
-- | Error of SMP encrypted transport over TCP.
@@ -412,13 +433,13 @@ serializeTransportError = \case
TEHandshake e -> "HANDSHAKE " <> bshow e
-- | Pad and send block to SMP transport.
tPutBlock :: Transport c => THandle c -> ByteString -> IO (Either TransportError ())
tPutBlock :: Transport c => THandle v c -> ByteString -> IO (Either TransportError ())
tPutBlock THandle {connection = c, params = THandleParams {blockSize}} block =
bimapM (const $ pure TELargeMsg) (cPut c) $
C.pad block blockSize
-- | Receive block from SMP transport.
tGetBlock :: Transport c => THandle c -> IO (Either TransportError ByteString)
tGetBlock :: Transport c => THandle v c -> IO (Either TransportError ByteString)
tGetBlock THandle {connection = c, params = THandleParams {blockSize}} = do
msg <- cGet c blockSize
if B.length msg == blockSize
@@ -428,7 +449,7 @@ tGetBlock THandle {connection = c, params = THandleParams {blockSize}} = do
-- | Server SMP transport handshake.
--
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a
smpServerHandshake :: forall c. Transport c => C.APrivateSignKey -> c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
smpServerHandshake :: forall c. Transport c => C.APrivateSignKey -> c -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeSMP -> ExceptT TransportError IO (THandleSMP c)
smpServerHandshake serverSignKey c (k, pk) kh smpVRange = do
let th@THandle {params = THandleParams {sessionId}} = smpTHandle c
sk = C.signX509 serverSignKey $ C.publicToX509 k
@@ -445,7 +466,7 @@ smpServerHandshake serverSignKey c (k, pk) kh smpVRange = do
-- | Client SMP transport handshake.
--
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a
smpClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
smpClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeSMP -> ExceptT TransportError IO (THandleSMP c)
smpClientHandshake c (k, pk) keyHash@(C.KeyHash kh) smpVRange = do
let th@THandle {params = THandleParams {sessionId}} = smpTHandle c
ServerHandshake {sessionId = sessId, smpVersionRange, authPubKey} <- getHandshake th
@@ -465,24 +486,24 @@ smpClientHandshake c (k, pk) keyHash@(C.KeyHash kh) smpVRange = do
pure $ smpThHandle th v pk sk_
Nothing -> throwE $ TEHandshake VERSION
smpThHandle :: forall c. THandle c -> Version -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> THandle c
smpThHandle :: forall c. THandleSMP c -> VersionSMP -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> THandleSMP c
smpThHandle th@THandle {params} v privKey k_ =
-- TODO drop SMP v6: make thAuth non-optional
let thAuth = (\k -> THandleAuth {peerPubKey = k, privKey}) <$> k_
params' = params {thVersion = v, thAuth, implySessId = v >= authCmdsSMPVersion}
in (th :: THandle c) {params = params'}
in (th :: THandleSMP c) {params = params'}
sendHandshake :: (Transport c, Encoding smp) => THandle c -> smp -> ExceptT TransportError IO ()
sendHandshake :: (Transport c, Encoding smp) => THandle v c -> smp -> ExceptT TransportError IO ()
sendHandshake th = ExceptT . tPutBlock th . smpEncode
-- ignores tail bytes to allow future extensions
getHandshake :: (Transport c, Encoding smp) => THandle c -> ExceptT TransportError IO smp
getHandshake :: (Transport c, Encoding smp) => THandle v c -> ExceptT TransportError IO smp
getHandshake th = ExceptT $ (first (\_ -> TEHandshake PARSE) . A.parseOnly smpP =<<) <$> tGetBlock th
smpTHandle :: Transport c => c -> THandle c
smpTHandle :: Transport c => c -> THandleSMP c
smpTHandle c = THandle {connection = c, params}
where
params = THandleParams {sessionId = tlsUnique c, blockSize = smpBlockSize, thVersion = 0, thAuth = Nothing, implySessId = False, batch = True}
params = THandleParams {sessionId = tlsUnique c, blockSize = smpBlockSize, thVersion = VersionSMP 0, thAuth = Nothing, implySessId = False, batch = True}
$(J.deriveJSON (sumTypeJSON id) ''HandshakeError)
+31 -40
View File
@@ -1,13 +1,15 @@
{-# LANGUAGE ConstrainedClassMethods #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeSynonymInstances #-}
module Simplex.Messaging.Version
( Version,
VersionRange (minVersion, maxVersion),
VersionScope,
pattern VersionRange,
VersionI (..),
VersionRangeI (..),
@@ -24,47 +26,45 @@ module Simplex.Messaging.Version
where
import Control.Applicative (optional)
import Data.Aeson (FromJSON (..), ToJSON (..))
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.Word (Word16)
import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Util ((<$?>))
import Simplex.Messaging.Version.Internal (Version (..))
pattern VersionRange :: Word16 -> Word16 -> VersionRange
pattern VersionRange :: Version v -> Version v -> VersionRange v
pattern VersionRange v1 v2 <- VRange v1 v2
{-# COMPLETE VersionRange #-}
type Version = Word16
data VersionRange = VRange
{ minVersion :: Version,
maxVersion :: Version
data VersionRange v = VRange
{ minVersion :: Version v,
maxVersion :: Version v
}
deriving (Eq, Show)
class VersionScope v
-- | construct valid version range, to be used in constants
mkVersionRange :: Version -> Version -> VersionRange
mkVersionRange :: Version v -> Version v -> VersionRange v
mkVersionRange v1 v2
| v1 <= v2 = VRange v1 v2
| otherwise = error "invalid version range"
safeVersionRange :: Version -> Version -> Maybe VersionRange
safeVersionRange :: Version v -> Version v -> Maybe (VersionRange v)
safeVersionRange v1 v2
| v1 <= v2 = Just $ VRange v1 v2
| otherwise = Nothing
versionToRange :: Version -> VersionRange
versionToRange :: Version v -> VersionRange v
versionToRange v = VRange v v
instance Encoding VersionRange where
instance VersionScope v => Encoding (VersionRange v) where
smpEncode (VRange v1 v2) = smpEncode (v1, v2)
smpP =
maybe (fail "invalid version range") pure
=<< safeVersionRange <$> smpP <*> smpP
instance StrEncoding VersionRange where
instance VersionScope v => StrEncoding (VersionRange v) where
strEncode (VRange v1 v2)
| v1 == v2 = strEncode v1
| otherwise = strEncode v1 <> "-" <> strEncode v2
@@ -73,32 +73,23 @@ instance StrEncoding VersionRange where
v2 <- maybe (pure v1) (const strP) =<< optional (A.char '-')
maybe (fail "invalid version range") pure $ safeVersionRange v1 v2
instance ToJSON VersionRange where
toJSON (VRange v1 v2) = toJSON (v1, v2)
toEncoding (VRange v1 v2) = toEncoding (v1, v2)
class VersionScope v => VersionI v a | a -> v where
type VersionRangeT v a
version :: a -> Version v
toVersionRangeT :: a -> VersionRange v -> VersionRangeT v a
instance FromJSON VersionRange where
parseJSON v =
(\(v1, v2) -> maybe (Left "bad VersionRange") Right $ safeVersionRange v1 v2)
<$?> parseJSON v
class VersionScope v => VersionRangeI v a | a -> v where
type VersionT v a
versionRange :: a -> VersionRange v
toVersionT :: a -> Version v -> VersionT v a
class VersionI a where
type VersionRangeT a
version :: a -> Version
toVersionRangeT :: a -> VersionRange -> VersionRangeT a
class VersionRangeI a where
type VersionT a
versionRange :: a -> VersionRange
toVersionT :: a -> Version -> VersionT a
instance VersionI Version where
type VersionRangeT Version = VersionRange
instance VersionScope v => VersionI v (Version v) where
type VersionRangeT v (Version v) = VersionRange v
version = id
toVersionRangeT _ vr = vr
instance VersionRangeI VersionRange where
type VersionT VersionRange = Version
instance VersionScope v => VersionRangeI v (VersionRange v) where
type VersionT v (VersionRange v) = Version v
versionRange = id
toVersionT _ v = v
@@ -109,18 +100,18 @@ pattern Compatible a <- Compatible_ a
{-# COMPLETE Compatible #-}
isCompatible :: VersionI a => a -> VersionRange -> Bool
isCompatible :: VersionI v a => a -> VersionRange v -> Bool
isCompatible x (VRange v1 v2) = let v = version x in v1 <= v && v <= v2
isCompatibleRange :: VersionRangeI a => a -> VersionRange -> Bool
isCompatibleRange :: VersionRangeI v a => a -> VersionRange v -> Bool
isCompatibleRange x (VRange min2 max2) = min1 <= max2 && min2 <= max1
where
VRange min1 max1 = versionRange x
proveCompatible :: VersionI a => a -> VersionRange -> Maybe (Compatible a)
proveCompatible :: VersionI v a => a -> VersionRange v -> Maybe (Compatible a)
proveCompatible x vr = x `mkCompatibleIf` (x `isCompatible` vr)
compatibleVersion :: VersionRangeI a => a -> VersionRange -> Maybe (Compatible (VersionT a))
compatibleVersion :: VersionRangeI v a => a -> VersionRange v -> Maybe (Compatible (VersionT v a))
compatibleVersion x vr =
toVersionT x (min max1 max2) `mkCompatibleIf` isCompatibleRange x vr
where
+25
View File
@@ -0,0 +1,25 @@
module Simplex.Messaging.Version.Internal where
import Data.Aeson (FromJSON (..), ToJSON (..))
import Data.Word (Word16)
import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String
-- Do not use constructor of this type directry
newtype Version v = Version Word16
deriving (Eq, Ord, Show)
instance Encoding (Version v) where
smpEncode (Version v) = smpEncode v
smpP = Version <$> smpP
instance StrEncoding (Version v) where
strEncode (Version v) = strEncode v
strP = Version <$> strP
instance ToJSON (Version v) where
toEncoding (Version v) = toEncoding v
toJSON (Version v) = toJSON v
instance FromJSON (Version v) where
parseJSON v = Version <$> parseJSON v