smp server: possible race when creating client that might lead to memory leak (#1260)

This commit is contained in:
Evgeny
2024-08-20 12:25:58 +01:00
committed by GitHub
parent 1cbf8c0015
commit ac930dff30
2 changed files with 23 additions and 18 deletions
+20 -14
View File
@@ -175,7 +175,7 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do
$>>= endPreviousSubscriptions
>>= liftIO . mapM_ unsub
where
updateSubscribers :: TVar (IM.IntMap Client) -> STM (Maybe (QueueId, Client))
updateSubscribers :: TVar (IM.IntMap (Maybe Client)) -> STM (Maybe (QueueId, Client))
updateSubscribers cls = do
(qId, clnt, subscribed) <- readTQueue $ subQ s
current <- IM.member (clientId clnt) <$> readTVar cls
@@ -412,7 +412,7 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do
CPClients -> withAdminRole $ do
active <- unliftIO u (asks clients) >>= readTVarIO
hPutStrLn h "clientId,sessionId,connected,createdAt,rcvActiveAt,sndActiveAt,age,subscriptions"
forM_ (IM.toList active) $ \(cid, Client {sessionId, connected, createdAt, rcvActiveAt, sndActiveAt, subscriptions}) -> do
forM_ (IM.toList active) $ \(cid, cl) -> forM_ cl $ \Client {sessionId, connected, createdAt, rcvActiveAt, sndActiveAt, subscriptions} -> do
connected' <- bshow <$> readTVarIO connected
rcvActiveAt' <- strEncode <$> readTVarIO rcvActiveAt
sndActiveAt' <- strEncode <$> readTVarIO sndActiveAt
@@ -507,7 +507,7 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do
activeClients <- readTVarIO clients
hPutStrLn h $ "Clients: " <> show (IM.size activeClients)
when (r == CPRAdmin) $ do
clQs <- clientTBQueueLengths activeClients
clQs <- clientTBQueueLengths' activeClients
hPutStrLn h $ "Client queues (rcvQ, sndQ, msgQ): " <> show clQs
(smpSubCnt, smpSubCntByGroup, smpClCnt, smpClQs) <- countClientSubs subscriptions (Just countSMPSubs) activeClients
hPutStrLn h $ "SMP subscriptions (via clients): " <> show smpSubCnt
@@ -542,11 +542,12 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do
| otherwise = (cl : cls, IS.insert clientId clSet)
countSubClients :: M.Map QueueId Client -> Int
countSubClients = IS.size . M.foldr' (IS.insert . clientId) IS.empty
countClientSubs :: (Client -> TMap QueueId a) -> Maybe (M.Map QueueId a -> IO (Int, Int, Int, Int)) -> IM.IntMap Client -> IO (Int, (Int, Int, Int, Int), Int, (Natural, Natural, Natural))
countClientSubs :: (Client -> TMap QueueId a) -> Maybe (M.Map QueueId a -> IO (Int, Int, Int, Int)) -> IM.IntMap (Maybe Client) -> IO (Int, (Int, Int, Int, Int), Int, (Natural, Natural, Natural))
countClientSubs subSel countSubs_ = foldM addSubs (0, (0, 0, 0, 0), 0, (0, 0, 0))
where
addSubs :: (Int, (Int, Int, Int, Int), Int, (Natural, Natural, Natural)) -> Client -> IO (Int, (Int, Int, Int, Int), Int, (Natural, Natural, Natural))
addSubs (!subCnt, cnts@(!c1, !c2, !c3, !c4), !clCnt, !qs) cl = do
addSubs :: (Int, (Int, Int, Int, Int), Int, (Natural, Natural, Natural)) -> Maybe Client -> IO (Int, (Int, Int, Int, Int), Int, (Natural, Natural, Natural))
addSubs acc Nothing = pure acc
addSubs (!subCnt, cnts@(!c1, !c2, !c3, !c4), !clCnt, !qs) (Just cl) = do
subs <- readTVarIO $ subSel cl
cnts' <- case countSubs_ of
Nothing -> pure cnts
@@ -559,6 +560,8 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do
pure (subCnt + cnt, cnts', clCnt', qs')
clientTBQueueLengths :: Foldable t => t Client -> IO (Natural, Natural, Natural)
clientTBQueueLengths = foldM addQueueLengths (0, 0, 0)
clientTBQueueLengths' :: Foldable t => t (Maybe Client) -> IO (Natural, Natural, Natural)
clientTBQueueLengths' = foldM (\acc -> maybe (pure acc) (addQueueLengths acc)) (0, 0, 0)
addQueueLengths (!rl, !sl, !ml) cl = do
(rl', sl', ml') <- queueLengths cl
pure (rl + rl', sl + sl', ml + ml')
@@ -619,15 +622,18 @@ runClientTransport h@THandle {params = thParams@THandleParams {thVersion, sessio
ts <- liftIO getSystemTime
active <- asks clients
nextClientId <- asks clientSeq
c@Client {clientId} <- liftIO $ newClient nextClientId q thVersion sessionId ts
atomically $ modifyTVar' active $ IM.insert clientId c
s <- asks server
expCfg <- asks $ inactiveClientExpiration . config
th <- newMVar h -- put TH under a fair lock to interleave messages and command responses
labelMyThread . B.unpack $ "client $" <> encode sessionId
raceAny_ ([liftIO $ send th c, liftIO $ sendMsg th c, client thParams c s, receive h c] <> disconnectThread_ c expCfg)
`finally` clientDisconnected c
clientId <- atomically $ stateTVar nextClientId $ \next -> (next, next + 1)
atomically $ modifyTVar' active $ IM.insert clientId Nothing
c <- liftIO $ newClient clientId q thVersion sessionId ts
runClientThreads active c clientId `finally` clientDisconnected c
where
runClientThreads active c clientId = do
atomically $ modifyTVar' active $ IM.insert clientId $ Just c
s <- asks server
expCfg <- asks $ inactiveClientExpiration . config
th <- newMVar h -- put TH under a fair lock to interleave messages and command responses
labelMyThread . B.unpack $ "client $" <> encode sessionId
raceAny_ $ [liftIO $ send th c, liftIO $ sendMsg th c, client thParams c s, receive h c] <> disconnectThread_ c expCfg
disconnectThread_ c (Just expCfg) = [liftIO $ disconnectTransport h (rcvActiveAt c) (sndActiveAt c) expCfg (noSubscriptions c)]
disconnectThread_ _ _ = []
noSubscriptions c = atomically $ (&&) <$> TM.null (ntfSubscriptions c) <*> (not . hasSubs <$> readTVar (subscriptions c))
+3 -4
View File
@@ -127,7 +127,7 @@ data Env = Env
serverStats :: ServerStats,
sockets :: SocketState,
clientSeq :: TVar ClientId,
clients :: TVar (IntMap Client),
clients :: TVar (IntMap (Maybe Client)),
proxyAgent :: ProxyAgent -- senders served on this proxy
}
@@ -183,9 +183,8 @@ newServer = do
savingLock <- atomically createLock
return Server {subscribedQ, subscribers, ntfSubscribedQ, notifiers, savingLock}
newClient :: TVar ClientId -> Natural -> VersionSMP -> ByteString -> SystemTime -> IO Client
newClient nextClientId qSize thVersion sessionId createdAt = do
clientId <- atomically $ stateTVar nextClientId $ \next -> (next, next + 1)
newClient :: ClientId -> Natural -> VersionSMP -> ByteString -> SystemTime -> IO Client
newClient clientId qSize thVersion sessionId createdAt = do
subscriptions <- TM.emptyIO
ntfSubscriptions <- TM.emptyIO
rcvQ <- newTBQueueIO qSize