diff --git a/simplexmq.cabal b/simplexmq.cabal index b7160fb05..b4923024c 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -103,6 +103,7 @@ library Simplex.Messaging.Agent.Store.SQLite.Migrations.M20231225_failed_work_items Simplex.Messaging.Agent.TAsyncs Simplex.Messaging.Agent.TRcvQueues + Simplex.Messaging.Builder Simplex.Messaging.Client Simplex.Messaging.Client.Agent Simplex.Messaging.Crypto diff --git a/src/Simplex/FileTransfer/Client.hs b/src/Simplex/FileTransfer/Client.hs index 9489f52c1..69896408c 100644 --- a/src/Simplex/FileTransfer/Client.hs +++ b/src/Simplex/FileTransfer/Client.hs @@ -13,7 +13,7 @@ import Control.Monad import Control.Monad.Except import Crypto.Random (ChaChaDRG) import Data.Bifunctor (first) -import Data.ByteString.Builder (Builder, byteString) +import qualified Data.ByteString.Builder as BB import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Int (Int64) @@ -25,6 +25,7 @@ import qualified Network.HTTP2.Client as H import Simplex.FileTransfer.Description (mb) import Simplex.FileTransfer.Protocol import Simplex.FileTransfer.Transport +import Simplex.Messaging.Builder (Builder, builder) import Simplex.Messaging.Client ( NetworkConfig (..), ProtocolClientError (..), @@ -138,7 +139,7 @@ sendXFTPCommand c@XFTPClient {http2Client = HTTP2Client {sessionId}} pKey fId cm xftpEncodeTransmission sessionId (Just pKey) ("", fId, FileCmd (sFileParty @p) cmd) sendXFTPTransmission c t chunkSpec_ -sendXFTPTransmission :: XFTPClient -> ByteString -> Maybe XFTPChunkSpec -> ExceptT XFTPClientError IO (FileResponse, HTTP2Body) +sendXFTPTransmission :: XFTPClient -> Builder -> Maybe XFTPChunkSpec -> ExceptT XFTPClientError IO (FileResponse, HTTP2Body) sendXFTPTransmission XFTPClient {config, http2Client = http2@HTTP2Client {sessionId}} t chunkSpec_ = do let req = H.requestStreaming N.methodPost "/" [] streamBody reqTimeout = (\XFTPChunkSpec {chunkSize} -> chunkTimeout config chunkSize) <$> chunkSpec_ @@ -152,9 +153,9 @@ sendXFTPTransmission XFTPClient {config, http2Client = http2@HTTP2Client {sessio _ -> pure (r, body) Left e -> throwError $ PCEResponseError e where - streamBody :: (Builder -> IO ()) -> IO () -> IO () + streamBody :: (BB.Builder -> IO ()) -> IO () -> IO () streamBody send done = do - send $ byteString t + send $ builder t forM_ chunkSpec_ $ \XFTPChunkSpec {filePath, chunkOffset, chunkSize} -> withFile filePath ReadMode $ \h -> do hSeek h AbsoluteSeek $ fromIntegral chunkOffset diff --git a/src/Simplex/FileTransfer/Protocol.hs b/src/Simplex/FileTransfer/Protocol.hs index 58392d685..19d458107 100644 --- a/src/Simplex/FileTransfer/Protocol.hs +++ b/src/Simplex/FileTransfer/Protocol.hs @@ -24,6 +24,7 @@ import Data.List.NonEmpty (NonEmpty (..)) import Data.Maybe (isNothing) import Data.Type.Equality import Data.Word (Word32) +import Simplex.Messaging.Builder (Builder) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String @@ -394,7 +395,7 @@ checkParty' c = case testEquality (sFileParty @p) (sFileParty @p') of Just Refl -> Just c _ -> Nothing -xftpEncodeTransmission :: ProtocolEncoding e c => SessionId -> Maybe C.APrivateSignKey -> Transmission c -> Either TransportError ByteString +xftpEncodeTransmission :: ProtocolEncoding e c => SessionId -> Maybe C.APrivateSignKey -> Transmission c -> Either TransportError Builder xftpEncodeTransmission sessionId pKey (corrId, fId, msg) = do let t = encodeTransmission currentXFTPVersion sessionId (corrId, fId, msg) xftpEncodeBatch1 $ signTransmission t @@ -403,10 +404,10 @@ xftpEncodeTransmission sessionId pKey (corrId, fId, msg) = do signTransmission t = ((`C.sign` t) <$> pKey, t) -- this function uses batch syntax but puts only one transmission in the batch -xftpEncodeBatch1 :: (Maybe C.ASignature, ByteString) -> Either TransportError ByteString +xftpEncodeBatch1 :: (Maybe C.ASignature, ByteString) -> Either TransportError Builder xftpEncodeBatch1 (sig, t) = - let t' = tEncodeBatch 1 . smpEncode . Large $ tEncode (sig, t) - in first (const TELargeMsg) $ C.pad t' xftpBlockSize + let t' = tEncodeBatch 1 . encodeLarge $ tEncode (sig, t) + in first (const TELargeMsg) $ C.pad' t' xftpBlockSize xftpDecodeTransmission :: ProtocolEncoding e c => SessionId -> ByteString -> Either XFTPErrorType (SignedTransmission e c) xftpDecodeTransmission sessionId t = do diff --git a/src/Simplex/FileTransfer/Server.hs b/src/Simplex/FileTransfer/Server.hs index f4b725462..c1fc4be41 100644 --- a/src/Simplex/FileTransfer/Server.hs +++ b/src/Simplex/FileTransfer/Server.hs @@ -18,7 +18,6 @@ import Control.Monad.Except import Control.Monad.Reader import Data.Bifunctor (first) import qualified Data.ByteString.Base64.URL as B64 -import Data.ByteString.Builder (byteString) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Functor (($>)) @@ -44,6 +43,7 @@ import Simplex.FileTransfer.Server.Stats import Simplex.FileTransfer.Server.Store import Simplex.FileTransfer.Server.StoreLog import Simplex.FileTransfer.Transport +import Simplex.Messaging.Builder (builder) import qualified Simplex.Messaging.Crypto as C import qualified Simplex.Messaging.Crypto.Lazy as LC import Simplex.Messaging.Encoding.String @@ -242,7 +242,7 @@ processRequest HTTP2Request {sessionId, reqBody = body@HTTP2Body {bodyHead}, sen send "padding error" -- TODO respond with BLOCK error? done Right t -> do - send $ byteString t + send $ builder t -- timeout sending file in the same way as receiving forM_ serverFile_ $ \ServerFile {filePath, fileSize, sbState} -> do withFile filePath ReadMode $ \h -> sendEncFile h send sbState (fromIntegral fileSize) diff --git a/src/Simplex/Messaging/Builder.hs b/src/Simplex/Messaging/Builder.hs new file mode 100644 index 000000000..31e94bf7d --- /dev/null +++ b/src/Simplex/Messaging/Builder.hs @@ -0,0 +1,52 @@ +{-# LANGUAGE StrictData #-} + +module Simplex.Messaging.Builder + ( Builder (length, builder), + byteString, + lazyByteString, + word16BE, + char8, + toLazyByteString, + ) +where + +import qualified Data.ByteString as B +import qualified Data.ByteString.Builder as BB +import qualified Data.ByteString.Lazy as LB +import Data.Word (Word16) + + +-- length-aware builder +data Builder = Builder {length :: Int, builder :: BB.Builder} + +instance Semigroup Builder where + Builder l1 b1 <> Builder l2 b2 = Builder (l1 + l2) (b1 <> b2) + {-# INLINE (<>) #-} + +instance Monoid Builder where + mempty = Builder 0 mempty + {-# INLINE mempty #-} + mconcat bs = Builder (sum ls) (mconcat bbs) + where + (ls, bbs) = foldr (\(Builder l b) ~(ls', bbs') -> (l : ls', b : bbs')) ([], []) bs + {-# INLINE mconcat #-} + +byteString :: B.ByteString -> Builder +byteString s = Builder (B.length s) (BB.byteString s) +{-# INLINE byteString #-} + +lazyByteString :: LB.ByteString -> Builder +lazyByteString s = Builder (fromIntegral $ LB.length s) (BB.lazyByteString s) +{-# INLINE lazyByteString #-} + +word16BE :: Word16 -> Builder +word16BE = Builder 2 . BB.word16BE +{-# INLINE word16BE #-} + +char8 :: Char -> Builder +char8 = Builder 1 . BB.char8 +{-# INLINE char8 #-} + +toLazyByteString :: Builder -> LB.ByteString +toLazyByteString = BB.toLazyByteString . builder +{-# INLINE toLazyByteString #-} diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index bddce8f34..1d941ca08 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -72,9 +72,7 @@ module Simplex.Messaging.Client ClientCommand, -- * For testing - ClientBatch (..), PCTransmission, - batchClientTransmissions, mkTransmission, clientStub, ) @@ -98,11 +96,13 @@ import Data.Maybe (fromMaybe) import Data.Time.Clock (UTCTime, getCurrentTime) import Network.Socket (ServiceName) import Numeric.Natural +import Simplex.Messaging.Builder (Builder) +import qualified Simplex.Messaging.Builder as BB import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (defaultJSON, dropPrefix, enumJSON) -import Simplex.Messaging.Protocol as SMP +import Simplex.Messaging.Protocol import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport @@ -136,7 +136,7 @@ data PClient err msg = PClient pingErrorCount :: TVar Int, clientCorrId :: TVar Natural, sentCommands :: TMap CorrId (Request err msg), - sndQ :: TBQueue ByteString, + sndQ :: TBQueue Builder, rcvQ :: TBQueue (NonEmpty (SignedTransmission err msg)), msgQ :: Maybe (TBQueue (ServerTransmission msg)) } @@ -173,7 +173,7 @@ clientStub sessionId = do } } -type SMPClient = ProtocolClient ErrorType SMP.BrokerMsg +type SMPClient = ProtocolClient ErrorType BrokerMsg -- | Type for client command data type ClientCommand msg = (Maybe C.APrivateSignKey, EntityId, ProtoCommand msg) @@ -634,7 +634,7 @@ type PCTransmission err msg = (SentRawTransmission, Request err msg) -- | Send multiple commands with batching and collect responses sendProtocolCommands :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> NonEmpty (ClientCommand msg) -> IO (NonEmpty (Response err msg)) sendProtocolCommands c@ProtocolClient {batch, blockSize} cs = do - bs <- batchClientTransmissions batch blockSize <$> mapM (mkTransmission c) cs + bs <- batchTransmissions' batch blockSize <$> mapM (mkTransmission c) cs validate . concat =<< mapM (sendBatch c) bs where validate :: [Response err msg] -> IO (NonEmpty (Response err msg)) @@ -651,58 +651,22 @@ sendProtocolCommands c@ProtocolClient {batch, blockSize} cs = do streamProtocolCommands :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> NonEmpty (ClientCommand msg) -> ([Response err msg] -> IO ()) -> IO () streamProtocolCommands c@ProtocolClient {batch, blockSize} cs cb = do - bs <- batchClientTransmissions batch blockSize <$> mapM (mkTransmission c) cs + bs <- batchTransmissions' batch blockSize <$> mapM (mkTransmission c) cs mapM_ (cb <=< sendBatch c) bs -sendBatch :: ProtocolClient err msg -> ClientBatch err msg -> IO [Response err msg] +sendBatch :: ProtocolClient err msg -> TransportBatch (Request err msg) -> IO [Response err msg] sendBatch c@ProtocolClient {client_ = PClient {sndQ}} b = do case b of - CBLargeTransmission Request {entityId} -> do + TBLargeTransmission Request {entityId} -> do putStrLn "send error: large message" pure [Response entityId $ Left $ PCETransportError TELargeMsg] - CBTransmissions s n rs -> do + TBTransmissions s n rs -> do when (n > 0) $ atomically $ writeTBQueue sndQ $ tEncodeBatch n s mapConcurrently (getResponse c) rs - CBTransmission s r -> do + TBTransmission s r -> do atomically $ writeTBQueue sndQ s (: []) <$> getResponse c r -data ClientBatch err msg - = -- ByteString in CBTransmissions does not include count byte, it is added by tEncodeBatch - CBTransmissions ByteString Int [Request err msg] - | CBTransmission ByteString (Request err msg) - | CBLargeTransmission (Request err msg) - --- | encodes and batches transmissions into blocks -batchClientTransmissions :: forall err msg. Bool -> Int -> NonEmpty (PCTransmission err msg) -> [ClientBatch err msg] -batchClientTransmissions batch blkSize - | batch = reverse . mkBatch [] - | otherwise = map mkBatch1 . L.toList - where - mkBatch :: [ClientBatch err msg] -> NonEmpty (PCTransmission err msg) -> [ClientBatch err msg] - mkBatch bs ts = - let (b, ts_) = encodeBatch "" 0 [] ts - bs' = b : bs - in maybe bs' (mkBatch bs') ts_ - mkBatch1 :: PCTransmission err msg -> ClientBatch err msg - mkBatch1 (t, r) - | B.length s <= blkSize - 2 = CBTransmission s r - | otherwise = CBLargeTransmission r - where - s = tEncode t - encodeBatch :: ByteString -> Int -> [Request err msg] -> NonEmpty (PCTransmission err msg) -> (ClientBatch err msg, Maybe (NonEmpty (PCTransmission err msg))) - encodeBatch s n rs ts@((t, r) :| ts_) - | B.length s' <= blkSize - 3 && n < 255 = - case L.nonEmpty ts_ of - Just ts' -> encodeBatch s' n' rs' ts' - Nothing -> (CBTransmissions s' n' (reverse rs'), Nothing) - | n == 0 = (CBLargeTransmission r, L.nonEmpty ts_) - | otherwise = (CBTransmissions s n (reverse rs), Just ts) - where - s' = s <> smpEncode (Large $ tEncode t) - n' = n + 1 - rs' = r : rs - -- | Send Protocol command sendProtocolCommand :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> Maybe C.APrivateSignKey -> EntityId -> ProtoCommand msg -> ExceptT (ProtocolClientError err) IO msg sendProtocolCommand c@ProtocolClient {client_ = PClient {sndQ}, batch, blockSize} pKey entId cmd = @@ -711,11 +675,11 @@ sendProtocolCommand c@ProtocolClient {client_ = PClient {sndQ}, batch, blockSize -- two separate "atomically" needed to avoid blocking sendRecv :: SentRawTransmission -> Request err msg -> IO (Either (ProtocolClientError err) msg) sendRecv t r - | B.length s > blockSize - 2 = pure $ Left $ PCETransportError TELargeMsg + | BB.length s > blockSize - 2 = pure $ Left $ PCETransportError TELargeMsg | otherwise = atomically (writeTBQueue sndQ s) >> response <$> getResponse c r where s - | batch = tEncodeBatch 1 . smpEncode . Large $ tEncode t + | batch = tEncodeBatch 1 . encodeLarge $ tEncode t | otherwise = tEncode t -- TODO switch to timeout or TimeManager that supports Int64 diff --git a/src/Simplex/Messaging/Crypto.hs b/src/Simplex/Messaging/Crypto.hs index 3564454ff..d6a24f772 100644 --- a/src/Simplex/Messaging/Crypto.hs +++ b/src/Simplex/Messaging/Crypto.hs @@ -140,6 +140,7 @@ module Simplex.Messaging.Crypto -- * Message padding / un-padding pad, + pad', unPad, -- * X509 Certificates @@ -205,6 +206,8 @@ import Database.SQLite.Simple.FromField (FromField (..)) import Database.SQLite.Simple.ToField (ToField (..)) import GHC.TypeLits (ErrorMessage (..), KnownNat, Nat, TypeError, natVal, type (+)) import Network.Transport.Internal (decodeWord16, encodeWord16) +import Simplex.Messaging.Builder (Builder, byteString, word16BE) +import qualified Simplex.Messaging.Builder as BB import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (blobFieldDecoder, parseAll, parseString) @@ -919,6 +922,14 @@ pad msg paddedLen len = B.length msg padLen = paddedLen - len - 2 +pad' :: Builder -> Int -> Either CryptoError Builder +pad' msg paddedLen + | len <= maxMsgLen && padLen >= 0 = Right $ word16BE (fromIntegral len) <> msg <> byteString (B.replicate padLen '#') + | otherwise = Left CryptoLargeMsgError + where + len = BB.length msg + padLen = paddedLen - len - 2 + unPad :: ByteString -> Either CryptoError ByteString unPad padded | B.length lenWrd == 2 && B.length rest >= len = Right $ B.take len rest diff --git a/src/Simplex/Messaging/Crypto/Lazy.hs b/src/Simplex/Messaging/Crypto/Lazy.hs index ebb3692c1..c83d93a07 100644 --- a/src/Simplex/Messaging/Crypto/Lazy.hs +++ b/src/Simplex/Messaging/Crypto/Lazy.hs @@ -41,7 +41,6 @@ import qualified Crypto.MAC.Poly1305 as Poly1305 import Data.Bifunctor (first) import Data.ByteArray (ByteArrayAccess) import qualified Data.ByteArray as BA -import qualified Data.ByteString as S import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import qualified Data.ByteString.Lazy.Char8 as LB @@ -175,7 +174,7 @@ secretBoxTailTag sbProcess secret nonce msg = run <$> sbInit_ secret nonce -- passes lazy bytestring via initialized secret box returning the reversed list of chunks secretBoxLazy_ :: (SbState -> ByteString -> (ByteString, SbState)) -> SbState -> LazyByteString -> ([ByteString], SbState) -secretBoxLazy_ sbProcess state = foldlChunks update ([], state) +secretBoxLazy_ sbProcess state = LB.foldlChunks update ([], state) where update (cs, st) chunk = let (!c, !st') = sbProcess st chunk in (c : cs, st') @@ -231,10 +230,3 @@ cryptoPassed :: CE.CryptoFailable b -> Either CryptoError b cryptoPassed = \case CE.CryptoPassed a -> Right a CE.CryptoFailed e -> Left $ CryptoPoly1305Error e - -foldlChunks :: (a -> S.ByteString -> a) -> a -> LazyByteString -> a -foldlChunks f = go - where - go !a LB.Empty = a - go !a (LB.Chunk c cs) = go (f a c) cs -{-# INLINE foldlChunks #-} diff --git a/src/Simplex/Messaging/Encoding.hs b/src/Simplex/Messaging/Encoding.hs index 814a536c4..846d071a1 100644 --- a/src/Simplex/Messaging/Encoding.hs +++ b/src/Simplex/Messaging/Encoding.hs @@ -11,6 +11,7 @@ module Simplex.Messaging.Encoding ( Encoding (..), Tail (..), Large (..), + encodeLarge, _smpP, smpEncodeList, smpListP, @@ -29,6 +30,8 @@ import qualified Data.List.NonEmpty as L import Data.Time.Clock.System (SystemTime (..)) import Data.Word (Word16, Word32) import Network.Transport.Internal (decodeWord16, decodeWord32, encodeWord16, encodeWord32) +import Simplex.Messaging.Builder (Builder, word16BE) +import qualified Simplex.Messaging.Builder as BB import Simplex.Messaging.Parsers (parseAll) import Simplex.Messaging.Util ((<$?>)) @@ -138,6 +141,10 @@ instance Encoding Large where Large <$> A.take len {-# INLINE smpP #-} +encodeLarge :: Builder -> Builder +encodeLarge s = word16BE (fromIntegral $ BB.length s) <> s +{-# INLINE encodeLarge #-} + instance Encoding SystemTime where smpEncode = smpEncode . systemSeconds {-# INLINE smpEncode #-} @@ -174,37 +181,37 @@ instance (Encoding a, Encoding b) => Encoding (a, b) where {-# INLINE smpP #-} instance (Encoding a, Encoding b, Encoding c) => Encoding (a, b, c) where - smpEncode (a, b, c) = smpEncode a <> smpEncode b <> smpEncode c + smpEncode (a, b, c) = B.concat [smpEncode a, smpEncode b, smpEncode c] {-# INLINE smpEncode #-} smpP = (,,) <$> smpP <*> smpP <*> smpP {-# INLINE smpP #-} instance (Encoding a, Encoding b, Encoding c, Encoding d) => Encoding (a, b, c, d) where - smpEncode (a, b, c, d) = smpEncode a <> smpEncode b <> smpEncode c <> smpEncode d + smpEncode (a, b, c, d) = B.concat [smpEncode a, smpEncode b, smpEncode c, smpEncode d] {-# INLINE smpEncode #-} smpP = (,,,) <$> smpP <*> smpP <*> smpP <*> smpP {-# INLINE smpP #-} instance (Encoding a, Encoding b, Encoding c, Encoding d, Encoding e) => Encoding (a, b, c, d, e) where - smpEncode (a, b, c, d, e) = smpEncode a <> smpEncode b <> smpEncode c <> smpEncode d <> smpEncode e + smpEncode (a, b, c, d, e) = B.concat [smpEncode a, smpEncode b, smpEncode c, smpEncode d, smpEncode e] {-# INLINE smpEncode #-} smpP = (,,,,) <$> smpP <*> smpP <*> smpP <*> smpP <*> smpP {-# INLINE smpP #-} instance (Encoding a, Encoding b, Encoding c, Encoding d, Encoding e, Encoding f) => Encoding (a, b, c, d, e, f) where - smpEncode (a, b, c, d, e, f) = smpEncode a <> smpEncode b <> smpEncode c <> smpEncode d <> smpEncode e <> smpEncode f + smpEncode (a, b, c, d, e, f) = B.concat [smpEncode a, smpEncode b, smpEncode c, smpEncode d, smpEncode e, smpEncode f] {-# INLINE smpEncode #-} smpP = (,,,,,) <$> smpP <*> smpP <*> smpP <*> smpP <*> smpP <*> smpP {-# INLINE smpP #-} instance (Encoding a, Encoding b, Encoding c, Encoding d, Encoding e, Encoding f, Encoding g) => Encoding (a, b, c, d, e, f, g) where - smpEncode (a, b, c, d, e, f, g) = smpEncode a <> smpEncode b <> smpEncode c <> smpEncode d <> smpEncode e <> smpEncode f <> smpEncode g + smpEncode (a, b, c, d, e, f, g) = B.concat [smpEncode a, smpEncode b, smpEncode c, smpEncode d, smpEncode e, smpEncode f, smpEncode g] {-# INLINE smpEncode #-} smpP = (,,,,,,) <$> smpP <*> smpP <*> smpP <*> smpP <*> smpP <*> smpP <*> smpP {-# INLINE smpP #-} instance (Encoding a, Encoding b, Encoding c, Encoding d, Encoding e, Encoding f, Encoding g, Encoding h) => Encoding (a, b, c, d, e, f, g, h) where - smpEncode (a, b, c, d, e, f, g, h) = smpEncode a <> smpEncode b <> smpEncode c <> smpEncode d <> smpEncode e <> smpEncode f <> smpEncode g <> smpEncode h + smpEncode (a, b, c, d, e, f, g, h) = B.concat [smpEncode a, smpEncode b, smpEncode c, smpEncode d, smpEncode e, smpEncode f, smpEncode g, smpEncode h] {-# INLINE smpEncode #-} smpP = (,,,,,,,) <$> smpP <*> smpP <*> smpP <*> smpP <*> smpP <*> smpP <*> smpP <*> smpP {-# INLINE smpP #-} diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index e65385cba..539d6d348 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -146,6 +146,7 @@ module Simplex.Messaging.Protocol tEncode, tEncodeBatch, batchTransmissions, + batchTransmissions', -- * exports for tests CommandTag (..), @@ -162,6 +163,8 @@ import Data.Attoparsec.ByteString.Char8 (Parser, ()) import qualified Data.Attoparsec.ByteString.Char8 as A import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B +import qualified Data.ByteString.Lazy.Char8 as LB +import qualified Data.ByteString.Lazy.Internal as LB import Data.Char (isPrint, isSpace) import Data.Constraint (Dict (..)) import Data.Functor (($>)) @@ -174,6 +177,8 @@ import Data.Time.Clock.System (SystemTime (..)) import Data.Type.Equality import GHC.TypeLits (ErrorMessage (..), TypeError, type (+)) import Network.Socket (HostName, ServiceName) +import Simplex.Messaging.Builder (Builder, char8, lazyByteString) +import qualified Simplex.Messaging.Builder as BB import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String @@ -1286,13 +1291,13 @@ instance Encoding CommandError where tPut :: Transport c => THandle c -> Maybe Int -> NonEmpty SentRawTransmission -> IO [Either TransportError ()] tPut th delay_ = fmap concat . mapM tPutBatch . batchTransmissions (batch th) (blockSize th) where - tPutBatch :: TransportBatch -> IO [Either TransportError ()] + tPutBatch :: TransportBatch () -> IO [Either TransportError ()] tPutBatch = \case - TBLargeTransmission -> [Left TELargeMsg] <$ putStrLn "tPut error: large message" - TBTransmissions n s -> replicate n <$> (tPutLog th (tEncodeBatch n s) <* mapM_ threadDelay delay_) - TBTransmission s -> (: []) <$> tPutLog th s + TBLargeTransmission _ -> [Left TELargeMsg] <$ putStrLn "tPut error: large message" + TBTransmissions s n _ -> replicate n <$> (tPutLog th (tEncodeBatch n s) <* mapM_ threadDelay delay_) + TBTransmission s _ -> (: []) <$> tPutLog th s -tPutLog :: Transport c => THandle c -> ByteString -> IO (Either TransportError ()) +tPutLog :: Transport c => THandle c -> Builder -> IO (Either TransportError ()) tPutLog th s = do r <- tPutBlock th s case r of @@ -1300,43 +1305,43 @@ tPutLog th s = do _ -> pure () pure r --- ByteString does not include length byte, it is added by tEncodeBatch -data TransportBatch = TBTransmissions Int ByteString | TBTransmission ByteString | TBLargeTransmission +-- Builder in TBTransmissions does not include byte with transmissions count, it is added by tEncodeBatch +data TransportBatch r = TBTransmissions Builder Int [r] | TBTransmission Builder r | TBLargeTransmission r + +batchTransmissions :: Bool -> Int -> NonEmpty SentRawTransmission -> [TransportBatch ()] +batchTransmissions batch bSize = batchTransmissions' batch bSize . L.map (,()) -- | encodes and batches transmissions into blocks, -batchTransmissions :: Bool -> Int -> NonEmpty SentRawTransmission -> [TransportBatch] -batchTransmissions batch bSize - | batch = reverse . mkBatch [] . L.map tEncode - | otherwise = map (mkBatch1 . tEncode) . L.toList +batchTransmissions' :: forall r. Bool -> Int -> NonEmpty (SentRawTransmission, r) -> [TransportBatch r] +batchTransmissions' batch bSize + | batch = addBatch . foldr addTransmission ([], mempty, 0, []) + | otherwise = map mkBatch1 . L.toList where - mkBatch :: [TransportBatch] -> NonEmpty ByteString -> [TransportBatch] - mkBatch rs ts = - let (n, s, ts_) = encodeBatch 0 "" ts - r = if n == 0 then TBLargeTransmission else TBTransmissions n s - rs' = r : rs - in case ts_ of - Just ts' -> mkBatch rs' ts' - _ -> rs' - mkBatch1 :: ByteString -> TransportBatch - mkBatch1 s = if B.length s > bSize - 2 then TBLargeTransmission else TBTransmission s - encodeBatch :: Int -> ByteString -> NonEmpty ByteString -> (Int, ByteString, Maybe (NonEmpty ByteString)) - encodeBatch n s ts@(t :| ts_) - | n == 255 = (n, s, Just ts) - | otherwise = - let s' = s <> smpEncode (Large t) - n' = n + 1 - in if B.length s' > bSize - 3 -- one byte is reserved for the number of messages in the batch - then (n,s,) $ if n == 0 then L.nonEmpty ts_ else Just ts - else case L.nonEmpty ts_ of - Just ts' -> encodeBatch n' s' ts' - _ -> (n', s', Nothing) + mkBatch1 :: (SentRawTransmission, r) -> TransportBatch r + mkBatch1 (t, r) + -- 2 bytes are reserved for pad size + | BB.length s <= bSize - 2 = TBTransmission s r + | otherwise = TBLargeTransmission r + where + s = tEncode t + addTransmission :: (SentRawTransmission, r) -> ([TransportBatch r], Builder, Int, [r]) -> ([TransportBatch r], Builder, Int, [r]) + addTransmission (t, r) acc@(bs, b, n, rs) + -- 3 = 2 bytes reserved for pad size + 1 for transmission count + | len + BB.length b <= bSize - 3 && n < 255 = (bs, s <> b, 1 + n, r : rs) + | len <= bSize - 3 = (addBatch acc, s, 1, [r]) + | otherwise = (TBLargeTransmission r : addBatch acc, mempty, 0, []) + where + s = encodeLarge $ tEncode t + len = BB.length s + addBatch :: ([TransportBatch r], Builder, Int, [r]) -> [TransportBatch r] + addBatch (bs, b, n, rs) = if n == 0 then bs else TBTransmissions b n rs : bs -tEncode :: SentRawTransmission -> ByteString -tEncode (sig, t) = smpEncode (C.signatureBytes sig) <> t +tEncode :: SentRawTransmission -> Builder +tEncode (sig, t) = lazyByteString $ LB.chunk (smpEncode $ C.signatureBytes sig) (LB.fromStrict t) {-# INLINE tEncode #-} -tEncodeBatch :: Int -> ByteString -> ByteString -tEncodeBatch n s = lenEncode n `B.cons` s +tEncodeBatch :: Int -> Builder -> Builder +tEncodeBatch n s = char8 (lenEncode n) <> s {-# INLINE tEncodeBatch #-} encodeTransmission :: ProtocolEncoding e c => Version -> ByteString -> Transmission c -> ByteString diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index 6509c1f6f..554b9a00a 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -69,7 +69,7 @@ import Data.Bifunctor (first) import Data.Bitraversable (bimapM) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B -import qualified Data.ByteString.Lazy as BL +import qualified Data.ByteString.Lazy.Char8 as LB import Data.Default (def) import Data.Functor (($>)) import Data.Version (showVersion) @@ -78,6 +78,7 @@ import Network.Socket import qualified Network.TLS as T import qualified Network.TLS.Extra as TE import qualified Paths_simplexmq as SMQ +import Simplex.Messaging.Builder (Builder, byteString, toLazyByteString) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Parsers (dropPrefix, parse, parseRead1, sumTypeJSON) @@ -134,6 +135,9 @@ class Transport c where -- | Write bytes to connection cPut :: c -> ByteString -> IO () + -- | Write bytes to connection + cPut' :: c -> LB.ByteString -> IO () + -- | Receive ByteString from connection, allowing LF or CRLF termination. getLn :: c -> IO ByteString @@ -217,8 +221,11 @@ instance Transport TLS where getBuffered tlsBuffer n t_ (T.recvData tlsContext) cPut :: TLS -> ByteString -> IO () - cPut TLS {tlsContext, tlsTransportConfig = TransportConfig {transportTimeout = t_}} s = - withTimedErr t_ . T.sendData tlsContext $ BL.fromStrict s + cPut cxt = cPut' cxt . LB.fromStrict + + cPut' :: TLS -> LB.ByteString -> IO () + cPut' TLS {tlsContext, tlsTransportConfig = TransportConfig {transportTimeout = t_}} s = + withTimedErr t_ $ T.sendData tlsContext s getLn :: TLS -> IO ByteString getLn TLS {tlsContext, tlsBuffer} = do @@ -309,10 +316,10 @@ serializeTransportError = \case TEHandshake e -> "HANDSHAKE " <> bshow e -- | Pad and send block to SMP transport. -tPutBlock :: Transport c => THandle c -> ByteString -> IO (Either TransportError ()) +tPutBlock :: Transport c => THandle c -> Builder -> IO (Either TransportError ()) tPutBlock THandle {connection = c, blockSize} block = - bimapM (const $ pure TELargeMsg) (cPut c) $ - C.pad block blockSize + bimapM (const $ pure TELargeMsg) (cPut' c . toLazyByteString) $ + C.pad' block blockSize -- | Receive block from SMP transport. tGetBlock :: Transport c => THandle c -> IO (Either TransportError ByteString) @@ -356,7 +363,7 @@ smpThHandle :: forall c. THandle c -> Version -> THandle c smpThHandle th v = (th :: THandle c) {thVersion = v, batch = v >= 4} sendHandshake :: (Transport c, Encoding smp) => THandle c -> smp -> ExceptT TransportError IO () -sendHandshake th = ExceptT . tPutBlock th . smpEncode +sendHandshake th = ExceptT . tPutBlock th . byteString . smpEncode getHandshake :: (Transport c, Encoding smp) => THandle c -> ExceptT TransportError IO smp getHandshake th = ExceptT $ (parse smpP (TEHandshake PARSE) =<<) <$> tGetBlock th diff --git a/src/Simplex/Messaging/Transport/WebSockets.hs b/src/Simplex/Messaging/Transport/WebSockets.hs index ae78da1fe..486ae1f20 100644 --- a/src/Simplex/Messaging/Transport/WebSockets.hs +++ b/src/Simplex/Messaging/Transport/WebSockets.hs @@ -7,7 +7,7 @@ module Simplex.Messaging.Transport.WebSockets (WS (..)) where import qualified Control.Exception as E import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B -import qualified Data.ByteString.Lazy as BL +import qualified Data.ByteString.Lazy as LB import qualified Network.TLS as T import Network.WebSockets import Network.WebSockets.Stream (Stream) @@ -72,6 +72,9 @@ instance Transport WS where cPut :: WS -> ByteString -> IO () cPut = sendBinaryData . wsConnection + cPut' :: WS -> LB.ByteString -> IO () + cPut' = sendBinaryData . wsConnection + getLn :: WS -> IO ByteString getLn c = do s <- trimCR <$> receiveData (wsConnection c) @@ -101,5 +104,5 @@ makeTLSContextStream cxt = (Just <$> T.recvData cxt) `E.catch` \case T.Error_EOF -> pure Nothing e -> E.throwIO e - writeStream :: Maybe BL.ByteString -> IO () + writeStream :: Maybe LB.ByteString -> IO () writeStream = maybe (closeTLS cxt) (T.sendData cxt) diff --git a/tests/AgentTests/NotificationTests.hs b/tests/AgentTests/NotificationTests.hs index da9f4c322..8d1ef241f 100644 --- a/tests/AgentTests/NotificationTests.hs +++ b/tests/AgentTests/NotificationTests.hs @@ -503,10 +503,11 @@ testNotificationsSMPRestartBatch :: Int -> ATransport -> APNSMockServer -> IO () testNotificationsSMPRestartBatch n t APNSMockServer {apnsQ} = do a <- getSMPAgentClient' agentCfg initAgentServers2 testDB b <- getSMPAgentClient' agentCfg initAgentServers2 testDB2 + threadDelay 1000000 conns <- runServers $ do conns <- replicateM (n :: Int) $ makeConnection a b _ <- registerTestToken a "abcd" NMInstant apnsQ - liftIO $ threadDelay 1500000 + liftIO $ threadDelay 5000000 forM_ conns $ \(aliceId, bobId) -> do msgId <- sendMessage b aliceId (SMP.MsgFlags True) "hello" get b ##> ("", aliceId, SENT msgId) @@ -572,7 +573,7 @@ testSwitchNotifications servers APNSMockServer {apnsQ} = do messageNotification :: TBQueue APNSMockRequest -> ExceptT AgentErrorType IO (C.CbNonce, ByteString) messageNotification apnsQ = do - 750000 `timeout` atomically (readTBQueue apnsQ) >>= \case + 1000000 `timeout` atomically (readTBQueue apnsQ) >>= \case Nothing -> error "no notification" Just APNSMockRequest {notification = APNSNotification {aps = APNSMutableContent {}, notificationData = Just ntfData}, sendApnsResponse} -> do nonce <- C.cbNonce <$> ntfData .-> "nonce" diff --git a/tests/CoreTests/BatchingTests.hs b/tests/CoreTests/BatchingTests.hs index fea45d4d5..c21ec8c7a 100644 --- a/tests/CoreTests/BatchingTests.hs +++ b/tests/CoreTests/BatchingTests.hs @@ -5,8 +5,9 @@ module CoreTests.BatchingTests (batchingTests) where import Control.Concurrent.STM import Control.Monad import Data.ByteString.Char8 (ByteString) -import qualified Data.ByteString.Char8 as B import qualified Data.List.NonEmpty as L +import Simplex.Messaging.Builder (Builder) +import qualified Simplex.Messaging.Builder as BB import Simplex.Messaging.Client import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol @@ -20,7 +21,7 @@ batchingTests = do it "should batch with 90 subscriptions per batch" testBatchSubscriptions it "should break on message that does not fit" testBatchWithMessage it "should break on large message" testBatchWithLargeMessage - describe "batchClientTransmissions" $ do + describe "batchTransmissions'" $ do it "should batch with 90 subscriptions per batch" testClientBatchSubscriptions it "should break on message that does not fit" testClientBatchWithMessage it "should break on large message" testClientBatchWithLargeMessage @@ -34,8 +35,8 @@ testBatchSubscriptions = do length batches1 `shouldBe` 200 let batches = batchTransmissions True smpBlockSize $ L.fromList subs length batches `shouldBe` 3 - [TBTransmissions n1 s1, TBTransmissions n2 s2, TBTransmissions n3 s3] <- pure batches - (n1, n2, n3) `shouldBe` (90, 90, 20) + [TBTransmissions s1 n1 _, TBTransmissions s2 n2 _, TBTransmissions s3 n3 _] <- pure batches + (n1, n2, n3) `shouldBe` (20, 90, 90) all lenOk [s1, s2, s3] `shouldBe` True testBatchWithMessage :: IO () @@ -50,8 +51,8 @@ testBatchWithMessage = do length batches1 `shouldBe` 101 let batches = batchTransmissions True smpBlockSize $ L.fromList cmds length batches `shouldBe` 2 - [TBTransmissions n1 s1, TBTransmissions n2 s2] <- pure batches - (n1, n2) `shouldBe` (60, 41) + [TBTransmissions s1 n1 _, TBTransmissions s2 n2 _] <- pure batches + (n1, n2) `shouldBe` (55, 46) all lenOk [s1, s2] `shouldBe` True testBatchWithLargeMessage :: IO () @@ -69,8 +70,8 @@ testBatchWithLargeMessage = do length batches1' `shouldBe` 160 let batches = batchTransmissions True smpBlockSize $ L.fromList cmds length batches `shouldBe` 4 - [TBTransmissions n1 s1, TBLargeTransmission, TBTransmissions n2 s2, TBTransmissions n3 s3] <- pure batches - (n1, n2, n3) `shouldBe` (60, 90, 10) + [TBTransmissions s1 n1 _, TBLargeTransmission _, TBTransmissions s2 n2 _, TBTransmissions s3 n3 _] <- pure batches + (n1, n2, n3) `shouldBe` (60, 10, 90) all lenOk [s1, s2, s3] `shouldBe` True testClientBatchSubscriptions :: IO () @@ -78,13 +79,13 @@ testClientBatchSubscriptions = do sessId <- atomically . C.randomBytes 32 =<< C.newRandom client <- atomically $ clientStub sessId subs <- replicateM 200 $ randomSUBCmd client - let batches1 = batchClientTransmissions False smpBlockSize $ L.fromList subs - all lenOk1' batches1 `shouldBe` True - let batches = batchClientTransmissions True smpBlockSize $ L.fromList subs + let batches1 = batchTransmissions' False smpBlockSize $ L.fromList subs + all lenOk1 batches1 `shouldBe` True + let batches = batchTransmissions' True smpBlockSize $ L.fromList subs length batches `shouldBe` 3 - [CBTransmissions s1 n1 rs1, CBTransmissions s2 n2 rs2, CBTransmissions s3 n3 rs3] <- pure batches - (n1, n2, n3) `shouldBe` (90, 90, 20) - (length rs1, length rs2, length rs3) `shouldBe` (90, 90, 20) + [TBTransmissions s1 n1 rs1, TBTransmissions s2 n2 rs2, TBTransmissions s3 n3 rs3] <- pure batches + (n1, n2, n3) `shouldBe` (20, 90, 90) + (length rs1, length rs2, length rs3) `shouldBe` (20, 90, 90) all lenOk [s1, s2, s3] `shouldBe` True testClientBatchWithMessage :: IO () @@ -95,14 +96,14 @@ testClientBatchWithMessage = do send <- randomSENDCmd client 8000 subs2 <- replicateM 40 $ randomSUBCmd client let cmds = subs1 <> [send] <> subs2 - batches1 = batchClientTransmissions False smpBlockSize $ L.fromList cmds - all lenOk1' batches1 `shouldBe` True + batches1 = batchTransmissions' False smpBlockSize $ L.fromList cmds + all lenOk1 batches1 `shouldBe` True length batches1 `shouldBe` 101 - let batches = batchClientTransmissions True smpBlockSize $ L.fromList cmds + let batches = batchTransmissions' True smpBlockSize $ L.fromList cmds length batches `shouldBe` 2 - [CBTransmissions s1 n1 rs1, CBTransmissions s2 n2 rs2] <- pure batches - (n1, n2) `shouldBe` (60, 41) - (length rs1, length rs2) `shouldBe` (60, 41) + [TBTransmissions s1 n1 rs1, TBTransmissions s2 n2 rs2] <- pure batches + (n1, n2) `shouldBe` (55, 46) + (length rs1, length rs2) `shouldBe` (55, 46) all lenOk [s1, s2] `shouldBe` True testClientBatchWithLargeMessage :: IO () @@ -113,26 +114,26 @@ testClientBatchWithLargeMessage = do send <- randomSENDCmd client 17000 subs2 <- replicateM 100 $ randomSUBCmd client let cmds = subs1 <> [send] <> subs2 - batches1 = batchClientTransmissions False smpBlockSize $ L.fromList cmds - all lenOk1' batches1 `shouldBe` False + batches1 = batchTransmissions' False smpBlockSize $ L.fromList cmds + all lenOk1 batches1 `shouldBe` False length batches1 `shouldBe` 161 let batches1' = take 60 batches1 <> drop 61 batches1 - all lenOk1' batches1' `shouldBe` True + all lenOk1 batches1' `shouldBe` True length batches1' `shouldBe` 160 -- - let batches = batchClientTransmissions True smpBlockSize $ L.fromList cmds + let batches = batchTransmissions' True smpBlockSize $ L.fromList cmds length batches `shouldBe` 4 - [CBTransmissions s1 n1 rs1, CBLargeTransmission _, CBTransmissions s2 n2 rs2, CBTransmissions s3 n3 rs3] <- pure batches - (n1, n2, n3) `shouldBe` (60, 90, 10) - (length rs1, length rs2, length rs3) `shouldBe` (60, 90, 10) + [TBTransmissions s1 n1 rs1, TBLargeTransmission _, TBTransmissions s2 n2 rs2, TBTransmissions s3 n3 rs3] <- pure batches + (n1, n2, n3) `shouldBe` (60, 10, 90) + (length rs1, length rs2, length rs3) `shouldBe` (60, 10, 90) all lenOk [s1, s2, s3] `shouldBe` True -- let cmds' = [send] <> subs1 <> subs2 - let batches' = batchClientTransmissions True smpBlockSize $ L.fromList cmds' + let batches' = batchTransmissions' True smpBlockSize $ L.fromList cmds' length batches' `shouldBe` 3 - [CBLargeTransmission _, CBTransmissions s1' n1' rs1', CBTransmissions s2' n2' rs2'] <- pure batches' - (n1', n2') `shouldBe` (90, 70) - (length rs1', length rs2') `shouldBe` (90, 70) + [TBLargeTransmission _, TBTransmissions s1' n1' rs1', TBTransmissions s2' n2' rs2'] <- pure batches' + (n1', n2') `shouldBe` (70, 90) + (length rs1', length rs2') `shouldBe` (70, 90) all lenOk [s1', s2'] `shouldBe` True randomSUB :: ByteString -> IO (Maybe C.ASignature, ByteString) @@ -169,15 +170,10 @@ randomSENDCmd c len = do msg <- atomically $ C.randomBytes len g mkTransmission c (Just rpKey, sId, Cmd SSender $ SEND noMsgFlags msg) -lenOk :: ByteString -> Bool -lenOk s = 0 < B.length s && B.length s <= smpBlockSize - 2 +lenOk :: Builder -> Bool +lenOk s = 0 < BB.length s && BB.length s <= smpBlockSize - 2 -lenOk1 :: TransportBatch -> Bool +lenOk1 :: TransportBatch r -> Bool lenOk1 = \case - TBTransmission s -> lenOk s - _ -> False - -lenOk1' :: ClientBatch err msg -> Bool -lenOk1' = \case - CBTransmission s _ -> lenOk s + TBTransmission s _ -> lenOk s _ -> False diff --git a/tests/SMPAgentClient.hs b/tests/SMPAgentClient.hs index a767218c0..c39f6fa43 100644 --- a/tests/SMPAgentClient.hs +++ b/tests/SMPAgentClient.hs @@ -198,8 +198,8 @@ agentCfg = ntfCfg = defaultClientConfig {qSize = 1, defaultTransport = (ntfTestPort, transport @TLS)}, reconnectInterval = defaultReconnectInterval {initialInterval = 50_000}, xftpNotifyErrsOnRetry = False, - ntfWorkerDelay = 1000, - ntfSMPWorkerDelay = 1000, + ntfWorkerDelay = 100, + ntfSMPWorkerDelay = 100, caCertificateFile = "tests/fixtures/ca.crt", privateKeyFile = "tests/fixtures/server.key", certificateFile = "tests/fixtures/server.crt"