client: streaming batched commands, refactor (#826)

This commit is contained in:
Evgeny Poberezkin
2023-08-19 16:11:05 +01:00
committed by GitHub
parent f3111f4559
commit 40e6d16e48
2 changed files with 116 additions and 105 deletions
+108 -97
View File
@@ -10,6 +10,7 @@
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
-- |
@@ -39,6 +40,7 @@ module Simplex.Messaging.Client
createSMPQueue,
subscribeSMPQueue,
subscribeSMPQueues,
streamSubscribeSMPQueues,
getSMPMessage,
subscribeSMPQueueNotifications,
subscribeSMPQueuesNtfs,
@@ -176,10 +178,10 @@ clientStub sessionId = do
type SMPClient = ProtocolClient ErrorType SMP.BrokerMsg
-- | Type for client command data
type ClientCommand msg = (Maybe C.APrivateSignKey, QueueId, ProtoCommand msg)
type ClientCommand msg = (Maybe C.APrivateSignKey, EntityId, ProtoCommand msg)
-- | Type synonym for transmission from some SPM server queue.
type ServerTransmission msg = (TransportSession msg, Version, SessionId, QueueId, msg)
type ServerTransmission msg = (TransportSession msg, Version, SessionId, EntityId, msg)
data HostMode
= -- | prefer (or require) onion hosts when connecting via SOCKS proxy
@@ -283,13 +285,14 @@ defaultClientConfig =
}
data Request err msg = Request
{ queueId :: QueueId,
responseVar :: TResponse err msg
{ entityId :: EntityId,
responseVar :: TMVar (Either (ProtocolClientError err) msg)
}
type Response err msg = Either (ProtocolClientError err) msg
type TResponse err msg = TMVar (Response err msg)
data Response err msg = Response
{ entityId :: EntityId,
response :: Either (ProtocolClientError err) msg
}
chooseTransportHost :: NetworkConfig -> NonEmpty TransportHost -> Either (ProtocolClientError err) TransportHost
chooseTransportHost NetworkConfig {socksProxy, hostMode, requiredHostMode} hosts =
@@ -414,26 +417,27 @@ getProtocolClient transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize,
process c = forever $ atomically (readTBQueue $ rcvQ $ client_ c) >>= mapM_ (processMsg c)
processMsg :: ProtocolClient err msg -> SignedTransmission err msg -> IO ()
processMsg c@ProtocolClient {client_ = PClient {sentCommands}} (_, _, (corrId, qId, respOrErr)) =
processMsg c@ProtocolClient {client_ = PClient {sentCommands}} (_, _, (corrId, entId, respOrErr)) =
if B.null $ bs corrId
then sendMsg respOrErr
else do
atomically (TM.lookup corrId sentCommands) >>= \case
Nothing -> sendMsg respOrErr
Just Request {queueId, responseVar} -> atomically $ do
Just Request {entityId, responseVar} -> atomically $ do
TM.delete corrId sentCommands
putTMVar responseVar $
if queueId == qId
then case respOrErr of
Left e -> Left $ PCEResponseError e
Right r -> case protocolError r of
Just e -> Left $ PCEProtocolError e
_ -> Right r
else Left . PCEUnexpectedResponse $ bshow respOrErr
putTMVar responseVar $ response entityId
where
response entityId
| entityId == entId =
case respOrErr of
Left e -> Left $ PCEResponseError e
Right r -> case protocolError r of
Just e -> Left $ PCEProtocolError e
_ -> Right r
| otherwise = Left . PCEUnexpectedResponse $ bshow respOrErr
sendMsg :: Either err msg -> IO ()
sendMsg = \case
Right msg -> atomically $ mapM_ (`writeTBQueue` serverTransmission c qId msg) msgQ
Right msg -> atomically $ mapM_ (`writeTBQueue` serverTransmission c entId msg) msgQ
Left e -> putStrLn $ "SMP client error: " <> show e
proxyUsername :: TransportSession msg -> ByteString
@@ -509,14 +513,22 @@ subscribeSMPQueue c rpKey rId =
-- | Subscribe to multiple SMP queues batching commands if supported.
subscribeSMPQueues :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId) -> IO (NonEmpty (Either SMPClientError ()))
subscribeSMPQueues c qs = sendProtocolCommands c cs >>= mapM response . L.zip qs
subscribeSMPQueues c qs = sendProtocolCommands c cs >>= mapM (processSUBResponse c)
where
cs = L.map (\(rpKey, rId) -> (Just rpKey, rId, Cmd SRecipient SUB)) qs
response ((_, rId), r) = case r of
Right OK -> pure $ Right ()
Right cmd@MSG {} -> writeSMPMessage c rId cmd $> Right ()
Right r' -> pure . Left . PCEUnexpectedResponse $ bshow r'
Left e -> pure $ Left e
streamSubscribeSMPQueues :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId) -> ([(RecipientId, Either SMPClientError ())] -> IO ()) -> IO ()
streamSubscribeSMPQueues c qs cb = streamProtocolCommands c cs $ mapM process >=> cb
where
cs = L.map (\(rpKey, rId) -> (Just rpKey, rId, Cmd SRecipient SUB)) qs
process r@(Response rId _) = (rId,) <$> processSUBResponse c r
processSUBResponse :: SMPClient -> Response ErrorType BrokerMsg -> IO (Either SMPClientError ())
processSUBResponse c (Response rId r) = case r of
Right OK -> pure $ Right ()
Right cmd@MSG {} -> writeSMPMessage c rId cmd $> Right ()
Right r' -> pure . Left . PCEUnexpectedResponse $ bshow r'
Left e -> pure $ Left e
writeSMPMessage :: SMPClient -> RecipientId -> BrokerMsg -> IO ()
writeSMPMessage c rId msg = atomically $ mapM_ (`writeTBQueue` serverTransmission c rId msg) (msgQ $ client_ c)
@@ -562,12 +574,12 @@ enableSMPQueueNotifications c rpKey rId notifierKey rcvNtfPublicDhKey =
-- | Enable notifications for the multiple queues for push notifications server.
enableSMPQueuesNtfs :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId, NtfPublicVerifyKey, RcvNtfPublicDhKey) -> IO (NonEmpty (Either SMPClientError (NotifierId, RcvNtfPublicDhKey)))
enableSMPQueuesNtfs c qs = L.map response <$> sendProtocolCommands c cs
enableSMPQueuesNtfs c qs = L.map process <$> sendProtocolCommands c cs
where
cs = L.map (\(rpKey, rId, notifierKey, rcvNtfPublicDhKey) -> (Just rpKey, rId, Cmd SRecipient $ NKEY notifierKey rcvNtfPublicDhKey)) qs
response = \case
process (Response _ r) = case r of
Right (NID nId rcvNtfSrvPublicDhKey) -> Right (nId, rcvNtfSrvPublicDhKey)
Right r -> Left . PCEUnexpectedResponse $ bshow r
Right r' -> Left . PCEUnexpectedResponse $ bshow r'
Left e -> Left e
-- | Disable notifications for the queue for push notifications server.
@@ -623,123 +635,122 @@ okSMPCommand cmd c pKey qId =
r -> throwE . PCEUnexpectedResponse $ bshow r
okSMPCommands :: PartyI p => Command p -> SMPClient -> NonEmpty (C.APrivateSignKey, QueueId) -> IO (NonEmpty (Either SMPClientError ()))
okSMPCommands cmd c qs = L.map response <$> sendProtocolCommands c cs
okSMPCommands cmd c qs = L.map process <$> sendProtocolCommands c cs
where
aCmd = Cmd sParty cmd
cs = L.map (\(pKey, qId) -> (Just pKey, qId, aCmd)) qs
response = \case
process (Response _ r) = case r of
Right OK -> Right ()
Right r -> Left . PCEUnexpectedResponse $ bshow r
Right r' -> Left . PCEUnexpectedResponse $ bshow r'
Left e -> Left e
-- | Send SMP command
sendSMPCommand :: PartyI p => SMPClient -> Maybe C.APrivateSignKey -> QueueId -> Command p -> ExceptT SMPClientError IO BrokerMsg
sendSMPCommand c pKey qId cmd = sendProtocolCommand c pKey qId (Cmd sParty cmd)
type PCTransmission err msg = (SentRawTransmission, TResponse err msg)
type PCTransmission err msg = (SentRawTransmission, Request err msg)
-- | Send multiple commands with batching and collect responses
-- TODO switch to timeout or TimeManager that supports Int64
sendProtocolCommands :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> NonEmpty (ClientCommand msg) -> IO (NonEmpty (Response err msg))
sendProtocolCommands c@ProtocolClient {client_ = PClient {sndQ, tcpTimeout}, batch, blockSize} cs = do
bs <- batchClientTransmissions batch blockSize <$> mapM (runExceptT . mkTransmission c) cs
validate . concat =<< mapM sendBatch bs
sendProtocolCommands c@ProtocolClient {batch, blockSize} cs = do
bs <- batchClientTransmissions batch blockSize <$> mapM (mkTransmission c) cs
validate . concat =<< mapM (sendBatch c) bs
where
validate :: [Response err msg] -> IO (NonEmpty (Response err msg))
validate rs
| diff == 0 = pure $ L.fromList rs
| diff > 0 = do
putStrLn "send error: fewer responses than expected"
pure $ L.fromList $ rs <> replicate diff (Left $ PCETransportError TEBadBlock)
pure $ L.fromList $ rs <> replicate diff (Response "" $ Left $ PCETransportError TEBadBlock)
| otherwise = do
putStrLn "send error: more responses than expected"
pure $ L.fromList $ take (L.length cs) rs
where
diff = L.length cs - length rs
sendBatch :: ClientBatch err msg -> IO [Response err msg]
sendBatch b = do
case b of
CBLargeTransmission -> [Left (PCETransportError TELargeMsg)] <$ putStrLn "send error: large message"
CBTransmissions n s rs -> do
when (n > 0) $ atomically $ writeTBQueue sndQ $ tEncodeBatch n s
forConcurrently rs $ \case
Right r -> withTimeout c tcpTimeout (atomically $ takeTMVar r)
Left e -> pure $ Left e
CBTransmission s r -> do
atomically $ writeTBQueue sndQ s
(: []) <$> withTimeout c tcpTimeout (atomically $ takeTMVar r)
type PCTransmissionOrErr err msg = Either (ProtocolClientError err) (PCTransmission err msg)
streamProtocolCommands :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> NonEmpty (ClientCommand msg) -> ([Response err msg] -> IO ()) -> IO ()
streamProtocolCommands c@ProtocolClient {batch, blockSize} cs cb = do
bs <- batchClientTransmissions batch blockSize <$> mapM (mkTransmission c) cs
mapM_ (cb <=< sendBatch c) bs
type TResponseOrErr err msg = Either (ProtocolClientError err) (TResponse err msg)
sendBatch :: ProtocolClient err msg -> ClientBatch err msg -> IO [Response err msg]
sendBatch c@ProtocolClient {client_ = PClient {sndQ}} b = do
case b of
CBLargeTransmission Request {entityId} -> do
putStrLn "send error: large message"
pure [Response entityId $ Left $ PCETransportError TELargeMsg]
CBTransmissions s n rs -> do
when (n > 0) $ atomically $ writeTBQueue sndQ $ tEncodeBatch n s
mapConcurrently (getResponse c) rs
CBTransmission s r -> do
atomically $ writeTBQueue sndQ s
(: []) <$> getResponse c r
data ClientBatch err msg
-- ByteString in CBTransmissions does not include count byte, it is added by tEncodeBatch
= CBTransmissions Int ByteString [TResponseOrErr err msg]
| CBTransmission ByteString (TResponse err msg)
| CBLargeTransmission
= CBTransmissions ByteString Int [Request err msg]
| CBTransmission ByteString (Request err msg)
| CBLargeTransmission (Request err msg)
-- | encodes and batches transmissions into blocks
batchClientTransmissions :: forall err msg. Bool -> Int -> NonEmpty (PCTransmissionOrErr err msg) -> [ClientBatch err msg]
batchClientTransmissions batch bSize
batchClientTransmissions :: forall err msg. Bool -> Int -> NonEmpty (PCTransmission err msg) -> [ClientBatch err msg]
batchClientTransmissions batch blkSize
| batch = reverse . mkBatch []
| otherwise = map mkBatch1 . L.toList
where
mkBatch :: [ClientBatch err msg] -> NonEmpty (PCTransmissionOrErr err msg) -> [ClientBatch err msg]
mkBatch :: [ClientBatch err msg] -> NonEmpty (PCTransmission err msg) -> [ClientBatch err msg]
mkBatch bs ts =
let (b, ts_) = encodeBatch 0 "" [] ts
let (b, ts_) = encodeBatch "" 0 [] ts
bs' = b : bs
in maybe bs' (mkBatch bs') ts_
mkBatch1 :: PCTransmissionOrErr err msg -> ClientBatch err msg
mkBatch1 = \case
Left e -> CBTransmissions 0 "" [Left e]
Right (t, r) ->
let s = tEncode t
in if B.length s > bSize - 2 then CBLargeTransmission else CBTransmission s r
encodeBatch :: Int -> ByteString -> [TResponseOrErr err msg] -> NonEmpty (PCTransmissionOrErr err msg) -> (ClientBatch err msg, Maybe (NonEmpty (PCTransmissionOrErr err msg)))
encodeBatch n s rs ts@(t_ :| ts_)
| n == 255 = (res, Just ts)
| otherwise = case t_ of
Left e -> next n s (Left e : rs)
Right (t, r)
| B.length s' <= bSize - 3 -> next (n + 1) s' (Right r : rs)
| null rs -> (CBLargeTransmission, L.nonEmpty ts_)
| otherwise -> (res, Just ts)
where
s' = s <> smpEncode (Large $ tEncode t)
mkBatch1 :: PCTransmission err msg -> ClientBatch err msg
mkBatch1 (t, r)
| B.length s <= blkSize - 2 = CBTransmission s r
| otherwise = CBLargeTransmission r
where
res = CBTransmissions n s (reverse rs)
next n' s' rs' = case L.nonEmpty ts_ of
Just ts' -> encodeBatch n' s' rs' ts'
Nothing -> (CBTransmissions n' s' (reverse rs'), Nothing)
s = tEncode t
encodeBatch :: ByteString -> Int -> [Request err msg] -> NonEmpty (PCTransmission err msg) -> (ClientBatch err msg, Maybe (NonEmpty (PCTransmission err msg)))
encodeBatch s n rs ts@((t, r) :| ts_)
| B.length s' <= blkSize - 3 && n < 255 =
case L.nonEmpty ts_ of
Just ts' -> encodeBatch s' n' rs' ts'
Nothing -> (CBTransmissions s' n' (reverse rs'), Nothing)
| n == 0 = (CBLargeTransmission r, L.nonEmpty ts_)
| otherwise = (CBTransmissions s n (reverse rs), Just ts)
where
s' = s <> smpEncode (Large $ tEncode t)
n' = n + 1
rs' = r : rs
-- | Send Protocol command
sendProtocolCommand :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> Maybe C.APrivateSignKey -> QueueId -> ProtoCommand msg -> ExceptT (ProtocolClientError err) IO msg
sendProtocolCommand c@ProtocolClient {client_ = PClient {sndQ, tcpTimeout}, batch, blockSize} pKey qId cmd = do
(t, r) <- mkTransmission c (pKey, qId, cmd)
ExceptT $ sendRecv t r
sendProtocolCommand :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> Maybe C.APrivateSignKey -> EntityId -> ProtoCommand msg -> ExceptT (ProtocolClientError err) IO msg
sendProtocolCommand c@ProtocolClient {client_ = PClient {sndQ}, batch, blockSize} pKey entId cmd =
ExceptT $ uncurry sendRecv =<< mkTransmission c (pKey, entId, cmd)
where
-- two separate "atomically" needed to avoid blocking
sendRecv :: SentRawTransmission -> TResponse err msg -> IO (Response err msg)
sendRecv :: SentRawTransmission -> Request err msg -> IO (Either (ProtocolClientError err) msg)
sendRecv t r
| B.length s > blockSize - 2 = pure $ Left $ PCETransportError TELargeMsg
| otherwise = atomically (writeTBQueue sndQ s) >> withTimeout c tcpTimeout (atomically $ takeTMVar r)
| otherwise = atomically (writeTBQueue sndQ s) >> response <$> getResponse c r
where
s
| batch = tEncodeBatch 1 . smpEncode . Large $ tEncode t
| otherwise = tEncode t
withTimeout :: ProtocolClient err msg -> Int -> IO (Response err msg) -> IO (Response err msg)
withTimeout ProtocolClient {client_ = PClient {pingErrorCount}} t a = do
timeout t a >>= \case
Just r -> atomically (writeTVar pingErrorCount 0) >> pure r
_ -> pure $ Left PCEResponseTimeout
-- TODO switch to timeout or TimeManager that supports Int64
getResponse :: ProtocolClient err msg -> Request err msg -> IO (Response err msg)
getResponse ProtocolClient {client_ = PClient {tcpTimeout, pingErrorCount}} Request {entityId, responseVar} = do
response <-
timeout tcpTimeout (atomically (takeTMVar responseVar)) >>= \case
Just r -> atomically (writeTVar pingErrorCount 0) $> r
Nothing -> pure $ Left PCEResponseTimeout
pure Response {entityId, response}
mkTransmission :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> ClientCommand msg -> ExceptT (ProtocolClientError err) IO (PCTransmission err msg)
mkTransmission ProtocolClient {sessionId, thVersion, client_ = PClient {clientCorrId, sentCommands}} (pKey, qId, cmd) = do
corrId <- liftIO $ atomically getNextCorrId
let t = signTransmission $ encodeTransmission thVersion sessionId (corrId, qId, cmd)
r <- liftIO . atomically $ mkRequest corrId
mkTransmission :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> ClientCommand msg -> IO (PCTransmission err msg)
mkTransmission ProtocolClient {sessionId, thVersion, client_ = PClient {clientCorrId, sentCommands}} (pKey, entId, cmd) = do
corrId <- atomically getNextCorrId
let t = signTransmission $ encodeTransmission thVersion sessionId (corrId, entId, cmd)
r <- atomically $ mkRequest corrId
pure (t, r)
where
getNextCorrId :: STM CorrId
@@ -748,8 +759,8 @@ mkTransmission ProtocolClient {sessionId, thVersion, client_ = PClient {clientCo
pure . CorrId $ bshow i
signTransmission :: ByteString -> SentRawTransmission
signTransmission t = ((`C.sign` t) <$> pKey, t)
mkRequest :: CorrId -> STM (TResponse err msg)
mkRequest :: CorrId -> STM (Request err msg)
mkRequest corrId = do
r <- newEmptyTMVar
TM.insert corrId (Request qId r) sentCommands
r <- Request entId <$> newEmptyTMVar
TM.insert corrId r sentCommands
pure r
+8 -8
View File
@@ -83,7 +83,7 @@ testClientBatchSubscriptions = do
all lenOk1' batches1 `shouldBe` True
let batches = batchClientTransmissions True smpBlockSize $ L.fromList subs
length batches `shouldBe` 3
[CBTransmissions n1 s1 rs1, CBTransmissions n2 s2 rs2, CBTransmissions n3 s3 rs3] <- pure batches
[CBTransmissions s1 n1 rs1, CBTransmissions s2 n2 rs2, CBTransmissions s3 n3 rs3] <- pure batches
(n1, n2, n3) `shouldBe` (90, 90, 20)
(length rs1, length rs2, length rs3) `shouldBe` (90, 90, 20)
all lenOk [s1, s2, s3] `shouldBe` True
@@ -101,7 +101,7 @@ testClientBatchWithMessage = do
length batches1 `shouldBe` 101
let batches = batchClientTransmissions True smpBlockSize $ L.fromList cmds
length batches `shouldBe` 2
[CBTransmissions n1 s1 rs1, CBTransmissions n2 s2 rs2] <- pure batches
[CBTransmissions s1 n1 rs1, CBTransmissions s2 n2 rs2] <- pure batches
(n1, n2) `shouldBe` (60, 41)
(length rs1, length rs2) `shouldBe` (60, 41)
all lenOk [s1, s2] `shouldBe` True
@@ -123,7 +123,7 @@ testClientBatchWithLargeMessage = do
--
let batches = batchClientTransmissions True smpBlockSize $ L.fromList cmds
length batches `shouldBe` 4
[CBTransmissions n1 s1 rs1, CBLargeTransmission, CBTransmissions n2 s2 rs2, CBTransmissions n3 s3 rs3] <- pure batches
[CBTransmissions s1 n1 rs1, CBLargeTransmission _, CBTransmissions s2 n2 rs2, CBTransmissions s3 n3 rs3] <- pure batches
(n1, n2, n3) `shouldBe` (60, 90, 10)
(length rs1, length rs2, length rs3) `shouldBe` (60, 90, 10)
all lenOk [s1, s2, s3] `shouldBe` True
@@ -131,7 +131,7 @@ testClientBatchWithLargeMessage = do
let cmds' = [send] <> subs1 <> subs2
let batches' = batchClientTransmissions True smpBlockSize $ L.fromList cmds'
length batches' `shouldBe` 3
[CBLargeTransmission, CBTransmissions n1' s1' rs1', CBTransmissions n2' s2' rs2'] <- pure batches'
[CBLargeTransmission _, CBTransmissions s1' n1' rs1', CBTransmissions s2' n2' rs2'] <- pure batches'
(n1', n2') `shouldBe` (90, 70)
(length rs1', length rs2') `shouldBe` (90, 70)
all lenOk [s1', s2'] `shouldBe` True
@@ -144,11 +144,11 @@ randomSUB sessId = do
let s = encodeTransmission (maxVersion supportedSMPServerVRange) sessId (corrId, rId, Cmd SRecipient SUB)
pure (Just $ C.sign rpKey s, s)
randomSUBCmd :: ProtocolClient ErrorType BrokerMsg -> IO (Either (ProtocolClientError ErrorType) (PCTransmission ErrorType BrokerMsg))
randomSUBCmd :: ProtocolClient ErrorType BrokerMsg -> IO (PCTransmission ErrorType BrokerMsg)
randomSUBCmd c = do
rId <- getRandomBytes 24
(_, rpKey) <- C.generateSignatureKeyPair C.SEd448
runExceptT $ mkTransmission c (Just rpKey, rId, Cmd SRecipient SUB)
mkTransmission c (Just rpKey, rId, Cmd SRecipient SUB)
randomSEND :: ByteString -> Int -> IO (Maybe C.ASignature, ByteString)
randomSEND sessId len = do
@@ -159,12 +159,12 @@ randomSEND sessId len = do
let s = encodeTransmission (maxVersion supportedSMPServerVRange) sessId (corrId, sId, Cmd SSender $ SEND noMsgFlags msg)
pure (Just $ C.sign rpKey s, s)
randomSENDCmd :: ProtocolClient ErrorType BrokerMsg -> Int -> IO (Either (ProtocolClientError ErrorType) (PCTransmission ErrorType BrokerMsg))
randomSENDCmd :: ProtocolClient ErrorType BrokerMsg -> Int -> IO (PCTransmission ErrorType BrokerMsg)
randomSENDCmd c len = do
sId <- getRandomBytes 24
(_, rpKey) <- C.generateSignatureKeyPair C.SEd448
msg <- getRandomBytes len
runExceptT $ mkTransmission c (Just rpKey, sId, Cmd SSender $ SEND noMsgFlags msg)
mkTransmission c (Just rpKey, sId, Cmd SSender $ SEND noMsgFlags msg)
lenOk :: ByteString -> Bool
lenOk s = 0 < B.length s && B.length s <= smpBlockSize - 2