From 17b429afe771079efb1f0ab95a2c8b8f030833f1 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Mon, 25 Jan 2021 19:06:26 +0000 Subject: [PATCH] handle TCP disconnections (WIP) (#29) * handle TCP disconnections (WIP) * agent: handle SMP server disconnections * agent: notify client about lost subscriptions when SMP server disconnects * comments for testing functions * remove test apps * chore: reorder functions in Transport * add comment Co-authored-by: Efim Poberezkin --- src/Simplex/Messaging/Agent.hs | 14 ++++-- src/Simplex/Messaging/Agent/Client.hs | 65 +++++++++++++++++++++------ src/Simplex/Messaging/Client.hs | 38 ++++++++++++---- src/Simplex/Messaging/Transport.hs | 38 ++++++++-------- tests/AgentTests.hs | 7 +++ 5 files changed, 117 insertions(+), 45 deletions(-) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 839f5e31f..82d229dd9 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -17,6 +17,7 @@ import Control.Monad.Reader import Crypto.Random import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B +import qualified Data.Text as T import Data.Text.Encoding import Simplex.Messaging.Agent.Client import Simplex.Messaging.Agent.Env.SQLite @@ -46,13 +47,17 @@ runSMPAgent cfg@AgentConfig {tcpPort} = do q <- asks $ tbqSize . config n <- asks clientCounter c <- atomically $ newAgentClient n q - logInfo $ "client " <> showText (clientId c) <> " connected to Agent" + logConnection c True race_ (connectClient h c) (runClient c) + `E.finally` (closeSMPServerClients c >> logConnection c False) connectClient :: MonadUnliftIO m => Handle -> AgentClient -> m () -connectClient h c = do - race_ (send h c) (receive h c) - logInfo $ "client " <> showText (clientId c) <> " disconnected from Agent" +connectClient h c = race_ (send h c) (receive h c) + +logConnection :: MonadUnliftIO m => AgentClient -> Bool -> m () +logConnection c connected = + let event = if connected then "connected to" else "disconnected from" + in logInfo $ T.unwords ["client", showText (clientId c), event, "Agent"] runClient :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m () runClient c = race_ (subscriber c) (client c) @@ -183,6 +188,7 @@ processCommand c@AgentClient {sndQ} (corrId, connAlias, cmd) = where delete rq = do deleteQueue c rq + removeSubscription c connAlias withStore (`deleteConn` connAlias) respond OK diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index d185ec35d..f95f3d25a 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -12,6 +12,7 @@ module Simplex.Messaging.Agent.Client newAgentClient, AgentMonad, getSMPServerClient, + closeSMPServerClients, newReceiveQueue, subscribeQueue, sendConfirmation, @@ -37,6 +38,8 @@ import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Map.Strict (Map) import qualified Data.Map.Strict as M +import Data.Set (Set) +import qualified Data.Set as S import Data.Text.Encoding import Data.Time.Clock import Numeric.Natural @@ -44,7 +47,7 @@ import Simplex.Messaging.Agent.Env.SQLite import Simplex.Messaging.Agent.Store import Simplex.Messaging.Agent.Transmission import Simplex.Messaging.Client -import Simplex.Messaging.Protocol (QueueId, RecipientId) +import Simplex.Messaging.Protocol (QueueId) import Simplex.Messaging.Server (randomBytes) import Simplex.Messaging.Types (ErrorType (AUTH), MsgBody, PrivateKey, PublicKey, SenderKey) import UnliftIO.Concurrent @@ -57,7 +60,8 @@ data AgentClient = AgentClient sndQ :: TBQueue (ATransmission 'Agent), msgQ :: TBQueue SMPServerTransmission, smpClients :: TVar (Map SMPServer SMPClient), - subscribed :: TVar (Map ConnAlias (SMPServer, RecipientId)), + subscrSrvrs :: TVar (Map SMPServer (Set ConnAlias)), + subscrConns :: TVar (Map ConnAlias SMPServer), clientId :: Int } @@ -67,32 +71,53 @@ newAgentClient cc qSize = do sndQ <- newTBQueue qSize msgQ <- newTBQueue qSize smpClients <- newTVar M.empty - subscribed <- newTVar M.empty + subscrSrvrs <- newTVar M.empty + subscrConns <- newTVar M.empty clientId <- (+ 1) <$> readTVar cc writeTVar cc clientId - return AgentClient {rcvQ, sndQ, msgQ, smpClients, subscribed, clientId} + return AgentClient {rcvQ, sndQ, msgQ, smpClients, subscrSrvrs, subscrConns, clientId} type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m) getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPServer -> m SMPClient -getSMPServerClient AgentClient {smpClients, msgQ} srv = +getSMPServerClient c@AgentClient {smpClients, msgQ} srv = readTVarIO smpClients >>= maybe newSMPClient return . M.lookup srv where newSMPClient :: m SMPClient newSMPClient = do - c <- connectClient + smp <- connectClient logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv -- TODO how can agent know client lost the connection? - atomically . modifyTVar smpClients $ M.insert srv c - return c + atomically . modifyTVar smpClients $ M.insert srv smp + return smp connectClient :: m SMPClient connectClient = do cfg <- asks $ smpCfg . config - liftIO (getSMPClient srv cfg msgQ) + liftIO (getSMPClient srv cfg msgQ clientDisconnected) `E.catch` \(_ :: IOException) -> throwError (BROKER smpErrTCPConnection) + clientDisconnected :: IO () + clientDisconnected = do + removeServerSubs >>= mapM_ (mapM_ removeNotifySub) + logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv + + removeServerSubs :: IO (Maybe (Set ConnAlias)) + removeServerSubs = atomically $ do + modifyTVar smpClients $ M.delete srv + ss <- readTVar (subscrSrvrs c) + writeTVar (subscrSrvrs c) $ M.delete srv ss + return $ M.lookup srv ss + + removeNotifySub :: ConnAlias -> IO () + removeNotifySub connAlias = atomically $ do + modifyTVar (subscrConns c) $ M.delete connAlias + writeTBQueue (sndQ c) ("", connAlias, END) + +closeSMPServerClients :: MonadUnliftIO m => AgentClient -> m () +closeSMPServerClients c = liftIO $ readTVarIO (smpClients c) >>= mapM_ closeSMPClient + withSMP :: forall a m. AgentMonad m => AgentClient -> SMPServer -> (SMPClient -> ExceptT SMPClientError IO a) -> m a withSMP c srv action = (getSMPServerClient c srv >>= runAction) `catchError` logServerError @@ -152,12 +177,26 @@ subscribeQueue c rq@ReceiveQueue {server, rcvPrivateKey, rcvId} connAlias = do addSubscription c rq connAlias addSubscription :: MonadUnliftIO m => AgentClient -> ReceiveQueue -> ConnAlias -> m () -addSubscription c ReceiveQueue {server, rcvId} connAlias = - atomically . modifyTVar (subscribed c) $ M.insert connAlias (server, rcvId) +addSubscription c ReceiveQueue {server} connAlias = atomically $ do + modifyTVar (subscrConns c) $ M.insert connAlias server + modifyTVar (subscrSrvrs c) $ M.alter (Just . addSub) server + where + addSub :: Maybe (Set ConnAlias) -> Set ConnAlias + addSub (Just cs) = S.insert connAlias cs + addSub _ = S.singleton connAlias removeSubscription :: AgentMonad m => AgentClient -> ConnAlias -> m () -removeSubscription c connAlias = - atomically . modifyTVar (subscribed c) $ M.delete connAlias +removeSubscription AgentClient {subscrConns, subscrSrvrs} connAlias = atomically $ do + cs <- readTVar subscrConns + writeTVar subscrConns $ M.delete connAlias cs + mapM_ + (modifyTVar subscrSrvrs . M.alter (>>= delSub)) + (M.lookup connAlias cs) + where + delSub :: Set ConnAlias -> Maybe (Set ConnAlias) + delSub cs = + let cs' = S.delete connAlias cs + in if S.null cs' then Nothing else Just cs' logServer :: AgentMonad m => ByteString -> AgentClient -> SMPServer -> QueueId -> ByteString -> m () logServer dir AgentClient {clientId} srv qId cmdStr = diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index 7abd54aa2..1d17a825f 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -10,6 +10,7 @@ module Simplex.Messaging.Client ( SMPClient, getSMPClient, + closeSMPClient, createSMPQueue, subscribeSMPQueue, secureSMPQueue, @@ -49,6 +50,7 @@ import System.Timeout data SMPClient = SMPClient { action :: Async (), + connected :: TVar Bool, smpServer :: SMPServer, clientCorrId :: TVar Natural, sentCommands :: TVar (Map CorrId Request), @@ -78,16 +80,20 @@ data Request = Request responseVar :: TMVar (Either SMPClientError Cmd) } -getSMPClient :: SMPServer -> SMPClientConfig -> TBQueue SMPServerTransmission -> IO SMPClient +getSMPClient :: SMPServer -> SMPClientConfig -> TBQueue SMPServerTransmission -> IO () -> IO SMPClient getSMPClient smpServer@SMPServer {host, port} SMPClientConfig {qSize, defaultPort, tcpTimeout} - msgQ = do + msgQ + disconnected = do c <- atomically mkSMPClient started <- newEmptyTMVarIO - action <- async $ runTCPClient host (fromMaybe defaultPort port) (client c started) + action <- + async $ + runTCPClient host (fromMaybe defaultPort port) (client c started) + `finally` atomically (putTMVar started False) tcpTimeout `timeout` atomically (takeTMVar started) >>= \case - Just _ -> return c {action} + Just True -> return c {action} _ -> throwIO err where err :: IOException @@ -95,18 +101,31 @@ getSMPClient mkSMPClient :: STM SMPClient mkSMPClient = do + connected <- newTVar False clientCorrId <- newTVar 0 sentCommands <- newTVar M.empty sndQ <- newTBQueue qSize rcvQ <- newTBQueue qSize - return SMPClient {action = undefined, smpServer, clientCorrId, sentCommands, sndQ, rcvQ, msgQ} + return + SMPClient + { action = undefined, + connected, + smpServer, + clientCorrId, + sentCommands, + sndQ, + rcvQ, + msgQ + } - client :: SMPClient -> TMVar () -> Handle -> IO () + client :: SMPClient -> TMVar Bool -> Handle -> IO () client c started h = do _ <- getLn h -- "Welcome to SMP" - atomically $ putTMVar started () - -- TODO call continuation on disconnection after raceAny_ exits + atomically $ do + modifyTVar (connected c) (const True) + putTMVar started True raceAny_ [send c h, process c, receive c h] + `finally` disconnected send :: SMPClient -> Handle -> IO () send SMPClient {sndQ} h = forever $ atomically (readTBQueue sndQ) >>= tPut h @@ -134,6 +153,9 @@ getSMPClient Right r -> Right r else Left SMPQueueIdError +closeSMPClient :: SMPClient -> IO () +closeSMPClient = uninterruptibleCancel . action + data SMPClientError = SMPServerError ErrorType | SMPResponseError ErrorType diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index a235d9dbc..e330aa980 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -21,12 +21,18 @@ import UnliftIO.Exception (Exception, IOException) import qualified UnliftIO.Exception as E import qualified UnliftIO.IO as IO +runTCPServer :: MonadUnliftIO m => ServiceName -> (Handle -> m ()) -> m () +runTCPServer port server = + E.bracket (liftIO $ startTCPServer port) (liftIO . close) $ \sock -> forever $ do + h <- liftIO $ acceptTCPConn sock + forkFinally (server h) (const $ IO.hClose h) + startTCPServer :: ServiceName -> IO Socket startTCPServer port = withSocketsDo $ resolve >>= open where - resolve = do + resolve = let hints = defaultHints {addrFlags = [AI_PASSIVE], addrSocketType = Stream} - head <$> getAddrInfo (Just hints) Nothing (Just port) + in head <$> getAddrInfo (Just hints) Nothing (Just port) open addr = do sock <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr) setSocketOption sock ReuseAddr 1 @@ -35,33 +41,30 @@ startTCPServer port = withSocketsDo $ resolve >>= open listen sock 1024 return sock -runTCPServer :: MonadUnliftIO m => ServiceName -> (Handle -> m ()) -> m () -runTCPServer port server = - E.bracket (liftIO $ startTCPServer port) (liftIO . close) $ \sock -> forever $ do - h <- liftIO $ acceptTCPConn sock - forkFinally (server h) (const $ IO.hClose h) - acceptTCPConn :: Socket -> IO Handle -acceptTCPConn sock = do - (conn, _) <- accept sock - getSocketHandle conn +acceptTCPConn sock = accept sock >>= getSocketHandle . fst + +runTCPClient :: MonadUnliftIO m => HostName -> ServiceName -> (Handle -> m a) -> m a +runTCPClient host port client = do + h <- liftIO $ startTCPClient host port + client h `E.finally` IO.hClose h startTCPClient :: HostName -> ServiceName -> IO Handle startTCPClient host port = withSocketsDo $ - resolve >>= foldM tryOpen (Left err) >>= either E.throwIO return + resolve >>= foldM tryOpen (Left err) >>= either E.throwIO return -- replace fold with recursion where err :: IOException err = mkIOError NoSuchThing "no address" Nothing Nothing resolve :: IO [AddrInfo] - resolve = do + resolve = let hints = defaultHints {addrSocketType = Stream} - getAddrInfo (Just hints) (Just host) (Just port) + in getAddrInfo (Just hints) (Just host) (Just port) tryOpen :: Exception e => Either e Handle -> AddrInfo -> IO (Either e Handle) - tryOpen h@(Right _) _ = return h tryOpen (Left _) addr = E.try $ open addr + tryOpen h _ = return h open :: AddrInfo -> IO Handle open addr = do @@ -69,11 +72,6 @@ startTCPClient host port = connect sock $ addrAddress addr getSocketHandle sock -runTCPClient :: MonadUnliftIO m => HostName -> ServiceName -> (Handle -> m a) -> m a -runTCPClient host port client = do - h <- liftIO $ startTCPClient host port - client h `E.finally` IO.hClose h - getSocketHandle :: Socket -> IO Handle getSocketHandle conn = do h <- socketToHandle conn ReadWriteMode diff --git a/tests/AgentTests.hs b/tests/AgentTests.hs index d1580b1fb..19b6c010a 100644 --- a/tests/AgentTests.hs +++ b/tests/AgentTests.hs @@ -30,21 +30,28 @@ agentTests = do it "should connect via one server and one agent" $ smpAgentTest3_1 testSubscription +-- | simple test for one command with the expected response (>#>) :: ARawTransmission -> ARawTransmission -> Expectation command >#> response = smpAgentTest command `shouldReturn` response +-- | simple test for one command with a predicate for the expected response (>#>=) :: ARawTransmission -> ((ByteString, ByteString, [ByteString]) -> Bool) -> Expectation command >#>= p = smpAgentTest command >>= (`shouldSatisfy` p . \(cId, cAlias, cmd) -> (cId, cAlias, B.words cmd)) +-- | send transmission `t` to handle `h` and get response (#:) :: Handle -> (ByteString, ByteString, ByteString) -> IO (ATransmissionOrError 'Agent) h #: t = tPutRaw h t >> tGet SAgent h +-- | action and expected response +-- `h #:t #> r` is the test that sends `t` to `h` and validates that the response is `r` (#>) :: IO (ATransmissionOrError 'Agent) -> ATransmission 'Agent -> Expectation action #> (corrId, cAlias, cmd) = action `shouldReturn` (corrId, cAlias, Right cmd) +-- | receive message to handle `h` and validate that it is the expected one (<#) :: Handle -> ATransmission 'Agent -> Expectation h <# (corrId, cAlias, cmd) = tGet SAgent h `shouldReturn` (corrId, cAlias, Right cmd) +-- | receive message to handle `h` and validate it using predicate `p` (<#=) :: Handle -> (ATransmissionOrError 'Agent -> Bool) -> Expectation h <#= p = tGet SAgent h >>= (`shouldSatisfy` p)