use throwE instead of throwError (#1187)

* use throwE instead of throwError

* test delay
This commit is contained in:
Evgeny Poberezkin
2024-06-05 11:20:50 +01:00
committed by GitHub
parent 44b8f265ae
commit 3dab330480
18 changed files with 155 additions and 145 deletions
+18 -17
View File
@@ -30,6 +30,7 @@ import Control.Logger.Simple
import Control.Monad
import Control.Monad.Except
import Control.Monad.IO.Class
import Control.Monad.Trans.Except
import Crypto.Random (ChaChaDRG)
import qualified Data.Aeson as J
import Data.ByteString (ByteString)
@@ -106,9 +107,9 @@ connectRCHost drg pairing@RCHostPairing {caKey, caCert, idPrivKey, knownHost} ct
action <- liftIO $ runClient c r hostKeys
-- wait for the port to make invitation
portNum <- atomically $ readTMVar startedPort
signedInv@RCSignedInvitation {invitation} <- maybe (throwError RCETLSStartFailed) (liftIO . mkInvitation hostKeys address) portNum
signedInv@RCSignedInvitation {invitation} <- maybe (throwE RCETLSStartFailed) (liftIO . mkInvitation hostKeys address) portNum
when multicast $ case knownHost of
Nothing -> throwError RCENewController
Nothing -> throwE RCENewController
Just KnownHostPairing {hostDhPubKey} -> do
ann <- liftIO . async . runExceptT $ announceRC drg 60 idPrivKey hostDhPubKey hostKeys invitation
atomically $ putTMVar announcer ann
@@ -117,7 +118,7 @@ connectRCHost drg pairing@RCHostPairing {caKey, caCert, idPrivKey, knownHost} ct
findCtrlAddress :: ExceptT RCErrorType IO (NonEmpty RCCtrlAddress)
findCtrlAddress = do
found' <- liftIO $ getLocalAddress rcAddrPrefs_
maybe (throwError RCENoLocalAddress) pure $ L.nonEmpty found'
maybe (throwE RCENoLocalAddress) pure $ L.nonEmpty found'
mkClient :: IO RCHClient_
mkClient = do
startedPort <- newEmptyTMVarIO
@@ -211,10 +212,10 @@ prepareHostSession
let sharedKey = C.dh' dhPubKey dhPrivKey
helloBody <- liftEitherWith (const RCEDecrypt) $ C.cbDecrypt sharedKey nonce encBody
hostHello@RCHostHello {v, ca, kem = kemPubKey} <- liftEitherWith RCESyntax $ J.eitherDecodeStrict helloBody
unless (ca == tlsHostFingerprint) $ throwError RCEIdentity
unless (ca == tlsHostFingerprint) $ throwE RCEIdentity
(kemCiphertext, kemSharedKey) <- liftIO $ sntrup761Enc drg kemPubKey
let hybridKey = kemHybridSecret dhPubKey dhPrivKey kemSharedKey
unless (isCompatible v supportedRCPVRange) $ throwError RCEVersion
unless (isCompatible v supportedRCPVRange) $ throwE RCEVersion
let keys = HostSessKeys {hybridKey, idPrivKey, sessPrivKey}
knownHost' <- updateKnownHost ca dhPubKey
let ctrlHello = RCCtrlHello {}
@@ -227,7 +228,7 @@ prepareHostSession
updateKnownHost :: C.KeyHash -> C.PublicKeyX25519 -> ExceptT RCErrorType IO KnownHostPairing
updateKnownHost ca hostDhPubKey = case knownHost_ of
Just h -> do
unless (hostFingerprint h == tlsHostFingerprint) . throwError $
unless (hostFingerprint h == tlsHostFingerprint) . throwE $
RCEInternal "TLS host CA is different from host pairing, should be caught in TLS handshake"
pure (h :: KnownHostPairing) {hostDhPubKey}
Nothing -> pure KnownHostPairing {hostFingerprint = ca, hostDhPubKey}
@@ -257,7 +258,7 @@ connectRCCtrl drg (RCVerifiedInvitation inv@RCInvitation {ca, idkey}) pairing_ h
pure RCCtrlPairing {caKey, caCert, ctrlFingerprint = ca, idPubKey = idkey, dhPrivKey, prevDhPrivKey = Nothing}
updateCtrlPairing :: RCCtrlPairing -> ExceptT RCErrorType IO RCCtrlPairing
updateCtrlPairing pairing@RCCtrlPairing {ctrlFingerprint, idPubKey, dhPrivKey = currDhPrivKey} = do
unless (ca == ctrlFingerprint && idPubKey == idkey) $ throwError RCEIdentity
unless (ca == ctrlFingerprint && idPubKey == idkey) $ throwE RCEIdentity
(_, dhPrivKey) <- atomically $ C.generateKeyPair drg
pure pairing {dhPrivKey, prevDhPrivKey = Just currDhPrivKey}
@@ -278,7 +279,7 @@ connectRCCtrl_ drg pairing'@RCCtrlPairing {caKey, caCert} inv@RCInvitation {ca,
clientCredentials <-
liftIO (genTLSCredentials drg caKey caCert) >>= \case
TLS.Credentials (creds : _) -> pure $ Just creds
_ -> throwError $ RCEInternal "genTLSCredentials must generate credentials"
_ -> throwE $ RCEInternal "genTLSCredentials must generate credentials"
let clientConfig = defaultTransportClientConfig {clientCredentials}
ExceptT . runTransportClient clientConfig Nothing host (show port) (Just ca) $ \tls@TLS {tlsBuffer, tlsContext} -> runExceptT $ do
-- pump socket to detect connection problems
@@ -307,7 +308,7 @@ catchRCError = catchAllErrors (RCEException . show)
{-# INLINE catchRCError #-}
putRCError :: ExceptT RCErrorType IO a -> TMVar (Either RCErrorType b) -> ExceptT RCErrorType IO a
a `putRCError` r = a `catchRCError` \e -> atomically (tryPutTMVar r $ Left e) >> throwError e
a `putRCError` r = a `catchRCError` \e -> atomically (tryPutTMVar r $ Left e) >> throwE e
sendRCPacket :: Encoding a => TLS -> a -> ExceptT RCErrorType IO ()
sendRCPacket tls pkt = do
@@ -317,7 +318,7 @@ sendRCPacket tls pkt = do
receiveRCPacket :: Encoding a => TLS -> ExceptT RCErrorType IO a
receiveRCPacket tls = do
b <- liftIO $ cGet tls xrcpBlockSize
when (B.length b /= xrcpBlockSize) $ throwError RCEBlockSize
when (B.length b /= xrcpBlockSize) $ throwE RCEBlockSize
b' <- liftEitherWith (const RCEBlockSize) $ C.unPad b
liftEitherWith RCESyntax $ smpDecode b'
@@ -329,7 +330,7 @@ prepareHostHello
hostAppInfo = do
logDebug "Preparing session"
case compatibleVersion v supportedRCPVRange of
Nothing -> throwError RCEVersion
Nothing -> throwE RCEVersion
Just (Compatible v') -> do
nonce <- liftIO . atomically $ C.randomCbNonce drg
(kemPubKey, kemPrivKey) <- liftIO $ sntrup761Keypair drg
@@ -355,7 +356,7 @@ prepareCtrlSession
pure CtrlSessKeys {hybridKey, idPubKey, sessPubKey = skey}
RCCtrlEncError {nonce, encMessage} -> do
message <- liftEitherWith (const RCEDecrypt) $ C.cbDecrypt sharedKey nonce encMessage
throwError $ RCECtrlError $ T.unpack $ safeDecodeUtf8 message
throwE $ RCECtrlError $ T.unpack $ safeDecodeUtf8 message
-- * Multicast discovery
@@ -382,7 +383,7 @@ discoverRCCtrl subscribers pairings =
r@(_, RCVerifiedInvitation RCInvitation {host}) <- findRCCtrlPairing pairings encInvitation
case source of
SockAddrInet _ ha | THIPv4 (hostAddressToTuple ha) == host -> pure ()
_ -> throwError RCEInvitation
_ -> throwE RCEInvitation
pure r
where
loop :: ExceptT RCErrorType IO a -> ExceptT RCErrorType IO a
@@ -392,8 +393,8 @@ findRCCtrlPairing :: NonEmpty RCCtrlPairing -> RCEncInvitation -> ExceptT RCErro
findRCCtrlPairing pairings RCEncInvitation {dhPubKey, nonce, encInvitation} = do
(pairing, signedInvStr) <- liftEither $ decrypt (L.toList pairings)
signedInv <- liftEitherWith RCESyntax $ strDecode signedInvStr
inv@(RCVerifiedInvitation RCInvitation {dh = invDh}) <- maybe (throwError RCEInvitation) pure $ verifySignedInvitation signedInv
unless (invDh == dhPubKey) $ throwError RCEInvitation
inv@(RCVerifiedInvitation RCInvitation {dh = invDh}) <- maybe (throwE RCEInvitation) pure $ verifySignedInvitation signedInv
unless (invDh == dhPubKey) $ throwE RCEInvitation
pure (pairing, inv)
where
decrypt :: [RCCtrlPairing] -> Either RCErrorType (RCCtrlPairing, ByteString)
@@ -433,7 +434,7 @@ rcEncryptBody drg hybridKey s = do
rcDecryptBody :: KEMHybridSecret -> C.CbNonce -> LazyByteString -> ExceptT RCErrorType IO LazyByteString
rcDecryptBody hybridKey nonce ct = do
let len = LB.length ct - 16
when (len < 0) $ throwError RCEDecrypt
when (len < 0) $ throwE RCEDecrypt
(ok, s) <- liftEitherWith (const RCEDecrypt) $ LC.kcbDecryptTailTag hybridKey nonce len ct
unless ok $ throwError RCEDecrypt
unless ok $ throwE RCEDecrypt
pure s