xftp: streaming file encryption/decryption to avoid memory spikes (#687)

* xftp: streaming file decryption to avoid memory spikes

* refactor, enable tests

* streaming encryption

* refactor
This commit is contained in:
Evgeny Poberezkin
2023-03-16 13:57:21 +00:00
committed by GitHub
parent a0eb53b891
commit bab689099f
4 changed files with 136 additions and 42 deletions
+5 -13
View File
@@ -25,7 +25,6 @@ import Control.Monad.Except
import Crypto.Random (getRandomBytes)
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.Bifunctor (first)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy.Char8 as LB
import Data.Char (toLower)
@@ -91,7 +90,7 @@ newtype CLIError = CLIError String
cliCryptoError :: FTCryptoError -> CLIError
cliCryptoError = \case
FTCEDecryptionError e -> CLIError $ "Error decrypting file: " <> show e
FTCECryptoError e -> CLIError $ "Error decrypting file: " <> show e
FTCEInvalidHeader e -> CLIError $ "Invalid file header: " <> e
FTCEInvalidAuthTag -> CLIError "Error decrypting file: incorrect auth tag"
FTCEFileIOError e -> CLIError $ "File IO error: " <> show e
@@ -261,7 +260,7 @@ cliSendFile :: SendOptions -> ExceptT CLIError IO ()
cliSendFile SendOptions {filePath, outputDir, numRecipients, xftpServers, retryCount, tempPath, verbose} = do
let (_, fileName) = splitFileName filePath
liftIO $ printNoNewLine "Encrypting file..."
(encPath, fdRcv, fdSnd, chunkSpecs, encSize) <- encryptFile fileName
(encPath, fdRcv, fdSnd, chunkSpecs, encSize) <- encryptFileForUpload fileName
liftIO $ printNoNewLine "Uploading file..."
uploadedChunks <- newTVarIO []
sentChunks <- uploadFile chunkSpecs uploadedChunks encSize
@@ -276,8 +275,8 @@ cliSendFile SendOptions {filePath, outputDir, numRecipients, xftpServers, retryC
putStrLn "Pass file descriptions to the recipient(s):"
forM_ fdRcvPaths putStrLn
where
encryptFile :: String -> ExceptT CLIError IO (FilePath, FileDescription 'FRecipient, FileDescription 'FSender, [XFTPChunkSpec], Int64)
encryptFile fileName = do
encryptFileForUpload :: String -> ExceptT CLIError IO (FilePath, FileDescription 'FRecipient, FileDescription 'FSender, [XFTPChunkSpec], Int64)
encryptFileForUpload fileName = do
fileSize <- fromInteger <$> getFileSize filePath
when (fileSize > maxFileSize) $ throwError $ CLIError $ "Files bigger than " <> maxFileSizeStr <> " are not supported"
encPath <- getEncPath tempPath "xftp"
@@ -289,20 +288,13 @@ cliSendFile SendOptions {filePath, outputDir, numRecipients, xftpServers, retryC
defChunkSize = head chunkSizes
chunkSizes' = map fromIntegral chunkSizes
encSize = sum chunkSizes'
encrypt fileHdr key nonce fileSize' encSize encPath
withExceptT (CLIError . show) $ encryptFile filePath fileHdr key nonce fileSize' encSize encPath
digest <- liftIO $ LC.sha512Hash <$> LB.readFile encPath
let chunkSpecs = prepareChunkSpecs encPath chunkSizes
fdRcv = FileDescription {party = SFRecipient, size = FileSize encSize, digest = FileDigest digest, key, nonce, chunkSize = FileSize defChunkSize, chunks = []}
fdSnd = FileDescription {party = SFSender, size = FileSize encSize, digest = FileDigest digest, key, nonce, chunkSize = FileSize defChunkSize, chunks = []}
logInfo $ "encrypted file to " <> tshow encPath
pure (encPath, fdRcv, fdSnd, chunkSpecs, encSize)
where
encrypt :: ByteString -> C.SbKey -> C.CbNonce -> Int64 -> Int64 -> FilePath -> ExceptT CLIError IO ()
encrypt fileHdr key nonce fileSize' encSize encFile = do
f <- liftIO $ LB.readFile filePath
let f' = LB.fromStrict fileHdr <> f
c <- liftEither $ first (CLIError . show) $ LC.sbEncryptTailTag key nonce f' fileSize' $ encSize - authTagSize
liftIO $ LB.writeFile encFile c
uploadFile :: [XFTPChunkSpec] -> TVar [Int64] -> Int64 -> ExceptT CLIError IO [SentFileChunk]
uploadFile chunks uploadedChunks encSize = do
a <- atomically $ newXFTPAgent defaultXFTPClientAgentConfig
+90 -17
View File
@@ -1,42 +1,115 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Simplex.FileTransfer.Crypto where
import Control.Monad.Except
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.Bifunctor (first)
import qualified Data.ByteArray as BA
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy.Char8 as LB
import Data.Int (Int64)
import Simplex.FileTransfer.Types (FileHeader (..), authTagSize)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Crypto.Lazy (LazyByteString)
import qualified Simplex.Messaging.Crypto.Lazy as LC
import Simplex.Messaging.Encoding
import Simplex.Messaging.Util (liftEitherWith)
import UnliftIO
import UnliftIO.Directory (removeFile)
encryptFile :: FilePath -> ByteString -> C.SbKey -> C.CbNonce -> Int64 -> Int64 -> FilePath -> ExceptT FTCryptoError IO ()
encryptFile filePath fileHdr key nonce fileSize' encSize encFile = do
sb <- liftEitherWith FTCECryptoError $ LC.sbInit key nonce
withFile filePath ReadMode $ \r -> withFile encFile WriteMode $ \w -> do
let lenStr = smpEncode fileSize'
(hdr, !sb') = LC.sbEncryptChunk sb $ lenStr <> fileHdr
padLen = encSize - authTagSize - fileSize' - 8
liftIO $ B.hPut w hdr
sb2 <- encryptChunks r w (sb', fileSize' - fromIntegral (B.length fileHdr))
sb3 <- encryptPad w (sb2, padLen)
let tag = BA.convert $ LC.sbAuth sb3
liftIO $ B.hPut w tag
where
encryptChunks r = encryptChunks_ $ liftIO . B.hGet r . fromIntegral
encryptPad = encryptChunks_ $ \sz -> pure $ B.replicate (fromIntegral sz) '#'
encryptChunks_ :: (Int64 -> IO ByteString) -> Handle -> (LC.SbState, Int64) -> ExceptT FTCryptoError IO LC.SbState
encryptChunks_ get w (!sb, !len)
| len == 0 = pure sb
| otherwise = do
let chSize = min len 65536
ch <- liftIO $ get chSize
when (B.length ch /= fromIntegral chSize) $ throwError $ FTCEFileIOError "encrypting file: unexpected EOF"
let (ch', sb') = LC.sbEncryptChunk sb ch
liftIO $ B.hPut w ch'
encryptChunks_ get w (sb', len - chSize)
decryptChunks :: Int64 -> [FilePath] -> C.SbKey -> C.CbNonce -> (String -> ExceptT String IO String) -> ExceptT FTCryptoError IO FilePath
decryptChunks encSize chunkPaths key nonce getFilePath = do
(authOk, f) <- liftEither . first FTCEDecryptionError . LC.sbDecryptTailTag key nonce (encSize - authTagSize) =<< liftIO (readChunks chunkPaths)
let (fileHdr, f') = LB.splitAt 1024 f
-- withFile encPath ReadMode $ \r -> do
-- fileHdr <- liftIO $ B.hGet r 1024
case A.parse smpP $ LB.toStrict fileHdr of
A.Fail _ _ e -> throwError $ FTCEInvalidHeader e
A.Partial _ -> throwError $ FTCEInvalidHeader "incomplete"
A.Done rest FileHeader {fileName} -> do
path <- withExceptT FTCEFileIOError $ getFilePath fileName
liftIO $ LB.writeFile path $ LB.fromStrict rest <> f'
unless authOk $ do
removeFile path
throwError FTCEInvalidAuthTag
pure path
decryptChunks _ [] _ _ _ = throwError $ FTCEInvalidHeader "empty"
decryptChunks encSize (chPath : chPaths) key nonce getFilePath = case reverse chPaths of
[] -> do
(!authOk, !f) <- liftEither . first FTCECryptoError . LC.sbDecryptTailTag key nonce (encSize - authTagSize) =<< liftIO (LB.readFile chPath)
unless authOk $ throwError FTCEInvalidAuthTag
(FileHeader {fileName}, !f') <- parseFileHeader f
path <- withExceptT FTCEFileIOError $ getFilePath fileName
liftIO $ LB.writeFile path f'
pure path
lastPath : chPaths' -> do
(state, expectedLen, ch) <- decryptFirstChunk
(FileHeader {fileName}, ch') <- parseFileHeader ch
path <- withExceptT FTCEFileIOError $ getFilePath fileName
authOk <- liftIO . withFile path WriteMode $ \h -> do
liftIO $ LB.hPut h ch'
state' <- foldM (decryptChunk h) state $ reverse chPaths'
decryptLastChunk h state' expectedLen
unless authOk $ do
removeFile path
throwError FTCEInvalidAuthTag
pure path
where
decryptFirstChunk = do
sb <- liftEitherWith FTCECryptoError $ LC.sbInit key nonce
ch <- liftIO $ LB.readFile chPath
let (ch1, !sb') = LC.sbDecryptChunkLazy sb ch
(!expectedLen, ch2) <- liftEitherWith FTCECryptoError $ LC.splitLen ch1
let len1 = LB.length ch2
pure ((sb', len1), expectedLen, ch2)
decryptChunk h (!sb, !len) chPth = do
ch <- LB.readFile chPth
let len' = len + LB.length ch
(ch', sb') = LC.sbDecryptChunkLazy sb ch
LB.hPut h ch'
pure (sb', len')
decryptLastChunk h (!sb, !len) expectedLen = do
ch <- LB.readFile lastPath
let (ch1, tag') = LB.splitAt (LB.length ch - authTagSize) ch
tag'' = LB.toStrict tag'
(ch2, sb') = LC.sbDecryptChunkLazy sb ch1
len' = len + LB.length ch2
ch3 = LB.take (LB.length ch2 - len' + expectedLen) ch2
tag :: ByteString = BA.convert (LC.sbAuth sb')
LB.hPut h ch3
pure $ B.length tag'' == 16 && BA.constEq tag'' tag
where
parseFileHeader :: LazyByteString -> ExceptT FTCryptoError IO (FileHeader, LazyByteString)
parseFileHeader s = do
let (hdrStr, s') = LB.splitAt 1024 s
case A.parse smpP $ LB.toStrict hdrStr of
A.Fail _ _ e -> throwError $ FTCEInvalidHeader e
A.Partial _ -> throwError $ FTCEInvalidHeader "incomplete"
A.Done rest hdr -> pure (hdr, LB.fromStrict rest <> s')
readChunks :: [FilePath] -> IO LB.ByteString
readChunks = foldM (\s path -> (s <>) <$> LB.readFile path) ""
data FTCryptoError
= FTCEDecryptionError C.CryptoError
= FTCECryptoError C.CryptoError
| FTCEInvalidHeader String
| FTCEInvalidAuthTag
| FTCEFileIOError String
deriving (Show, Eq)
deriving (Show, Eq, Exception)
+37 -12
View File
@@ -11,6 +11,7 @@ module Simplex.Messaging.Crypto.Lazy
sha512Hash,
pad,
unPad,
splitLen,
sbEncrypt,
sbDecrypt,
sbEncryptTailTag,
@@ -21,7 +22,10 @@ module Simplex.Messaging.Crypto.Lazy
sbInit,
sbEncryptChunk,
sbDecryptChunk,
sbEncryptChunkLazy,
sbDecryptChunkLazy,
sbAuth,
LazyByteString,
)
where
@@ -30,6 +34,7 @@ import qualified Crypto.Error as CE
import Crypto.Hash (Digest, hashlazy)
import Crypto.Hash.Algorithms (SHA256, SHA512)
import qualified Crypto.MAC.Poly1305 as Poly1305
import Data.Bifunctor (first)
import Data.ByteArray (ByteArrayAccess)
import qualified Data.ByteArray as BA
import qualified Data.ByteString as S
@@ -37,6 +42,7 @@ import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy.Char8 as LB
import qualified Data.ByteString.Lazy.Internal as LB
import Data.Composition ((.:.))
import Data.Int (Int64)
import Data.List.NonEmpty (NonEmpty (..))
import Foreign (sizeOf)
@@ -77,11 +83,14 @@ fastReplicate n c
-- this function does not validate the length of the message to avoid consuming all chunks,
-- so it can return a shorter string than expected
unPad :: LazyByteString -> Either CryptoError LazyByteString
unPad padded
unPad = fmap snd . splitLen
splitLen :: LazyByteString -> Either CryptoError (Int64, LazyByteString)
splitLen padded
| LB.length lenStr == 8 = case smpDecode $ LB.toStrict lenStr of
Right len
| len < 0 -> Left CryptoInvalidMsgError
| otherwise -> Right $ LB.take len rest
| otherwise -> Right (len, LB.take len rest)
Left _ -> Left CryptoInvalidMsgError
| otherwise = Left CryptoInvalidMsgError
where
@@ -111,9 +120,9 @@ sbDecrypt (SbKey key) (CbNonce nonce) packet
secretBox :: ByteArrayAccess key => (SbState -> ByteString -> (ByteString, SbState)) -> key -> ByteString -> LazyByteString -> Either CryptoError (NonEmpty ByteString)
secretBox sbProcess secret nonce msg = run <$> sbInit_ secret nonce
where
process state = foldlChunks update ([], state) msg
update (cs, st) chunk = let (c, st') = sbProcess st chunk in (c : cs, st')
run state = let (cs, state') = process state in BA.convert (sbAuth state') :| reverse cs
run state =
let (!cs, !state') = secretBoxLazy_ sbProcess state msg
in BA.convert (sbAuth state') :| reverse cs
-- | NaCl @secret_box@ lazy encrypt with a symmetric 256-bit key and 192-bit nonce with appended auth tag (more efficient with large files).
sbEncryptTailTag :: SbKey -> CbNonce -> LazyByteString -> Int64 -> Int64 -> Either CryptoError LazyByteString
@@ -135,9 +144,15 @@ sbDecryptTailTag (SbKey key) (CbNonce nonce) paddedLen packet =
secretBoxTailTag :: ByteArrayAccess key => (SbState -> ByteString -> (ByteString, SbState)) -> key -> ByteString -> LazyByteString -> Either CryptoError [ByteString]
secretBoxTailTag sbProcess secret nonce msg = run <$> sbInit_ secret nonce
where
process state = foldlChunks update ([], state) msg
update (cs, st) chunk = let (c, st') = sbProcess st chunk in (c : cs, st')
run state = let (cs, state') = process state in reverse $ BA.convert (sbAuth state') : cs
run state =
let (cs, state') = secretBoxLazy_ sbProcess state msg
in reverse $ BA.convert (sbAuth state') : cs
-- passes lazy bytestring via initialized secret box returning the reversed list of chunks
secretBoxLazy_ :: (SbState -> ByteString -> (ByteString, SbState)) -> SbState -> LazyByteString -> ([ByteString], SbState)
secretBoxLazy_ sbProcess state = foldlChunks update ([], state)
where
update (cs, st) chunk = let (!c, !st') = sbProcess st chunk in (c : cs, st')
type SbState = (XSalsa.State, Poly1305.State)
@@ -158,16 +173,26 @@ sbInit_ secret nonce = (state2,) <$> cryptoPassed (Poly1305.initialize rs)
state1 = XSalsa.derive state0 iv1
(rs :: ByteString, state2) = XSalsa.generate state1 32
sbEncryptChunkLazy :: SbState -> LazyByteString -> (LazyByteString, SbState)
sbEncryptChunkLazy = sbProcessChunkLazy_ sbEncryptChunk
sbDecryptChunkLazy :: SbState -> LazyByteString -> (LazyByteString, SbState)
sbDecryptChunkLazy = sbProcessChunkLazy_ sbDecryptChunk
sbProcessChunkLazy_ :: (SbState -> ByteString -> (ByteString, SbState)) -> SbState -> LazyByteString -> (LazyByteString, SbState)
sbProcessChunkLazy_ = first (LB.fromChunks . reverse) .:. secretBoxLazy_
{-# INLINE sbProcessChunkLazy_ #-}
sbEncryptChunk :: SbState -> ByteString -> (ByteString, SbState)
sbEncryptChunk (st, authSt) chunk =
let (c, st') = XSalsa.combine st chunk
authSt' = Poly1305.update authSt c
let (!c, !st') = XSalsa.combine st chunk
!authSt' = Poly1305.update authSt c
in (c, (st', authSt'))
sbDecryptChunk :: SbState -> ByteString -> (ByteString, SbState)
sbDecryptChunk (st, authSt) chunk =
let (s, st') = XSalsa.combine st chunk
authSt' = Poly1305.update authSt chunk
let (!s, !st') = XSalsa.combine st chunk
!authSt' = Poly1305.update authSt chunk
in (s, (st', authSt'))
sbAuth :: SbState -> Poly1305.Auth
+4
View File
@@ -55,6 +55,10 @@ liftEitherError :: (MonadIO m, MonadError e' m) => (e -> e') -> IO (Either e a)
liftEitherError f a = liftIOEither (first f <$> a)
{-# INLINE liftEitherError #-}
liftEitherWith :: (MonadError e' m) => (e -> e') -> Either e a -> m a
liftEitherWith f = liftEither . first f
{-# INLINE liftEitherWith #-}
tryError :: MonadError e m => m a -> m (Either e a)
tryError action = (Right <$> action) `catchError` (pure . Left)
{-# INLINE tryError #-}