handle TCP disconnections (WIP) (#29)

* handle TCP disconnections (WIP)

* agent: handle SMP server disconnections

* agent: notify client about lost subscriptions when SMP server disconnects

* comments for testing functions

* remove test apps

* chore: reorder functions in Transport

* add comment

Co-authored-by: Efim Poberezkin <efim.poberezkin@gmail.com>
This commit is contained in:
Evgeny Poberezkin
2021-01-25 19:06:26 +00:00
committed by Efim Poberezkin
parent e09d3bae99
commit 17b429afe7
5 changed files with 117 additions and 45 deletions
+10 -4
View File
@@ -17,6 +17,7 @@ import Control.Monad.Reader
import Crypto.Random
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import qualified Data.Text as T
import Data.Text.Encoding
import Simplex.Messaging.Agent.Client
import Simplex.Messaging.Agent.Env.SQLite
@@ -46,13 +47,17 @@ runSMPAgent cfg@AgentConfig {tcpPort} = do
q <- asks $ tbqSize . config
n <- asks clientCounter
c <- atomically $ newAgentClient n q
logInfo $ "client " <> showText (clientId c) <> " connected to Agent"
logConnection c True
race_ (connectClient h c) (runClient c)
`E.finally` (closeSMPServerClients c >> logConnection c False)
connectClient :: MonadUnliftIO m => Handle -> AgentClient -> m ()
connectClient h c = do
race_ (send h c) (receive h c)
logInfo $ "client " <> showText (clientId c) <> " disconnected from Agent"
connectClient h c = race_ (send h c) (receive h c)
logConnection :: MonadUnliftIO m => AgentClient -> Bool -> m ()
logConnection c connected =
let event = if connected then "connected to" else "disconnected from"
in logInfo $ T.unwords ["client", showText (clientId c), event, "Agent"]
runClient :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m ()
runClient c = race_ (subscriber c) (client c)
@@ -183,6 +188,7 @@ processCommand c@AgentClient {sndQ} (corrId, connAlias, cmd) =
where
delete rq = do
deleteQueue c rq
removeSubscription c connAlias
withStore (`deleteConn` connAlias)
respond OK
+52 -13
View File
@@ -12,6 +12,7 @@ module Simplex.Messaging.Agent.Client
newAgentClient,
AgentMonad,
getSMPServerClient,
closeSMPServerClients,
newReceiveQueue,
subscribeQueue,
sendConfirmation,
@@ -37,6 +38,8 @@ import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.Set (Set)
import qualified Data.Set as S
import Data.Text.Encoding
import Data.Time.Clock
import Numeric.Natural
@@ -44,7 +47,7 @@ import Simplex.Messaging.Agent.Env.SQLite
import Simplex.Messaging.Agent.Store
import Simplex.Messaging.Agent.Transmission
import Simplex.Messaging.Client
import Simplex.Messaging.Protocol (QueueId, RecipientId)
import Simplex.Messaging.Protocol (QueueId)
import Simplex.Messaging.Server (randomBytes)
import Simplex.Messaging.Types (ErrorType (AUTH), MsgBody, PrivateKey, PublicKey, SenderKey)
import UnliftIO.Concurrent
@@ -57,7 +60,8 @@ data AgentClient = AgentClient
sndQ :: TBQueue (ATransmission 'Agent),
msgQ :: TBQueue SMPServerTransmission,
smpClients :: TVar (Map SMPServer SMPClient),
subscribed :: TVar (Map ConnAlias (SMPServer, RecipientId)),
subscrSrvrs :: TVar (Map SMPServer (Set ConnAlias)),
subscrConns :: TVar (Map ConnAlias SMPServer),
clientId :: Int
}
@@ -67,32 +71,53 @@ newAgentClient cc qSize = do
sndQ <- newTBQueue qSize
msgQ <- newTBQueue qSize
smpClients <- newTVar M.empty
subscribed <- newTVar M.empty
subscrSrvrs <- newTVar M.empty
subscrConns <- newTVar M.empty
clientId <- (+ 1) <$> readTVar cc
writeTVar cc clientId
return AgentClient {rcvQ, sndQ, msgQ, smpClients, subscribed, clientId}
return AgentClient {rcvQ, sndQ, msgQ, smpClients, subscrSrvrs, subscrConns, clientId}
type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m)
getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPServer -> m SMPClient
getSMPServerClient AgentClient {smpClients, msgQ} srv =
getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
readTVarIO smpClients
>>= maybe newSMPClient return . M.lookup srv
where
newSMPClient :: m SMPClient
newSMPClient = do
c <- connectClient
smp <- connectClient
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv
-- TODO how can agent know client lost the connection?
atomically . modifyTVar smpClients $ M.insert srv c
return c
atomically . modifyTVar smpClients $ M.insert srv smp
return smp
connectClient :: m SMPClient
connectClient = do
cfg <- asks $ smpCfg . config
liftIO (getSMPClient srv cfg msgQ)
liftIO (getSMPClient srv cfg msgQ clientDisconnected)
`E.catch` \(_ :: IOException) -> throwError (BROKER smpErrTCPConnection)
clientDisconnected :: IO ()
clientDisconnected = do
removeServerSubs >>= mapM_ (mapM_ removeNotifySub)
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
removeServerSubs :: IO (Maybe (Set ConnAlias))
removeServerSubs = atomically $ do
modifyTVar smpClients $ M.delete srv
ss <- readTVar (subscrSrvrs c)
writeTVar (subscrSrvrs c) $ M.delete srv ss
return $ M.lookup srv ss
removeNotifySub :: ConnAlias -> IO ()
removeNotifySub connAlias = atomically $ do
modifyTVar (subscrConns c) $ M.delete connAlias
writeTBQueue (sndQ c) ("", connAlias, END)
closeSMPServerClients :: MonadUnliftIO m => AgentClient -> m ()
closeSMPServerClients c = liftIO $ readTVarIO (smpClients c) >>= mapM_ closeSMPClient
withSMP :: forall a m. AgentMonad m => AgentClient -> SMPServer -> (SMPClient -> ExceptT SMPClientError IO a) -> m a
withSMP c srv action =
(getSMPServerClient c srv >>= runAction) `catchError` logServerError
@@ -152,12 +177,26 @@ subscribeQueue c rq@ReceiveQueue {server, rcvPrivateKey, rcvId} connAlias = do
addSubscription c rq connAlias
addSubscription :: MonadUnliftIO m => AgentClient -> ReceiveQueue -> ConnAlias -> m ()
addSubscription c ReceiveQueue {server, rcvId} connAlias =
atomically . modifyTVar (subscribed c) $ M.insert connAlias (server, rcvId)
addSubscription c ReceiveQueue {server} connAlias = atomically $ do
modifyTVar (subscrConns c) $ M.insert connAlias server
modifyTVar (subscrSrvrs c) $ M.alter (Just . addSub) server
where
addSub :: Maybe (Set ConnAlias) -> Set ConnAlias
addSub (Just cs) = S.insert connAlias cs
addSub _ = S.singleton connAlias
removeSubscription :: AgentMonad m => AgentClient -> ConnAlias -> m ()
removeSubscription c connAlias =
atomically . modifyTVar (subscribed c) $ M.delete connAlias
removeSubscription AgentClient {subscrConns, subscrSrvrs} connAlias = atomically $ do
cs <- readTVar subscrConns
writeTVar subscrConns $ M.delete connAlias cs
mapM_
(modifyTVar subscrSrvrs . M.alter (>>= delSub))
(M.lookup connAlias cs)
where
delSub :: Set ConnAlias -> Maybe (Set ConnAlias)
delSub cs =
let cs' = S.delete connAlias cs
in if S.null cs' then Nothing else Just cs'
logServer :: AgentMonad m => ByteString -> AgentClient -> SMPServer -> QueueId -> ByteString -> m ()
logServer dir AgentClient {clientId} srv qId cmdStr =
+30 -8
View File
@@ -10,6 +10,7 @@
module Simplex.Messaging.Client
( SMPClient,
getSMPClient,
closeSMPClient,
createSMPQueue,
subscribeSMPQueue,
secureSMPQueue,
@@ -49,6 +50,7 @@ import System.Timeout
data SMPClient = SMPClient
{ action :: Async (),
connected :: TVar Bool,
smpServer :: SMPServer,
clientCorrId :: TVar Natural,
sentCommands :: TVar (Map CorrId Request),
@@ -78,16 +80,20 @@ data Request = Request
responseVar :: TMVar (Either SMPClientError Cmd)
}
getSMPClient :: SMPServer -> SMPClientConfig -> TBQueue SMPServerTransmission -> IO SMPClient
getSMPClient :: SMPServer -> SMPClientConfig -> TBQueue SMPServerTransmission -> IO () -> IO SMPClient
getSMPClient
smpServer@SMPServer {host, port}
SMPClientConfig {qSize, defaultPort, tcpTimeout}
msgQ = do
msgQ
disconnected = do
c <- atomically mkSMPClient
started <- newEmptyTMVarIO
action <- async $ runTCPClient host (fromMaybe defaultPort port) (client c started)
action <-
async $
runTCPClient host (fromMaybe defaultPort port) (client c started)
`finally` atomically (putTMVar started False)
tcpTimeout `timeout` atomically (takeTMVar started) >>= \case
Just _ -> return c {action}
Just True -> return c {action}
_ -> throwIO err
where
err :: IOException
@@ -95,18 +101,31 @@ getSMPClient
mkSMPClient :: STM SMPClient
mkSMPClient = do
connected <- newTVar False
clientCorrId <- newTVar 0
sentCommands <- newTVar M.empty
sndQ <- newTBQueue qSize
rcvQ <- newTBQueue qSize
return SMPClient {action = undefined, smpServer, clientCorrId, sentCommands, sndQ, rcvQ, msgQ}
return
SMPClient
{ action = undefined,
connected,
smpServer,
clientCorrId,
sentCommands,
sndQ,
rcvQ,
msgQ
}
client :: SMPClient -> TMVar () -> Handle -> IO ()
client :: SMPClient -> TMVar Bool -> Handle -> IO ()
client c started h = do
_ <- getLn h -- "Welcome to SMP"
atomically $ putTMVar started ()
-- TODO call continuation on disconnection after raceAny_ exits
atomically $ do
modifyTVar (connected c) (const True)
putTMVar started True
raceAny_ [send c h, process c, receive c h]
`finally` disconnected
send :: SMPClient -> Handle -> IO ()
send SMPClient {sndQ} h = forever $ atomically (readTBQueue sndQ) >>= tPut h
@@ -134,6 +153,9 @@ getSMPClient
Right r -> Right r
else Left SMPQueueIdError
closeSMPClient :: SMPClient -> IO ()
closeSMPClient = uninterruptibleCancel . action
data SMPClientError
= SMPServerError ErrorType
| SMPResponseError ErrorType
+18 -20
View File
@@ -21,12 +21,18 @@ import UnliftIO.Exception (Exception, IOException)
import qualified UnliftIO.Exception as E
import qualified UnliftIO.IO as IO
runTCPServer :: MonadUnliftIO m => ServiceName -> (Handle -> m ()) -> m ()
runTCPServer port server =
E.bracket (liftIO $ startTCPServer port) (liftIO . close) $ \sock -> forever $ do
h <- liftIO $ acceptTCPConn sock
forkFinally (server h) (const $ IO.hClose h)
startTCPServer :: ServiceName -> IO Socket
startTCPServer port = withSocketsDo $ resolve >>= open
where
resolve = do
resolve =
let hints = defaultHints {addrFlags = [AI_PASSIVE], addrSocketType = Stream}
head <$> getAddrInfo (Just hints) Nothing (Just port)
in head <$> getAddrInfo (Just hints) Nothing (Just port)
open addr = do
sock <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr)
setSocketOption sock ReuseAddr 1
@@ -35,33 +41,30 @@ startTCPServer port = withSocketsDo $ resolve >>= open
listen sock 1024
return sock
runTCPServer :: MonadUnliftIO m => ServiceName -> (Handle -> m ()) -> m ()
runTCPServer port server =
E.bracket (liftIO $ startTCPServer port) (liftIO . close) $ \sock -> forever $ do
h <- liftIO $ acceptTCPConn sock
forkFinally (server h) (const $ IO.hClose h)
acceptTCPConn :: Socket -> IO Handle
acceptTCPConn sock = do
(conn, _) <- accept sock
getSocketHandle conn
acceptTCPConn sock = accept sock >>= getSocketHandle . fst
runTCPClient :: MonadUnliftIO m => HostName -> ServiceName -> (Handle -> m a) -> m a
runTCPClient host port client = do
h <- liftIO $ startTCPClient host port
client h `E.finally` IO.hClose h
startTCPClient :: HostName -> ServiceName -> IO Handle
startTCPClient host port =
withSocketsDo $
resolve >>= foldM tryOpen (Left err) >>= either E.throwIO return
resolve >>= foldM tryOpen (Left err) >>= either E.throwIO return -- replace fold with recursion
where
err :: IOException
err = mkIOError NoSuchThing "no address" Nothing Nothing
resolve :: IO [AddrInfo]
resolve = do
resolve =
let hints = defaultHints {addrSocketType = Stream}
getAddrInfo (Just hints) (Just host) (Just port)
in getAddrInfo (Just hints) (Just host) (Just port)
tryOpen :: Exception e => Either e Handle -> AddrInfo -> IO (Either e Handle)
tryOpen h@(Right _) _ = return h
tryOpen (Left _) addr = E.try $ open addr
tryOpen h _ = return h
open :: AddrInfo -> IO Handle
open addr = do
@@ -69,11 +72,6 @@ startTCPClient host port =
connect sock $ addrAddress addr
getSocketHandle sock
runTCPClient :: MonadUnliftIO m => HostName -> ServiceName -> (Handle -> m a) -> m a
runTCPClient host port client = do
h <- liftIO $ startTCPClient host port
client h `E.finally` IO.hClose h
getSocketHandle :: Socket -> IO Handle
getSocketHandle conn = do
h <- socketToHandle conn ReadWriteMode
+7
View File
@@ -30,21 +30,28 @@ agentTests = do
it "should connect via one server and one agent" $
smpAgentTest3_1 testSubscription
-- | simple test for one command with the expected response
(>#>) :: ARawTransmission -> ARawTransmission -> Expectation
command >#> response = smpAgentTest command `shouldReturn` response
-- | simple test for one command with a predicate for the expected response
(>#>=) :: ARawTransmission -> ((ByteString, ByteString, [ByteString]) -> Bool) -> Expectation
command >#>= p = smpAgentTest command >>= (`shouldSatisfy` p . \(cId, cAlias, cmd) -> (cId, cAlias, B.words cmd))
-- | send transmission `t` to handle `h` and get response
(#:) :: Handle -> (ByteString, ByteString, ByteString) -> IO (ATransmissionOrError 'Agent)
h #: t = tPutRaw h t >> tGet SAgent h
-- | action and expected response
-- `h #:t #> r` is the test that sends `t` to `h` and validates that the response is `r`
(#>) :: IO (ATransmissionOrError 'Agent) -> ATransmission 'Agent -> Expectation
action #> (corrId, cAlias, cmd) = action `shouldReturn` (corrId, cAlias, Right cmd)
-- | receive message to handle `h` and validate that it is the expected one
(<#) :: Handle -> ATransmission 'Agent -> Expectation
h <# (corrId, cAlias, cmd) = tGet SAgent h `shouldReturn` (corrId, cAlias, Right cmd)
-- | receive message to handle `h` and validate it using predicate `p`
(<#=) :: Handle -> (ATransmissionOrError 'Agent -> Bool) -> Expectation
h <#= p = tGet SAgent h >>= (`shouldSatisfy` p)