deduplicate connections in connect/disconnect responses, log errors in tPut (#593)

* remove TODO for old handshake version (this HELLO is not sent now)

* deduplicate connections in responses and verify server in the list of subscribed queues

* log transport and LargeMsg in tPut (the results it returns are only used in the tests)

* refactor

* refactor
This commit is contained in:
Evgeny Poberezkin
2023-01-06 17:14:49 +00:00
committed by GitHub
parent 61e0c346df
commit acfa65200a
5 changed files with 27 additions and 25 deletions
-1
View File
@@ -1045,7 +1045,6 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
withStore' c $ \db -> do
setSndQueueStatus db sq Confirmed
when (isJust rq_) $ removeConfirmations db connId
-- TODO possibly notification flag should be ON for one of the parties, to result in contact connected notification
unless (duplexHandshake == Just True) . void $ enqueueMessage c cData sq SMP.noMsgFlags HELLO
AM_CONN_INFO_REPLY -> pure ()
AM_REPLY_ -> pure ()
+10 -13
View File
@@ -98,7 +98,7 @@ import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Either (isRight, partitionEithers)
import Data.Functor (($>))
import Data.List (partition, (\\))
import Data.List (partition)
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as L
import Data.Map.Strict (Map)
@@ -303,12 +303,9 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
removeClientAndSubs :: IO ([RcvQueue], [ConnId])
removeClientAndSubs = atomically $ do
TM.delete srv smpClients
qs <- RQ.getDelSrvQueues srv $ activeSubs c
(qs, conns) <- RQ.getDelSrvQueues srv $ activeSubs c
mapM_ (`RQ.addQueue` pendingSubs c) qs
cs <- RQ.getConns (activeSubs c)
-- TODO deduplicate conns
let conns = map (connId :: RcvQueue -> ConnId) qs \\ S.toList cs
pure (qs, conns)
pure (qs, S.toList conns)
serverDown :: ([RcvQueue], [ConnId]) -> IO ()
serverDown (qs, conns) = whenM (readTVarIO active) $ do
@@ -345,8 +342,7 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
unless connected . forM_ client_ $ \cl -> do
incClientStat c cl "CONNECT" ""
notifySub "" $ hostEvent CONNECT cl
-- TODO deduplicate okConns
let conns = okConns \\ S.toList cs
let conns = S.toList $ S.fromList okConns `S.difference` cs
unless (null conns) $ notifySub "" $ UP srv conns
let (tempErrs, finalErrs) = partition (temporaryAgentError . snd) errs
liftIO $ mapM_ (\(connId, e) -> notifySub connId $ ERR e) finalErrs
@@ -647,8 +643,7 @@ temporaryOrHostError = \case
subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> [RcvQueue] -> m (Maybe SMPClient, [(RcvQueue, Either AgentErrorType ())])
subscribeQueues c srv qs = do
(errs, qs_) <- partitionEithers <$> mapM checkQueue qs
forM_ qs_ $ \rq@RcvQueue {connId, server = _server} -> atomically $ do
-- TODO check server is correct
forM_ qs_ $ \rq@RcvQueue {connId} -> atomically $ do
modifyTVar (subscrConns c) $ S.insert connId
RQ.addQueue rq $ pendingSubs c
case L.nonEmpty qs_ of
@@ -667,9 +662,11 @@ subscribeQueues c srv qs = do
pure $ map (second . first $ protocolClientError SMP $ clientServer smp) rs
_ -> pure (Nothing, errs)
where
checkQueue rq@RcvQueue {rcvId, server} = do
prohibited <- atomically . TM.member (server, rcvId) $ getMsgLocks c
pure $ if prohibited || srv /= server then Left (rq, Left $ CMD PROHIBITED) else Right rq
checkQueue rq@RcvQueue {rcvId, server}
| server == srv = do
prohibited <- atomically . TM.member (server, rcvId) $ getMsgLocks c
pure $ if prohibited || srv /= server then Left (rq, Left $ CMD PROHIBITED) else Right rq
| otherwise = pure $ Left (rq, Left $ INTERNAL "queue server does not match parameter")
queueCreds RcvQueue {rcvPrivateKey, rcvId} = (rcvPrivateKey, rcvId)
addSubscription :: MonadIO m => AgentClient -> RcvQueue -> m ()
+4 -4
View File
@@ -40,9 +40,9 @@ getSrvQueues srv (TRcvQueues qs) = M.foldl' addQ [] <$> readTVar qs
where
addQ qs' rq@RcvQueue {server} = if srv == server then rq : qs' else qs'
getDelSrvQueues :: SMPServer -> TRcvQueues -> STM [RcvQueue]
getDelSrvQueues srv (TRcvQueues qs) = stateTVar qs $ M.foldl' addQ ([], M.empty)
getDelSrvQueues :: SMPServer -> TRcvQueues -> STM ([RcvQueue], Set ConnId)
getDelSrvQueues srv (TRcvQueues qs) = stateTVar qs $ M.foldl' addQ (([], S.empty), M.empty)
where
addQ (removed, qs') rq@RcvQueue {server, rcvId}
| srv == server = (rq : removed, qs')
addQ (removed@(remQs, remConns), qs') rq@RcvQueue {connId, server, rcvId}
| srv == server = ((rq : remQs, S.insert connId remConns), qs')
| otherwise = (removed, M.insert (server, rcvId) rq qs')
+13 -6
View File
@@ -1160,26 +1160,33 @@ instance Encoding CommandError where
_ -> fail "bad command error type"
-- | Send signed SMP transmission to TCP transport.
tPut :: Transport c => THandle c -> NonEmpty SentRawTransmission -> IO (NonEmpty (Either TransportError ()))
tPut :: Transport c => THandle c -> NonEmpty SentRawTransmission -> IO [Either TransportError ()]
tPut th trs
| batch th = tPutBatch [] $ L.map tEncode trs
| otherwise = forM trs $ tPutBlock th . tEncode
| otherwise = forM (L.toList trs) $ tPutLog . tEncode
where
tPutBatch :: [Either TransportError ()] -> NonEmpty ByteString -> IO (NonEmpty (Either TransportError ()))
tPutBatch :: [Either TransportError ()] -> NonEmpty ByteString -> IO [Either TransportError ()]
tPutBatch rs ts = do
let (n, s, ts_) = encodeBatch 0 "" ts
r <- if n == 0 then pure [Left TELargeMsg] else replicate n <$> tPutBlock th (lenEncode n `B.cons` s)
r <- if n == 0 then largeMsg else replicate n <$> tPutLog (lenEncode n `B.cons` s)
let rs' = rs <> r
case ts_ of
Just ts' -> tPutBatch rs' ts'
_ -> pure $ L.fromList rs'
_ -> pure rs'
largeMsg = putStrLn "tPut error: large message" >> pure [Left TELargeMsg]
tPutLog s = do
r <- tPutBlock th s
case r of
Left e -> putStrLn ("tPut error: " <> show e)
_ -> pure ()
pure r
encodeBatch :: Int -> ByteString -> NonEmpty ByteString -> (Int, ByteString, Maybe (NonEmpty ByteString))
encodeBatch n s ts@(t :| ts_)
| n == 255 = (n, s, Just ts)
| otherwise =
let s' = s <> smpEncode (Large t)
n' = n + 1
in if B.length s' > blockSize th - 1
in if B.length s' > blockSize th - 1 -- one byte is reserved for the number of messages in the batch
then (n,s,) $ if n == 0 then L.nonEmpty ts_ else Just ts
else case L.nonEmpty ts_ of
Just ts' -> encodeBatch n' s' ts'
-1
View File
@@ -256,7 +256,6 @@ receive th Client {rcvQ, sndQ, activeAt} = forever $ do
send :: Transport c => THandle c -> Client -> IO ()
send h@THandle {thVersion = v} Client {sndQ, sessionId, activeAt} = forever $ do
ts <- atomically $ L.sortWith tOrder <$> readTBQueue sndQ
-- TODO the line below can return Lefts, but we ignore it and do not disconnect the client
void . liftIO . tPut h $ L.map ((Nothing,) . encodeTransmission v sessionId) ts
atomically . writeTVar activeAt =<< liftIO getSystemTime
where