diff --git a/.gitignore b/.gitignore index 965b1e528..49d685e39 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ *.db *.db.bak *.session.sql +tests/tmp diff --git a/migrations/20220101_initial.sql b/migrations/20220101_initial.sql index 95c7bacf4..f1ea5f67e 100644 --- a/migrations/20220101_initial.sql +++ b/migrations/20220101_initial.sql @@ -14,8 +14,7 @@ CREATE TABLE connections ( last_external_snd_msg_id INTEGER NOT NULL DEFAULT 0, last_rcv_msg_hash BLOB NOT NULL DEFAULT x'', last_snd_msg_hash BLOB NOT NULL DEFAULT x'', - smp_agent_version INTEGER NOT NULL DEFAULT 1, - e2e_version INTEGER NOT NULL DEFAULT 1 + smp_agent_version INTEGER NOT NULL DEFAULT 1 ) WITHOUT ROWID; CREATE TABLE rcv_queues ( @@ -98,8 +97,9 @@ CREATE TABLE snd_messages ( CREATE TABLE conn_confirmations ( confirmation_id BLOB NOT NULL PRIMARY KEY, conn_id BLOB NOT NULL REFERENCES connections ON DELETE CASCADE, - e2e_snd_pub_key BLOB NOT NULL, - sender_key BLOB NOT NULL, + e2e_snd_pub_key BLOB NOT NULL, -- TODO per-queue key. Split? + sender_key BLOB NOT NULL, -- TODO per-queue key. Split? + ratchet_state BLOB NOT NULL, sender_conn_info BLOB NOT NULL, accepted INTEGER NOT NULL, own_conn_info BLOB, @@ -115,3 +115,23 @@ CREATE TABLE conn_invitations ( own_conn_info BLOB, created_at TEXT NOT NULL DEFAULT (datetime('now')) ) WITHOUT ROWID; + +CREATE TABLE ratchets ( + conn_id BLOB NOT NULL PRIMARY KEY REFERENCES connections + ON DELETE CASCADE, + -- x3dh keys are not saved on the sending side (the side accepting the connection) + x3dh_priv_key_1 BLOB, + x3dh_priv_key_2 BLOB, + -- ratchet is initially empty on the receiving side (the side offering the connection) + ratchet_state BLOB, + e2e_version INTEGER NOT NULL DEFAULT 1 +) WITHOUT ROWID; + +CREATE TABLE skipped_messages ( + skipped_message_id INTEGER PRIMARY KEY, + conn_id BLOB NOT NULL REFERENCES ratchets + ON DELETE CASCADE, + header_key BLOB NOT NULL, + msg_n INTEGER NOT NULL, + msg_key BLOB NOT NULL +); diff --git a/package.yaml b/package.yaml index 3e62cc01b..34564b939 100644 --- a/package.yaml +++ b/package.yaml @@ -23,6 +23,7 @@ extra-source-files: - migrations/*.* dependencies: + - aeson == 1.5.* - ansi-terminal >= 0.10 && < 0.12 - asn1-encoding == 0.9.* - asn1-types == 0.3.* diff --git a/simplexmq.cabal b/simplexmq.cabal index 34e980f70..6a39cc028 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -34,7 +34,9 @@ library Simplex.Messaging.Agent Simplex.Messaging.Agent.Client Simplex.Messaging.Agent.Env.SQLite + Simplex.Messaging.Agent.ExceptT Simplex.Messaging.Agent.Protocol + Simplex.Messaging.Agent.QueryString Simplex.Messaging.Agent.RetryInterval Simplex.Messaging.Agent.Store Simplex.Messaging.Agent.Store.SQLite @@ -64,6 +66,7 @@ library ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns build-depends: QuickCheck ==2.14.* + , aeson ==1.5.* , ansi-terminal >=0.10 && <0.12 , asn1-encoding ==0.9.* , asn1-types ==0.3.* @@ -115,6 +118,7 @@ executable smp-agent ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded build-depends: QuickCheck ==2.14.* + , aeson ==1.5.* , ansi-terminal >=0.10 && <0.12 , asn1-encoding ==0.9.* , asn1-types ==0.3.* @@ -167,6 +171,7 @@ executable smp-server ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded build-depends: QuickCheck ==2.14.* + , aeson ==1.5.* , ansi-terminal >=0.10 && <0.12 , asn1-encoding ==0.9.* , asn1-types ==0.3.* @@ -235,6 +240,7 @@ test-suite smp-server-test build-depends: HUnit ==1.6.* , QuickCheck ==2.14.* + , aeson ==1.5.* , ansi-terminal >=0.10 && <0.12 , asn1-encoding ==0.9.* , asn1-types ==0.3.* diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 05674b670..954317241 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -60,7 +60,7 @@ import Control.Monad.Except import Control.Monad.IO.Unlift (MonadUnliftIO) import Control.Monad.Reader import Crypto.Random (MonadRandom) -import Data.Bifunctor (second) +import Data.Bifunctor (first, second) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Composition ((.:), (.:.)) @@ -82,12 +82,13 @@ import Simplex.Messaging.Agent.Store import Simplex.Messaging.Agent.Store.SQLite (SQLiteStore) import Simplex.Messaging.Client (SMPServerTransmission) import qualified Simplex.Messaging.Crypto as C +import qualified Simplex.Messaging.Crypto.Ratchet as CR import Simplex.Messaging.Encoding import Simplex.Messaging.Parsers (parse) import Simplex.Messaging.Protocol (MsgBody) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Transport (ATransport (..), TProxy, Transport (..), loadTLSServerParams, runTransportServer, simplexMQVersion) -import Simplex.Messaging.Util (bshow, tryError, unlessM) +import Simplex.Messaging.Util (bshow, liftError, tryError, unlessM) import Simplex.Messaging.Version import System.Random (randomR) import UnliftIO.Async (async, race_) @@ -253,8 +254,8 @@ withStore action = do -- | execute any SMP agent command processCommand :: forall m. AgentMonad m => AgentClient -> (ConnId, ACommand 'Client) -> m (ConnId, ACommand 'Agent) processCommand c (connId, cmd) = case cmd of - NEW (ACM cMode) -> second (INV . ACRU cMode) <$> newConn c connId cMode - JOIN (ACRU _ cReq) connInfo -> (,OK) <$> joinConn c connId cReq connInfo + NEW (ACM cMode) -> second (INV . ACR cMode) <$> newConn c connId cMode + JOIN (ACR _ cReq) connInfo -> (,OK) <$> joinConn c connId cReq connInfo LET confId ownCInfo -> allowConnection' c connId confId ownCInfo $> (connId, OK) ACPT invId ownCInfo -> (,OK) <$> acceptContact' c connId invId ownCInfo RJCT invId -> rejectContact' c connId invId $> (connId, OK) @@ -273,35 +274,52 @@ newConn c connId cMode = do connId' <- withStore $ \st -> createRcvConn st g cData rq cMode addSubscription c rq connId' let crData = ConnReqUriData simplexChat smpAgentVRange [qUri] - pure . (connId',) $ case cMode of - SCMInvitation -> CRInvitationUri crData connEncStubUri - SCMContact -> CRContactUri crData + case cMode of + SCMContact -> pure (connId', CRContactUri crData) + SCMInvitation -> do + (pk1, pk2, e2eRcvParams) <- liftIO $ CR.generateE2EParams CR.e2eEncryptVersion + withStore $ \st -> createRatchetX3dhKeys st connId' pk1 pk2 + pure (connId', CRInvitationUri crData $ toVersionRangeT e2eRcvParams CR.e2eEncryptVRange) joinConn :: AgentMonad m => AgentClient -> ConnId -> ConnectionRequestUri c -> ConnInfo -> m ConnId -joinConn c connId (CRInvitationUri (ConnReqUriData _ _ (qUri :| _)) _e2eEnc) cInfo = do - -- TODO check all versions in connection request are compatible with supported - -- (add agent and e2e) - case qUri `compatibleVersion` SMP.smpClientVersion of - Nothing -> throwError $ AGENT A_VERSION - Just qInfo -> do +joinConn c connId (CRInvitationUri (ConnReqUriData _ agentVRange (qUri :| _)) e2eRcvParamsUri) cInfo = + case ( qUri `compatibleVersion` SMP.smpClientVRange, + e2eRcvParamsUri `compatibleVersion` CR.e2eEncryptVRange, + agentVRange `compatibleVersion` smpAgentVRange + ) of + (Just qInfo, Just (Compatible e2eRcvParams@(CR.E2ERatchetParams _ _ rcDHRr)), Just _) -> do + -- TODO in agent v2 - use found compatible version rather than current + (pk1, pk2, e2eSndParams) <- liftIO . CR.generateE2EParams $ version e2eRcvParams + (_, rcDHRs) <- liftIO C.generateKeyPair' + let rc = CR.initSndRatchet rcDHRr rcDHRs $ CR.x3dhSnd pk1 pk2 e2eRcvParams (sq, smpConf) <- newSndQueue qInfo cInfo g <- asks idsDrg let cData = ConnData {connId} - connId' <- withStore $ \st -> createSndConn st g cData sq - confirmQueue c sq smpConf + connId' <- withStore $ \st -> do + connId' <- createSndConn st g cData sq + createRatchet st connId' rc + pure connId' + confirmQueue c connId' sq smpConf $ Just e2eSndParams void $ enqueueMessage c connId' sq HELLO pure connId' -joinConn c connId (CRContactUri (ConnReqUriData _ _ (qUri :| _))) cInfo = do - (connId', cReq) <- newConn c connId SCMInvitation - sendInvitation c qUri cReq cInfo - pure connId' + _ -> throwError $ AGENT A_VERSION +joinConn c connId (CRContactUri (ConnReqUriData _ agentVRange (qUri :| _))) cInfo = + case ( qUri `compatibleVersion` SMP.smpClientVRange, + agentVRange `compatibleVersion` smpAgentVRange + ) of + (Just qInfo, Just _) -> do + -- TODO in agent v2 - use found compatible version rather than current + (connId', cReq) <- newConn c connId SCMInvitation + sendInvitation c qInfo cReq cInfo + pure connId' + _ -> throwError $ AGENT A_VERSION createReplyQueue :: AgentMonad m => AgentClient -> ConnId -> SndQueue -> m () createReplyQueue c connId sq = do srv <- getSMPServer (rq, qUri) <- newRcvQueue c srv -- TODO reply queue version should be the same as send queue, ignoring it in v1 - let qInfo = toVersionT qUri (maxVersion SMP.smpClientVersion) + let qInfo = toVersionT qUri SMP.smpClientVersion addSubscription c rq connId withStore $ \st -> upgradeSndConnToDuplex st connId rq void . enqueueMessage c connId sq $ REPLY [qInfo] @@ -311,7 +329,8 @@ allowConnection' :: AgentMonad m => AgentClient -> ConnId -> ConfirmationId -> C allowConnection' c connId confId ownConnInfo = do withStore (`getConn` connId) >>= \case SomeConn _ (RcvConnection _ rq) -> do - AcceptedConfirmation {senderConf} <- withStore $ \st -> acceptConfirmation st confId ownConnInfo + AcceptedConfirmation {senderConf, ratchetState} <- withStore $ \st -> acceptConfirmation st confId ownConnInfo + withStore $ \st -> createRatchet st connId ratchetState processConfirmation c rq senderConf _ -> throwError $ CMD PROHIBITED @@ -386,10 +405,10 @@ enqueueMessage c connId sq aMessage = do internalTs <- liftIO getCurrentTime (internalId, internalSndId, prevMsgHash) <- withStore (`updateSndIds` connId) let privHeader = APrivHeader (unSndId internalSndId) prevMsgHash - agentMessage = smpEncode $ AgentMessage' privHeader aMessage + agentMessage = smpEncode $ AgentMessage privHeader aMessage internalHash = C.sha256Hash agentMessage - encAgentMessage <- agentRatchetEncrypt agentMessage + encAgentMessage <- agentRatchetEncrypt connId agentMessage e2eEncUserMsgLength let msgBody = smpEncode $ AgentMsgEnvelope {agentVersion = smpAgentVersion, encAgentMessage} msgType = aMessageType aMessage msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgBody, internalHash, prevMsgHash} @@ -536,30 +555,24 @@ processSMPTransmission c@AgentClient {subQ} (srv, rId, cmd) = do SMP.MSG srvMsgId srvTs msgBody' -> handleNotifyAck $ do -- TODO deduplicate with previously received msgBody <- agentCbDecrypt rcvDhSecret (C.cbNonce srvMsgId) msgBody' - clientMsg@SMP.ClientMsgEnvelope {cmHeader = SMP.PubHeader v e2ePubKey_} <- + clientMsg@SMP.ClientMsgEnvelope {cmHeader = SMP.PubHeader phVer e2ePubKey_} <- parseMessage msgBody + unless (phVer `isCompatible` SMP.smpClientVRange) . throwError $ AGENT A_VERSION case (e2eDhSecret, e2ePubKey_) of (Nothing, Just e2ePubKey) -> do let e2eDh = C.dh' e2ePubKey e2ePrivKey decryptClientMessage e2eDh clientMsg >>= \case - (SMP.PHConfirmation senderKey, AgentConfirmation {agentVersion = _v, e2eEncryption, encConnInfo}) -> do - agentMsgBody <- agentRatchetDecrypt encConnInfo - agentMessage <- parseMessage agentMsgBody - case agentMessage of - AgentConnInfo connInfo -> do - smpConfirmation SMPConfirmation {senderKey, e2ePubKey, connInfo} - ack - _ -> prohibited >> ack - (SMP.PHEmpty, AgentInvitation' {agentVersion = _v, connReq, connInfo}) -> + (SMP.PHConfirmation senderKey, AgentConfirmation {e2eEncryption, encConnInfo}) -> + smpConfirmation senderKey e2ePubKey e2eEncryption encConnInfo >> ack + (SMP.PHEmpty, AgentInvitation {connReq, connInfo}) -> smpInvitation connReq connInfo >> ack _ -> prohibited >> ack (Just e2eDh, Nothing) -> do decryptClientMessage e2eDh clientMsg >>= \case - (SMP.PHEmpty, AgentMsgEnvelope _v encAgentMsg) -> do - agentMsgBody <- agentRatchetDecrypt encAgentMsg - agentMessage <- parseMessage agentMsgBody - case agentMessage of - AgentMessage' APrivHeader {sndMsgId, prevMsgHash} aMessage -> do + (SMP.PHEmpty, AgentMsgEnvelope _ encAgentMsg) -> do + agentMsgBody <- agentRatchetDecrypt connId encAgentMsg + parseMessage agentMsgBody >>= \case + AgentMessage APrivHeader {sndMsgId, prevMsgHash} aMessage -> do (msgId, msgMeta) <- agentClientMsg prevMsgHash sndMsgId (srvMsgId, systemToUTCTime srvTs) agentMsgBody aMessage case aMessage of HELLO -> helloMsg >> ack >> withStore (\st -> deleteMsg st connId msgId) @@ -594,24 +607,39 @@ processSMPTransmission c@AgentClient {subQ} (srv, rId, cmd) = do clientMsg <- agentCbDecrypt e2eDh cmNonce cmEncBody SMP.ClientMessage privHeader clientBody <- parseMessage clientMsg agentEnvelope <- parseMessage clientBody - pure (privHeader, agentEnvelope) + if agentVersion agentEnvelope `isCompatible` smpAgentVRange + then pure (privHeader, agentEnvelope) + else throwError $ AGENT A_VERSION parseMessage :: Encoding a => ByteString -> m a parseMessage = liftEither . parse smpP (AGENT A_MESSAGE) - smpConfirmation :: SMPConfirmation -> m () - smpConfirmation senderConf@SMPConfirmation {connInfo} = do - logServer "<--" c srv rId "MSG " + smpConfirmation :: C.APublicVerifyKey -> C.PublicKeyX25519 -> Maybe (CR.E2ERatchetParams 'C.X448) -> ByteString -> m () + smpConfirmation senderKey e2ePubKey e2eEncryption encConnInfo = do + logServer "<--" c srv rId "MSG " case status of - New -> case cType of - SCRcv -> do - g <- asks idsDrg - let newConfirmation = NewConfirmation {connId, senderConf} - confId <- withStore $ \st -> createConfirmation st g newConfirmation - notify $ CONF confId connInfo - SCDuplex -> do - notify $ INFO connInfo - processConfirmation c rq senderConf + New -> case (cType, e2eEncryption) of + (SCRcv, Just e2eSndParams) -> do + (pk1, rcDHRs) <- withStore $ \st -> getRatchetX3dhKeys st connId + let rc = CR.initRcvRatchet rcDHRs $ CR.x3dhRcv pk1 rcDHRs e2eSndParams + (agentMsgBody_, rc', skipped) <- liftError cryptoError $ CR.rcDecrypt rc M.empty encConnInfo + case (agentMsgBody_, skipped) of + (Right agentMsgBody, CR.SMDNoChange) -> + parseMessage agentMsgBody >>= \case + AgentConnInfo connInfo -> do + g <- asks idsDrg + let senderConf = SMPConfirmation {senderKey, e2ePubKey, connInfo} + newConfirmation = NewConfirmation {connId, senderConf, ratchetState = rc'} + confId <- withStore $ \st -> createConfirmation st g newConfirmation + notify $ CONF confId connInfo + _ -> prohibited + _ -> prohibited + (SCDuplex, Nothing) -> do + agentRatchetDecrypt connId encConnInfo >>= parseMessage >>= \case + AgentConnInfo connInfo -> do + notify $ INFO connInfo + processConfirmation c rq $ SMPConfirmation {senderKey, e2ePubKey, connInfo} + _ -> prohibited _ -> prohibited _ -> prohibited @@ -632,12 +660,12 @@ processSMPTransmission c@AgentClient {subQ} (srv, rId, cmd) = do case cType of SCRcv -> do AcceptedConfirmation {ownConnInfo} <- withStore (`getAcceptedConfirmation` connId) - case qInfo `proveCompatible` SMP.smpClientVersion of - Nothing -> notify (ERR $ AGENT A_VERSION) >> ack + case qInfo `proveCompatible` SMP.smpClientVRange of + Nothing -> notify . ERR $ AGENT A_VERSION Just qInfo' -> do (sq, smpConf) <- newSndQueue qInfo' ownConnInfo withStore $ \st -> upgradeRcvConnToDuplex st connId sq - confirmQueue c sq smpConf + confirmQueue c connId sq smpConf Nothing withStore (`removeConfirmations` connId) void $ enqueueMessage c connId sq HELLO _ -> prohibited @@ -676,10 +704,35 @@ processSMPTransmission c@AgentClient {subQ} (srv, rId, cmd) = do | internalPrevMsgHash /= receivedPrevMsgHash = MsgError MsgBadHash | otherwise = MsgError MsgDuplicate -- this case is not possible -confirmQueue :: AgentMonad m => AgentClient -> SndQueue -> SMPConfirmation -> m () -confirmQueue c sq smpConf = do - sendConfirmation c sq smpConf connEncStub +confirmQueue :: forall m. AgentMonad m => AgentClient -> ConnId -> SndQueue -> SMPConfirmation -> Maybe (CR.E2ERatchetParams 'C.X448) -> m () +confirmQueue c connId sq SMPConfirmation {senderKey, e2ePubKey, connInfo} e2eEncryption = do + msg <- mkConfirmation + sendConfirmation c sq msg withStore $ \st -> setSndQueueStatus st sq Confirmed + where + mkConfirmation :: m MsgBody + mkConfirmation = do + encConnInfo <- agentRatchetEncrypt connId (smpEncode $ AgentConnInfo connInfo) e2eEncConnInfoLength + let agentEnvelope = AgentConfirmation {agentVersion = smpAgentVersion, e2eEncryption, encConnInfo} + agentCbEncrypt sq (Just e2ePubKey) . smpEncode $ + SMP.ClientMessage (SMP.PHConfirmation senderKey) $ smpEncode agentEnvelope + +-- encoded AgentMessage -> encoded EncAgentMessage +agentRatchetEncrypt :: AgentMonad m => ConnId -> ByteString -> Int -> m ByteString +agentRatchetEncrypt connId msg paddedLen = do + rc <- withStore $ \st -> getRatchet st connId + (encMsg, rc') <- liftError cryptoError $ CR.rcEncrypt rc paddedLen msg + withStore $ \st -> updateRatchet st connId rc' CR.SMDNoChange + pure encMsg + +-- encoded EncAgentMessage -> encoded AgentMessage +agentRatchetDecrypt :: AgentMonad m => ConnId -> ByteString -> m ByteString +agentRatchetDecrypt connId encAgentMsg = do + (rc, skipped) <- withStore $ \st -> + (,) <$> getRatchet st connId <*> getSkippedMsgKeys st connId + (agentMsgBody_, rc', skippedDiff) <- liftError cryptoError $ CR.rcDecrypt rc skipped encAgentMsg + withStore $ \st -> updateRatchet st connId rc' skippedDiff + liftEither $ first cryptoError agentMsgBody_ notifyConnected :: AgentMonad m => AgentClient -> ConnId -> m () notifyConnected c connId = atomically $ writeTBQueue (subQ c) ("", connId, CON) diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 34268b5ea..b9aecd401 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -22,18 +22,14 @@ module Simplex.Messaging.Agent.Client RetryInterval (..), secureQueue, sendAgentMessage, - agentRatchetEncrypt, - agentRatchetDecrypt, agentCbEncrypt, agentCbDecrypt, + cryptoError, sendAck, suspendQueue, deleteQueue, logServer, removeSubscription, - addActivation, - getActivation, - removeActivation, ) where @@ -54,13 +50,14 @@ import Data.Set (Set) import qualified Data.Set as S import Data.Text.Encoding import Simplex.Messaging.Agent.Env.SQLite +import Simplex.Messaging.Agent.ExceptT () import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Agent.RetryInterval import Simplex.Messaging.Agent.Store import Simplex.Messaging.Client import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding -import Simplex.Messaging.Protocol (MsgBody, QueueId, QueueIdsKeys (..), SndPublicVerifyKey) +import Simplex.Messaging.Protocol (QueueId, QueueIdsKeys (..), SndPublicVerifyKey) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Util (bshow, liftEitherError, liftError) import Simplex.Messaging.Version @@ -75,7 +72,6 @@ data AgentClient = AgentClient smpClients :: TVar (Map SMPServer SMPClient), subscrSrvrs :: TVar (Map SMPServer (Map ConnId RcvQueue)), subscrConns :: TVar (Map ConnId SMPServer), - activations :: TVar (Map ConnId (Async ())), -- activations of send queues in progress connMsgsQueued :: TVar (Map ConnId Bool), smpQueueMsgQueues :: TVar (Map (ConnId, SMPServer, SMP.SenderId) (TQueue InternalId)), smpQueueMsgDeliveries :: TVar (Map (ConnId, SMPServer, SMP.SenderId) (Async ())), @@ -95,14 +91,13 @@ newAgentClient agentEnv = do smpClients <- newTVar M.empty subscrSrvrs <- newTVar M.empty subscrConns <- newTVar M.empty - activations <- newTVar M.empty connMsgsQueued <- newTVar M.empty smpQueueMsgQueues <- newTVar M.empty smpQueueMsgDeliveries <- newTVar M.empty reconnections <- newTVar [] clientId <- stateTVar (clientCounter agentEnv) $ \i -> (i + 1, i + 1) lock <- newTMVar () - return AgentClient {rcvQ, subQ, msgQ, smpClients, subscrSrvrs, subscrConns, activations, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, reconnections, clientId, agentEnv, smpSubscriber = undefined, lock} + return AgentClient {rcvQ, subQ, msgQ, smpClients, subscrSrvrs, subscrConns, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, reconnections, clientId, agentEnv, smpSubscriber = undefined, lock} -- | Agent monad with MonadReader Env and MonadError AgentErrorType type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m) @@ -176,7 +171,6 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv = closeAgentClient :: MonadUnliftIO m => AgentClient -> m () closeAgentClient c = liftIO $ do closeSMPServerClients c - cancelActions $ activations c cancelActions $ reconnections c cancelActions $ smpQueueMsgDeliveries c @@ -257,7 +251,7 @@ newRcvQueue_ a c srv = do sndId = Just sndId, status = New } - pure (rq, SMPQueueUri srv sndId SMP.smpClientVersion e2eDhKey) + pure (rq, SMPQueueUri srv sndId SMP.smpClientVRange e2eDhKey) subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> ConnId -> m () subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connId = do @@ -287,15 +281,6 @@ removeSubscription AgentClient {subscrConns, subscrSrvrs} connId = atomically $ let cs' = M.delete connId cs in if M.null cs' then Nothing else Just cs' -addActivation :: MonadUnliftIO m => AgentClient -> ConnId -> Async () -> m () -addActivation c connId a = atomically . modifyTVar (activations c) $ M.insert connId a - -getActivation :: MonadUnliftIO m => AgentClient -> ConnId -> m (Maybe (Async ())) -getActivation c connId = M.lookup connId <$> readTVarIO (activations c) - -removeActivation :: MonadUnliftIO m => AgentClient -> ConnId -> m () -removeActivation c connId = atomically . modifyTVar (activations c) $ M.delete connId - logServer :: AgentMonad m => ByteString -> AgentClient -> SMPServer -> QueueId -> ByteString -> m () logServer dir AgentClient {clientId} srv qId cmdStr = logInfo . decodeUtf8 $ B.unwords ["A", "(" <> bshow clientId <> ")", dir, showServer srv, ":", logSecret qId, cmdStr] @@ -307,26 +292,13 @@ logSecret :: ByteString -> ByteString logSecret bs = encode $ B.take 3 bs -- TODO maybe package E2ERatchetParams into SMPConfirmation -sendConfirmation :: forall m. AgentMonad m => AgentClient -> SndQueue -> SMPConfirmation -> E2ERatchetParams -> m () -sendConfirmation c sq@SndQueue {server, sndId} SMPConfirmation {senderKey, e2ePubKey, connInfo} e2eEncryption = - withLogSMP_ c server sndId "SEND " $ \smp -> do - msg <- mkConfirmation - liftSMP $ sendSMPMessage smp Nothing sndId msg - where - mkConfirmation :: m MsgBody - mkConfirmation = do - encConnInfo <- agentRatchetEncrypt . smpEncode $ AgentConnInfo connInfo - let agentEnvelope = - AgentConfirmation - { agentVersion = smpAgentVersion, - e2eEncryption, - encConnInfo - } - agentCbEncrypt sq (Just e2ePubKey) . smpEncode $ - SMP.ClientMessage (SMP.PHConfirmation senderKey) $ smpEncode agentEnvelope +sendConfirmation :: forall m. AgentMonad m => AgentClient -> SndQueue -> ByteString -> m () +sendConfirmation c SndQueue {server, sndId} encConfirmation = + withLogSMP_ c server sndId "SEND " $ \smp -> + liftSMP $ sendSMPMessage smp Nothing sndId encConfirmation -sendInvitation :: forall m. AgentMonad m => AgentClient -> SMPQueueUri -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m () -sendInvitation c SMPQueueUri {smpServer, senderId, dhPublicKey} connReq connInfo = do +sendInvitation :: forall m. AgentMonad m => AgentClient -> Compatible SMPQueueInfo -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m () +sendInvitation c (Compatible SMPQueueInfo {smpServer, senderId, dhPublicKey}) connReq connInfo = withLogSMP_ c smpServer senderId "SEND " $ \smp -> do msg <- mkInvitation liftSMP $ sendSMPMessage smp Nothing senderId msg @@ -334,7 +306,7 @@ sendInvitation c SMPQueueUri {smpServer, senderId, dhPublicKey} connReq connInfo mkInvitation :: m ByteString -- this is only encrypted with per-queue E2E, not with double ratchet mkInvitation = do - let agentEnvelope = AgentInvitation' {agentVersion = smpAgentVersion, connReq, connInfo} + let agentEnvelope = AgentInvitation {agentVersion = smpAgentVersion, connReq, connInfo} agentCbEncryptOnce dhPublicKey . smpEncode $ SMP.ClientMessage SMP.PHEmpty $ smpEncode agentEnvelope @@ -366,14 +338,6 @@ sendAgentMessage c sq@SndQueue {server, sndId, sndPrivateKey} agentMsg = msg <- agentCbEncrypt sq Nothing $ smpEncode clientMsg liftSMP $ sendSMPMessage smp (Just sndPrivateKey) sndId msg --- encoded AgentMessage' -> encoded EncAgentMessage -agentRatchetEncrypt :: AgentMonad m => ByteString -> m ByteString -agentRatchetEncrypt = pure - --- encoded EncAgentMessage -> encoded AgentMessage' -agentRatchetDecrypt :: AgentMonad m => ByteString -> m ByteString -agentRatchetDecrypt = pure - agentCbEncrypt :: AgentMonad m => SndQueue -> Maybe C.PublicKeyX25519 -> ByteString -> m ByteString agentCbEncrypt SndQueue {e2eDhSecret} e2ePubKey msg = do cmNonce <- liftIO C.randomCbNonce @@ -381,7 +345,7 @@ agentCbEncrypt SndQueue {e2eDhSecret} e2ePubKey msg = do liftEither . first cryptoError $ C.cbEncrypt e2eDhSecret cmNonce msg SMP.e2eEncMessageLength -- TODO per-queue client version - let cmHeader = SMP.PubHeader (maxVersion SMP.smpClientVersion) e2ePubKey + let cmHeader = SMP.PubHeader (maxVersion SMP.smpClientVRange) e2ePubKey pure $ smpEncode SMP.ClientMsgEnvelope {cmHeader, cmNonce, cmEncBody} -- add encoding as AgentInvitation'? @@ -394,7 +358,7 @@ agentCbEncryptOnce dhRcvPubKey msg = do liftEither . first cryptoError $ C.cbEncrypt e2eDhSecret cmNonce msg SMP.e2eEncMessageLength -- TODO per-queue client version - let cmHeader = SMP.PubHeader (maxVersion SMP.smpClientVersion) (Just dhSndPubKey) + let cmHeader = SMP.PubHeader (maxVersion SMP.smpClientVRange) (Just dhSndPubKey) pure $ smpEncode SMP.ClientMsgEnvelope {cmHeader, cmNonce, cmEncBody} -- | NaCl crypto-box decrypt - both for messages received from the server diff --git a/src/Simplex/Messaging/Agent/ExceptT.hs b/src/Simplex/Messaging/Agent/ExceptT.hs new file mode 100644 index 000000000..9fccaaadd --- /dev/null +++ b/src/Simplex/Messaging/Agent/ExceptT.hs @@ -0,0 +1,22 @@ +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE RankNTypes #-} +{-# OPTIONS_GHC -fno-warn-orphans #-} + +module Simplex.Messaging.Agent.ExceptT where + +import Control.Monad.Except +import Control.Monad.IO.Unlift +import UnliftIO.Exception (Exception) +import qualified UnliftIO.Exception as E + +newtype InternalException e = InternalException {unInternalException :: e} + deriving (Eq, Show) + +instance Exception e => Exception (InternalException e) + +instance (MonadUnliftIO m, Exception e) => MonadUnliftIO (ExceptT e m) where + withRunInIO :: ((forall a. ExceptT e m a -> IO a) -> IO b) -> ExceptT e m b + withRunInIO exceptToIO = + withExceptT unInternalException . ExceptT . E.try $ + withRunInIO $ \run -> + exceptToIO $ run . (either (E.throwIO . InternalException) return <=< runExceptT) diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index 44809b1a2..a5759ba80 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -33,6 +33,8 @@ module Simplex.Messaging.Agent.Protocol ( -- * Protocol parameters smpAgentVersion, smpAgentVRange, + e2eEncConnInfoLength, + e2eEncUserMsgLength, -- * SMP agent protocol types ConnInfo, @@ -43,7 +45,7 @@ module Simplex.Messaging.Agent.Protocol MsgMeta (..), SMPConfirmation (..), AgentMsgEnvelope (..), - AgentMessage' (..), + AgentMessage (..), APrivHeader (..), AMessage (..), AMsgType (..), @@ -61,8 +63,6 @@ module Simplex.Messaging.Agent.Protocol AConnectionRequestUri (..), ConnReqUriData (..), ConnReqScheme (..), - E2ERatchetParams (..), - E2ERatchetParamsUri (..), simplexChat, AgentErrorType (..), CommandErrorType (..), @@ -80,9 +80,6 @@ module Simplex.Messaging.Agent.Protocol QueueStatus (..), ACorrId, AgentMsgId, - -- TODO remove - connEncStubUri, - connEncStub, -- * Encode/decode serializeCommand, @@ -119,7 +116,6 @@ import Data.Composition ((.:)) import Data.Functor (($>)) import Data.Int (Int64) import Data.Kind (Type) -import Data.List (find) import qualified Data.List.NonEmpty as L import Data.Maybe (isJust) import Data.Text (Text) @@ -129,11 +125,12 @@ import Data.Type.Equality import Data.Typeable () import GHC.Generics (Generic) import Generic.Random (genericArbitraryU) -import qualified Network.HTTP.Types as Q +import Simplex.Messaging.Agent.QueryString import qualified Simplex.Messaging.Crypto as C +import Simplex.Messaging.Crypto.Ratchet (E2ERatchetParams, E2ERatchetParamsUri) import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Parsers +import Simplex.Messaging.Parsers (base64P, parse, parseRead, parseRead1, parseRead2, tsISO8601P) import Simplex.Messaging.Protocol ( ErrorType, MsgBody, @@ -156,6 +153,12 @@ smpAgentVersion = 1 smpAgentVRange :: VersionRange smpAgentVRange = mkVersionRange 1 smpAgentVersion +e2eEncConnInfoLength :: Int +e2eEncConnInfoLength = 14336 + +e2eEncUserMsgLength :: Int +e2eEncUserMsgLength = 15488 + -- | Raw (unparsed) SMP agent protocol transmission. type ARawTransmission = (ByteString, ByteString, ByteString) @@ -289,14 +292,14 @@ data SMPConfirmation = SMPConfirmation data AgentMsgEnvelope = AgentConfirmation { agentVersion :: Version, - e2eEncryption :: E2ERatchetParams, + e2eEncryption :: Maybe (E2ERatchetParams 'C.X448), encConnInfo :: ByteString } | AgentMsgEnvelope { agentVersion :: Version, encAgentMessage :: ByteString } - | AgentInvitation' -- the connInfo in contactInvite is only encrypted with per-queue E2E, not with double ratchet, + | AgentInvitation -- the connInfo in contactInvite is only encrypted with per-queue E2E, not with double ratchet, { agentVersion :: !Version, connReq :: !(ConnectionRequestUri 'CMInvitation), connInfo :: !ByteString -- this message is only encrypted with per-queue E2E, not with double ratchet, @@ -309,7 +312,7 @@ instance Encoding AgentMsgEnvelope where smpEncode (agentVersion, 'C', e2eEncryption, Tail encConnInfo) AgentMsgEnvelope {agentVersion, encAgentMessage} -> smpEncode (agentVersion, 'M', Tail encAgentMessage) - AgentInvitation' {agentVersion, connReq, connInfo} -> + AgentInvitation {agentVersion, connReq, connInfo} -> smpEncode (agentVersion, 'I', Large $ strEncode connReq, Tail connInfo) smpP = do agentVersion <- smpP @@ -323,64 +326,22 @@ instance Encoding AgentMsgEnvelope where 'I' -> do connReq <- strDecode . unLarge <$?> smpP Tail connInfo <- smpP - pure AgentInvitation' {agentVersion, connReq, connInfo} + pure AgentInvitation {agentVersion, connReq, connInfo} _ -> fail "bad AgentMsgEnvelope" -data E2ERatchetParams - = E2ERatchetParams Version C.PublicKeyX448 C.PublicKeyX448 - deriving (Eq, Show) - -instance Encoding E2ERatchetParams where - smpEncode (E2ERatchetParams v k1 k2) = smpEncode (v, k1, k2) - smpP = E2ERatchetParams <$> smpP <*> smpP <*> smpP - -instance VersionI E2ERatchetParams where - type VersionRangeT E2ERatchetParams = E2ERatchetParamsUri - version (E2ERatchetParams v _ _) = v - toVersionRangeT (E2ERatchetParams _ k1 k2) vr = E2ERatchetParamsUri vr k1 k2 - -instance VersionRangeI E2ERatchetParamsUri where - type VersionT E2ERatchetParamsUri = E2ERatchetParams - versionRange (E2ERatchetParamsUri vr _ _) = vr - toVersionT (E2ERatchetParamsUri _ k1 k2) v = E2ERatchetParams v k1 k2 - -data E2ERatchetParamsUri - = E2ERatchetParamsUri VersionRange C.PublicKeyX448 C.PublicKeyX448 - deriving (Eq, Show) - -connEncStubUri :: E2ERatchetParamsUri -connEncStubUri = E2ERatchetParamsUri smpAgentVRange stubDhPubKey stubDhPubKey - -connEncStub :: E2ERatchetParams -connEncStub = E2ERatchetParams smpAgentVersion stubDhPubKey stubDhPubKey - -stubDhPubKey :: C.PublicKeyX448 -stubDhPubKey = "MEIwBQYDK2VvAzkAmKuSYeQ/m0SixPDS8Wq8VBaTS1cW+Lp0n0h4Diu+kUpR+qXx4SDJ32YGEFoGFGSbGPry5Ychr6U=" - -instance StrEncoding E2ERatchetParamsUri where - strEncode (E2ERatchetParamsUri vs key1 key2) = - strEncode $ - QSP QNoEscaping [("v", strEncode vs), ("x3dh", strEncode [key1, key2])] - strP = do - query <- strP - vs <- queryParam "v" query - keys <- queryParam "x3dh" query - case keys of - [key1, key2] -> pure $ E2ERatchetParamsUri vs key1 key2 - _ -> fail "bad e2e params" - -- SMP agent message formats (after double ratchet decryption, -- or in case of AgentInvitation - in plain text body) -data AgentMessage' = AgentConnInfo ConnInfo | AgentMessage' APrivHeader AMessage +data AgentMessage = AgentConnInfo ConnInfo | AgentMessage APrivHeader AMessage + deriving (Show) -instance Encoding AgentMessage' where +instance Encoding AgentMessage where smpEncode = \case - AgentConnInfo cInfo -> smpEncode ('I', cInfo) - AgentMessage' hdr aMsg -> smpEncode ('M', hdr, aMsg) + AgentConnInfo cInfo -> smpEncode ('I', Tail cInfo) + AgentMessage hdr aMsg -> smpEncode ('M', hdr, aMsg) smpP = smpP >>= \case - 'I' -> AgentConnInfo <$> smpP - 'M' -> AgentMessage' <$> smpP <*> smpP + 'I' -> AgentConnInfo . unTail <$> smpP + 'M' -> AgentMessage <$> smpP <*> smpP _ -> fail "bad AgentMessage" data APrivHeader = APrivHeader @@ -389,6 +350,7 @@ data APrivHeader = APrivHeader -- | digest of the previous message prevMsgHash :: MsgHash } + deriving (Show) instance Encoding APrivHeader where smpEncode APrivHeader {sndMsgId, prevMsgHash} = @@ -440,32 +402,12 @@ instance Encoding AMessage where REPLY_ -> REPLY <$> smpP A_MSG_ -> A_MSG . unTail <$> smpP -data QueryStringParams = QSP QSPEscaping Q.SimpleQuery - deriving (Show) - -data QSPEscaping = QEscape | QNoEscaping - deriving (Show) - -instance StrEncoding QueryStringParams where - strEncode (QSP esc q) = case esc of - QEscape -> Q.renderSimpleQuery False q - QNoEscaping -> - Q.renderQueryPartialEscape False $ - map (\(n, v) -> (n, [Q.QN v])) q - strP = QSP QEscape . Q.parseSimpleQuery <$> A.takeTill (\c -> c == ' ' || c == '\n') - -queryParam :: StrEncoding a => ByteString -> QueryStringParams -> Parser a -queryParam name (QSP _ q) = - case find ((== name) . fst) q of - Just (_, p) -> either fail pure $ parseAll strP p - _ -> fail $ "no qs param " <> B.unpack name - instance forall m. ConnectionModeI m => StrEncoding (ConnectionRequestUri m) where strEncode = \case CRInvitationUri crData e2eParams -> crEncode "invitation" crData (Just e2eParams) CRContactUri crData -> crEncode "contact" crData Nothing where - crEncode :: ByteString -> ConnReqUriData -> Maybe E2ERatchetParamsUri -> ByteString + crEncode :: ByteString -> ConnReqUriData -> Maybe (E2ERatchetParamsUri 'C.X448) -> ByteString crEncode crMode ConnReqUriData {crScheme, crAgentVRange, crSmpQueues} e2eParams = strEncode crScheme <> "/" <> crMode <> "#/?" <> queryStr where @@ -474,13 +416,13 @@ instance forall m. ConnectionModeI m => StrEncoding (ConnectionRequestUri m) whe [("v", strEncode crAgentVRange), ("smp", strEncode crSmpQueues)] <> maybe [] (\e2e -> [("e2e", strEncode e2e)]) e2eParams strP = do - ACRU m cr <- strP + ACR m cr <- strP case testEquality m $ sConnectionMode @m of Just Refl -> pure cr _ -> fail "bad connection request mode" instance StrEncoding AConnectionRequestUri where - strEncode (ACRU _ cr) = strEncode cr + strEncode (ACR _ cr) = strEncode cr strP = do crScheme <- strP crMode <- A.char '/' *> crModeP <* optional (A.char '/') <* "#/?" @@ -491,8 +433,8 @@ instance StrEncoding AConnectionRequestUri where case crMode of CMInvitation -> do crE2eParams <- queryParam "e2e" query - pure . ACRU SCMInvitation $ CRInvitationUri crData crE2eParams - CMContact -> pure . ACRU SCMContact $ CRContactUri crData + pure . ACR SCMInvitation $ CRInvitationUri crData crE2eParams + CMContact -> pure . ACR SCMContact $ CRContactUri crData where crModeP = "invitation" $> CMInvitation <|> "contact" $> CMContact @@ -546,11 +488,11 @@ instance VersionI SMPQueueInfo where type VersionRangeT SMPQueueInfo = SMPQueueUri version = clientVersion toVersionRangeT SMPQueueInfo {smpServer, senderId, dhPublicKey} vr = - SMPQueueUri {clientVersionRange = vr, smpServer, senderId, dhPublicKey} + SMPQueueUri {clientVRange = vr, smpServer, senderId, dhPublicKey} instance VersionRangeI SMPQueueUri where type VersionT SMPQueueUri = SMPQueueInfo - versionRange = clientVersionRange + versionRange = clientVRange toVersionT SMPQueueUri {smpServer, senderId, dhPublicKey} v = SMPQueueInfo {clientVersion = v, smpServer, senderId, dhPublicKey} @@ -560,7 +502,7 @@ instance VersionRangeI SMPQueueUri where data SMPQueueUri = SMPQueueUri { smpServer :: SMPServer, senderId :: SMP.SenderId, - clientVersionRange :: VersionRange, + clientVRange :: VersionRange, dhPublicKey :: C.PublicKeyX25519 } deriving (Eq, Show) @@ -568,15 +510,15 @@ data SMPQueueUri = SMPQueueUri -- TODO change SMP queue URI format to include version range and allow unknown parameters instance StrEncoding SMPQueueUri where -- v1 uses short SMP queue URI format - strEncode SMPQueueUri {smpServer = srv, senderId = qId, clientVersionRange = _vr, dhPublicKey = k} = + strEncode SMPQueueUri {smpServer = srv, senderId = qId, clientVRange = _vr, dhPublicKey = k} = strEncode srv <> "/" <> strEncode qId <> "#" <> strEncode k strP = do smpServer <- strP <* A.char '/' senderId <- strP <* optional (A.char '/') <* A.char '#' (vr, dhPublicKey) <- unversioned <|> versioned - pure SMPQueueUri {smpServer, senderId, clientVersionRange = vr, dhPublicKey} + pure SMPQueueUri {smpServer, senderId, clientVRange = vr, dhPublicKey} where - unversioned = (SMP.smpClientVersion,) <$> strP <* A.endOfInput + unversioned = (SMP.smpClientVRange,) <$> strP <* A.endOfInput versioned = do dhKey_ <- optional strP query <- optional (A.char '/') *> A.char '?' *> strP @@ -585,7 +527,7 @@ instance StrEncoding SMPQueueUri where pure (vr, dhKey) data ConnectionRequestUri (m :: ConnectionMode) where - CRInvitationUri :: ConnReqUriData -> E2ERatchetParamsUri -> ConnectionRequestUri CMInvitation + CRInvitationUri :: ConnReqUriData -> E2ERatchetParamsUri 'C.X448 -> ConnectionRequestUri CMInvitation -- contact connection request does NOT contain E2E encryption parameters - -- they are passed in AgentInvitation message CRContactUri :: ConnReqUriData -> ConnectionRequestUri CMContact @@ -594,10 +536,10 @@ deriving instance Eq (ConnectionRequestUri m) deriving instance Show (ConnectionRequestUri m) -data AConnectionRequestUri = forall m. ConnectionModeI m => ACRU (SConnectionMode m) (ConnectionRequestUri m) +data AConnectionRequestUri = forall m. ConnectionModeI m => ACR (SConnectionMode m) (ConnectionRequestUri m) instance Eq AConnectionRequestUri where - ACRU m cr == ACRU m' cr' = case testEquality m m' of + ACR m cr == ACR m' cr' = case testEquality m m' of Just Refl -> cr == cr' _ -> False diff --git a/src/Simplex/Messaging/Agent/QueryString.hs b/src/Simplex/Messaging/Agent/QueryString.hs new file mode 100644 index 000000000..f2c362e09 --- /dev/null +++ b/src/Simplex/Messaging/Agent/QueryString.hs @@ -0,0 +1,30 @@ +module Simplex.Messaging.Agent.QueryString where + +import Data.Attoparsec.ByteString.Char8 (Parser) +import qualified Data.Attoparsec.ByteString.Char8 as A +import Data.ByteString.Char8 (ByteString) +import qualified Data.ByteString.Char8 as B +import Data.List (find) +import qualified Network.HTTP.Types as Q +import Simplex.Messaging.Encoding.String +import Simplex.Messaging.Parsers (parseAll) + +data QueryStringParams = QSP QSPEscaping Q.SimpleQuery + deriving (Show) + +data QSPEscaping = QEscape | QNoEscaping + deriving (Show) + +instance StrEncoding QueryStringParams where + strEncode (QSP esc q) = case esc of + QEscape -> Q.renderSimpleQuery False q + QNoEscaping -> + Q.renderQueryPartialEscape False $ + map (\(n, v) -> (n, [Q.QN v])) q + strP = QSP QEscape . Q.parseSimpleQuery <$> A.takeTill (\c -> c == ' ' || c == '\n') + +queryParam :: StrEncoding a => ByteString -> QueryStringParams -> Parser a +queryParam name (QSP _ q) = + case find ((== name) . fst) q of + Just (_, p) -> either fail pure $ parseAll strP p + _ -> fail $ "no qs param " <> B.unpack name diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index 8dea9eea1..a664355f2 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -19,6 +19,7 @@ import Data.Time (UTCTime) import Data.Type.Equality import Simplex.Messaging.Agent.Protocol import qualified Simplex.Messaging.Crypto as C +import Simplex.Messaging.Crypto.Ratchet (RatchetX448, SkippedMsgDiff, SkippedMsgKeys) import Simplex.Messaging.Protocol ( MsgBody, MsgId, @@ -66,6 +67,14 @@ class Monad m => MonadAgentStore s m where checkRcvMsg :: s -> ConnId -> InternalId -> m () deleteMsg :: s -> ConnId -> InternalId -> m () + -- Double ratchet persistence + createRatchetX3dhKeys :: s -> ConnId -> C.PrivateKeyX448 -> C.PrivateKeyX448 -> m () + getRatchetX3dhKeys :: s -> ConnId -> m (C.PrivateKeyX448, C.PrivateKeyX448) + createRatchet :: s -> ConnId -> RatchetX448 -> m () + getRatchet :: s -> ConnId -> m RatchetX448 + getSkippedMsgKeys :: s -> ConnId -> m SkippedMsgKeys + updateRatchet :: s -> ConnId -> RatchetX448 -> SkippedMsgDiff -> m () + -- * Queue types -- | A receive queue. SMP queue through which the agent receives messages from a sender. @@ -168,13 +177,15 @@ newtype ConnData = ConnData {connId :: ConnId} data NewConfirmation = NewConfirmation { connId :: ConnId, - senderConf :: SMPConfirmation + senderConf :: SMPConfirmation, + ratchetState :: RatchetX448 } data AcceptedConfirmation = AcceptedConfirmation { confirmationId :: ConfirmationId, connId :: ConnId, senderConf :: SMPConfirmation, + ratchetState :: RatchetX448, ownConnInfo :: ConnInfo } @@ -291,6 +302,10 @@ data StoreError -- as we always know what it should be at any stage of the protocol, -- and in case it does not match use this error. SEBadQueueStatus + | -- | connection does not have associated double-ratchet state + SERatchetNotFound + | -- | connection does not have associated x3dh keys + SEX3dhKeysNotFound | -- | Used in `getMsg` that is not implemented/used. TODO remove. SENotImplemented deriving (Eq, Show, Exception) diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 3dc5bbf2a..7dc35bbb4 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -23,6 +23,7 @@ module Simplex.Messaging.Agent.Store.SQLite withConnection, withTransaction, fromTextField_, + firstRow, ) where @@ -32,11 +33,13 @@ import Control.Exception (bracket) import Control.Monad.Except import Control.Monad.IO.Unlift (MonadUnliftIO) import Crypto.Random (ChaChaDRG, randomBytesGenerate) +import Data.Bifunctor (second) import Data.ByteString (ByteString) -import Data.ByteString.Base64 (encode) +import qualified Data.ByteString.Base64.URL as U import Data.Char (toLower) import Data.Functor (($>)) -import Data.List (find) +import Data.List (find, foldl') +import qualified Data.Map.Strict as M import Data.Text (Text) import qualified Data.Text as T import Data.Text.Encoding (decodeLatin1) @@ -52,6 +55,7 @@ import Simplex.Messaging.Agent.Store import Simplex.Messaging.Agent.Store.SQLite.Migrations (Migration) import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations import qualified Simplex.Messaging.Crypto as C +import Simplex.Messaging.Crypto.Ratchet (RatchetX448, SkippedMsgDiff (..), SkippedMsgKeys) import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (blobFieldParser) @@ -274,16 +278,16 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto [":status" := status, ":host" := host, ":port" := port, ":snd_id" := sndId] createConfirmation :: SQLiteStore -> TVar ChaChaDRG -> NewConfirmation -> m ConfirmationId - createConfirmation st gVar NewConfirmation {connId, senderConf = SMPConfirmation {senderKey, e2ePubKey, connInfo}} = + createConfirmation st gVar NewConfirmation {connId, senderConf = SMPConfirmation {senderKey, e2ePubKey, connInfo}, ratchetState} = liftIOEither . withTransaction st $ \db -> createWithRandomId gVar $ \confirmationId -> DB.execute db [sql| INSERT INTO conn_confirmations - (confirmation_id, conn_id, sender_key, e2e_snd_pub_key, sender_conn_info, accepted) VALUES (?, ?, ?, ?, ?, 0); + (confirmation_id, conn_id, sender_key, e2e_snd_pub_key, ratchet_state, sender_conn_info, accepted) VALUES (?, ?, ?, ?, ?, ?, 0); |] - (confirmationId, connId, senderKey, e2ePubKey, connInfo) + (confirmationId, connId, senderKey, e2ePubKey, ratchetState, connInfo) acceptConfirmation :: SQLiteStore -> ConfirmationId -> ConnInfo -> m AcceptedConfirmation acceptConfirmation st confirmationId ownConnInfo = @@ -299,48 +303,46 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto [ ":own_conn_info" := ownConnInfo, ":confirmation_id" := confirmationId ] - confirmation - <$> DB.query + firstRow confirmation SEConfirmationNotFound $ + DB.query db [sql| - SELECT conn_id, sender_key, e2e_snd_pub_key, sender_conn_info + SELECT conn_id, sender_key, e2e_snd_pub_key, ratchet_state, sender_conn_info FROM conn_confirmations WHERE confirmation_id = ?; |] (Only confirmationId) where - confirmation [(connId, senderKey, e2ePubKey, connInfo)] = - Right - AcceptedConfirmation - { confirmationId, - connId, - senderConf = SMPConfirmation {senderKey, e2ePubKey, connInfo}, - ownConnInfo - } - confirmation _ = Left SEConfirmationNotFound + confirmation (connId, senderKey, e2ePubKey, ratchetState, connInfo) = + AcceptedConfirmation + { confirmationId, + connId, + senderConf = SMPConfirmation {senderKey, e2ePubKey, connInfo}, + ratchetState, + ownConnInfo + } getAcceptedConfirmation :: SQLiteStore -> ConnId -> m AcceptedConfirmation getAcceptedConfirmation st connId = liftIOEither . withTransaction st $ \db -> - confirmation - <$> DB.query + firstRow confirmation SEConfirmationNotFound $ + DB.query db [sql| - SELECT confirmation_id, sender_key, e2e_snd_pub_key, sender_conn_info, own_conn_info + SELECT confirmation_id, sender_key, e2e_snd_pub_key, ratchet_state, sender_conn_info, own_conn_info FROM conn_confirmations WHERE conn_id = ? AND accepted = 1; |] (Only connId) where - confirmation [(confirmationId, senderKey, e2ePubKey, connInfo, ownConnInfo)] = - Right - AcceptedConfirmation - { confirmationId, - connId, - senderConf = SMPConfirmation {senderKey, e2ePubKey, connInfo}, - ownConnInfo - } - confirmation _ = Left SEConfirmationNotFound + confirmation (confirmationId, senderKey, e2ePubKey, ratchetState, connInfo, ownConnInfo) = + AcceptedConfirmation + { confirmationId, + connId, + senderConf = SMPConfirmation {senderKey, e2ePubKey, connInfo}, + ratchetState, + ownConnInfo + } removeConfirmations :: SQLiteStore -> ConnId -> m () removeConfirmations st connId = @@ -368,8 +370,8 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto getInvitation :: SQLiteStore -> InvitationId -> m Invitation getInvitation st invitationId = liftIOEither . withTransaction st $ \db -> - invitation - <$> DB.query + firstRow invitation SEInvitationNotFound $ + DB.query db [sql| SELECT contact_conn_id, cr_invitation, recipient_conn_info, own_conn_info, accepted @@ -379,9 +381,8 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto |] (Only invitationId) where - invitation [(contactConnId, connReq, recipientConnInfo, ownConnInfo, accepted)] = - Right Invitation {invitationId, contactConnId, connReq, recipientConnInfo, ownConnInfo, accepted} - invitation _ = Left SEInvitationNotFound + invitation (contactConnId, connReq, recipientConnInfo, ownConnInfo, accepted) = + Invitation {invitationId, contactConnId, connReq, recipientConnInfo, ownConnInfo, accepted} acceptInvitation :: SQLiteStore -> InvitationId -> ConnInfo -> m () acceptInvitation st invitationId ownConnInfo = @@ -444,22 +445,17 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto liftIOEither . withTransaction st $ \db -> runExceptT $ do rq_ <- liftIO $ getRcvQueueByConnId_ db connId msgData <- - ExceptT $ - sndMsgData - <$> DB.query - db - [sql| + ExceptT . firstRow id SEMsgNotFound $ + DB.query + db + [sql| SELECT m.msg_type, m.msg_body FROM messages m JOIN snd_messages s ON s.conn_id = m.conn_id AND s.internal_id = m.internal_id WHERE m.conn_id = ? AND m.internal_id = ? |] - (connId, msgId) + (connId, msgId) pure (rq_, msgData) - where - sndMsgData :: [(AMsgType, MsgBody)] -> Either StoreError (AMsgType, MsgBody) - sndMsgData [msgData] = Right msgData - sndMsgData _ = Left SEMsgNotFound getPendingMsgs :: SQLiteStore -> ConnId -> m [InternalId] getPendingMsgs st connId = @@ -488,6 +484,69 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto liftIO . withTransaction st $ \db -> DB.execute db "DELETE FROM messages WHERE conn_id = ? AND internal_id = ?;" (connId, msgId) + createRatchetX3dhKeys :: SQLiteStore -> ConnId -> C.PrivateKeyX448 -> C.PrivateKeyX448 -> m () + createRatchetX3dhKeys st connId x3dhPrivKey1 x3dhPrivKey2 = + liftIO . withTransaction st $ \db -> + DB.execute db "INSERT INTO ratchets (conn_id, x3dh_priv_key_1, x3dh_priv_key_2) VALUES (?, ?, ?)" (connId, x3dhPrivKey1, x3dhPrivKey2) + + getRatchetX3dhKeys :: SQLiteStore -> ConnId -> m (C.PrivateKeyX448, C.PrivateKeyX448) + getRatchetX3dhKeys st connId = + liftIOEither . withTransaction st $ \db -> + fmap hasKeys $ + firstRow id SEX3dhKeysNotFound $ + DB.query db "SELECT x3dh_priv_key_1, x3dh_priv_key_2 FROM ratchets WHERE conn_id = ?" (Only connId) + where + hasKeys = \case + Right (Just k1, Just k2) -> Right (k1, k2) + _ -> Left SEX3dhKeysNotFound + + createRatchet :: SQLiteStore -> ConnId -> RatchetX448 -> m () + createRatchet st connId rc = + liftIO . withTransaction st $ \db -> do + DB.executeNamed + db + [sql| + INSERT INTO ratchets (conn_id, ratchet_state) + VALUES (:conn_id, :ratchet_state) + ON CONFLICT (conn_id) DO UPDATE SET + ratchet_state = :ratchet_state, + x3dh_priv_key_1 = NULL, + x3dh_priv_key_2 = NULL + |] + [":conn_id" := connId, ":ratchet_state" := rc] + + getRatchet :: SQLiteStore -> ConnId -> m RatchetX448 + getRatchet st connId = + liftIOEither . withTransaction st $ \db -> + ratchet + <$> DB.query db "SELECT ratchet_state FROM ratchets WHERE conn_id = ?" (Only connId) + where + ratchet (Only (Just rc) : _) = Right rc + ratchet _ = Left SERatchetNotFound + + getSkippedMsgKeys :: SQLiteStore -> ConnId -> m SkippedMsgKeys + getSkippedMsgKeys st connId = + liftIO . withTransaction st $ \db -> + skipped <$> DB.query db "SELECT header_key, msg_n, msg_key FROM skipped_messages WHERE conn_id = ?" (Only connId) + where + skipped ms = foldl' addSkippedKey M.empty ms + addSkippedKey smks (hk, msgN, mk) = M.alter (Just . addMsgKey) hk smks + where + addMsgKey = maybe (M.singleton msgN mk) (M.insert msgN mk) + + updateRatchet :: SQLiteStore -> ConnId -> RatchetX448 -> SkippedMsgDiff -> m () + updateRatchet st connId rc skipped = + liftIO . withTransaction st $ \db -> do + DB.execute db "UPDATE ratchets SET ratchet_state = ? WHERE conn_id = ?" (rc, connId) + case skipped of + SMDNoChange -> pure () + SMDRemove hk msgN -> + DB.execute db "DELETE FROM skipped_messages WHERE conn_id = ? AND header_key = ? AND msg_n = ?" (connId, hk, msgN) + SMDAdd smks -> + forM_ (M.assocs smks) $ \(hk, mks) -> + forM_ (M.assocs mks) $ \(msgN, mk) -> + DB.execute db "INSERT INTO skipped_messages (conn_id, header_key, msg_n, msg_key) VALUES (?, ?, ?, ?)" (connId, hk, msgN, mk) + -- * Auxiliary helpers instance ToField QueueStatus where toField = toField . serializeQueueStatus @@ -542,6 +601,13 @@ fromTextField_ fromText = \case _ -> returnError ConversionFailed f ("invalid text: " <> T.unpack t) f -> returnError ConversionFailed f "expecting SQLText column type" +listToEither :: e -> [a] -> Either e a +listToEither _ (x : _) = Right x +listToEither e _ = Left e + +firstRow :: (a -> b) -> e -> IO [a] -> IO (Either e b) +firstRow f e a = second f . listToEither e <$> a + {- ORMOLU_DISABLE -} -- SQLite.Simple only has these up to 10 fields, which is insufficient for some of our queries instance (FromField a, FromField b, FromField c, FromField d, FromField e, @@ -884,4 +950,4 @@ createWithRandomId gVar create = tryCreate 3 | otherwise -> pure . Left . SEInternal $ bshow e randomId :: TVar ChaChaDRG -> Int -> IO ByteString -randomId gVar n = encode <$> (atomically . stateTVar gVar $ randomBytesGenerate n) +randomId gVar n = U.encode <$> (atomically . stateTVar gVar $ randomBytesGenerate n) diff --git a/src/Simplex/Messaging/Crypto.hs b/src/Simplex/Messaging/Crypto.hs index bf94f8931..669792476 100644 --- a/src/Simplex/Messaging/Crypto.hs +++ b/src/Simplex/Messaging/Crypto.hs @@ -15,6 +15,7 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_GHC -fno-warn-redundant-constraints #-} {-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-} -- | @@ -60,10 +61,12 @@ module Simplex.Messaging.Crypto generateSignatureKeyPair, generateDhKeyPair, privateToX509, + publicKey, -- * key encoding/decoding encodePubKey, encodePrivKey, + pubKeyBytes, -- * sign/verify Signature (..), @@ -79,7 +82,7 @@ module Simplex.Messaging.Crypto -- * DH derivation dh', - dhSecretBytes', + dhBytes', -- * AES256 AEAD-GCM scheme Key (..), @@ -130,10 +133,12 @@ import Crypto.Random (getRandomBytes) import Data.ASN1.BinaryEncoding import Data.ASN1.Encoding import Data.ASN1.Types +import Data.Aeson (FromJSON (..), ToJSON (..)) import qualified Data.Attoparsec.ByteString.Char8 as A import Data.Bifunctor (bimap, first) import qualified Data.ByteArray as BA import Data.ByteString.Base64 (decode, encode) +import qualified Data.ByteString.Base64.URL as U import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.ByteString.Internal (c2w, w2c) @@ -151,6 +156,7 @@ import Network.Transport.Internal (decodeWord16, encodeWord16) import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (blobFieldDecoder, parseAll, parseString) +import Simplex.Messaging.Util ((<$?>)) -- | Cryptographic algorithms. data Algorithm = Ed25519 | Ed448 | X25519 | X448 @@ -231,8 +237,8 @@ type PublicKeyX448 = PublicKey X448 data PrivateKey (a :: Algorithm) where PrivateKeyEd25519 :: Ed25519.SecretKey -> Ed25519.PublicKey -> PrivateKey Ed25519 PrivateKeyEd448 :: Ed448.SecretKey -> Ed448.PublicKey -> PrivateKey Ed448 - PrivateKeyX25519 :: X25519.SecretKey -> PrivateKey X25519 - PrivateKeyX448 :: X448.SecretKey -> PrivateKey X448 + PrivateKeyX25519 :: X25519.SecretKey -> X25519.PublicKey -> PrivateKey X25519 + PrivateKeyX448 :: X448.SecretKey -> X448.PublicKey -> PrivateKey X448 deriving instance Eq (PrivateKey a) @@ -341,17 +347,17 @@ dhAlgorithm = \case SX448 -> Just Dict _ -> Nothing -dhSecretBytes' :: DhSecret a -> ByteString -dhSecretBytes' = \case +dhBytes' :: DhSecret a -> ByteString +dhBytes' = \case DhSecretX25519 s -> BA.convert s DhSecretX448 s -> BA.convert s instance AlgorithmI a => StrEncoding (DhSecret a) where - strEncode = strEncode . dhSecretBytes' + strEncode = strEncode . dhBytes' strDecode = (\(ADhSecret _ s) -> checkAlgorithm s) <=< strDecode instance StrEncoding ADhSecret where - strEncode (ADhSecret _ s) = strEncode $ dhSecretBytes' s + strEncode (ADhSecret _ s) = strEncode $ dhBytes' s strDecode = cryptoPassed . secret where secret bs @@ -426,34 +432,60 @@ instance AlgorithmI a => StrEncoding (PublicKey a) where strDecode = decodePubKey {-# INLINE strDecode #-} -encodePubKey :: CryptoPublicKey pk => pk -> ByteString +instance AlgorithmI a => ToJSON (PublicKey a) where + toJSON = strToJSON + toEncoding = strToJEncoding + +instance AlgorithmI a => FromJSON (PublicKey a) where + parseJSON = strParseJSON "PublicKey" + +encodePubKey :: CryptoPublicKey k => k -> ByteString encodePubKey = toPubKey $ encodeASNObj . publicToX509 {-# INLINE encodePubKey #-} +pubKeyBytes :: PublicKey a -> ByteString +pubKeyBytes = \case + PublicKeyEd25519 k -> BA.convert k + PublicKeyEd448 k -> BA.convert k + PublicKeyX25519 k -> BA.convert k + PublicKeyX448 k -> BA.convert k + class CryptoPrivateKey pk where + type PublicKeyType pk toPrivKey :: (forall a. AlgorithmI a => PrivateKey a -> b) -> pk -> b privKey :: APrivateKey -> Either String pk instance CryptoPrivateKey APrivateKey where + type PublicKeyType APrivateKey = APublicKey toPrivKey f (APrivateKey _ k) = f k privKey = Right instance CryptoPrivateKey APrivateSignKey where + type PublicKeyType APrivateSignKey = APublicVerifyKey toPrivKey f (APrivateSignKey _ k) = f k privKey (APrivateKey a k) = case signatureAlgorithm a of Just Dict -> Right $ APrivateSignKey a k _ -> Left "key does not support signature algorithms" instance CryptoPrivateKey APrivateDhKey where + type PublicKeyType APrivateDhKey = APublicDhKey toPrivKey f (APrivateDhKey _ k) = f k privKey (APrivateKey a k) = case dhAlgorithm a of Just Dict -> Right $ APrivateDhKey a k _ -> Left "key does not support DH algorithm" instance AlgorithmI a => CryptoPrivateKey (PrivateKey a) where + type PublicKeyType (PrivateKey a) = PublicKey a toPrivKey = id privKey (APrivateKey _ k) = checkAlgorithm k +publicKey :: PrivateKey a -> PublicKey a +publicKey = \case + PrivateKeyEd25519 _ k -> PublicKeyEd25519 k + PrivateKeyEd448 _ k -> PublicKeyEd448 k + PrivateKeyX25519 _ k -> PublicKeyX25519 k + PrivateKeyX448 _ k -> PublicKeyX448 k + encodePrivKey :: CryptoPrivateKey pk => pk -> ByteString encodePrivKey = toPrivKey $ encodeASNObj . privateToX509 @@ -463,14 +495,22 @@ instance AlgorithmI a => IsString (PrivateKey a) where instance AlgorithmI a => IsString (PublicKey a) where fromString = parseString $ decode >=> decodePubKey --- | Tuple of RSA 'PublicKey' and 'PrivateKey'. -type KeyPair a = (PublicKey a, PrivateKey a) +instance AlgorithmI a => ToJSON (PrivateKey a) where + toJSON = strToJSON . strEncode . encodePrivKey + toEncoding = strToJEncoding . strEncode . encodePrivKey -type AKeyPair = (APublicKey, APrivateKey) +instance AlgorithmI a => FromJSON (PrivateKey a) where + parseJSON v = (decodePrivKey <=< U.decode) <$?> strParseJSON "PrivateKey" v -type ASignatureKeyPair = (APublicVerifyKey, APrivateSignKey) +type KeyPairType pk = (PublicKeyType pk, pk) -type ADhKeyPair = (APublicDhKey, APrivateDhKey) +type KeyPair a = KeyPairType (PrivateKey a) + +type AKeyPair = KeyPairType APrivateKey + +type ASignatureKeyPair = KeyPairType APrivateSignKey + +type ADhKeyPair = KeyPairType APrivateDhKey generateKeyPair :: AlgorithmI a => SAlgorithm a -> IO AKeyPair generateKeyPair a = bimap (APublicKey a) (APrivateKey a) <$> generateKeyPair' @@ -494,11 +534,11 @@ generateKeyPair' = case sAlgorithm @a of SX25519 -> X25519.generateSecretKey >>= \pk -> let k = X25519.toPublic pk - in pure (PublicKeyX25519 k, PrivateKeyX25519 pk) + in pure (PublicKeyX25519 k, PrivateKeyX25519 pk k) SX448 -> X448.generateSecretKey >>= \pk -> let k = X448.toPublic pk - in pure (PublicKeyX448 k, PrivateKeyX448 pk) + in pure (PublicKeyX448 k, PrivateKeyX448 pk k) instance ToField APrivateSignKey where toField = toField . encodePrivKey @@ -512,7 +552,7 @@ instance AlgorithmI a => ToField (PrivateKey a) where toField = toField . encode instance AlgorithmI a => ToField (PublicKey a) where toField = toField . encodePubKey -instance ToField (DhSecret a) where toField = toField . dhSecretBytes' +instance ToField (DhSecret a) where toField = toField . dhBytes' instance FromField APrivateSignKey where fromField = blobFieldDecoder decodePrivKey @@ -648,7 +688,18 @@ validSignatureSize n = -- | AES key newtype. newtype Key = Key {unKey :: ByteString} - deriving (Eq, Ord) + deriving (Eq, Ord, Show) + +instance ToField Key where toField = toField . unKey + +instance FromField Key where fromField f = Key <$> fromField f + +instance ToJSON Key where + toJSON = strToJSON . unKey + toEncoding = strToJEncoding . unKey + +instance FromJSON Key where + parseJSON = fmap Key . strParseJSON "Key" -- | IV bytes newtype. newtype IV = IV {unIV :: ByteString} @@ -782,14 +833,14 @@ verify (APublicVerifyKey a k) (ASignature a' sig) msg = case testEquality a a' o _ -> False dh' :: DhAlgorithm a => PublicKey a -> PrivateKey a -> DhSecret a -dh' (PublicKeyX25519 k) (PrivateKeyX25519 pk) = DhSecretX25519 $ X25519.dh k pk -dh' (PublicKeyX448 k) (PrivateKeyX448 pk) = DhSecretX448 $ X448.dh k pk +dh' (PublicKeyX25519 k) (PrivateKeyX25519 pk _) = DhSecretX25519 $ X25519.dh k pk +dh' (PublicKeyX448 k) (PrivateKeyX448 pk _) = DhSecretX448 $ X448.dh k pk -- | NaCl @crypto_box@ encrypt with a shared DH secret and 192-bit nonce. cbEncrypt :: DhSecret X25519 -> CbNonce -> ByteString -> Int -> Either CryptoError ByteString cbEncrypt secret (CbNonce nonce) msg paddedLen = cryptoBox <$> pad msg paddedLen where - cryptoBox s = BA.convert tag `B.append` c + cryptoBox s = BA.convert tag <> c where (rs, c) = xSalsa20 secret nonce s tag = Poly1305.auth rs c @@ -844,8 +895,8 @@ privateToX509 :: PrivateKey a -> PrivKey privateToX509 = \case PrivateKeyEd25519 k _ -> PrivKeyEd25519 k PrivateKeyEd448 k _ -> PrivKeyEd448 k - PrivateKeyX25519 k -> PrivKeyX25519 k - PrivateKeyX448 k -> PrivKeyX448 k + PrivateKeyX25519 k _ -> PrivKeyX25519 k + PrivateKeyX448 k _ -> PrivKeyX448 k encodeASNObj :: ASN1Object a => a -> ByteString encodeASNObj k = toStrict . encodeASN1 DER $ toASN1 k [] @@ -870,8 +921,8 @@ x509ToPrivate :: (PrivKey, [ASN1]) -> Either String APrivateKey x509ToPrivate = \case (PrivKeyEd25519 k, []) -> Right . APrivateKey SEd25519 . PrivateKeyEd25519 k $ Ed25519.toPublic k (PrivKeyEd448 k, []) -> Right . APrivateKey SEd448 . PrivateKeyEd448 k $ Ed448.toPublic k - (PrivKeyX25519 k, []) -> Right . APrivateKey SX25519 $ PrivateKeyX25519 k - (PrivKeyX448 k, []) -> Right . APrivateKey SX448 $ PrivateKeyX448 k + (PrivKeyX25519 k, []) -> Right . APrivateKey SX25519 . PrivateKeyX25519 k $ X25519.toPublic k + (PrivKeyX448 k, []) -> Right . APrivateKey SX448 . PrivateKeyX448 k $ X448.toPublic k r -> keyError r decodeKey :: ASN1Object a => ByteString -> Either String (a, [ASN1]) diff --git a/src/Simplex/Messaging/Crypto/Ratchet.hs b/src/Simplex/Messaging/Crypto/Ratchet.hs index 26c1ebd25..35540337f 100644 --- a/src/Simplex/Messaging/Crypto/Ratchet.hs +++ b/src/Simplex/Messaging/Crypto/Ratchet.hs @@ -1,3 +1,6 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} @@ -6,6 +9,8 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# OPTIONS_GHC -fno-warn-redundant-constraints #-} module Simplex.Messaging.Crypto.Ratchet where @@ -14,15 +19,24 @@ import Control.Monad.Trans.Except import Crypto.Cipher.AES (AES256) import Crypto.Hash (SHA512) import qualified Crypto.KDF.HKDF as H +import Data.Aeson (FromJSON, ToJSON) +import qualified Data.Aeson as J import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B +import qualified Data.ByteString.Lazy as LB import Data.Map.Strict (Map) import qualified Data.Map.Strict as M import Data.Maybe (fromMaybe) +import Data.Typeable (Typeable) import Data.Word (Word32) +import Database.SQLite.Simple.FromField (FromField (..)) +import Database.SQLite.Simple.ToField (ToField (..)) +import GHC.Generics +import Simplex.Messaging.Agent.QueryString import Simplex.Messaging.Crypto import Simplex.Messaging.Encoding -import Simplex.Messaging.Parsers (parseE, parseE') +import Simplex.Messaging.Encoding.String +import Simplex.Messaging.Parsers (blobFieldDecoder, parseE, parseE') import Simplex.Messaging.Util (tryE) import Simplex.Messaging.Version @@ -32,98 +46,207 @@ e2eEncryptVersion = 1 e2eEncryptVRange :: VersionRange e2eEncryptVRange = mkVersionRange 1 e2eEncryptVersion +data E2ERatchetParams (a :: Algorithm) + = E2ERatchetParams Version (PublicKey a) (PublicKey a) + deriving (Eq, Show) + +instance AlgorithmI a => Encoding (E2ERatchetParams a) where + smpEncode (E2ERatchetParams v k1 k2) = smpEncode (v, k1, k2) + smpP = E2ERatchetParams <$> smpP <*> smpP <*> smpP + +instance VersionI (E2ERatchetParams a) where + type VersionRangeT (E2ERatchetParams a) = E2ERatchetParamsUri a + version (E2ERatchetParams v _ _) = v + toVersionRangeT (E2ERatchetParams _ k1 k2) vr = E2ERatchetParamsUri vr k1 k2 + +instance VersionRangeI (E2ERatchetParamsUri a) where + type VersionT (E2ERatchetParamsUri a) = (E2ERatchetParams a) + versionRange (E2ERatchetParamsUri vr _ _) = vr + toVersionT (E2ERatchetParamsUri _ k1 k2) v = E2ERatchetParams v k1 k2 + +data E2ERatchetParamsUri (a :: Algorithm) + = E2ERatchetParamsUri VersionRange (PublicKey a) (PublicKey a) + deriving (Eq, Show) + +instance AlgorithmI a => StrEncoding (E2ERatchetParamsUri a) where + strEncode (E2ERatchetParamsUri vs key1 key2) = + strEncode $ + QSP QNoEscaping [("v", strEncode vs), ("x3dh", strEncode [key1, key2])] + strP = do + query <- strP + vs <- queryParam "v" query + keys <- queryParam "x3dh" query + case keys of + [key1, key2] -> pure $ E2ERatchetParamsUri vs key1 key2 + _ -> fail "bad e2e params" + +generateE2EParams :: (AlgorithmI a, DhAlgorithm a) => Version -> IO (PrivateKey a, PrivateKey a, E2ERatchetParams a) +generateE2EParams v = do + (k1, pk1) <- generateKeyPair' + (k2, pk2) <- generateKeyPair' + pure (pk1, pk2, E2ERatchetParams v k1 k2) + +data RatchetInitParams = RatchetInitParams + { assocData :: Str, + ratchetKey :: RatchetKey, + sndHK :: HeaderKey, + rcvNextHK :: HeaderKey + } + deriving (Eq, Show) + +x3dhSnd :: DhAlgorithm a => PrivateKey a -> PrivateKey a -> E2ERatchetParams a -> RatchetInitParams +x3dhSnd spk1 spk2 (E2ERatchetParams _ rk1 rk2) = + x3dh (publicKey spk1, rk1) (dh' rk1 spk2) (dh' rk2 spk1) (dh' rk2 spk2) + +x3dhRcv :: DhAlgorithm a => PrivateKey a -> PrivateKey a -> E2ERatchetParams a -> RatchetInitParams +x3dhRcv rpk1 rpk2 (E2ERatchetParams _ sk1 sk2) = + x3dh (sk1, publicKey rpk1) (dh' sk2 rpk1) (dh' sk1 rpk2) (dh' sk2 rpk2) + +x3dh :: DhAlgorithm a => (PublicKey a, PublicKey a) -> DhSecret a -> DhSecret a -> DhSecret a -> RatchetInitParams +x3dh (sk1, rk1) dh1 dh2 dh3 = + RatchetInitParams {assocData, ratchetKey = RatchetKey sk, sndHK = Key hk, rcvNextHK = Key nhk} + where + assocData = Str $ pubKeyBytes sk1 <> pubKeyBytes rk1 + (hk, rest) = B.splitAt 32 $ dhBytes' dh1 <> dhBytes' dh2 <> dhBytes' dh3 + (nhk, sk) = B.splitAt 32 rest + +type RatchetX448 = Ratchet 'X448 + data Ratchet a = Ratchet { -- ratchet version range sent in messages (current .. max supported ratchet version) rcVersion :: VersionRange, -- associated data - must be the same in both parties ratchets - rcAD :: ByteString, - rcDHRs :: KeyPair a, + rcAD :: Str, + rcDHRs :: PrivateKey a, rcRK :: RatchetKey, rcSnd :: Maybe (SndRatchet a), rcRcv :: Maybe RcvRatchet, - rcMKSkipped :: Map HeaderKey SkippedMsgKeys, rcNs :: Word32, rcNr :: Word32, rcPN :: Word32, rcNHKs :: HeaderKey, rcNHKr :: HeaderKey } + deriving (Eq, Show, Generic, FromJSON) + +instance AlgorithmI a => ToJSON (Ratchet a) where + toEncoding = J.genericToEncoding J.defaultOptions data SndRatchet a = SndRatchet { rcDHRr :: PublicKey a, rcCKs :: RatchetKey, rcHKs :: HeaderKey } + deriving (Eq, Show, Generic, FromJSON) + +instance AlgorithmI a => ToJSON (SndRatchet a) where + toEncoding = J.genericToEncoding J.defaultOptions data RcvRatchet = RcvRatchet { rcCKr :: RatchetKey, rcHKr :: HeaderKey } + deriving (Eq, Show, Generic, FromJSON) -type SkippedMsgKeys = Map Word32 MessageKey +instance ToJSON RcvRatchet where + toEncoding = J.genericToEncoding J.defaultOptions + +type SkippedMsgKeys = Map HeaderKey SkippedHdrMsgKeys + +type SkippedHdrMsgKeys = Map Word32 MessageKey + +data SkippedMsgDiff + = SMDNoChange + | SMDRemove HeaderKey Word32 + | SMDAdd SkippedMsgKeys + +-- | this function is only used in tests to apply changes in skipped messages, +-- in the agent the diff is persisted, and the whole state is loaded for the next message. +applySMDiff :: SkippedMsgKeys -> SkippedMsgDiff -> SkippedMsgKeys +applySMDiff smks = \case + SMDNoChange -> smks + SMDRemove hk msgN -> fromMaybe smks $ do + mks <- M.lookup hk smks + _ <- M.lookup msgN mks + let mks' = M.delete msgN mks + pure $ + if M.null mks' + then M.delete hk smks + else M.insert hk mks' smks + SMDAdd smks' -> + let merge hk mks = M.alter (Just . maybe mks (M.union mks)) hk + in M.foldrWithKey merge smks smks' type HeaderKey = Key data MessageKey = MessageKey Key IV -data ARatchet - = forall a. - (AlgorithmI a, DhAlgorithm a) => - ARatchet (SAlgorithm a) (Ratchet a) +instance Encoding MessageKey where + smpEncode (MessageKey (Key key) (IV iv)) = smpEncode (key, iv) + smpP = MessageKey <$> (Key <$> smpP) <*> (IV <$> smpP) -- | Input key material for double ratchet HKDF functions newtype RatchetKey = RatchetKey ByteString + deriving (Eq, Show) + +instance ToJSON RatchetKey where + toJSON (RatchetKey k) = strToJSON k + toEncoding (RatchetKey k) = strToJEncoding k + +instance FromJSON RatchetKey where + parseJSON = fmap RatchetKey . strParseJSON "Key" + +instance AlgorithmI a => ToField (Ratchet a) where toField = toField . LB.toStrict . J.encode + +instance (AlgorithmI a, Typeable a) => FromField (Ratchet a) where fromField = blobFieldDecoder $ J.eitherDecode' . LB.fromStrict + +instance ToField MessageKey where toField = toField . smpEncode + +instance FromField MessageKey where fromField = blobFieldDecoder smpDecode -- | Sending ratchet initialization, equivalent to RatchetInitAliceHE in double ratchet spec -- -- Please note that sPKey is not stored, and its public part together with random salt -- is sent to the recipient. -initSndRatchet' :: - forall a. (AlgorithmI a, DhAlgorithm a) => PublicKey a -> PrivateKey a -> ByteString -> ByteString -> IO (Ratchet a) -initSndRatchet' rcDHRr sPKey salt rcAD = do - rcDHRs@(_, pk) <- generateKeyPair' @a - let (sk, rcHKs, rcNHKr) = initKdf salt rcDHRr sPKey - -- state.RK, state.CKs, state.NHKs = KDF_RK_HE(SK, DH(state.DHRs, state.DHRr)) - (rcRK, rcCKs, rcNHKs) = rootKdf sk rcDHRr pk - pure - Ratchet - { rcVersion = e2eEncryptVRange, - rcAD, - rcDHRs, - rcRK, - rcSnd = Just SndRatchet {rcDHRr, rcCKs, rcHKs}, - rcRcv = Nothing, - rcMKSkipped = M.empty, - rcPN = 0, - rcNs = 0, - rcNr = 0, - rcNHKs, - rcNHKr - } +initSndRatchet :: + forall a. (AlgorithmI a, DhAlgorithm a) => PublicKey a -> PrivateKey a -> RatchetInitParams -> Ratchet a +initSndRatchet rcDHRr rcDHRs RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK} = do + -- state.RK, state.CKs, state.NHKs = KDF_RK_HE(SK, DH(state.DHRs, state.DHRr)) + let (rcRK, rcCKs, rcNHKs) = rootKdf ratchetKey rcDHRr rcDHRs + in Ratchet + { rcVersion = e2eEncryptVRange, + rcAD = assocData, + rcDHRs, + rcRK, + rcSnd = Just SndRatchet {rcDHRr, rcCKs, rcHKs = sndHK}, + rcRcv = Nothing, + rcPN = 0, + rcNs = 0, + rcNr = 0, + rcNHKs, + rcNHKr = rcvNextHK + } -- | Receiving ratchet initialization, equivalent to RatchetInitBobHE in double ratchet spec -- -- Please note that the public part of rcDHRs was sent to the sender -- as part of the connection request and random salt was received from the sender. -initRcvRatchet' :: - forall a. (AlgorithmI a, DhAlgorithm a) => PublicKey a -> KeyPair a -> ByteString -> ByteString -> IO (Ratchet a) -initRcvRatchet' sKey rcDHRs@(_, pk) salt rcAD = do - let (sk, rcNHKr, rcNHKs) = initKdf salt sKey pk - pure - Ratchet - { rcVersion = e2eEncryptVRange, - rcAD, - rcDHRs, - rcRK = sk, - rcSnd = Nothing, - rcRcv = Nothing, - rcMKSkipped = M.empty, - rcPN = 0, - rcNs = 0, - rcNr = 0, - rcNHKs, - rcNHKr - } +initRcvRatchet :: + forall a. (AlgorithmI a, DhAlgorithm a) => PrivateKey a -> RatchetInitParams -> Ratchet a +initRcvRatchet rcDHRs RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK} = + Ratchet + { rcVersion = e2eEncryptVRange, + rcAD = assocData, + rcDHRs, + rcRK = ratchetKey, + rcSnd = Nothing, + rcRcv = Nothing, + rcPN = 0, + rcNs = 0, + rcNr = 0, + rcNHKs = rcvNextHK, + rcNHKr = sndHK + } data MsgHeader a = MsgHeader { -- | max supported ratchet version @@ -170,10 +293,7 @@ instance Encoding EncMessageHeader where smpEncode EncMessageHeader {ehVersion, ehBody, ehAuthTag, ehIV} = smpEncode (ehVersion, ehBody, ehAuthTag, ehIV) smpP = do - ehVersion <- smpP - ehBody <- smpP - ehAuthTag <- smpP - ehIV <- smpP + (ehVersion, ehBody, ehAuthTag, ehIV) <- smpP pure EncMessageHeader {ehVersion, ehBody, ehAuthTag, ehIV} data EncRatchetMessage = EncRatchetMessage @@ -184,16 +304,14 @@ data EncRatchetMessage = EncRatchetMessage instance Encoding EncRatchetMessage where smpEncode EncRatchetMessage {emHeader, emBody, emAuthTag} = - smpEncode (emHeader, emBody, emAuthTag) + smpEncode (emHeader, emAuthTag, Tail emBody) smpP = do - emHeader <- smpP - emBody <- smpP - emAuthTag <- smpP + (emHeader, emAuthTag, Tail emBody) <- smpP pure EncRatchetMessage {emHeader, emBody, emAuthTag} -rcEncrypt' :: AlgorithmI a => Ratchet a -> Int -> ByteString -> ExceptT CryptoError IO (ByteString, Ratchet a) -rcEncrypt' Ratchet {rcSnd = Nothing} _ _ = throwE CERatchetState -rcEncrypt' rc@Ratchet {rcSnd = Just sr@SndRatchet {rcCKs, rcHKs}, rcNs, rcAD, rcVersion} paddedMsgLen msg = do +rcEncrypt :: AlgorithmI a => Ratchet a -> Int -> ByteString -> ExceptT CryptoError IO (ByteString, Ratchet a) +rcEncrypt Ratchet {rcSnd = Nothing} _ _ = throwE CERatchetState +rcEncrypt rc@Ratchet {rcSnd = Just sr@SndRatchet {rcCKs, rcHKs}, rcDHRs, rcNs, rcPN, rcAD = Str rcAD, rcVersion} paddedMsgLen msg = do -- state.CKs, mk = KDF_CK(state.CKs) let (ck', mk, iv, ehIV) = chainKdf rcCKs -- enc_header = HENCRYPT(state.HKs, header) @@ -211,34 +329,35 @@ rcEncrypt' rc@Ratchet {rcSnd = Just sr@SndRatchet {rcCKs, rcHKs}, rcNs, rcAD, rc smpEncode MsgHeader { msgMaxVersion = maxVersion rcVersion, - msgDHRs = fst $ rcDHRs rc, - msgPN = rcPN rc, + msgDHRs = publicKey rcDHRs, + msgPN = rcPN, msgNs = rcNs } data SkippedMessage a - = SMMessage (Either CryptoError ByteString) (Ratchet a) + = SMMessage (DecryptResult a) | SMHeader (Maybe RatchetStep) (MsgHeader a) | SMNone data RatchetStep = AdvanceRatchet | SameRatchet deriving (Eq) -type DecryptResult a = (Either CryptoError ByteString, Ratchet a) +type DecryptResult a = (Either CryptoError ByteString, Ratchet a, SkippedMsgDiff) maxSkip :: Word32 maxSkip = 512 -rcDecrypt' :: +rcDecrypt :: forall a. (AlgorithmI a, DhAlgorithm a) => Ratchet a -> + SkippedMsgKeys -> ByteString -> ExceptT CryptoError IO (DecryptResult a) -rcDecrypt' rc@Ratchet {rcRcv, rcMKSkipped, rcAD} msg' = do +rcDecrypt rc@Ratchet {rcRcv, rcAD = Str rcAD} rcMKSkipped msg' = do encMsg@EncRatchetMessage {emHeader} <- parseE CryptoHeaderError smpP msg' encHdr <- parseE CryptoHeaderError smpP emHeader - -- plaintext = TrySkippedMessageKeysHE(state, enc_header, ciphertext, AD) + -- plaintext = TrySkippedMessageKeysHE(state, enc_header, cipher-text, AD) decryptSkipped encHdr encMsg >>= \case SMNone -> do (rcStep, hdr) <- decryptRcHeader rcRcv encHdr @@ -247,64 +366,63 @@ rcDecrypt' rc@Ratchet {rcRcv, rcMKSkipped, rcAD} msg' = do case rcStep_ of Just rcStep -> decryptRcMessage rcStep hdr encMsg Nothing -> throwE CERatchetHeader - SMMessage msg rc' -> pure (msg, rc') + SMMessage r -> pure r where decryptRcMessage :: RatchetStep -> MsgHeader a -> EncRatchetMessage -> ExceptT CryptoError IO (DecryptResult a) decryptRcMessage rcStep MsgHeader {msgDHRs, msgPN, msgNs} encMsg = do -- if dh_ratchet: - rc' <- ratchetStep rcStep + (rc', smks1) <- ratchetStep rcStep case skipMessageKeys msgNs rc' of - Left e -> pure (Left e, rc') - Right rc''@Ratchet {rcRcv = Just rr@RcvRatchet {rcCKr}, rcNr} -> do + Left e -> pure (Left e, rc', smkDiff smks1) + Right (rc''@Ratchet {rcRcv = Just rr@RcvRatchet {rcCKr}, rcNr}, smks2) -> do -- state.CKr, mk = KDF_CK(state.CKr) let (rcCKr', mk, iv, _) = chainKdf rcCKr - -- return DECRYPT (mk, ciphertext, CONCAT (AD, enc_header)) + -- return DECRYPT (mk, cipher-text, CONCAT (AD, enc_header)) msg <- decryptMessage (MessageKey mk iv) encMsg -- state . Nr += 1 - pure (msg, rc'' {rcRcv = Just rr {rcCKr = rcCKr'}, rcNr = rcNr + 1}) - Right rc'' -> pure (Left CERatchetState, rc'') + pure (msg, rc'' {rcRcv = Just rr {rcCKr = rcCKr'}, rcNr = rcNr + 1}, smkDiff $ smks1 <> smks2) + Right (rc'', smks2) -> do + pure (Left CERatchetState, rc'', smkDiff $ smks1 <> smks2) where - ratchetStep :: RatchetStep -> ExceptT CryptoError IO (Ratchet a) - ratchetStep SameRatchet = pure rc + smkDiff :: SkippedMsgKeys -> SkippedMsgDiff + smkDiff smks = if M.null smks then SMDNoChange else SMDAdd smks + ratchetStep :: RatchetStep -> ExceptT CryptoError IO (Ratchet a, SkippedMsgKeys) + ratchetStep SameRatchet = pure (rc, M.empty) ratchetStep AdvanceRatchet = -- SkipMessageKeysHE(state, header.pn) case skipMessageKeys msgPN rc of Left e -> throwE e - Right rc'@Ratchet {rcDHRs, rcRK, rcNHKs, rcNHKr} -> do + Right (rc'@Ratchet {rcDHRs, rcRK, rcNHKs, rcNHKr}, hmks) -> do -- DHRatchetHE(state, header) - rcDHRs' <- liftIO $ generateKeyPair' @a + (_, rcDHRs') <- liftIO $ generateKeyPair' @a -- state.RK, state.CKr, state.NHKr = KDF_RK_HE(state.RK, DH(state.DHRs, state.DHRr)) - let (rcRK', rcCKr', rcNHKr') = rootKdf rcRK msgDHRs (snd rcDHRs) + let (rcRK', rcCKr', rcNHKr') = rootKdf rcRK msgDHRs rcDHRs -- state.RK, state.CKs, state.NHKs = KDF_RK_HE(state.RK, DH(state.DHRs, state.DHRr)) - (rcRK'', rcCKs', rcNHKs') = rootKdf rcRK' msgDHRs (snd rcDHRs') - pure - rc' - { rcDHRs = rcDHRs', - rcRK = rcRK'', - rcSnd = Just SndRatchet {rcDHRr = msgDHRs, rcCKs = rcCKs', rcHKs = rcNHKs}, - rcRcv = Just RcvRatchet {rcCKr = rcCKr', rcHKr = rcNHKr}, - rcPN = rcNs rc, - rcNs = 0, - rcNr = 0, - rcNHKs = rcNHKs', - rcNHKr = rcNHKr' - } - skipMessageKeys :: Word32 -> Ratchet a -> Either CryptoError (Ratchet a) - skipMessageKeys _ r@Ratchet {rcRcv = Nothing} = Right r - skipMessageKeys untilN r@Ratchet {rcRcv = Just rr@RcvRatchet {rcCKr, rcHKr}, rcNr, rcMKSkipped = mkSkipped} + (rcRK'', rcCKs', rcNHKs') = rootKdf rcRK' msgDHRs rcDHRs' + rc'' = + rc' + { rcDHRs = rcDHRs', + rcRK = rcRK'', + rcSnd = Just SndRatchet {rcDHRr = msgDHRs, rcCKs = rcCKs', rcHKs = rcNHKs}, + rcRcv = Just RcvRatchet {rcCKr = rcCKr', rcHKr = rcNHKr}, + rcPN = rcNs rc, + rcNs = 0, + rcNr = 0, + rcNHKs = rcNHKs', + rcNHKr = rcNHKr' + } + pure (rc'', hmks) + skipMessageKeys :: Word32 -> Ratchet a -> Either CryptoError (Ratchet a, SkippedMsgKeys) + skipMessageKeys _ r@Ratchet {rcRcv = Nothing} = Right (r, M.empty) + skipMessageKeys untilN r@Ratchet {rcRcv = Just rr@RcvRatchet {rcCKr, rcHKr}, rcNr} | rcNr > untilN = Left CERatchetDuplicateMessage | rcNr + maxSkip < untilN = Left CERatchetTooManySkipped - | rcNr == untilN = Right r + | rcNr == untilN = Right (r, M.empty) | otherwise = - let mks = fromMaybe M.empty $ M.lookup rcHKr mkSkipped - (rcCKr', rcNr', mks') = advanceRcvRatchet (untilN - rcNr) rcCKr rcNr mks - in Right - r - { rcRcv = Just rr {rcCKr = rcCKr'}, - rcNr = rcNr', - rcMKSkipped = M.insert rcHKr mks' mkSkipped - } - advanceRcvRatchet :: Word32 -> RatchetKey -> Word32 -> SkippedMsgKeys -> (RatchetKey, Word32, SkippedMsgKeys) + let (rcCKr', rcNr', mks) = advanceRcvRatchet (untilN - rcNr) rcCKr rcNr M.empty + r' = r {rcRcv = Just rr {rcCKr = rcCKr'}, rcNr = rcNr'} + in Right (r', M.singleton rcHKr mks) + advanceRcvRatchet :: Word32 -> RatchetKey -> Word32 -> SkippedHdrMsgKeys -> (RatchetKey, Word32, SkippedHdrMsgKeys) advanceRcvRatchet 0 ck msgNs mks = (ck, msgNs, mks) advanceRcvRatchet n ck msgNs mks = let (ck', mk, iv, _) = chainKdf ck @@ -313,7 +431,7 @@ rcDecrypt' rc@Ratchet {rcRcv, rcMKSkipped, rcAD} msg' = do decryptSkipped :: EncMessageHeader -> EncRatchetMessage -> ExceptT CryptoError IO (SkippedMessage a) decryptSkipped encHdr encMsg = tryDecryptSkipped SMNone $ M.assocs rcMKSkipped where - tryDecryptSkipped :: SkippedMessage a -> [(HeaderKey, SkippedMsgKeys)] -> ExceptT CryptoError IO (SkippedMessage a) + tryDecryptSkipped :: SkippedMessage a -> [(HeaderKey, SkippedHdrMsgKeys)] -> ExceptT CryptoError IO (SkippedMessage a) tryDecryptSkipped SMNone ((hk, mks) : hks) = do tryE (decryptHeader hk encHdr) >>= \case Left CERatchetHeader -> tryDecryptSkipped SMNone hks @@ -327,13 +445,8 @@ rcDecrypt' rc@Ratchet {rcRcv, rcMKSkipped, rcAD} msg' = do | otherwise = Nothing in pure $ SMHeader nextRc hdr Just mk -> do - let mks' = M.delete msgNs mks - mksSkipped - | M.null mks' = M.delete hk rcMKSkipped - | otherwise = M.insert hk mks' rcMKSkipped - rc' = rc {rcMKSkipped = mksSkipped} msg <- decryptMessage mk encMsg - pure $ SMMessage msg rc' + pure $ SMMessage (msg, rc, SMDRemove hk msgNs) tryDecryptSkipped r _ = pure r decryptRcHeader :: Maybe RcvRatchet -> EncMessageHeader -> ExceptT CryptoError IO (RatchetStep, MsgHeader a) decryptRcHeader Nothing hdr = decryptNextHeader hdr @@ -349,19 +462,13 @@ rcDecrypt' rc@Ratchet {rcRcv, rcMKSkipped, rcAD} msg' = do parseE' CryptoHeaderError smpP header decryptMessage :: MessageKey -> EncRatchetMessage -> ExceptT CryptoError IO (Either CryptoError ByteString) decryptMessage (MessageKey mk iv) EncRatchetMessage {emHeader, emBody, emAuthTag} = - -- DECRYPT(mk, ciphertext, CONCAT(AD, enc_header)) + -- DECRYPT(mk, cipher-text, CONCAT(AD, enc_header)) -- TODO add associated data tryE $ decryptAEAD mk iv (rcAD <> emHeader) emBody emAuthTag -initKdf :: (AlgorithmI a, DhAlgorithm a) => ByteString -> PublicKey a -> PrivateKey a -> (RatchetKey, Key, Key) -initKdf salt k pk = - let dhOut = dhSecretBytes' $ dh' k pk - (sk, hk, nhk) = hkdf3 salt dhOut "SimpleXInitRatchet" - in (RatchetKey sk, Key hk, Key nhk) - rootKdf :: (AlgorithmI a, DhAlgorithm a) => RatchetKey -> PublicKey a -> PrivateKey a -> (RatchetKey, RatchetKey, Key) rootKdf (RatchetKey rk) k pk = - let dhOut = dhSecretBytes' $ dh' k pk + let dhOut = dhBytes' $ dh' k pk (rk', ck, nhk) = hkdf3 rk dhOut "SimpleXRootRatchet" in (RatchetKey rk', RatchetKey ck, Key nhk) diff --git a/src/Simplex/Messaging/Encoding/String.hs b/src/Simplex/Messaging/Encoding/String.hs index f737e183a..0f440af26 100644 --- a/src/Simplex/Messaging/Encoding/String.hs +++ b/src/Simplex/Messaging/Encoding/String.hs @@ -5,10 +5,18 @@ module Simplex.Messaging.Encoding.String ( StrEncoding (..), Str (..), strP_, + strToJSON, + strToJEncoding, + strParseJSON, + base64urlP, ) where import Control.Applicative (optional) +import Data.Aeson (FromJSON (..), ToJSON (..)) +import qualified Data.Aeson as J +import qualified Data.Aeson.Encoding as JE +import qualified Data.Aeson.Types as JT import Data.Attoparsec.ByteString.Char8 (Parser) import qualified Data.Attoparsec.ByteString.Char8 as A import qualified Data.ByteString.Base64.URL as U @@ -16,6 +24,7 @@ import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Char (isAlphaNum) import qualified Data.List.NonEmpty as L +import Data.Text.Encoding (decodeLatin1, encodeUtf8) import Data.Word (Word16) import Simplex.Messaging.Parsers (parseAll) import Simplex.Messaging.Util ((<$?>)) @@ -34,6 +43,7 @@ class StrEncoding a where -- base64url encoding/decoding of ByteStrings - the parser only allows non-empty strings instance StrEncoding ByteString where strEncode = U.encode + strDecode = U.decode strP = base64urlP base64urlP :: Parser ByteString @@ -43,11 +53,19 @@ base64urlP = do either fail pure $ U.decode (str <> pad) newtype Str = Str {unStr :: ByteString} + deriving (Eq, Show) instance StrEncoding Str where strEncode = unStr strP = Str <$> A.takeTill (== ' ') <* optional A.space +instance ToJSON Str where + toJSON (Str s) = strToJSON s + toEncoding (Str s) = strToJEncoding s + +instance FromJSON Str where + parseJSON = fmap Str . strParseJSON "Str" + instance StrEncoding a => StrEncoding (Maybe a) where strEncode = maybe "" strEncode strP = optional strP @@ -88,3 +106,12 @@ instance (StrEncoding a, StrEncoding b, StrEncoding c, StrEncoding d, StrEncodin strP_ :: StrEncoding a => Parser a strP_ = strP <* A.space + +strToJSON :: StrEncoding a => a -> J.Value +strToJSON = J.String . decodeLatin1 . strEncode + +strToJEncoding :: StrEncoding a => a -> J.Encoding +strToJEncoding = JE.text . decodeLatin1 . strEncode + +strParseJSON :: StrEncoding a => String -> J.Value -> JT.Parser a +strParseJSON name = J.withText name $ either fail pure . parseAll strP . encodeUtf8 diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index e8d0d8729..2fdcb1fe1 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -31,6 +31,7 @@ module Simplex.Messaging.Protocol ( -- * SMP protocol parameters smpClientVersion, + smpClientVRange, maxMessageLength, e2eEncMessageLength, @@ -109,8 +110,11 @@ import Simplex.Messaging.Util ((<$?>)) import Simplex.Messaging.Version import Test.QuickCheck (Arbitrary (..)) -smpClientVersion :: VersionRange -smpClientVersion = mkVersionRange 1 1 +smpClientVersion :: Version +smpClientVersion = 1 + +smpClientVRange :: VersionRange +smpClientVRange = mkVersionRange 1 smpClientVersion maxMessageLength :: Int maxMessageLength = 15968 @@ -337,9 +341,7 @@ instance Encoding ClientMsgEnvelope where smpEncode ClientMsgEnvelope {cmHeader, cmNonce, cmEncBody} = smpEncode (cmHeader, cmNonce, Tail cmEncBody) smpP = do - cmHeader <- smpP - cmNonce <- smpP - cmEncBody <- A.takeByteString + (cmHeader, cmNonce, Tail cmEncBody) <- smpP pure ClientMsgEnvelope {cmHeader, cmNonce, cmEncBody} data ClientMessage = ClientMessage PrivHeader ByteString diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index 95a433f48..8a62a9e30 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -405,14 +405,15 @@ data ClientHandshake = ClientHandshake instance Encoding ClientHandshake where smpEncode ClientHandshake {smpVersion, keyHash} = smpEncode (smpVersion, keyHash) smpP = do - smpVersion <- smpP - keyHash <- smpP + (smpVersion, keyHash) <- smpP pure ClientHandshake {smpVersion, keyHash} instance Encoding ServerHandshake where smpEncode ServerHandshake {smpVersionRange, sessionId} = smpEncode (smpVersionRange, sessionId) - smpP = ServerHandshake <$> smpP <*> smpP + smpP = do + (smpVersionRange, sessionId) <- smpP + pure ServerHandshake {smpVersionRange, sessionId} -- | Error of SMP encrypted transport over TCP. data TransportError diff --git a/src/Simplex/Messaging/Util.hs b/src/Simplex/Messaging/Util.hs index b8194f1e3..7ea2a523b 100644 --- a/src/Simplex/Messaging/Util.hs +++ b/src/Simplex/Messaging/Util.hs @@ -1,7 +1,4 @@ -{-# LANGUAGE InstanceSigs #-} {-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE RankNTypes #-} -{-# OPTIONS_GHC -fno-warn-orphans #-} module Simplex.Messaging.Util where @@ -12,20 +9,6 @@ import Data.Bifunctor (first) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import UnliftIO.Async -import UnliftIO.Exception (Exception) -import qualified UnliftIO.Exception as E - -newtype InternalException e = InternalException {unInternalException :: e} - deriving (Eq, Show) - -instance Exception e => Exception (InternalException e) - -instance (MonadUnliftIO m, Exception e) => MonadUnliftIO (ExceptT e m) where - withRunInIO :: ((forall a. ExceptT e m a -> IO a) -> IO b) -> ExceptT e m b - withRunInIO exceptToIO = - withExceptT unInternalException . ExceptT . E.try $ - withRunInIO $ \run -> - exceptToIO $ run . (either (E.throwIO . InternalException) return <=< runExceptT) raceAny_ :: MonadUnliftIO m => [m a] -> m () raceAny_ = r [] diff --git a/src/Simplex/Messaging/Version.hs b/src/Simplex/Messaging/Version.hs index 96966bcda..8f19e884d 100644 --- a/src/Simplex/Messaging/Version.hs +++ b/src/Simplex/Messaging/Version.hs @@ -22,10 +22,12 @@ module Simplex.Messaging.Version where import Control.Applicative (optional) +import Data.Aeson (FromJSON (..), ToJSON (..)) import qualified Data.Attoparsec.ByteString.Char8 as A import Data.Word (Word16) import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String +import Simplex.Messaging.Util ((<$?>)) pattern VersionRange :: Word16 -> Word16 -> VersionRange pattern VersionRange v1 v2 <- VRange v1 v2 @@ -66,6 +68,15 @@ instance StrEncoding VersionRange where v2 <- maybe (pure v1) (const strP) =<< optional (A.char '-') maybe (fail "invalid version range") pure $ safeVersionRange v1 v2 +instance ToJSON VersionRange where + toJSON (VRange v1 v2) = toJSON (v1, v2) + toEncoding (VRange v1 v2) = toEncoding (v1, v2) + +instance FromJSON VersionRange where + parseJSON v = + (\(v1, v2) -> maybe (Left "bad VersionRange") Right $ safeVersionRange v1 v2) + <$?> parseJSON v + class VersionI a where type VersionRangeT a version :: a -> Version diff --git a/tests/AgentTests/ConnectionRequestTests.hs b/tests/AgentTests/ConnectionRequestTests.hs index 07fbf5f01..93c073dfb 100644 --- a/tests/AgentTests/ConnectionRequestTests.hs +++ b/tests/AgentTests/ConnectionRequestTests.hs @@ -9,9 +9,9 @@ import Data.ByteString (ByteString) import Network.HTTP.Types (urlEncode) import Simplex.Messaging.Agent.Protocol import qualified Simplex.Messaging.Crypto as C -import Simplex.Messaging.Crypto.Ratchet (e2eEncryptVRange) +import Simplex.Messaging.Crypto.Ratchet import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Protocol (smpClientVersion) +import Simplex.Messaging.Protocol (smpClientVRange) import Simplex.Messaging.Version import Test.Hspec @@ -31,7 +31,7 @@ queue = SMPQueueUri { smpServer = srv, senderId = "\223\142z\251", - clientVersionRange = smpClientVersion, + clientVRange = smpClientVRange, dhPublicKey = testDhKey } @@ -55,20 +55,20 @@ connReqData = testDhPubKey :: C.PublicKeyX448 testDhPubKey = "MEIwBQYDK2VvAzkAmKuSYeQ/m0SixPDS8Wq8VBaTS1cW+Lp0n0h4Diu+kUpR+qXx4SDJ32YGEFoGFGSbGPry5Ychr6U=" -testE2ERatchetParams :: E2ERatchetParamsUri +testE2ERatchetParams :: E2ERatchetParamsUri 'C.X448 testE2ERatchetParams = E2ERatchetParamsUri e2eEncryptVRange testDhPubKey testDhPubKey -testE2ERatchetParams13 :: E2ERatchetParamsUri +testE2ERatchetParams13 :: E2ERatchetParamsUri 'C.X448 testE2ERatchetParams13 = E2ERatchetParamsUri (mkVersionRange 1 3) testDhPubKey testDhPubKey connectionRequest :: AConnectionRequestUri connectionRequest = - ACRU SCMInvitation $ + ACR SCMInvitation $ CRInvitationUri connReqData testE2ERatchetParams connectionRequest12 :: AConnectionRequestUri connectionRequest12 = - ACRU SCMInvitation $ + ACR SCMInvitation $ CRInvitationUri connReqData {crAgentVRange = mkVersionRange 1 2, crSmpQueues = [queue, queue]} testE2ERatchetParams13 @@ -79,7 +79,7 @@ connectionRequestTests = it "should serialize SMP queue URIs" $ do strEncode (queue :: SMPQueueUri) {smpServer = srv {port = Nothing}} `shouldBe` "smp://1234-w==@smp.simplex.im/3456-w==#" <> testDhKeyStr - strEncode queue {clientVersionRange = mkVersionRange 1 2} + strEncode queue {clientVRange = mkVersionRange 1 2} `shouldBe` "smp://1234-w==@smp.simplex.im:5223/3456-w==#" <> testDhKeyStr it "should parse SMP queue URIs" $ do strDecode ("smp://1234-w==@smp.simplex.im/3456-w==#/?v=1&dh=" <> testDhKeyStr) @@ -91,9 +91,9 @@ connectionRequestTests = strDecode ("smp://1234-w==@smp.simplex.im:5223/3456-w==#" <> testDhKeyStr <> "/?v=1&extra_param=abc") `shouldBe` Right queue strDecode ("smp://1234-w==@smp.simplex.im:5223/3456-w==#/?extra_param=abc&v=1-2&dh=" <> testDhKeyStr) - `shouldBe` Right queue {clientVersionRange = mkVersionRange 1 2} + `shouldBe` Right queue {clientVRange = mkVersionRange 1 2} strDecode ("smp://1234-w==@smp.simplex.im:5223/3456-w==#" <> testDhKeyStr <> "/?v=1-2&extra_param=abc") - `shouldBe` Right queue {clientVersionRange = mkVersionRange 1 2} + `shouldBe` Right queue {clientVRange = mkVersionRange 1 2} it "should serialize connection requests" $ do strEncode connectionRequest `shouldBe` "https://simplex.chat/invitation#/?v=1&smp=smp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23" diff --git a/tests/AgentTests/DoubleRatchetTests.hs b/tests/AgentTests/DoubleRatchetTests.hs index e4e3508dd..71a2d6bd1 100644 --- a/tests/AgentTests/DoubleRatchetTests.hs +++ b/tests/AgentTests/DoubleRatchetTests.hs @@ -11,14 +11,17 @@ module AgentTests.DoubleRatchetTests where import Control.Concurrent.STM import Control.Monad.Except -import Crypto.Random (getRandomBytes) +import Data.Aeson (FromJSON, ToJSON) +import qualified Data.Aeson as J import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B +import qualified Data.Map.Strict as M import Simplex.Messaging.Crypto (Algorithm (..), AlgorithmI, CryptoError, DhAlgorithm) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.Ratchet import Simplex.Messaging.Encoding import Simplex.Messaging.Parsers (parseAll) +import Simplex.Messaging.Util ((<$$>)) import Test.Hspec doubleRatchetTests :: Spec @@ -35,12 +38,20 @@ doubleRatchetTests = do withRatchets @X25519 testManyMessages it "should allow skipped after ratchet advance" $ do withRatchets @X25519 testSkippedAfterRatchetAdvance + it "should encode/decode ratchet as JSON" $ do + testKeyJSON C.SX25519 + testKeyJSON C.SX448 + testRatchetJSON C.SX25519 + testRatchetJSON C.SX448 + it "should agree the same ratchet parameters" $ do + testX3dh C.SX25519 + testX3dh C.SX448 paddedMsgLen :: Int paddedMsgLen = 100 fullMsgLen :: Int -fullMsgLen = 1 + fullHeaderLen + 1 + paddedMsgLen + C.authTagSize +fullMsgLen = 1 + fullHeaderLen + C.authTagSize + paddedMsgLen testMessageHeader :: Expectation testMessageHeader = do @@ -51,7 +62,7 @@ testMessageHeader = do pattern Decrypted :: ByteString -> Either CryptoError (Either CryptoError ByteString) pattern Decrypted msg <- Right (Right msg) -type TestRatchets a = (AlgorithmI a, DhAlgorithm a) => TVar (Ratchet a) -> TVar (Ratchet a) -> IO () +type TestRatchets a = (AlgorithmI a, DhAlgorithm a) => TVar (Ratchet a, SkippedMsgKeys) -> TVar (Ratchet a, SkippedMsgKeys) -> IO () testEncryptDecrypt :: TestRatchets a testEncryptDecrypt alice bob = do @@ -137,54 +148,85 @@ testSkippedAfterRatchetAdvance alice bob = do Decrypted "b11" <- decrypt alice b11 pure () -(#>) :: (AlgorithmI a, DhAlgorithm a) => (TVar (Ratchet a), ByteString) -> TVar (Ratchet a) -> Expectation +testKeyJSON :: forall a. AlgorithmI a => C.SAlgorithm a -> IO () +testKeyJSON _ = do + (k, pk) <- C.generateKeyPair' @a + testEncodeDecode k + testEncodeDecode pk + +testRatchetJSON :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO () +testRatchetJSON _ = do + (alice, bob) <- initRatchets @a + testEncodeDecode alice + testEncodeDecode bob + +testEncodeDecode :: (Eq a, Show a, ToJSON a, FromJSON a) => a -> Expectation +testEncodeDecode x = do + let j = J.encode x + x' = J.eitherDecode' j + x' `shouldBe` Right x + +testX3dh :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO () +testX3dh _ = do + (pkBob1, pkBob2, e2eBob) <- generateE2EParams @a e2eEncryptVersion + (pkAlice1, pkAlice2, e2eAlice) <- generateE2EParams @a e2eEncryptVersion + let paramsBob = x3dhSnd pkBob1 pkBob2 e2eAlice + paramsAlice = x3dhRcv pkAlice1 pkAlice2 e2eBob + paramsAlice `shouldBe` paramsBob + +(#>) :: (AlgorithmI a, DhAlgorithm a) => (TVar (Ratchet a, SkippedMsgKeys), ByteString) -> TVar (Ratchet a, SkippedMsgKeys) -> Expectation (alice, msg) #> bob = do Right msg' <- encrypt alice msg Decrypted msg'' <- decrypt bob msg' msg'' `shouldBe` msg -withRatchets :: forall a. (AlgorithmI a, DhAlgorithm a) => (TVar (Ratchet a) -> TVar (Ratchet a) -> IO ()) -> Expectation +withRatchets :: forall a. (AlgorithmI a, DhAlgorithm a) => (TVar (Ratchet a, SkippedMsgKeys) -> TVar (Ratchet a, SkippedMsgKeys) -> IO ()) -> Expectation withRatchets test = do (a, b) <- initRatchets @a - alice <- newTVarIO a - bob <- newTVarIO b + alice <- newTVarIO (a, M.empty) + bob <- newTVarIO (b, M.empty) test alice bob `shouldReturn` () initRatchets :: (AlgorithmI a, DhAlgorithm a) => IO (Ratchet a, Ratchet a) initRatchets = do - salt <- getRandomBytes 16 - (ak, apk) <- C.generateKeyPair' - (bk, bpk) <- C.generateKeyPair' - bob <- initSndRatchet' ak bpk salt "bob -> alice" - alice <- initRcvRatchet' bk (ak, apk) salt "bob -> alice" + (pkBob1, pkBob2, e2eBob) <- generateE2EParams e2eEncryptVersion + (pkAlice1, pkAlice2, e2eAlice) <- generateE2EParams e2eEncryptVersion + let paramsBob = x3dhSnd pkBob1 pkBob2 e2eAlice + paramsAlice = x3dhRcv pkAlice1 pkAlice2 e2eBob + (_, pkBob3) <- C.generateKeyPair' + let bob = initSndRatchet (C.publicKey pkAlice2) pkBob3 paramsBob + alice = initRcvRatchet pkAlice2 paramsAlice pure (alice, bob) -encrypt_ :: AlgorithmI a => Ratchet a -> ByteString -> IO (Either CryptoError (ByteString, Ratchet a)) -encrypt_ rc msg = - runExceptT (rcEncrypt' rc paddedMsgLen msg) +encrypt_ :: AlgorithmI a => (Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError (ByteString, Ratchet a, SkippedMsgDiff)) +encrypt_ (rc, _) msg = + runExceptT (rcEncrypt rc paddedMsgLen msg) >>= either (pure . Left) checkLength where - checkLength r@(msg', _) = do + checkLength (msg', rc') = do B.length msg' `shouldBe` fullMsgLen - pure $ Right r + pure $ Right (msg', rc', SMDNoChange) -decrypt_ :: (AlgorithmI a, DhAlgorithm a) => Ratchet a -> ByteString -> IO (Either CryptoError (Either CryptoError ByteString, Ratchet a)) -decrypt_ rc msg = runExceptT $ rcDecrypt' rc msg +decrypt_ :: (AlgorithmI a, DhAlgorithm a) => (Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError (Either CryptoError ByteString, Ratchet a, SkippedMsgDiff)) +decrypt_ (rc, smks) msg = runExceptT $ rcDecrypt rc smks msg -encrypt :: AlgorithmI a => TVar (Ratchet a) -> ByteString -> IO (Either CryptoError ByteString) +encrypt :: AlgorithmI a => TVar (Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError ByteString) encrypt = withTVar encrypt_ -decrypt :: (AlgorithmI a, DhAlgorithm a) => TVar (Ratchet a) -> ByteString -> IO (Either CryptoError (Either CryptoError ByteString)) +decrypt :: (AlgorithmI a, DhAlgorithm a) => TVar (Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError (Either CryptoError ByteString)) decrypt = withTVar decrypt_ withTVar :: - (Ratchet a -> ByteString -> IO (Either e (r, Ratchet a))) -> - TVar (Ratchet a) -> + AlgorithmI a => + ((Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either e (r, Ratchet a, SkippedMsgDiff))) -> + TVar (Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either e r) withTVar op rcVar msg = readTVarIO rcVar - >>= (`op` msg) + >>= (\(rc, smks) -> applyDiff smks <$$> (testEncodeDecode rc >> op (rc, smks) msg)) >>= \case - Right (res, rc') -> atomically (writeTVar rcVar rc') >> pure (Right res) + Right (res, rc', smks') -> atomically (writeTVar rcVar (rc', smks')) >> pure (Right res) Left e -> pure $ Left e + where + applyDiff smks (res, rc', smDiff) = (res, rc', applySMDiff smks smDiff) diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index 1d8d21812..350c6f808 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -21,6 +21,7 @@ import Data.Word (Word32) import qualified Database.SQLite.Simple as DB import Database.SQLite.Simple.QQ (sql) import SMPClient (testKeyHash) +import Simplex.Messaging.Agent.ExceptT () import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Agent.Store import Simplex.Messaging.Agent.Store.SQLite