mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-05-20 10:55:10 +00:00
xftp: streaming file encryption/decryption to avoid memory spikes (#687)
* xftp: streaming file decryption to avoid memory spikes * refactor, enable tests * streaming encryption * refactor
This commit is contained in:
committed by
GitHub
parent
a0eb53b891
commit
bab689099f
@@ -25,7 +25,6 @@ import Control.Monad.Except
|
||||
import Crypto.Random (getRandomBytes)
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Bifunctor (first)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.ByteString.Lazy.Char8 as LB
|
||||
import Data.Char (toLower)
|
||||
@@ -91,7 +90,7 @@ newtype CLIError = CLIError String
|
||||
|
||||
cliCryptoError :: FTCryptoError -> CLIError
|
||||
cliCryptoError = \case
|
||||
FTCEDecryptionError e -> CLIError $ "Error decrypting file: " <> show e
|
||||
FTCECryptoError e -> CLIError $ "Error decrypting file: " <> show e
|
||||
FTCEInvalidHeader e -> CLIError $ "Invalid file header: " <> e
|
||||
FTCEInvalidAuthTag -> CLIError "Error decrypting file: incorrect auth tag"
|
||||
FTCEFileIOError e -> CLIError $ "File IO error: " <> show e
|
||||
@@ -261,7 +260,7 @@ cliSendFile :: SendOptions -> ExceptT CLIError IO ()
|
||||
cliSendFile SendOptions {filePath, outputDir, numRecipients, xftpServers, retryCount, tempPath, verbose} = do
|
||||
let (_, fileName) = splitFileName filePath
|
||||
liftIO $ printNoNewLine "Encrypting file..."
|
||||
(encPath, fdRcv, fdSnd, chunkSpecs, encSize) <- encryptFile fileName
|
||||
(encPath, fdRcv, fdSnd, chunkSpecs, encSize) <- encryptFileForUpload fileName
|
||||
liftIO $ printNoNewLine "Uploading file..."
|
||||
uploadedChunks <- newTVarIO []
|
||||
sentChunks <- uploadFile chunkSpecs uploadedChunks encSize
|
||||
@@ -276,8 +275,8 @@ cliSendFile SendOptions {filePath, outputDir, numRecipients, xftpServers, retryC
|
||||
putStrLn "Pass file descriptions to the recipient(s):"
|
||||
forM_ fdRcvPaths putStrLn
|
||||
where
|
||||
encryptFile :: String -> ExceptT CLIError IO (FilePath, FileDescription 'FRecipient, FileDescription 'FSender, [XFTPChunkSpec], Int64)
|
||||
encryptFile fileName = do
|
||||
encryptFileForUpload :: String -> ExceptT CLIError IO (FilePath, FileDescription 'FRecipient, FileDescription 'FSender, [XFTPChunkSpec], Int64)
|
||||
encryptFileForUpload fileName = do
|
||||
fileSize <- fromInteger <$> getFileSize filePath
|
||||
when (fileSize > maxFileSize) $ throwError $ CLIError $ "Files bigger than " <> maxFileSizeStr <> " are not supported"
|
||||
encPath <- getEncPath tempPath "xftp"
|
||||
@@ -289,20 +288,13 @@ cliSendFile SendOptions {filePath, outputDir, numRecipients, xftpServers, retryC
|
||||
defChunkSize = head chunkSizes
|
||||
chunkSizes' = map fromIntegral chunkSizes
|
||||
encSize = sum chunkSizes'
|
||||
encrypt fileHdr key nonce fileSize' encSize encPath
|
||||
withExceptT (CLIError . show) $ encryptFile filePath fileHdr key nonce fileSize' encSize encPath
|
||||
digest <- liftIO $ LC.sha512Hash <$> LB.readFile encPath
|
||||
let chunkSpecs = prepareChunkSpecs encPath chunkSizes
|
||||
fdRcv = FileDescription {party = SFRecipient, size = FileSize encSize, digest = FileDigest digest, key, nonce, chunkSize = FileSize defChunkSize, chunks = []}
|
||||
fdSnd = FileDescription {party = SFSender, size = FileSize encSize, digest = FileDigest digest, key, nonce, chunkSize = FileSize defChunkSize, chunks = []}
|
||||
logInfo $ "encrypted file to " <> tshow encPath
|
||||
pure (encPath, fdRcv, fdSnd, chunkSpecs, encSize)
|
||||
where
|
||||
encrypt :: ByteString -> C.SbKey -> C.CbNonce -> Int64 -> Int64 -> FilePath -> ExceptT CLIError IO ()
|
||||
encrypt fileHdr key nonce fileSize' encSize encFile = do
|
||||
f <- liftIO $ LB.readFile filePath
|
||||
let f' = LB.fromStrict fileHdr <> f
|
||||
c <- liftEither $ first (CLIError . show) $ LC.sbEncryptTailTag key nonce f' fileSize' $ encSize - authTagSize
|
||||
liftIO $ LB.writeFile encFile c
|
||||
uploadFile :: [XFTPChunkSpec] -> TVar [Int64] -> Int64 -> ExceptT CLIError IO [SentFileChunk]
|
||||
uploadFile chunks uploadedChunks encSize = do
|
||||
a <- atomically $ newXFTPAgent defaultXFTPClientAgentConfig
|
||||
|
||||
@@ -1,42 +1,115 @@
|
||||
{-# LANGUAGE BangPatterns #-}
|
||||
{-# LANGUAGE DeriveAnyClass #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module Simplex.FileTransfer.Crypto where
|
||||
|
||||
import Control.Monad.Except
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Bifunctor (first)
|
||||
import qualified Data.ByteArray as BA
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.ByteString.Lazy.Char8 as LB
|
||||
import Data.Int (Int64)
|
||||
import Simplex.FileTransfer.Types (FileHeader (..), authTagSize)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Crypto.Lazy (LazyByteString)
|
||||
import qualified Simplex.Messaging.Crypto.Lazy as LC
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Util (liftEitherWith)
|
||||
import UnliftIO
|
||||
import UnliftIO.Directory (removeFile)
|
||||
|
||||
encryptFile :: FilePath -> ByteString -> C.SbKey -> C.CbNonce -> Int64 -> Int64 -> FilePath -> ExceptT FTCryptoError IO ()
|
||||
encryptFile filePath fileHdr key nonce fileSize' encSize encFile = do
|
||||
sb <- liftEitherWith FTCECryptoError $ LC.sbInit key nonce
|
||||
withFile filePath ReadMode $ \r -> withFile encFile WriteMode $ \w -> do
|
||||
let lenStr = smpEncode fileSize'
|
||||
(hdr, !sb') = LC.sbEncryptChunk sb $ lenStr <> fileHdr
|
||||
padLen = encSize - authTagSize - fileSize' - 8
|
||||
liftIO $ B.hPut w hdr
|
||||
sb2 <- encryptChunks r w (sb', fileSize' - fromIntegral (B.length fileHdr))
|
||||
sb3 <- encryptPad w (sb2, padLen)
|
||||
let tag = BA.convert $ LC.sbAuth sb3
|
||||
liftIO $ B.hPut w tag
|
||||
where
|
||||
encryptChunks r = encryptChunks_ $ liftIO . B.hGet r . fromIntegral
|
||||
encryptPad = encryptChunks_ $ \sz -> pure $ B.replicate (fromIntegral sz) '#'
|
||||
encryptChunks_ :: (Int64 -> IO ByteString) -> Handle -> (LC.SbState, Int64) -> ExceptT FTCryptoError IO LC.SbState
|
||||
encryptChunks_ get w (!sb, !len)
|
||||
| len == 0 = pure sb
|
||||
| otherwise = do
|
||||
let chSize = min len 65536
|
||||
ch <- liftIO $ get chSize
|
||||
when (B.length ch /= fromIntegral chSize) $ throwError $ FTCEFileIOError "encrypting file: unexpected EOF"
|
||||
let (ch', sb') = LC.sbEncryptChunk sb ch
|
||||
liftIO $ B.hPut w ch'
|
||||
encryptChunks_ get w (sb', len - chSize)
|
||||
|
||||
decryptChunks :: Int64 -> [FilePath] -> C.SbKey -> C.CbNonce -> (String -> ExceptT String IO String) -> ExceptT FTCryptoError IO FilePath
|
||||
decryptChunks encSize chunkPaths key nonce getFilePath = do
|
||||
(authOk, f) <- liftEither . first FTCEDecryptionError . LC.sbDecryptTailTag key nonce (encSize - authTagSize) =<< liftIO (readChunks chunkPaths)
|
||||
let (fileHdr, f') = LB.splitAt 1024 f
|
||||
-- withFile encPath ReadMode $ \r -> do
|
||||
-- fileHdr <- liftIO $ B.hGet r 1024
|
||||
case A.parse smpP $ LB.toStrict fileHdr of
|
||||
A.Fail _ _ e -> throwError $ FTCEInvalidHeader e
|
||||
A.Partial _ -> throwError $ FTCEInvalidHeader "incomplete"
|
||||
A.Done rest FileHeader {fileName} -> do
|
||||
path <- withExceptT FTCEFileIOError $ getFilePath fileName
|
||||
liftIO $ LB.writeFile path $ LB.fromStrict rest <> f'
|
||||
unless authOk $ do
|
||||
removeFile path
|
||||
throwError FTCEInvalidAuthTag
|
||||
pure path
|
||||
decryptChunks _ [] _ _ _ = throwError $ FTCEInvalidHeader "empty"
|
||||
decryptChunks encSize (chPath : chPaths) key nonce getFilePath = case reverse chPaths of
|
||||
[] -> do
|
||||
(!authOk, !f) <- liftEither . first FTCECryptoError . LC.sbDecryptTailTag key nonce (encSize - authTagSize) =<< liftIO (LB.readFile chPath)
|
||||
unless authOk $ throwError FTCEInvalidAuthTag
|
||||
(FileHeader {fileName}, !f') <- parseFileHeader f
|
||||
path <- withExceptT FTCEFileIOError $ getFilePath fileName
|
||||
liftIO $ LB.writeFile path f'
|
||||
pure path
|
||||
lastPath : chPaths' -> do
|
||||
(state, expectedLen, ch) <- decryptFirstChunk
|
||||
(FileHeader {fileName}, ch') <- parseFileHeader ch
|
||||
path <- withExceptT FTCEFileIOError $ getFilePath fileName
|
||||
authOk <- liftIO . withFile path WriteMode $ \h -> do
|
||||
liftIO $ LB.hPut h ch'
|
||||
state' <- foldM (decryptChunk h) state $ reverse chPaths'
|
||||
decryptLastChunk h state' expectedLen
|
||||
unless authOk $ do
|
||||
removeFile path
|
||||
throwError FTCEInvalidAuthTag
|
||||
pure path
|
||||
where
|
||||
decryptFirstChunk = do
|
||||
sb <- liftEitherWith FTCECryptoError $ LC.sbInit key nonce
|
||||
ch <- liftIO $ LB.readFile chPath
|
||||
let (ch1, !sb') = LC.sbDecryptChunkLazy sb ch
|
||||
(!expectedLen, ch2) <- liftEitherWith FTCECryptoError $ LC.splitLen ch1
|
||||
let len1 = LB.length ch2
|
||||
pure ((sb', len1), expectedLen, ch2)
|
||||
decryptChunk h (!sb, !len) chPth = do
|
||||
ch <- LB.readFile chPth
|
||||
let len' = len + LB.length ch
|
||||
(ch', sb') = LC.sbDecryptChunkLazy sb ch
|
||||
LB.hPut h ch'
|
||||
pure (sb', len')
|
||||
decryptLastChunk h (!sb, !len) expectedLen = do
|
||||
ch <- LB.readFile lastPath
|
||||
let (ch1, tag') = LB.splitAt (LB.length ch - authTagSize) ch
|
||||
tag'' = LB.toStrict tag'
|
||||
(ch2, sb') = LC.sbDecryptChunkLazy sb ch1
|
||||
len' = len + LB.length ch2
|
||||
ch3 = LB.take (LB.length ch2 - len' + expectedLen) ch2
|
||||
tag :: ByteString = BA.convert (LC.sbAuth sb')
|
||||
LB.hPut h ch3
|
||||
pure $ B.length tag'' == 16 && BA.constEq tag'' tag
|
||||
where
|
||||
parseFileHeader :: LazyByteString -> ExceptT FTCryptoError IO (FileHeader, LazyByteString)
|
||||
parseFileHeader s = do
|
||||
let (hdrStr, s') = LB.splitAt 1024 s
|
||||
case A.parse smpP $ LB.toStrict hdrStr of
|
||||
A.Fail _ _ e -> throwError $ FTCEInvalidHeader e
|
||||
A.Partial _ -> throwError $ FTCEInvalidHeader "incomplete"
|
||||
A.Done rest hdr -> pure (hdr, LB.fromStrict rest <> s')
|
||||
|
||||
readChunks :: [FilePath] -> IO LB.ByteString
|
||||
readChunks = foldM (\s path -> (s <>) <$> LB.readFile path) ""
|
||||
|
||||
data FTCryptoError
|
||||
= FTCEDecryptionError C.CryptoError
|
||||
= FTCECryptoError C.CryptoError
|
||||
| FTCEInvalidHeader String
|
||||
| FTCEInvalidAuthTag
|
||||
| FTCEFileIOError String
|
||||
deriving (Show, Eq)
|
||||
deriving (Show, Eq, Exception)
|
||||
|
||||
@@ -11,6 +11,7 @@ module Simplex.Messaging.Crypto.Lazy
|
||||
sha512Hash,
|
||||
pad,
|
||||
unPad,
|
||||
splitLen,
|
||||
sbEncrypt,
|
||||
sbDecrypt,
|
||||
sbEncryptTailTag,
|
||||
@@ -21,7 +22,10 @@ module Simplex.Messaging.Crypto.Lazy
|
||||
sbInit,
|
||||
sbEncryptChunk,
|
||||
sbDecryptChunk,
|
||||
sbEncryptChunkLazy,
|
||||
sbDecryptChunkLazy,
|
||||
sbAuth,
|
||||
LazyByteString,
|
||||
)
|
||||
where
|
||||
|
||||
@@ -30,6 +34,7 @@ import qualified Crypto.Error as CE
|
||||
import Crypto.Hash (Digest, hashlazy)
|
||||
import Crypto.Hash.Algorithms (SHA256, SHA512)
|
||||
import qualified Crypto.MAC.Poly1305 as Poly1305
|
||||
import Data.Bifunctor (first)
|
||||
import Data.ByteArray (ByteArrayAccess)
|
||||
import qualified Data.ByteArray as BA
|
||||
import qualified Data.ByteString as S
|
||||
@@ -37,6 +42,7 @@ import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.ByteString.Lazy.Char8 as LB
|
||||
import qualified Data.ByteString.Lazy.Internal as LB
|
||||
import Data.Composition ((.:.))
|
||||
import Data.Int (Int64)
|
||||
import Data.List.NonEmpty (NonEmpty (..))
|
||||
import Foreign (sizeOf)
|
||||
@@ -77,11 +83,14 @@ fastReplicate n c
|
||||
-- this function does not validate the length of the message to avoid consuming all chunks,
|
||||
-- so it can return a shorter string than expected
|
||||
unPad :: LazyByteString -> Either CryptoError LazyByteString
|
||||
unPad padded
|
||||
unPad = fmap snd . splitLen
|
||||
|
||||
splitLen :: LazyByteString -> Either CryptoError (Int64, LazyByteString)
|
||||
splitLen padded
|
||||
| LB.length lenStr == 8 = case smpDecode $ LB.toStrict lenStr of
|
||||
Right len
|
||||
| len < 0 -> Left CryptoInvalidMsgError
|
||||
| otherwise -> Right $ LB.take len rest
|
||||
| otherwise -> Right (len, LB.take len rest)
|
||||
Left _ -> Left CryptoInvalidMsgError
|
||||
| otherwise = Left CryptoInvalidMsgError
|
||||
where
|
||||
@@ -111,9 +120,9 @@ sbDecrypt (SbKey key) (CbNonce nonce) packet
|
||||
secretBox :: ByteArrayAccess key => (SbState -> ByteString -> (ByteString, SbState)) -> key -> ByteString -> LazyByteString -> Either CryptoError (NonEmpty ByteString)
|
||||
secretBox sbProcess secret nonce msg = run <$> sbInit_ secret nonce
|
||||
where
|
||||
process state = foldlChunks update ([], state) msg
|
||||
update (cs, st) chunk = let (c, st') = sbProcess st chunk in (c : cs, st')
|
||||
run state = let (cs, state') = process state in BA.convert (sbAuth state') :| reverse cs
|
||||
run state =
|
||||
let (!cs, !state') = secretBoxLazy_ sbProcess state msg
|
||||
in BA.convert (sbAuth state') :| reverse cs
|
||||
|
||||
-- | NaCl @secret_box@ lazy encrypt with a symmetric 256-bit key and 192-bit nonce with appended auth tag (more efficient with large files).
|
||||
sbEncryptTailTag :: SbKey -> CbNonce -> LazyByteString -> Int64 -> Int64 -> Either CryptoError LazyByteString
|
||||
@@ -135,9 +144,15 @@ sbDecryptTailTag (SbKey key) (CbNonce nonce) paddedLen packet =
|
||||
secretBoxTailTag :: ByteArrayAccess key => (SbState -> ByteString -> (ByteString, SbState)) -> key -> ByteString -> LazyByteString -> Either CryptoError [ByteString]
|
||||
secretBoxTailTag sbProcess secret nonce msg = run <$> sbInit_ secret nonce
|
||||
where
|
||||
process state = foldlChunks update ([], state) msg
|
||||
update (cs, st) chunk = let (c, st') = sbProcess st chunk in (c : cs, st')
|
||||
run state = let (cs, state') = process state in reverse $ BA.convert (sbAuth state') : cs
|
||||
run state =
|
||||
let (cs, state') = secretBoxLazy_ sbProcess state msg
|
||||
in reverse $ BA.convert (sbAuth state') : cs
|
||||
|
||||
-- passes lazy bytestring via initialized secret box returning the reversed list of chunks
|
||||
secretBoxLazy_ :: (SbState -> ByteString -> (ByteString, SbState)) -> SbState -> LazyByteString -> ([ByteString], SbState)
|
||||
secretBoxLazy_ sbProcess state = foldlChunks update ([], state)
|
||||
where
|
||||
update (cs, st) chunk = let (!c, !st') = sbProcess st chunk in (c : cs, st')
|
||||
|
||||
type SbState = (XSalsa.State, Poly1305.State)
|
||||
|
||||
@@ -158,16 +173,26 @@ sbInit_ secret nonce = (state2,) <$> cryptoPassed (Poly1305.initialize rs)
|
||||
state1 = XSalsa.derive state0 iv1
|
||||
(rs :: ByteString, state2) = XSalsa.generate state1 32
|
||||
|
||||
sbEncryptChunkLazy :: SbState -> LazyByteString -> (LazyByteString, SbState)
|
||||
sbEncryptChunkLazy = sbProcessChunkLazy_ sbEncryptChunk
|
||||
|
||||
sbDecryptChunkLazy :: SbState -> LazyByteString -> (LazyByteString, SbState)
|
||||
sbDecryptChunkLazy = sbProcessChunkLazy_ sbDecryptChunk
|
||||
|
||||
sbProcessChunkLazy_ :: (SbState -> ByteString -> (ByteString, SbState)) -> SbState -> LazyByteString -> (LazyByteString, SbState)
|
||||
sbProcessChunkLazy_ = first (LB.fromChunks . reverse) .:. secretBoxLazy_
|
||||
{-# INLINE sbProcessChunkLazy_ #-}
|
||||
|
||||
sbEncryptChunk :: SbState -> ByteString -> (ByteString, SbState)
|
||||
sbEncryptChunk (st, authSt) chunk =
|
||||
let (c, st') = XSalsa.combine st chunk
|
||||
authSt' = Poly1305.update authSt c
|
||||
let (!c, !st') = XSalsa.combine st chunk
|
||||
!authSt' = Poly1305.update authSt c
|
||||
in (c, (st', authSt'))
|
||||
|
||||
sbDecryptChunk :: SbState -> ByteString -> (ByteString, SbState)
|
||||
sbDecryptChunk (st, authSt) chunk =
|
||||
let (s, st') = XSalsa.combine st chunk
|
||||
authSt' = Poly1305.update authSt chunk
|
||||
let (!s, !st') = XSalsa.combine st chunk
|
||||
!authSt' = Poly1305.update authSt chunk
|
||||
in (s, (st', authSt'))
|
||||
|
||||
sbAuth :: SbState -> Poly1305.Auth
|
||||
|
||||
@@ -55,6 +55,10 @@ liftEitherError :: (MonadIO m, MonadError e' m) => (e -> e') -> IO (Either e a)
|
||||
liftEitherError f a = liftIOEither (first f <$> a)
|
||||
{-# INLINE liftEitherError #-}
|
||||
|
||||
liftEitherWith :: (MonadError e' m) => (e -> e') -> Either e a -> m a
|
||||
liftEitherWith f = liftEither . first f
|
||||
{-# INLINE liftEitherWith #-}
|
||||
|
||||
tryError :: MonadError e m => m a -> m (Either e a)
|
||||
tryError action = (Right <$> action) `catchError` (pure . Left)
|
||||
{-# INLINE tryError #-}
|
||||
|
||||
Reference in New Issue
Block a user