mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-05-23 08:15:10 +00:00
deduplicate connections in connect/disconnect responses, log errors in tPut (#593)
* remove TODO for old handshake version (this HELLO is not sent now) * deduplicate connections in responses and verify server in the list of subscribed queues * log transport and LargeMsg in tPut (the results it returns are only used in the tests) * refactor * refactor
This commit is contained in:
committed by
GitHub
parent
61e0c346df
commit
acfa65200a
@@ -1045,7 +1045,6 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
|
||||
withStore' c $ \db -> do
|
||||
setSndQueueStatus db sq Confirmed
|
||||
when (isJust rq_) $ removeConfirmations db connId
|
||||
-- TODO possibly notification flag should be ON for one of the parties, to result in contact connected notification
|
||||
unless (duplexHandshake == Just True) . void $ enqueueMessage c cData sq SMP.noMsgFlags HELLO
|
||||
AM_CONN_INFO_REPLY -> pure ()
|
||||
AM_REPLY_ -> pure ()
|
||||
|
||||
@@ -98,7 +98,7 @@ import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Either (isRight, partitionEithers)
|
||||
import Data.Functor (($>))
|
||||
import Data.List (partition, (\\))
|
||||
import Data.List (partition)
|
||||
import Data.List.NonEmpty (NonEmpty (..))
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import Data.Map.Strict (Map)
|
||||
@@ -303,12 +303,9 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
|
||||
removeClientAndSubs :: IO ([RcvQueue], [ConnId])
|
||||
removeClientAndSubs = atomically $ do
|
||||
TM.delete srv smpClients
|
||||
qs <- RQ.getDelSrvQueues srv $ activeSubs c
|
||||
(qs, conns) <- RQ.getDelSrvQueues srv $ activeSubs c
|
||||
mapM_ (`RQ.addQueue` pendingSubs c) qs
|
||||
cs <- RQ.getConns (activeSubs c)
|
||||
-- TODO deduplicate conns
|
||||
let conns = map (connId :: RcvQueue -> ConnId) qs \\ S.toList cs
|
||||
pure (qs, conns)
|
||||
pure (qs, S.toList conns)
|
||||
|
||||
serverDown :: ([RcvQueue], [ConnId]) -> IO ()
|
||||
serverDown (qs, conns) = whenM (readTVarIO active) $ do
|
||||
@@ -345,8 +342,7 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
|
||||
unless connected . forM_ client_ $ \cl -> do
|
||||
incClientStat c cl "CONNECT" ""
|
||||
notifySub "" $ hostEvent CONNECT cl
|
||||
-- TODO deduplicate okConns
|
||||
let conns = okConns \\ S.toList cs
|
||||
let conns = S.toList $ S.fromList okConns `S.difference` cs
|
||||
unless (null conns) $ notifySub "" $ UP srv conns
|
||||
let (tempErrs, finalErrs) = partition (temporaryAgentError . snd) errs
|
||||
liftIO $ mapM_ (\(connId, e) -> notifySub connId $ ERR e) finalErrs
|
||||
@@ -647,8 +643,7 @@ temporaryOrHostError = \case
|
||||
subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> [RcvQueue] -> m (Maybe SMPClient, [(RcvQueue, Either AgentErrorType ())])
|
||||
subscribeQueues c srv qs = do
|
||||
(errs, qs_) <- partitionEithers <$> mapM checkQueue qs
|
||||
forM_ qs_ $ \rq@RcvQueue {connId, server = _server} -> atomically $ do
|
||||
-- TODO check server is correct
|
||||
forM_ qs_ $ \rq@RcvQueue {connId} -> atomically $ do
|
||||
modifyTVar (subscrConns c) $ S.insert connId
|
||||
RQ.addQueue rq $ pendingSubs c
|
||||
case L.nonEmpty qs_ of
|
||||
@@ -667,9 +662,11 @@ subscribeQueues c srv qs = do
|
||||
pure $ map (second . first $ protocolClientError SMP $ clientServer smp) rs
|
||||
_ -> pure (Nothing, errs)
|
||||
where
|
||||
checkQueue rq@RcvQueue {rcvId, server} = do
|
||||
prohibited <- atomically . TM.member (server, rcvId) $ getMsgLocks c
|
||||
pure $ if prohibited || srv /= server then Left (rq, Left $ CMD PROHIBITED) else Right rq
|
||||
checkQueue rq@RcvQueue {rcvId, server}
|
||||
| server == srv = do
|
||||
prohibited <- atomically . TM.member (server, rcvId) $ getMsgLocks c
|
||||
pure $ if prohibited || srv /= server then Left (rq, Left $ CMD PROHIBITED) else Right rq
|
||||
| otherwise = pure $ Left (rq, Left $ INTERNAL "queue server does not match parameter")
|
||||
queueCreds RcvQueue {rcvPrivateKey, rcvId} = (rcvPrivateKey, rcvId)
|
||||
|
||||
addSubscription :: MonadIO m => AgentClient -> RcvQueue -> m ()
|
||||
|
||||
@@ -40,9 +40,9 @@ getSrvQueues srv (TRcvQueues qs) = M.foldl' addQ [] <$> readTVar qs
|
||||
where
|
||||
addQ qs' rq@RcvQueue {server} = if srv == server then rq : qs' else qs'
|
||||
|
||||
getDelSrvQueues :: SMPServer -> TRcvQueues -> STM [RcvQueue]
|
||||
getDelSrvQueues srv (TRcvQueues qs) = stateTVar qs $ M.foldl' addQ ([], M.empty)
|
||||
getDelSrvQueues :: SMPServer -> TRcvQueues -> STM ([RcvQueue], Set ConnId)
|
||||
getDelSrvQueues srv (TRcvQueues qs) = stateTVar qs $ M.foldl' addQ (([], S.empty), M.empty)
|
||||
where
|
||||
addQ (removed, qs') rq@RcvQueue {server, rcvId}
|
||||
| srv == server = (rq : removed, qs')
|
||||
addQ (removed@(remQs, remConns), qs') rq@RcvQueue {connId, server, rcvId}
|
||||
| srv == server = ((rq : remQs, S.insert connId remConns), qs')
|
||||
| otherwise = (removed, M.insert (server, rcvId) rq qs')
|
||||
|
||||
@@ -1160,26 +1160,33 @@ instance Encoding CommandError where
|
||||
_ -> fail "bad command error type"
|
||||
|
||||
-- | Send signed SMP transmission to TCP transport.
|
||||
tPut :: Transport c => THandle c -> NonEmpty SentRawTransmission -> IO (NonEmpty (Either TransportError ()))
|
||||
tPut :: Transport c => THandle c -> NonEmpty SentRawTransmission -> IO [Either TransportError ()]
|
||||
tPut th trs
|
||||
| batch th = tPutBatch [] $ L.map tEncode trs
|
||||
| otherwise = forM trs $ tPutBlock th . tEncode
|
||||
| otherwise = forM (L.toList trs) $ tPutLog . tEncode
|
||||
where
|
||||
tPutBatch :: [Either TransportError ()] -> NonEmpty ByteString -> IO (NonEmpty (Either TransportError ()))
|
||||
tPutBatch :: [Either TransportError ()] -> NonEmpty ByteString -> IO [Either TransportError ()]
|
||||
tPutBatch rs ts = do
|
||||
let (n, s, ts_) = encodeBatch 0 "" ts
|
||||
r <- if n == 0 then pure [Left TELargeMsg] else replicate n <$> tPutBlock th (lenEncode n `B.cons` s)
|
||||
r <- if n == 0 then largeMsg else replicate n <$> tPutLog (lenEncode n `B.cons` s)
|
||||
let rs' = rs <> r
|
||||
case ts_ of
|
||||
Just ts' -> tPutBatch rs' ts'
|
||||
_ -> pure $ L.fromList rs'
|
||||
_ -> pure rs'
|
||||
largeMsg = putStrLn "tPut error: large message" >> pure [Left TELargeMsg]
|
||||
tPutLog s = do
|
||||
r <- tPutBlock th s
|
||||
case r of
|
||||
Left e -> putStrLn ("tPut error: " <> show e)
|
||||
_ -> pure ()
|
||||
pure r
|
||||
encodeBatch :: Int -> ByteString -> NonEmpty ByteString -> (Int, ByteString, Maybe (NonEmpty ByteString))
|
||||
encodeBatch n s ts@(t :| ts_)
|
||||
| n == 255 = (n, s, Just ts)
|
||||
| otherwise =
|
||||
let s' = s <> smpEncode (Large t)
|
||||
n' = n + 1
|
||||
in if B.length s' > blockSize th - 1
|
||||
in if B.length s' > blockSize th - 1 -- one byte is reserved for the number of messages in the batch
|
||||
then (n,s,) $ if n == 0 then L.nonEmpty ts_ else Just ts
|
||||
else case L.nonEmpty ts_ of
|
||||
Just ts' -> encodeBatch n' s' ts'
|
||||
|
||||
@@ -256,7 +256,6 @@ receive th Client {rcvQ, sndQ, activeAt} = forever $ do
|
||||
send :: Transport c => THandle c -> Client -> IO ()
|
||||
send h@THandle {thVersion = v} Client {sndQ, sessionId, activeAt} = forever $ do
|
||||
ts <- atomically $ L.sortWith tOrder <$> readTBQueue sndQ
|
||||
-- TODO the line below can return Lefts, but we ignore it and do not disconnect the client
|
||||
void . liftIO . tPut h $ L.map ((Nothing,) . encodeTransmission v sessionId) ts
|
||||
atomically . writeTVar activeAt =<< liftIO getSystemTime
|
||||
where
|
||||
|
||||
Reference in New Issue
Block a user