From 27d77e2d76ddb906cb5616aed2efae50c27572c0 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Mon, 8 Jan 2024 10:02:24 +0000 Subject: [PATCH] refactor --- src/Simplex/Messaging/Client.hs | 49 +++++-------------------------- src/Simplex/Messaging/Protocol.hs | 44 ++++++++++++++------------- tests/CoreTests/BatchingTests.hs | 47 +++++++++++++---------------- 3 files changed, 53 insertions(+), 87 deletions(-) diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index 471a9dbaa..eaeac8473 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -72,9 +72,7 @@ module Simplex.Messaging.Client ClientCommand, -- * For testing - ClientBatch (..), PCTransmission, - batchClientTransmissions, mkTransmission, clientStub, ) @@ -104,7 +102,7 @@ import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (defaultJSON, dropPrefix, enumJSON) -import Simplex.Messaging.Protocol as SMP +import Simplex.Messaging.Protocol import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport @@ -175,7 +173,7 @@ clientStub sessionId = do } } -type SMPClient = ProtocolClient ErrorType SMP.BrokerMsg +type SMPClient = ProtocolClient ErrorType BrokerMsg -- | Type for client command data type ClientCommand msg = (Maybe C.APrivateSignKey, EntityId, ProtoCommand msg) @@ -636,7 +634,7 @@ type PCTransmission err msg = (SentRawTransmission, Request err msg) -- | Send multiple commands with batching and collect responses sendProtocolCommands :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> NonEmpty (ClientCommand msg) -> IO (NonEmpty (Response err msg)) sendProtocolCommands c@ProtocolClient {batch, blockSize} cs = do - bs <- batchClientTransmissions batch blockSize <$> mapM (mkTransmission c) cs + bs <- batchTransmissions' batch blockSize <$> mapM (mkTransmission c) cs validate . concat =<< mapM (sendBatch c) bs where validate :: [Response err msg] -> IO (NonEmpty (Response err msg)) @@ -653,53 +651,22 @@ sendProtocolCommands c@ProtocolClient {batch, blockSize} cs = do streamProtocolCommands :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> NonEmpty (ClientCommand msg) -> ([Response err msg] -> IO ()) -> IO () streamProtocolCommands c@ProtocolClient {batch, blockSize} cs cb = do - bs <- batchClientTransmissions batch blockSize <$> mapM (mkTransmission c) cs + bs <- batchTransmissions' batch blockSize <$> mapM (mkTransmission c) cs mapM_ (cb <=< sendBatch c) bs -sendBatch :: ProtocolClient err msg -> ClientBatch err msg -> IO [Response err msg] +sendBatch :: ProtocolClient err msg -> TransportBatch (Request err msg) -> IO [Response err msg] sendBatch c@ProtocolClient {client_ = PClient {sndQ}} b = do case b of - CBLargeTransmission Request {entityId} -> do + TBLargeTransmission Request {entityId} -> do putStrLn "send error: large message" pure [Response entityId $ Left $ PCETransportError TELargeMsg] - CBTransmissions s n rs -> do + TBTransmissions s n rs -> do when (n > 0) $ atomically $ writeTBQueue sndQ $ tEncodeBatch n s mapConcurrently (getResponse c) rs - CBTransmission s r -> do + TBTransmission s r -> do atomically $ writeTBQueue sndQ s (: []) <$> getResponse c r -data ClientBatch err msg - = -- Builder in CBTransmissions does not include count byte, it is added by tEncodeBatch - CBTransmissions Builder Int [Request err msg] - | CBTransmission Builder (Request err msg) - | CBLargeTransmission (Request err msg) - -batchClientTransmissions :: forall err msg. Bool -> Int -> NonEmpty (PCTransmission err msg) -> [ClientBatch err msg] -batchClientTransmissions batch bSize - | batch = addBatch . foldr addTransmission ([], mempty, 0, 0, []) - | otherwise = map mkBatch1 . L.toList - where - mkBatch1 :: PCTransmission err msg -> ClientBatch err msg - mkBatch1 (t, r) - -- 2 bytes are reserved for pad size - | LB.length s <= fromIntegral (bSize - 2) = CBTransmission (lazyByteString s) r - | otherwise = CBLargeTransmission r - where - s = tEncode t - addTransmission :: PCTransmission err msg -> ([ClientBatch err msg], Builder, Int, Int, [Request err msg]) -> ([ClientBatch err msg], Builder, Int, Int, [Request err msg]) - addTransmission (t, r) acc@(bs, b, len, n, rs) - | len' <= bSize - 3 && n < 255 = (bs, s <> b, len', 1 + n, r : rs) - | sLen <= bSize - 3 = (addBatch acc, s, sLen, 1, [r]) - | otherwise = (CBLargeTransmission r : addBatch acc, mempty, 0, 0, []) - where - s = encodeLarge t' - sLen = 2 + fromIntegral (LB.length t') -- 2-bytes length is added by encodeLarge - t' = tEncode t - len' = sLen + len - addBatch :: ([ClientBatch err msg], Builder, Int, Int, [Request err msg]) -> [ClientBatch err msg] - addBatch (bs, b, _, n, rs) = if n == 0 then bs else CBTransmissions b n rs : bs - -- | Send Protocol command sendProtocolCommand :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> Maybe C.APrivateSignKey -> EntityId -> ProtoCommand msg -> ExceptT (ProtocolClientError err) IO msg sendProtocolCommand c@ProtocolClient {client_ = PClient {sndQ}, batch, blockSize} pKey entId cmd = diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 6a16ea557..98f282ca5 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -146,6 +146,7 @@ module Simplex.Messaging.Protocol tEncode, tEncodeBatch, batchTransmissions, + batchTransmissions', -- * exports for tests CommandTag (..), @@ -1289,11 +1290,11 @@ instance Encoding CommandError where tPut :: Transport c => THandle c -> Maybe Int -> NonEmpty SentRawTransmission -> IO [Either TransportError ()] tPut th delay_ = fmap concat . mapM tPutBatch . batchTransmissions (batch th) (blockSize th) where - tPutBatch :: TransportBatch -> IO [Either TransportError ()] + tPutBatch :: TransportBatch () -> IO [Either TransportError ()] tPutBatch = \case - TBLargeTransmission -> [Left TELargeMsg] <$ putStrLn "tPut error: large message" - TBTransmissions n s -> replicate n <$> (tPutLog th (tEncodeBatch n s) <* mapM_ threadDelay delay_) - TBTransmission s -> (: []) <$> tPutLog th s + TBLargeTransmission _ -> [Left TELargeMsg] <$ putStrLn "tPut error: large message" + TBTransmissions s n _ -> replicate n <$> (tPutLog th (tEncodeBatch n s) <* mapM_ threadDelay delay_) + TBTransmission s _ -> (: []) <$> tPutLog th s tPutLog :: Transport c => THandle c -> Builder -> IO (Either TransportError ()) tPutLog th s = do @@ -1303,34 +1304,37 @@ tPutLog th s = do _ -> pure () pure r --- ByteString does not include length byte, it is added by tEncodeBatch -data TransportBatch = TBTransmissions Int Builder | TBTransmission Builder | TBLargeTransmission +-- Builder in TBTransmissions does not include byte with transmissions count, it is added by tEncodeBatch +data TransportBatch r = TBTransmissions Builder Int [r] | TBTransmission Builder r | TBLargeTransmission r + +batchTransmissions :: Bool -> Int -> NonEmpty SentRawTransmission -> [TransportBatch ()] +batchTransmissions batch bSize = batchTransmissions' batch bSize . L.map (,()) -- | encodes and batches transmissions into blocks, -batchTransmissions :: Bool -> Int -> NonEmpty SentRawTransmission -> [TransportBatch] -batchTransmissions batch bSize - | batch = addBatch . foldr addTransmission ([], mempty, 0, 0) +batchTransmissions' :: forall r. Bool -> Int -> NonEmpty (SentRawTransmission, r) -> [TransportBatch r] +batchTransmissions' batch bSize + | batch = addBatch . foldr addTransmission ([], mempty, 0, 0, []) | otherwise = map mkBatch1 . L.toList where - mkBatch1 :: SentRawTransmission -> TransportBatch - mkBatch1 t + mkBatch1 :: (SentRawTransmission, r) -> TransportBatch r + mkBatch1 (t, r) -- 2 bytes are reserved for pad size - | LB.length s <= fromIntegral (bSize - 2) = TBTransmission (lazyByteString s) - | otherwise = TBLargeTransmission + | LB.length s <= fromIntegral (bSize - 2) = TBTransmission (lazyByteString s) r + | otherwise = TBLargeTransmission r where s = tEncode t - addTransmission :: SentRawTransmission -> ([TransportBatch], Builder, Int, Int) -> ([TransportBatch], Builder, Int, Int) - addTransmission t acc@(bs, b, len, n) - | len' <= bSize - 3 && n < 255 = (bs, s <> b, len', 1 + n) - | sLen <= bSize - 3 = (addBatch acc, s, sLen, 1) - | otherwise = (TBLargeTransmission : addBatch acc, mempty, 0, 0) + addTransmission :: (SentRawTransmission, r) -> ([TransportBatch r], Builder, Int, Int, [r]) -> ([TransportBatch r], Builder, Int, Int, [r]) + addTransmission (t, r) acc@(bs, b, len, n, rs) + | len' <= bSize - 3 && n < 255 = (bs, s <> b, len', 1 + n, r : rs) + | sLen <= bSize - 3 = (addBatch acc, s, sLen, 1, [r]) + | otherwise = (TBLargeTransmission r : addBatch acc, mempty, 0, 0, []) where s = encodeLarge t' sLen = 2 + fromIntegral (LB.length t') -- 2-bytes length is added by encodeLarge t' = tEncode t len' = sLen + len - addBatch :: ([TransportBatch], Builder, Int, Int) -> [TransportBatch] - addBatch (bs, b, _, n) = if n == 0 then bs else TBTransmissions n b : bs + addBatch :: ([TransportBatch r], Builder, Int, Int, [r]) -> [TransportBatch r] + addBatch (bs, b, _, n, rs) = if n == 0 then bs else TBTransmissions b n rs : bs tEncode :: SentRawTransmission -> LB.ByteString tEncode (sig, t) = LB.chunk (smpEncode $ C.signatureBytes sig) (LB.fromStrict t) diff --git a/tests/CoreTests/BatchingTests.hs b/tests/CoreTests/BatchingTests.hs index 78a1079dd..28f9f0435 100644 --- a/tests/CoreTests/BatchingTests.hs +++ b/tests/CoreTests/BatchingTests.hs @@ -21,7 +21,7 @@ batchingTests = do it "should batch with 90 subscriptions per batch" testBatchSubscriptions it "should break on message that does not fit" testBatchWithMessage it "should break on large message" testBatchWithLargeMessage - describe "batchClientTransmissions" $ do + describe "batchTransmissions'" $ do it "should batch with 90 subscriptions per batch" testClientBatchSubscriptions it "should break on message that does not fit" testClientBatchWithMessage it "should break on large message" testClientBatchWithLargeMessage @@ -35,7 +35,7 @@ testBatchSubscriptions = do length batches1 `shouldBe` 200 let batches = batchTransmissions True smpBlockSize $ L.fromList subs length batches `shouldBe` 3 - [TBTransmissions n1 s1, TBTransmissions n2 s2, TBTransmissions n3 s3] <- pure batches + [TBTransmissions s1 n1 _, TBTransmissions s2 n2 _, TBTransmissions s3 n3 _] <- pure batches (n1, n2, n3) `shouldBe` (20, 90, 90) all lenOk [s1, s2, s3] `shouldBe` True @@ -51,7 +51,7 @@ testBatchWithMessage = do length batches1 `shouldBe` 101 let batches = batchTransmissions True smpBlockSize $ L.fromList cmds length batches `shouldBe` 2 - [TBTransmissions n1 s1, TBTransmissions n2 s2] <- pure batches + [TBTransmissions s1 n1 _, TBTransmissions s2 n2 _] <- pure batches (n1, n2) `shouldBe` (55, 46) all lenOk [s1, s2] `shouldBe` True @@ -70,7 +70,7 @@ testBatchWithLargeMessage = do length batches1' `shouldBe` 160 let batches = batchTransmissions True smpBlockSize $ L.fromList cmds length batches `shouldBe` 4 - [TBTransmissions n1 s1, TBLargeTransmission, TBTransmissions n2 s2, TBTransmissions n3 s3] <- pure batches + [TBTransmissions s1 n1 _, TBLargeTransmission _, TBTransmissions s2 n2 _, TBTransmissions s3 n3 _] <- pure batches (n1, n2, n3) `shouldBe` (60, 10, 90) all lenOk [s1, s2, s3] `shouldBe` True @@ -79,11 +79,11 @@ testClientBatchSubscriptions = do sessId <- atomically . C.randomBytes 32 =<< C.newRandom client <- atomically $ clientStub sessId subs <- replicateM 200 $ randomSUBCmd client - let batches1 = batchClientTransmissions False smpBlockSize $ L.fromList subs - all lenOk1' batches1 `shouldBe` True - let batches = batchClientTransmissions True smpBlockSize $ L.fromList subs + let batches1 = batchTransmissions' False smpBlockSize $ L.fromList subs + all lenOk1 batches1 `shouldBe` True + let batches = batchTransmissions' True smpBlockSize $ L.fromList subs length batches `shouldBe` 3 - [CBTransmissions s1 n1 rs1, CBTransmissions s2 n2 rs2, CBTransmissions s3 n3 rs3] <- pure batches + [TBTransmissions s1 n1 rs1, TBTransmissions s2 n2 rs2, TBTransmissions s3 n3 rs3] <- pure batches (n1, n2, n3) `shouldBe` (20, 90, 90) (length rs1, length rs2, length rs3) `shouldBe` (20, 90, 90) all lenOk [s1, s2, s3] `shouldBe` True @@ -96,12 +96,12 @@ testClientBatchWithMessage = do send <- randomSENDCmd client 8000 subs2 <- replicateM 40 $ randomSUBCmd client let cmds = subs1 <> [send] <> subs2 - batches1 = batchClientTransmissions False smpBlockSize $ L.fromList cmds - all lenOk1' batches1 `shouldBe` True + batches1 = batchTransmissions' False smpBlockSize $ L.fromList cmds + all lenOk1 batches1 `shouldBe` True length batches1 `shouldBe` 101 - let batches = batchClientTransmissions True smpBlockSize $ L.fromList cmds + let batches = batchTransmissions' True smpBlockSize $ L.fromList cmds length batches `shouldBe` 2 - [CBTransmissions s1 n1 rs1, CBTransmissions s2 n2 rs2] <- pure batches + [TBTransmissions s1 n1 rs1, TBTransmissions s2 n2 rs2] <- pure batches (n1, n2) `shouldBe` (55, 46) (length rs1, length rs2) `shouldBe` (55, 46) all lenOk [s1, s2] `shouldBe` True @@ -114,24 +114,24 @@ testClientBatchWithLargeMessage = do send <- randomSENDCmd client 17000 subs2 <- replicateM 100 $ randomSUBCmd client let cmds = subs1 <> [send] <> subs2 - batches1 = batchClientTransmissions False smpBlockSize $ L.fromList cmds - all lenOk1' batches1 `shouldBe` False + batches1 = batchTransmissions' False smpBlockSize $ L.fromList cmds + all lenOk1 batches1 `shouldBe` False length batches1 `shouldBe` 161 let batches1' = take 60 batches1 <> drop 61 batches1 - all lenOk1' batches1' `shouldBe` True + all lenOk1 batches1' `shouldBe` True length batches1' `shouldBe` 160 -- - let batches = batchClientTransmissions True smpBlockSize $ L.fromList cmds + let batches = batchTransmissions' True smpBlockSize $ L.fromList cmds length batches `shouldBe` 4 - [CBTransmissions s1 n1 rs1, CBLargeTransmission _, CBTransmissions s2 n2 rs2, CBTransmissions s3 n3 rs3] <- pure batches + [TBTransmissions s1 n1 rs1, TBLargeTransmission _, TBTransmissions s2 n2 rs2, TBTransmissions s3 n3 rs3] <- pure batches (n1, n2, n3) `shouldBe` (60, 10, 90) (length rs1, length rs2, length rs3) `shouldBe` (60, 10, 90) all lenOk [s1, s2, s3] `shouldBe` True -- let cmds' = [send] <> subs1 <> subs2 - let batches' = batchClientTransmissions True smpBlockSize $ L.fromList cmds' + let batches' = batchTransmissions' True smpBlockSize $ L.fromList cmds' length batches' `shouldBe` 3 - [CBLargeTransmission _, CBTransmissions s1' n1' rs1', CBTransmissions s2' n2' rs2'] <- pure batches' + [TBLargeTransmission _, TBTransmissions s1' n1' rs1', TBTransmissions s2' n2' rs2'] <- pure batches' (n1', n2') `shouldBe` (70, 90) (length rs1', length rs2') `shouldBe` (70, 90) all lenOk [s1', s2'] `shouldBe` True @@ -175,12 +175,7 @@ lenOk s = 0 < len && len <= smpBlockSize - 2 where len = fromIntegral . LB.length $ toLazyByteString s -lenOk1 :: TransportBatch -> Bool +lenOk1 :: TransportBatch r -> Bool lenOk1 = \case - TBTransmission s -> lenOk s - _ -> False - -lenOk1' :: ClientBatch err msg -> Bool -lenOk1' = \case - CBTransmission s _ -> lenOk s + TBTransmission s _ -> lenOk s _ -> False