From 90a5da41d57bc2692db4b7633ab459beb38aa221 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Mon, 25 Jan 2021 20:34:36 +0000 Subject: [PATCH] kill TCP server client threads when the main server thread is killed; test END notification when server connection dies --- src/Simplex/Messaging/Transport.hs | 14 ++++++++++--- tests/AgentTests.hs | 32 ++++++++++++++++++++++++------ tests/SMPAgentClient.hs | 21 +++++++++++++++----- tests/SMPClient.hs | 10 ++++++---- 4 files changed, 59 insertions(+), 18 deletions(-) diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index e330aa980..d0e2aa0e9 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -12,6 +12,8 @@ import Control.Monad.IO.Unlift import Control.Monad.Reader import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B +import Data.Set (Set) +import qualified Data.Set as S import GHC.IO.Exception (IOErrorType (..)) import Network.Socket import System.IO @@ -20,12 +22,18 @@ import UnliftIO.Concurrent import UnliftIO.Exception (Exception, IOException) import qualified UnliftIO.Exception as E import qualified UnliftIO.IO as IO +import UnliftIO.STM runTCPServer :: MonadUnliftIO m => ServiceName -> (Handle -> m ()) -> m () -runTCPServer port server = - E.bracket (liftIO $ startTCPServer port) (liftIO . close) $ \sock -> forever $ do +runTCPServer port server = do + clients <- newTVarIO S.empty + E.bracket (liftIO $ startTCPServer port) (liftIO . closeServer clients) $ \sock -> forever $ do h <- liftIO $ acceptTCPConn sock - forkFinally (server h) (const $ IO.hClose h) + tid <- forkFinally (server h) (const $ IO.hClose h) + atomically . modifyTVar clients $ S.insert tid + where + closeServer :: TVar (Set ThreadId) -> Socket -> IO () + closeServer clients sock = readTVarIO clients >>= mapM_ killThread >> close sock startTCPServer :: ServiceName -> IO Socket startTCPServer port = withSocketsDo $ resolve >>= open diff --git a/tests/AgentTests.hs b/tests/AgentTests.hs index 19b6c010a..0e03643f1 100644 --- a/tests/AgentTests.hs +++ b/tests/AgentTests.hs @@ -8,6 +8,7 @@ module AgentTests where import AgentTests.SQLiteTests (storeTests) +import Control.Concurrent import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import SMPAgentClient @@ -29,6 +30,8 @@ agentTests = do describe "Connection subscriptions" do it "should connect via one server and one agent" $ smpAgentTest3_1 testSubscription + it "should send notifications to client when server disconnects" $ + smpAgentServerTest testSubscrNotification -- | simple test for one command with the expected response (>#>) :: ARawTransmission -> ARawTransmission -> Expectation @@ -47,6 +50,11 @@ h #: t = tPutRaw h t >> tGet SAgent h (#>) :: IO (ATransmissionOrError 'Agent) -> ATransmission 'Agent -> Expectation action #> (corrId, cAlias, cmd) = action `shouldReturn` (corrId, cAlias, Right cmd) +-- | action and predicate for the response +-- `h #:t =#> p` is the test that sends `t` to `h` and validates the response using `p` +(=#>) :: IO (ATransmissionOrError 'Agent) -> (ATransmissionOrError 'Agent -> Bool) -> Expectation +action =#> p = action >>= (`shouldSatisfy` p) + -- | receive message to handle `h` and validate that it is the expected one (<#) :: Handle -> ATransmission 'Agent -> Expectation h <# (corrId, cAlias, cmd) = tGet SAgent h `shouldReturn` (corrId, cAlias, Right cmd) @@ -55,6 +63,15 @@ h <# (corrId, cAlias, cmd) = tGet SAgent h `shouldReturn` (corrId, cAlias, Right (<#=) :: Handle -> (ATransmissionOrError 'Agent -> Bool) -> Expectation h <#= p = tGet SAgent h >>= (`shouldSatisfy` p) +-- | test that nothing is delivered to handle `h` during 10ms +(#:#) :: Handle -> String -> Expectation +h #:# err = tryGet `shouldReturn` () + where + tryGet = + 10000 `timeout` tGet SAgent h >>= \case + Just _ -> error err + _ -> return () + pattern Msg :: MsgBody -> Either AgentErrorType (ACommand 'Agent) pattern Msg msg <- Right (MSG _ _ _ _ msg) @@ -79,9 +96,7 @@ testDuplexConnection alice bob = do alice #: ("5", "bob", "OFF") #> ("5", "bob", OK) bob #: ("17", "alice", "SEND 9\nmessage 3") #> ("17", "alice", ERR (SMP AUTH)) alice #: ("6", "bob", "DEL") #> ("6", "bob", OK) - 10000 `timeout` tGet SAgent alice >>= \case - Nothing -> return () - Just _ -> error "nothing else should be delivered to alice" + alice #:# "nothing else should be delivered to alice" testSubscription :: Handle -> Handle -> Handle -> IO () testSubscription alice1 alice2 bob = do @@ -100,9 +115,14 @@ testSubscription alice1 alice2 bob = do alice2 #: ("22", "bob", "ACK 0") #> ("22", "bob", OK) bob #: ("14", "alice", "SEND 2\nhi") #> ("14", "alice", OK) alice2 <#= \case ("", "bob", Msg "hi") -> True; _ -> False - 10000 `timeout` tGet SAgent alice1 >>= \case - Nothing -> return () - Just _ -> error "nothing else should be delivered to alice" + alice1 #:# "nothing else should be delivered to alice1" + +testSubscrNotification :: (ThreadId, ThreadId) -> Handle -> IO () +testSubscrNotification (server, _) client = do + client #: ("1", "conn1", "NEW localhost:5000") =#> \case ("1", "conn1", Right (INV _)) -> True; _ -> False + client #:# "nothing should be delivered to client before the server is killed" + killThread server + client <# ("", "conn1", END) syntaxTests :: Spec syntaxTests = do diff --git a/tests/SMPAgentClient.hs b/tests/SMPAgentClient.hs index 15eb48dc6..ee67ef8f9 100644 --- a/tests/SMPAgentClient.hs +++ b/tests/SMPAgentClient.hs @@ -9,7 +9,7 @@ import Control.Monad import Control.Monad.IO.Unlift import Crypto.Random import Network.Socket -import SMPClient (testPort, withSmpServer) +import SMPClient (testPort, withSmpServer, withSmpServerThreadOn) import Simplex.Messaging.Agent import Simplex.Messaging.Agent.Env.SQLite import Simplex.Messaging.Agent.Transmission @@ -48,6 +48,15 @@ smpAgentTest cmd = runSmpAgentTest $ \h -> tPutRaw h cmd >> tGetRaw h runSmpAgentTest :: (MonadUnliftIO m, MonadRandom m) => (Handle -> m a) -> m a runSmpAgentTest test = withSmpServer . withSmpAgent $ testSMPAgentClient test +runSmpAgentServerTest :: (MonadUnliftIO m, MonadRandom m) => ((ThreadId, ThreadId) -> Handle -> m a) -> m a +runSmpAgentServerTest test = + withSmpServerThreadOn testPort $ + \server -> withSmpAgentThreadOn (agentTestPort, testDB) $ + \agent -> testSMPAgentClient $ test (server, agent) + +smpAgentServerTest :: ((ThreadId, ThreadId) -> Handle -> IO ()) -> Expectation +smpAgentServerTest test' = runSmpAgentServerTest test' `shouldReturn` () + runSmpAgentTestN :: forall m a. (MonadUnliftIO m, MonadRandom m) => [(ServiceName, String)] -> ([Handle] -> m a) -> m a runSmpAgentTestN agents test = withSmpServer $ run agents [] where @@ -111,12 +120,14 @@ cfg = } } -withSmpAgentOn :: (MonadUnliftIO m, MonadRandom m) => (ServiceName, String) -> m a -> m a -withSmpAgentOn (port', db') = +withSmpAgentThreadOn :: (MonadUnliftIO m, MonadRandom m) => (ServiceName, String) -> (ThreadId -> m a) -> m a +withSmpAgentThreadOn (port', db') = E.bracket - (forkIO $ runSMPAgent cfg {tcpPort = port', dbFile = db'}) + (forkIOWithUnmask ($ runSMPAgent cfg {tcpPort = port', dbFile = db'})) (liftIO . killThread >=> const (removeFile db')) - . const + +withSmpAgentOn :: (MonadUnliftIO m, MonadRandom m) => (ServiceName, String) -> m a -> m a +withSmpAgentOn (port', db') = withSmpAgentThreadOn (port', db') . const withSmpAgent :: (MonadUnliftIO m, MonadRandom m) => m a -> m a withSmpAgent = withSmpAgentOn (agentTestPort, testDB) diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index d28ec2030..6ca87db6d 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -40,12 +40,14 @@ cfg = msgIdBytes = 6 } -withSmpServerOn :: (MonadUnliftIO m, MonadRandom m) => ServiceName -> m a -> m a -withSmpServerOn port = +withSmpServerThreadOn :: (MonadUnliftIO m, MonadRandom m) => ServiceName -> (ThreadId -> m a) -> m a +withSmpServerThreadOn port = E.bracket - (forkIO $ runSMPServer cfg {tcpPort = port}) + (forkIOWithUnmask ($ runSMPServer cfg {tcpPort = port})) (liftIO . killThread) - . const + +withSmpServerOn :: (MonadUnliftIO m, MonadRandom m) => ServiceName -> m a -> m a +withSmpServerOn port = withSmpServerThreadOn port . const withSmpServer :: (MonadUnliftIO m, MonadRandom m) => m a -> m a withSmpServer = withSmpServerOn testPort