mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-05-24 19:05:24 +00:00
handle TCP disconnections (WIP) (#29)
* handle TCP disconnections (WIP) * agent: handle SMP server disconnections * agent: notify client about lost subscriptions when SMP server disconnects * comments for testing functions * remove test apps * chore: reorder functions in Transport * add comment Co-authored-by: Efim Poberezkin <efim.poberezkin@gmail.com>
This commit is contained in:
committed by
Efim Poberezkin
parent
e09d3bae99
commit
17b429afe7
@@ -17,6 +17,7 @@ import Control.Monad.Reader
|
||||
import Crypto.Random
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.Text as T
|
||||
import Data.Text.Encoding
|
||||
import Simplex.Messaging.Agent.Client
|
||||
import Simplex.Messaging.Agent.Env.SQLite
|
||||
@@ -46,13 +47,17 @@ runSMPAgent cfg@AgentConfig {tcpPort} = do
|
||||
q <- asks $ tbqSize . config
|
||||
n <- asks clientCounter
|
||||
c <- atomically $ newAgentClient n q
|
||||
logInfo $ "client " <> showText (clientId c) <> " connected to Agent"
|
||||
logConnection c True
|
||||
race_ (connectClient h c) (runClient c)
|
||||
`E.finally` (closeSMPServerClients c >> logConnection c False)
|
||||
|
||||
connectClient :: MonadUnliftIO m => Handle -> AgentClient -> m ()
|
||||
connectClient h c = do
|
||||
race_ (send h c) (receive h c)
|
||||
logInfo $ "client " <> showText (clientId c) <> " disconnected from Agent"
|
||||
connectClient h c = race_ (send h c) (receive h c)
|
||||
|
||||
logConnection :: MonadUnliftIO m => AgentClient -> Bool -> m ()
|
||||
logConnection c connected =
|
||||
let event = if connected then "connected to" else "disconnected from"
|
||||
in logInfo $ T.unwords ["client", showText (clientId c), event, "Agent"]
|
||||
|
||||
runClient :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m ()
|
||||
runClient c = race_ (subscriber c) (client c)
|
||||
@@ -183,6 +188,7 @@ processCommand c@AgentClient {sndQ} (corrId, connAlias, cmd) =
|
||||
where
|
||||
delete rq = do
|
||||
deleteQueue c rq
|
||||
removeSubscription c connAlias
|
||||
withStore (`deleteConn` connAlias)
|
||||
respond OK
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ module Simplex.Messaging.Agent.Client
|
||||
newAgentClient,
|
||||
AgentMonad,
|
||||
getSMPServerClient,
|
||||
closeSMPServerClients,
|
||||
newReceiveQueue,
|
||||
subscribeQueue,
|
||||
sendConfirmation,
|
||||
@@ -37,6 +38,8 @@ import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Map.Strict (Map)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Set (Set)
|
||||
import qualified Data.Set as S
|
||||
import Data.Text.Encoding
|
||||
import Data.Time.Clock
|
||||
import Numeric.Natural
|
||||
@@ -44,7 +47,7 @@ import Simplex.Messaging.Agent.Env.SQLite
|
||||
import Simplex.Messaging.Agent.Store
|
||||
import Simplex.Messaging.Agent.Transmission
|
||||
import Simplex.Messaging.Client
|
||||
import Simplex.Messaging.Protocol (QueueId, RecipientId)
|
||||
import Simplex.Messaging.Protocol (QueueId)
|
||||
import Simplex.Messaging.Server (randomBytes)
|
||||
import Simplex.Messaging.Types (ErrorType (AUTH), MsgBody, PrivateKey, PublicKey, SenderKey)
|
||||
import UnliftIO.Concurrent
|
||||
@@ -57,7 +60,8 @@ data AgentClient = AgentClient
|
||||
sndQ :: TBQueue (ATransmission 'Agent),
|
||||
msgQ :: TBQueue SMPServerTransmission,
|
||||
smpClients :: TVar (Map SMPServer SMPClient),
|
||||
subscribed :: TVar (Map ConnAlias (SMPServer, RecipientId)),
|
||||
subscrSrvrs :: TVar (Map SMPServer (Set ConnAlias)),
|
||||
subscrConns :: TVar (Map ConnAlias SMPServer),
|
||||
clientId :: Int
|
||||
}
|
||||
|
||||
@@ -67,32 +71,53 @@ newAgentClient cc qSize = do
|
||||
sndQ <- newTBQueue qSize
|
||||
msgQ <- newTBQueue qSize
|
||||
smpClients <- newTVar M.empty
|
||||
subscribed <- newTVar M.empty
|
||||
subscrSrvrs <- newTVar M.empty
|
||||
subscrConns <- newTVar M.empty
|
||||
clientId <- (+ 1) <$> readTVar cc
|
||||
writeTVar cc clientId
|
||||
return AgentClient {rcvQ, sndQ, msgQ, smpClients, subscribed, clientId}
|
||||
return AgentClient {rcvQ, sndQ, msgQ, smpClients, subscrSrvrs, subscrConns, clientId}
|
||||
|
||||
type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m)
|
||||
|
||||
getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPServer -> m SMPClient
|
||||
getSMPServerClient AgentClient {smpClients, msgQ} srv =
|
||||
getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
|
||||
readTVarIO smpClients
|
||||
>>= maybe newSMPClient return . M.lookup srv
|
||||
where
|
||||
newSMPClient :: m SMPClient
|
||||
newSMPClient = do
|
||||
c <- connectClient
|
||||
smp <- connectClient
|
||||
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv
|
||||
-- TODO how can agent know client lost the connection?
|
||||
atomically . modifyTVar smpClients $ M.insert srv c
|
||||
return c
|
||||
atomically . modifyTVar smpClients $ M.insert srv smp
|
||||
return smp
|
||||
|
||||
connectClient :: m SMPClient
|
||||
connectClient = do
|
||||
cfg <- asks $ smpCfg . config
|
||||
liftIO (getSMPClient srv cfg msgQ)
|
||||
liftIO (getSMPClient srv cfg msgQ clientDisconnected)
|
||||
`E.catch` \(_ :: IOException) -> throwError (BROKER smpErrTCPConnection)
|
||||
|
||||
clientDisconnected :: IO ()
|
||||
clientDisconnected = do
|
||||
removeServerSubs >>= mapM_ (mapM_ removeNotifySub)
|
||||
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
|
||||
|
||||
removeServerSubs :: IO (Maybe (Set ConnAlias))
|
||||
removeServerSubs = atomically $ do
|
||||
modifyTVar smpClients $ M.delete srv
|
||||
ss <- readTVar (subscrSrvrs c)
|
||||
writeTVar (subscrSrvrs c) $ M.delete srv ss
|
||||
return $ M.lookup srv ss
|
||||
|
||||
removeNotifySub :: ConnAlias -> IO ()
|
||||
removeNotifySub connAlias = atomically $ do
|
||||
modifyTVar (subscrConns c) $ M.delete connAlias
|
||||
writeTBQueue (sndQ c) ("", connAlias, END)
|
||||
|
||||
closeSMPServerClients :: MonadUnliftIO m => AgentClient -> m ()
|
||||
closeSMPServerClients c = liftIO $ readTVarIO (smpClients c) >>= mapM_ closeSMPClient
|
||||
|
||||
withSMP :: forall a m. AgentMonad m => AgentClient -> SMPServer -> (SMPClient -> ExceptT SMPClientError IO a) -> m a
|
||||
withSMP c srv action =
|
||||
(getSMPServerClient c srv >>= runAction) `catchError` logServerError
|
||||
@@ -152,12 +177,26 @@ subscribeQueue c rq@ReceiveQueue {server, rcvPrivateKey, rcvId} connAlias = do
|
||||
addSubscription c rq connAlias
|
||||
|
||||
addSubscription :: MonadUnliftIO m => AgentClient -> ReceiveQueue -> ConnAlias -> m ()
|
||||
addSubscription c ReceiveQueue {server, rcvId} connAlias =
|
||||
atomically . modifyTVar (subscribed c) $ M.insert connAlias (server, rcvId)
|
||||
addSubscription c ReceiveQueue {server} connAlias = atomically $ do
|
||||
modifyTVar (subscrConns c) $ M.insert connAlias server
|
||||
modifyTVar (subscrSrvrs c) $ M.alter (Just . addSub) server
|
||||
where
|
||||
addSub :: Maybe (Set ConnAlias) -> Set ConnAlias
|
||||
addSub (Just cs) = S.insert connAlias cs
|
||||
addSub _ = S.singleton connAlias
|
||||
|
||||
removeSubscription :: AgentMonad m => AgentClient -> ConnAlias -> m ()
|
||||
removeSubscription c connAlias =
|
||||
atomically . modifyTVar (subscribed c) $ M.delete connAlias
|
||||
removeSubscription AgentClient {subscrConns, subscrSrvrs} connAlias = atomically $ do
|
||||
cs <- readTVar subscrConns
|
||||
writeTVar subscrConns $ M.delete connAlias cs
|
||||
mapM_
|
||||
(modifyTVar subscrSrvrs . M.alter (>>= delSub))
|
||||
(M.lookup connAlias cs)
|
||||
where
|
||||
delSub :: Set ConnAlias -> Maybe (Set ConnAlias)
|
||||
delSub cs =
|
||||
let cs' = S.delete connAlias cs
|
||||
in if S.null cs' then Nothing else Just cs'
|
||||
|
||||
logServer :: AgentMonad m => ByteString -> AgentClient -> SMPServer -> QueueId -> ByteString -> m ()
|
||||
logServer dir AgentClient {clientId} srv qId cmdStr =
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
module Simplex.Messaging.Client
|
||||
( SMPClient,
|
||||
getSMPClient,
|
||||
closeSMPClient,
|
||||
createSMPQueue,
|
||||
subscribeSMPQueue,
|
||||
secureSMPQueue,
|
||||
@@ -49,6 +50,7 @@ import System.Timeout
|
||||
|
||||
data SMPClient = SMPClient
|
||||
{ action :: Async (),
|
||||
connected :: TVar Bool,
|
||||
smpServer :: SMPServer,
|
||||
clientCorrId :: TVar Natural,
|
||||
sentCommands :: TVar (Map CorrId Request),
|
||||
@@ -78,16 +80,20 @@ data Request = Request
|
||||
responseVar :: TMVar (Either SMPClientError Cmd)
|
||||
}
|
||||
|
||||
getSMPClient :: SMPServer -> SMPClientConfig -> TBQueue SMPServerTransmission -> IO SMPClient
|
||||
getSMPClient :: SMPServer -> SMPClientConfig -> TBQueue SMPServerTransmission -> IO () -> IO SMPClient
|
||||
getSMPClient
|
||||
smpServer@SMPServer {host, port}
|
||||
SMPClientConfig {qSize, defaultPort, tcpTimeout}
|
||||
msgQ = do
|
||||
msgQ
|
||||
disconnected = do
|
||||
c <- atomically mkSMPClient
|
||||
started <- newEmptyTMVarIO
|
||||
action <- async $ runTCPClient host (fromMaybe defaultPort port) (client c started)
|
||||
action <-
|
||||
async $
|
||||
runTCPClient host (fromMaybe defaultPort port) (client c started)
|
||||
`finally` atomically (putTMVar started False)
|
||||
tcpTimeout `timeout` atomically (takeTMVar started) >>= \case
|
||||
Just _ -> return c {action}
|
||||
Just True -> return c {action}
|
||||
_ -> throwIO err
|
||||
where
|
||||
err :: IOException
|
||||
@@ -95,18 +101,31 @@ getSMPClient
|
||||
|
||||
mkSMPClient :: STM SMPClient
|
||||
mkSMPClient = do
|
||||
connected <- newTVar False
|
||||
clientCorrId <- newTVar 0
|
||||
sentCommands <- newTVar M.empty
|
||||
sndQ <- newTBQueue qSize
|
||||
rcvQ <- newTBQueue qSize
|
||||
return SMPClient {action = undefined, smpServer, clientCorrId, sentCommands, sndQ, rcvQ, msgQ}
|
||||
return
|
||||
SMPClient
|
||||
{ action = undefined,
|
||||
connected,
|
||||
smpServer,
|
||||
clientCorrId,
|
||||
sentCommands,
|
||||
sndQ,
|
||||
rcvQ,
|
||||
msgQ
|
||||
}
|
||||
|
||||
client :: SMPClient -> TMVar () -> Handle -> IO ()
|
||||
client :: SMPClient -> TMVar Bool -> Handle -> IO ()
|
||||
client c started h = do
|
||||
_ <- getLn h -- "Welcome to SMP"
|
||||
atomically $ putTMVar started ()
|
||||
-- TODO call continuation on disconnection after raceAny_ exits
|
||||
atomically $ do
|
||||
modifyTVar (connected c) (const True)
|
||||
putTMVar started True
|
||||
raceAny_ [send c h, process c, receive c h]
|
||||
`finally` disconnected
|
||||
|
||||
send :: SMPClient -> Handle -> IO ()
|
||||
send SMPClient {sndQ} h = forever $ atomically (readTBQueue sndQ) >>= tPut h
|
||||
@@ -134,6 +153,9 @@ getSMPClient
|
||||
Right r -> Right r
|
||||
else Left SMPQueueIdError
|
||||
|
||||
closeSMPClient :: SMPClient -> IO ()
|
||||
closeSMPClient = uninterruptibleCancel . action
|
||||
|
||||
data SMPClientError
|
||||
= SMPServerError ErrorType
|
||||
| SMPResponseError ErrorType
|
||||
|
||||
@@ -21,12 +21,18 @@ import UnliftIO.Exception (Exception, IOException)
|
||||
import qualified UnliftIO.Exception as E
|
||||
import qualified UnliftIO.IO as IO
|
||||
|
||||
runTCPServer :: MonadUnliftIO m => ServiceName -> (Handle -> m ()) -> m ()
|
||||
runTCPServer port server =
|
||||
E.bracket (liftIO $ startTCPServer port) (liftIO . close) $ \sock -> forever $ do
|
||||
h <- liftIO $ acceptTCPConn sock
|
||||
forkFinally (server h) (const $ IO.hClose h)
|
||||
|
||||
startTCPServer :: ServiceName -> IO Socket
|
||||
startTCPServer port = withSocketsDo $ resolve >>= open
|
||||
where
|
||||
resolve = do
|
||||
resolve =
|
||||
let hints = defaultHints {addrFlags = [AI_PASSIVE], addrSocketType = Stream}
|
||||
head <$> getAddrInfo (Just hints) Nothing (Just port)
|
||||
in head <$> getAddrInfo (Just hints) Nothing (Just port)
|
||||
open addr = do
|
||||
sock <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr)
|
||||
setSocketOption sock ReuseAddr 1
|
||||
@@ -35,33 +41,30 @@ startTCPServer port = withSocketsDo $ resolve >>= open
|
||||
listen sock 1024
|
||||
return sock
|
||||
|
||||
runTCPServer :: MonadUnliftIO m => ServiceName -> (Handle -> m ()) -> m ()
|
||||
runTCPServer port server =
|
||||
E.bracket (liftIO $ startTCPServer port) (liftIO . close) $ \sock -> forever $ do
|
||||
h <- liftIO $ acceptTCPConn sock
|
||||
forkFinally (server h) (const $ IO.hClose h)
|
||||
|
||||
acceptTCPConn :: Socket -> IO Handle
|
||||
acceptTCPConn sock = do
|
||||
(conn, _) <- accept sock
|
||||
getSocketHandle conn
|
||||
acceptTCPConn sock = accept sock >>= getSocketHandle . fst
|
||||
|
||||
runTCPClient :: MonadUnliftIO m => HostName -> ServiceName -> (Handle -> m a) -> m a
|
||||
runTCPClient host port client = do
|
||||
h <- liftIO $ startTCPClient host port
|
||||
client h `E.finally` IO.hClose h
|
||||
|
||||
startTCPClient :: HostName -> ServiceName -> IO Handle
|
||||
startTCPClient host port =
|
||||
withSocketsDo $
|
||||
resolve >>= foldM tryOpen (Left err) >>= either E.throwIO return
|
||||
resolve >>= foldM tryOpen (Left err) >>= either E.throwIO return -- replace fold with recursion
|
||||
where
|
||||
err :: IOException
|
||||
err = mkIOError NoSuchThing "no address" Nothing Nothing
|
||||
|
||||
resolve :: IO [AddrInfo]
|
||||
resolve = do
|
||||
resolve =
|
||||
let hints = defaultHints {addrSocketType = Stream}
|
||||
getAddrInfo (Just hints) (Just host) (Just port)
|
||||
in getAddrInfo (Just hints) (Just host) (Just port)
|
||||
|
||||
tryOpen :: Exception e => Either e Handle -> AddrInfo -> IO (Either e Handle)
|
||||
tryOpen h@(Right _) _ = return h
|
||||
tryOpen (Left _) addr = E.try $ open addr
|
||||
tryOpen h _ = return h
|
||||
|
||||
open :: AddrInfo -> IO Handle
|
||||
open addr = do
|
||||
@@ -69,11 +72,6 @@ startTCPClient host port =
|
||||
connect sock $ addrAddress addr
|
||||
getSocketHandle sock
|
||||
|
||||
runTCPClient :: MonadUnliftIO m => HostName -> ServiceName -> (Handle -> m a) -> m a
|
||||
runTCPClient host port client = do
|
||||
h <- liftIO $ startTCPClient host port
|
||||
client h `E.finally` IO.hClose h
|
||||
|
||||
getSocketHandle :: Socket -> IO Handle
|
||||
getSocketHandle conn = do
|
||||
h <- socketToHandle conn ReadWriteMode
|
||||
|
||||
@@ -30,21 +30,28 @@ agentTests = do
|
||||
it "should connect via one server and one agent" $
|
||||
smpAgentTest3_1 testSubscription
|
||||
|
||||
-- | simple test for one command with the expected response
|
||||
(>#>) :: ARawTransmission -> ARawTransmission -> Expectation
|
||||
command >#> response = smpAgentTest command `shouldReturn` response
|
||||
|
||||
-- | simple test for one command with a predicate for the expected response
|
||||
(>#>=) :: ARawTransmission -> ((ByteString, ByteString, [ByteString]) -> Bool) -> Expectation
|
||||
command >#>= p = smpAgentTest command >>= (`shouldSatisfy` p . \(cId, cAlias, cmd) -> (cId, cAlias, B.words cmd))
|
||||
|
||||
-- | send transmission `t` to handle `h` and get response
|
||||
(#:) :: Handle -> (ByteString, ByteString, ByteString) -> IO (ATransmissionOrError 'Agent)
|
||||
h #: t = tPutRaw h t >> tGet SAgent h
|
||||
|
||||
-- | action and expected response
|
||||
-- `h #:t #> r` is the test that sends `t` to `h` and validates that the response is `r`
|
||||
(#>) :: IO (ATransmissionOrError 'Agent) -> ATransmission 'Agent -> Expectation
|
||||
action #> (corrId, cAlias, cmd) = action `shouldReturn` (corrId, cAlias, Right cmd)
|
||||
|
||||
-- | receive message to handle `h` and validate that it is the expected one
|
||||
(<#) :: Handle -> ATransmission 'Agent -> Expectation
|
||||
h <# (corrId, cAlias, cmd) = tGet SAgent h `shouldReturn` (corrId, cAlias, Right cmd)
|
||||
|
||||
-- | receive message to handle `h` and validate it using predicate `p`
|
||||
(<#=) :: Handle -> (ATransmissionOrError 'Agent -> Bool) -> Expectation
|
||||
h <#= p = tGet SAgent h >>= (`shouldSatisfy` p)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user