diff --git a/src/Simplex/FileTransfer/Client.hs b/src/Simplex/FileTransfer/Client.hs index 9489f52c1..1e5aa5bb7 100644 --- a/src/Simplex/FileTransfer/Client.hs +++ b/src/Simplex/FileTransfer/Client.hs @@ -13,7 +13,7 @@ import Control.Monad import Control.Monad.Except import Crypto.Random (ChaChaDRG) import Data.Bifunctor (first) -import Data.ByteString.Builder (Builder, byteString) +import Data.ByteString.Builder (Builder) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Int (Int64) @@ -138,7 +138,7 @@ sendXFTPCommand c@XFTPClient {http2Client = HTTP2Client {sessionId}} pKey fId cm xftpEncodeTransmission sessionId (Just pKey) ("", fId, FileCmd (sFileParty @p) cmd) sendXFTPTransmission c t chunkSpec_ -sendXFTPTransmission :: XFTPClient -> ByteString -> Maybe XFTPChunkSpec -> ExceptT XFTPClientError IO (FileResponse, HTTP2Body) +sendXFTPTransmission :: XFTPClient -> Builder -> Maybe XFTPChunkSpec -> ExceptT XFTPClientError IO (FileResponse, HTTP2Body) sendXFTPTransmission XFTPClient {config, http2Client = http2@HTTP2Client {sessionId}} t chunkSpec_ = do let req = H.requestStreaming N.methodPost "/" [] streamBody reqTimeout = (\XFTPChunkSpec {chunkSize} -> chunkTimeout config chunkSize) <$> chunkSpec_ @@ -154,7 +154,7 @@ sendXFTPTransmission XFTPClient {config, http2Client = http2@HTTP2Client {sessio where streamBody :: (Builder -> IO ()) -> IO () -> IO () streamBody send done = do - send $ byteString t + send t forM_ chunkSpec_ $ \XFTPChunkSpec {filePath, chunkOffset, chunkSize} -> withFile filePath ReadMode $ \h -> do hSeek h AbsoluteSeek $ fromIntegral chunkOffset diff --git a/src/Simplex/FileTransfer/Protocol.hs b/src/Simplex/FileTransfer/Protocol.hs index 58392d685..659db2201 100644 --- a/src/Simplex/FileTransfer/Protocol.hs +++ b/src/Simplex/FileTransfer/Protocol.hs @@ -17,6 +17,7 @@ import Control.Applicative ((<|>)) import qualified Data.Aeson.TH as J import qualified Data.Attoparsec.ByteString.Char8 as A import Data.Bifunctor (first) +import Data.ByteString.Builder (Builder) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Kind (Type) @@ -394,7 +395,7 @@ checkParty' c = case testEquality (sFileParty @p) (sFileParty @p') of Just Refl -> Just c _ -> Nothing -xftpEncodeTransmission :: ProtocolEncoding e c => SessionId -> Maybe C.APrivateSignKey -> Transmission c -> Either TransportError ByteString +xftpEncodeTransmission :: ProtocolEncoding e c => SessionId -> Maybe C.APrivateSignKey -> Transmission c -> Either TransportError Builder xftpEncodeTransmission sessionId pKey (corrId, fId, msg) = do let t = encodeTransmission currentXFTPVersion sessionId (corrId, fId, msg) xftpEncodeBatch1 $ signTransmission t @@ -403,10 +404,10 @@ xftpEncodeTransmission sessionId pKey (corrId, fId, msg) = do signTransmission t = ((`C.sign` t) <$> pKey, t) -- this function uses batch syntax but puts only one transmission in the batch -xftpEncodeBatch1 :: (Maybe C.ASignature, ByteString) -> Either TransportError ByteString +xftpEncodeBatch1 :: (Maybe C.ASignature, ByteString) -> Either TransportError Builder xftpEncodeBatch1 (sig, t) = - let t' = tEncodeBatch 1 . smpEncode . Large $ tEncode (sig, t) - in first (const TELargeMsg) $ C.pad t' xftpBlockSize + let t' = tEncodeBatch 1 . encodeLarge $ tEncode (sig, t) + in first (const TELargeMsg) $ C.pad' t' xftpBlockSize xftpDecodeTransmission :: ProtocolEncoding e c => SessionId -> ByteString -> Either XFTPErrorType (SignedTransmission e c) xftpDecodeTransmission sessionId t = do diff --git a/src/Simplex/FileTransfer/Server.hs b/src/Simplex/FileTransfer/Server.hs index f4b725462..661b4dc98 100644 --- a/src/Simplex/FileTransfer/Server.hs +++ b/src/Simplex/FileTransfer/Server.hs @@ -18,7 +18,6 @@ import Control.Monad.Except import Control.Monad.Reader import Data.Bifunctor (first) import qualified Data.ByteString.Base64.URL as B64 -import Data.ByteString.Builder (byteString) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Functor (($>)) @@ -242,7 +241,7 @@ processRequest HTTP2Request {sessionId, reqBody = body@HTTP2Body {bodyHead}, sen send "padding error" -- TODO respond with BLOCK error? done Right t -> do - send $ byteString t + send t -- timeout sending file in the same way as receiving forM_ serverFile_ $ \ServerFile {filePath, fileSize, sbState} -> do withFile filePath ReadMode $ \h -> sendEncFile h send sbState (fromIntegral fileSize) diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index bddce8f34..996e0e3c3 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -87,8 +87,10 @@ import Control.Monad import Control.Monad.IO.Class (liftIO) import Control.Monad.Trans.Except import qualified Data.Aeson.TH as J +import Data.ByteString.Builder (Builder, lazyByteString, toLazyByteString) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B +import qualified Data.ByteString.Lazy.Char8 as LB import Data.Functor (($>)) import Data.Int (Int64) import Data.List (find) @@ -136,7 +138,7 @@ data PClient err msg = PClient pingErrorCount :: TVar Int, clientCorrId :: TVar Natural, sentCommands :: TMap CorrId (Request err msg), - sndQ :: TBQueue ByteString, + sndQ :: TBQueue Builder, rcvQ :: TBQueue (NonEmpty (SignedTransmission err msg)), msgQ :: Maybe (TBQueue (ServerTransmission msg)) } @@ -668,40 +670,36 @@ sendBatch c@ProtocolClient {client_ = PClient {sndQ}} b = do (: []) <$> getResponse c r data ClientBatch err msg - = -- ByteString in CBTransmissions does not include count byte, it is added by tEncodeBatch - CBTransmissions ByteString Int [Request err msg] - | CBTransmission ByteString (Request err msg) + = -- Builder in CBTransmissions does not include count byte, it is added by tEncodeBatch + CBTransmissions Builder Int [Request err msg] + | CBTransmission Builder (Request err msg) | CBLargeTransmission (Request err msg) --- | encodes and batches transmissions into blocks batchClientTransmissions :: forall err msg. Bool -> Int -> NonEmpty (PCTransmission err msg) -> [ClientBatch err msg] -batchClientTransmissions batch blkSize - | batch = reverse . mkBatch [] - | otherwise = map mkBatch1 . L.toList +batchClientTransmissions batch blkSize ts + | batch = + let (bs, b, _, n, rs) = foldr addToBatch ([], mempty, 0, 0, []) ts + in if n == 0 then bs else CBTransmissions b n rs : bs + | otherwise = map mkBatch1 $ L.toList ts where - mkBatch :: [ClientBatch err msg] -> NonEmpty (PCTransmission err msg) -> [ClientBatch err msg] - mkBatch bs ts = - let (b, ts_) = encodeBatch "" 0 [] ts - bs' = b : bs - in maybe bs' (mkBatch bs') ts_ mkBatch1 :: PCTransmission err msg -> ClientBatch err msg mkBatch1 (t, r) - | B.length s <= blkSize - 2 = CBTransmission s r + -- 2 bytes are reserved for pad size + | LB.length s <= fromIntegral (blkSize - 2) = CBTransmission (lazyByteString s) r | otherwise = CBLargeTransmission r where s = tEncode t - encodeBatch :: ByteString -> Int -> [Request err msg] -> NonEmpty (PCTransmission err msg) -> (ClientBatch err msg, Maybe (NonEmpty (PCTransmission err msg))) - encodeBatch s n rs ts@((t, r) :| ts_) - | B.length s' <= blkSize - 3 && n < 255 = - case L.nonEmpty ts_ of - Just ts' -> encodeBatch s' n' rs' ts' - Nothing -> (CBTransmissions s' n' (reverse rs'), Nothing) - | n == 0 = (CBLargeTransmission r, L.nonEmpty ts_) - | otherwise = (CBTransmissions s n (reverse rs), Just ts) + addToBatch :: PCTransmission err msg -> ([ClientBatch err msg], Builder, Int, Int, [Request err msg]) -> ([ClientBatch err msg], Builder, Int, Int, [Request err msg]) + addToBatch (t, r) (bs, b, len, n, rs) + | len' <= blkSize - 3 && n < 255 = (bs, s <> b, len', 1 + n, r : rs) + | sLen <= blkSize - 3 = (bs', s, sLen, 1, [r]) + | otherwise = (CBLargeTransmission r : (if n == 0 then bs else bs'), mempty, 0, 0, []) where - s' = s <> smpEncode (Large $ tEncode t) - n' = n + 1 - rs' = r : rs + s = encodeLarge s' + sLen = 2 + (fromIntegral $ LB.length s') + s' = tEncode t + len' = sLen + len + bs' = CBTransmissions b n rs : bs -- | Send Protocol command sendProtocolCommand :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> Maybe C.APrivateSignKey -> EntityId -> ProtoCommand msg -> ExceptT (ProtocolClientError err) IO msg @@ -711,12 +709,12 @@ sendProtocolCommand c@ProtocolClient {client_ = PClient {sndQ}, batch, blockSize -- two separate "atomically" needed to avoid blocking sendRecv :: SentRawTransmission -> Request err msg -> IO (Either (ProtocolClientError err) msg) sendRecv t r - | B.length s > blockSize - 2 = pure $ Left $ PCETransportError TELargeMsg + | LB.length (toLazyByteString s) > fromIntegral (blockSize - 2) = pure $ Left $ PCETransportError TELargeMsg | otherwise = atomically (writeTBQueue sndQ s) >> response <$> getResponse c r where s - | batch = tEncodeBatch 1 . smpEncode . Large $ tEncode t - | otherwise = tEncode t + | batch = tEncodeBatch 1 . encodeLarge $ tEncode t + | otherwise = lazyByteString $ tEncode t -- TODO switch to timeout or TimeManager that supports Int64 getResponse :: ProtocolClient err msg -> Request err msg -> IO (Response err msg) diff --git a/src/Simplex/Messaging/Crypto.hs b/src/Simplex/Messaging/Crypto.hs index 3564454ff..506e8e7cd 100644 --- a/src/Simplex/Messaging/Crypto.hs +++ b/src/Simplex/Messaging/Crypto.hs @@ -140,6 +140,7 @@ module Simplex.Messaging.Crypto -- * Message padding / un-padding pad, + pad', unPad, -- * X509 Certificates @@ -190,8 +191,10 @@ import Data.ByteArray (ByteArrayAccess) import qualified Data.ByteArray as BA import Data.ByteString.Base64 (decode, encode) import qualified Data.ByteString.Base64.URL as U +import Data.ByteString.Builder (Builder, byteString, toLazyByteString, word16BE) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B +import qualified Data.ByteString.Lazy.Char8 as LB import Data.ByteString.Lazy (fromStrict, toStrict) import Data.Constraint (Dict (..)) import Data.Kind (Constraint, Type) @@ -919,6 +922,14 @@ pad msg paddedLen len = B.length msg padLen = paddedLen - len - 2 +pad' :: Builder -> Int -> Either CryptoError Builder +pad' msg paddedLen + | len <= maxMsgLen && padLen >= 0 = Right $ word16BE (fromIntegral len) <> msg <> byteString (B.replicate padLen '#') + | otherwise = Left CryptoLargeMsgError + where + len = fromIntegral $ LB.length $ toLazyByteString msg + padLen = paddedLen - len - 2 + unPad :: ByteString -> Either CryptoError ByteString unPad padded | B.length lenWrd == 2 && B.length rest >= len = Right $ B.take len rest diff --git a/src/Simplex/Messaging/Crypto/Lazy.hs b/src/Simplex/Messaging/Crypto/Lazy.hs index ebb3692c1..c83d93a07 100644 --- a/src/Simplex/Messaging/Crypto/Lazy.hs +++ b/src/Simplex/Messaging/Crypto/Lazy.hs @@ -41,7 +41,6 @@ import qualified Crypto.MAC.Poly1305 as Poly1305 import Data.Bifunctor (first) import Data.ByteArray (ByteArrayAccess) import qualified Data.ByteArray as BA -import qualified Data.ByteString as S import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import qualified Data.ByteString.Lazy.Char8 as LB @@ -175,7 +174,7 @@ secretBoxTailTag sbProcess secret nonce msg = run <$> sbInit_ secret nonce -- passes lazy bytestring via initialized secret box returning the reversed list of chunks secretBoxLazy_ :: (SbState -> ByteString -> (ByteString, SbState)) -> SbState -> LazyByteString -> ([ByteString], SbState) -secretBoxLazy_ sbProcess state = foldlChunks update ([], state) +secretBoxLazy_ sbProcess state = LB.foldlChunks update ([], state) where update (cs, st) chunk = let (!c, !st') = sbProcess st chunk in (c : cs, st') @@ -231,10 +230,3 @@ cryptoPassed :: CE.CryptoFailable b -> Either CryptoError b cryptoPassed = \case CE.CryptoPassed a -> Right a CE.CryptoFailed e -> Left $ CryptoPoly1305Error e - -foldlChunks :: (a -> S.ByteString -> a) -> a -> LazyByteString -> a -foldlChunks f = go - where - go !a LB.Empty = a - go !a (LB.Chunk c cs) = go (f a c) cs -{-# INLINE foldlChunks #-} diff --git a/src/Simplex/Messaging/Encoding.hs b/src/Simplex/Messaging/Encoding.hs index 814a536c4..2d9e822a1 100644 --- a/src/Simplex/Messaging/Encoding.hs +++ b/src/Simplex/Messaging/Encoding.hs @@ -11,6 +11,7 @@ module Simplex.Messaging.Encoding ( Encoding (..), Tail (..), Large (..), + encodeLarge, _smpP, smpEncodeList, smpListP, @@ -21,8 +22,10 @@ where import Data.Attoparsec.ByteString.Char8 (Parser) import qualified Data.Attoparsec.ByteString.Char8 as A import Data.Bits (shiftL, shiftR, (.|.)) +import Data.ByteString.Builder (Builder, lazyByteString, word16BE) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B +import qualified Data.ByteString.Lazy.Char8 as LB import Data.ByteString.Internal (c2w, w2c) import Data.Int (Int64) import qualified Data.List.NonEmpty as L @@ -138,6 +141,10 @@ instance Encoding Large where Large <$> A.take len {-# INLINE smpP #-} +encodeLarge :: LB.ByteString -> Builder +encodeLarge s = word16BE (fromIntegral $ LB.length s) <> lazyByteString s +{-# INLINE encodeLarge #-} + instance Encoding SystemTime where smpEncode = smpEncode . systemSeconds {-# INLINE smpEncode #-} diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index e65385cba..8afd44e56 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -160,8 +160,11 @@ import Data.Aeson (FromJSON (..), ToJSON (..)) import qualified Data.Aeson.TH as J import Data.Attoparsec.ByteString.Char8 (Parser, ()) import qualified Data.Attoparsec.ByteString.Char8 as A +import Data.ByteString.Builder (Builder, char8, lazyByteString) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B +import qualified Data.ByteString.Lazy.Char8 as LB +import qualified Data.ByteString.Lazy.Internal as LB import Data.Char (isPrint, isSpace) import Data.Constraint (Dict (..)) import Data.Functor (($>)) @@ -1292,7 +1295,7 @@ tPut th delay_ = fmap concat . mapM tPutBatch . batchTransmissions (batch th) (b TBTransmissions n s -> replicate n <$> (tPutLog th (tEncodeBatch n s) <* mapM_ threadDelay delay_) TBTransmission s -> (: []) <$> tPutLog th s -tPutLog :: Transport c => THandle c -> ByteString -> IO (Either TransportError ()) +tPutLog :: Transport c => THandle c -> Builder -> IO (Either TransportError ()) tPutLog th s = do r <- tPutBlock th s case r of @@ -1301,42 +1304,38 @@ tPutLog th s = do pure r -- ByteString does not include length byte, it is added by tEncodeBatch -data TransportBatch = TBTransmissions Int ByteString | TBTransmission ByteString | TBLargeTransmission +data TransportBatch = TBTransmissions Int Builder | TBTransmission Builder | TBLargeTransmission -- | encodes and batches transmissions into blocks, batchTransmissions :: Bool -> Int -> NonEmpty SentRawTransmission -> [TransportBatch] -batchTransmissions batch bSize - | batch = reverse . mkBatch [] . L.map tEncode - | otherwise = map (mkBatch1 . tEncode) . L.toList +batchTransmissions batch bSize ts + | batch = + let (bs, b, _, n) = foldr addToBatch ([], mempty, 0, 0) ts + in if n == 0 then bs else TBTransmissions n b : bs + | otherwise = map (mkBatch1 . tEncode) (L.toList ts) where - mkBatch :: [TransportBatch] -> NonEmpty ByteString -> [TransportBatch] - mkBatch rs ts = - let (n, s, ts_) = encodeBatch 0 "" ts - r = if n == 0 then TBLargeTransmission else TBTransmissions n s - rs' = r : rs - in case ts_ of - Just ts' -> mkBatch rs' ts' - _ -> rs' - mkBatch1 :: ByteString -> TransportBatch - mkBatch1 s = if B.length s > bSize - 2 then TBLargeTransmission else TBTransmission s - encodeBatch :: Int -> ByteString -> NonEmpty ByteString -> (Int, ByteString, Maybe (NonEmpty ByteString)) - encodeBatch n s ts@(t :| ts_) - | n == 255 = (n, s, Just ts) - | otherwise = - let s' = s <> smpEncode (Large t) - n' = n + 1 - in if B.length s' > bSize - 3 -- one byte is reserved for the number of messages in the batch - then (n,s,) $ if n == 0 then L.nonEmpty ts_ else Just ts - else case L.nonEmpty ts_ of - Just ts' -> encodeBatch n' s' ts' - _ -> (n', s', Nothing) + mkBatch1 :: LB.ByteString -> TransportBatch + mkBatch1 s + | LB.length s > fromIntegral (bSize - 2) = TBLargeTransmission + | otherwise = TBTransmission $ lazyByteString s + addToBatch :: SentRawTransmission -> ([TransportBatch], Builder, Int, Int) -> ([TransportBatch], Builder, Int, Int) + addToBatch t (bs, b, len, n) + | len' <= bSize - 3 && n < 255 = (bs, s <> b, len', 1 + n) + | sLen <= bSize - 3 = (bs', s, sLen, 1) + | otherwise = (TBLargeTransmission : (if n == 0 then bs else bs'), mempty, 0, 0) + where + s = encodeLarge s' + sLen = 2 + fromIntegral (LB.length s') + s' = tEncode t + len' = sLen + len + bs' = TBTransmissions n b : bs -tEncode :: SentRawTransmission -> ByteString -tEncode (sig, t) = smpEncode (C.signatureBytes sig) <> t +tEncode :: SentRawTransmission -> LB.ByteString +tEncode (sig, t) = LB.chunk (smpEncode $ C.signatureBytes sig) (LB.fromStrict t) {-# INLINE tEncode #-} -tEncodeBatch :: Int -> ByteString -> ByteString -tEncodeBatch n s = lenEncode n `B.cons` s +tEncodeBatch :: Int -> Builder -> Builder +tEncodeBatch n s = char8 (lenEncode n) <> s {-# INLINE tEncodeBatch #-} encodeTransmission :: ProtocolEncoding e c => Version -> ByteString -> Transmission c -> ByteString diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index 6509c1f6f..a1ff259ad 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -67,9 +67,10 @@ import qualified Data.Aeson.TH as J import Data.Attoparsec.ByteString.Char8 (Parser) import Data.Bifunctor (first) import Data.Bitraversable (bimapM) +import Data.ByteString.Builder (Builder, byteString, toLazyByteString) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B -import qualified Data.ByteString.Lazy as BL +import qualified Data.ByteString.Lazy.Char8 as LB import Data.Default (def) import Data.Functor (($>)) import Data.Version (showVersion) @@ -134,6 +135,9 @@ class Transport c where -- | Write bytes to connection cPut :: c -> ByteString -> IO () + -- | Write bytes to connection + cPut' :: c -> LB.ByteString -> IO () + -- | Receive ByteString from connection, allowing LF or CRLF termination. getLn :: c -> IO ByteString @@ -217,8 +221,11 @@ instance Transport TLS where getBuffered tlsBuffer n t_ (T.recvData tlsContext) cPut :: TLS -> ByteString -> IO () - cPut TLS {tlsContext, tlsTransportConfig = TransportConfig {transportTimeout = t_}} s = - withTimedErr t_ . T.sendData tlsContext $ BL.fromStrict s + cPut cxt = cPut' cxt . LB.fromStrict + + cPut' :: TLS -> LB.ByteString -> IO () + cPut' TLS {tlsContext, tlsTransportConfig = TransportConfig {transportTimeout = t_}} s = + withTimedErr t_ $ T.sendData tlsContext s getLn :: TLS -> IO ByteString getLn TLS {tlsContext, tlsBuffer} = do @@ -309,10 +316,10 @@ serializeTransportError = \case TEHandshake e -> "HANDSHAKE " <> bshow e -- | Pad and send block to SMP transport. -tPutBlock :: Transport c => THandle c -> ByteString -> IO (Either TransportError ()) +tPutBlock :: Transport c => THandle c -> Builder -> IO (Either TransportError ()) tPutBlock THandle {connection = c, blockSize} block = - bimapM (const $ pure TELargeMsg) (cPut c) $ - C.pad block blockSize + bimapM (const $ pure TELargeMsg) (cPut' c . toLazyByteString) $ + C.pad' block blockSize -- | Receive block from SMP transport. tGetBlock :: Transport c => THandle c -> IO (Either TransportError ByteString) @@ -356,7 +363,7 @@ smpThHandle :: forall c. THandle c -> Version -> THandle c smpThHandle th v = (th :: THandle c) {thVersion = v, batch = v >= 4} sendHandshake :: (Transport c, Encoding smp) => THandle c -> smp -> ExceptT TransportError IO () -sendHandshake th = ExceptT . tPutBlock th . smpEncode +sendHandshake th = ExceptT . tPutBlock th . byteString . smpEncode getHandshake :: (Transport c, Encoding smp) => THandle c -> ExceptT TransportError IO smp getHandshake th = ExceptT $ (parse smpP (TEHandshake PARSE) =<<) <$> tGetBlock th diff --git a/src/Simplex/Messaging/Transport/WebSockets.hs b/src/Simplex/Messaging/Transport/WebSockets.hs index ae78da1fe..486ae1f20 100644 --- a/src/Simplex/Messaging/Transport/WebSockets.hs +++ b/src/Simplex/Messaging/Transport/WebSockets.hs @@ -7,7 +7,7 @@ module Simplex.Messaging.Transport.WebSockets (WS (..)) where import qualified Control.Exception as E import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B -import qualified Data.ByteString.Lazy as BL +import qualified Data.ByteString.Lazy as LB import qualified Network.TLS as T import Network.WebSockets import Network.WebSockets.Stream (Stream) @@ -72,6 +72,9 @@ instance Transport WS where cPut :: WS -> ByteString -> IO () cPut = sendBinaryData . wsConnection + cPut' :: WS -> LB.ByteString -> IO () + cPut' = sendBinaryData . wsConnection + getLn :: WS -> IO ByteString getLn c = do s <- trimCR <$> receiveData (wsConnection c) @@ -101,5 +104,5 @@ makeTLSContextStream cxt = (Just <$> T.recvData cxt) `E.catch` \case T.Error_EOF -> pure Nothing e -> E.throwIO e - writeStream :: Maybe BL.ByteString -> IO () + writeStream :: Maybe LB.ByteString -> IO () writeStream = maybe (closeTLS cxt) (T.sendData cxt) diff --git a/tests/CoreTests/BatchingTests.hs b/tests/CoreTests/BatchingTests.hs index fea45d4d5..78a1079dd 100644 --- a/tests/CoreTests/BatchingTests.hs +++ b/tests/CoreTests/BatchingTests.hs @@ -4,8 +4,9 @@ module CoreTests.BatchingTests (batchingTests) where import Control.Concurrent.STM import Control.Monad +import Data.ByteString.Builder (Builder, toLazyByteString) import Data.ByteString.Char8 (ByteString) -import qualified Data.ByteString.Char8 as B +import qualified Data.ByteString.Lazy.Char8 as LB import qualified Data.List.NonEmpty as L import Simplex.Messaging.Client import qualified Simplex.Messaging.Crypto as C @@ -35,7 +36,7 @@ testBatchSubscriptions = do let batches = batchTransmissions True smpBlockSize $ L.fromList subs length batches `shouldBe` 3 [TBTransmissions n1 s1, TBTransmissions n2 s2, TBTransmissions n3 s3] <- pure batches - (n1, n2, n3) `shouldBe` (90, 90, 20) + (n1, n2, n3) `shouldBe` (20, 90, 90) all lenOk [s1, s2, s3] `shouldBe` True testBatchWithMessage :: IO () @@ -51,7 +52,7 @@ testBatchWithMessage = do let batches = batchTransmissions True smpBlockSize $ L.fromList cmds length batches `shouldBe` 2 [TBTransmissions n1 s1, TBTransmissions n2 s2] <- pure batches - (n1, n2) `shouldBe` (60, 41) + (n1, n2) `shouldBe` (55, 46) all lenOk [s1, s2] `shouldBe` True testBatchWithLargeMessage :: IO () @@ -70,7 +71,7 @@ testBatchWithLargeMessage = do let batches = batchTransmissions True smpBlockSize $ L.fromList cmds length batches `shouldBe` 4 [TBTransmissions n1 s1, TBLargeTransmission, TBTransmissions n2 s2, TBTransmissions n3 s3] <- pure batches - (n1, n2, n3) `shouldBe` (60, 90, 10) + (n1, n2, n3) `shouldBe` (60, 10, 90) all lenOk [s1, s2, s3] `shouldBe` True testClientBatchSubscriptions :: IO () @@ -83,8 +84,8 @@ testClientBatchSubscriptions = do let batches = batchClientTransmissions True smpBlockSize $ L.fromList subs length batches `shouldBe` 3 [CBTransmissions s1 n1 rs1, CBTransmissions s2 n2 rs2, CBTransmissions s3 n3 rs3] <- pure batches - (n1, n2, n3) `shouldBe` (90, 90, 20) - (length rs1, length rs2, length rs3) `shouldBe` (90, 90, 20) + (n1, n2, n3) `shouldBe` (20, 90, 90) + (length rs1, length rs2, length rs3) `shouldBe` (20, 90, 90) all lenOk [s1, s2, s3] `shouldBe` True testClientBatchWithMessage :: IO () @@ -101,8 +102,8 @@ testClientBatchWithMessage = do let batches = batchClientTransmissions True smpBlockSize $ L.fromList cmds length batches `shouldBe` 2 [CBTransmissions s1 n1 rs1, CBTransmissions s2 n2 rs2] <- pure batches - (n1, n2) `shouldBe` (60, 41) - (length rs1, length rs2) `shouldBe` (60, 41) + (n1, n2) `shouldBe` (55, 46) + (length rs1, length rs2) `shouldBe` (55, 46) all lenOk [s1, s2] `shouldBe` True testClientBatchWithLargeMessage :: IO () @@ -123,16 +124,16 @@ testClientBatchWithLargeMessage = do let batches = batchClientTransmissions True smpBlockSize $ L.fromList cmds length batches `shouldBe` 4 [CBTransmissions s1 n1 rs1, CBLargeTransmission _, CBTransmissions s2 n2 rs2, CBTransmissions s3 n3 rs3] <- pure batches - (n1, n2, n3) `shouldBe` (60, 90, 10) - (length rs1, length rs2, length rs3) `shouldBe` (60, 90, 10) + (n1, n2, n3) `shouldBe` (60, 10, 90) + (length rs1, length rs2, length rs3) `shouldBe` (60, 10, 90) all lenOk [s1, s2, s3] `shouldBe` True -- let cmds' = [send] <> subs1 <> subs2 let batches' = batchClientTransmissions True smpBlockSize $ L.fromList cmds' length batches' `shouldBe` 3 [CBLargeTransmission _, CBTransmissions s1' n1' rs1', CBTransmissions s2' n2' rs2'] <- pure batches' - (n1', n2') `shouldBe` (90, 70) - (length rs1', length rs2') `shouldBe` (90, 70) + (n1', n2') `shouldBe` (70, 90) + (length rs1', length rs2') `shouldBe` (70, 90) all lenOk [s1', s2'] `shouldBe` True randomSUB :: ByteString -> IO (Maybe C.ASignature, ByteString) @@ -169,8 +170,10 @@ randomSENDCmd c len = do msg <- atomically $ C.randomBytes len g mkTransmission c (Just rpKey, sId, Cmd SSender $ SEND noMsgFlags msg) -lenOk :: ByteString -> Bool -lenOk s = 0 < B.length s && B.length s <= smpBlockSize - 2 +lenOk :: Builder -> Bool +lenOk s = 0 < len && len <= smpBlockSize - 2 + where + len = fromIntegral . LB.length $ toLazyByteString s lenOk1 :: TransportBatch -> Bool lenOk1 = \case