From 366e208ae0f93532c4e1ebdd38f5424cfae1f02d Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Fri, 16 Oct 2020 18:38:01 +0100 Subject: [PATCH] server thread to track client smp connection subscriptions and notify clients when they should unsubscribe --- src/ConnStore.hs | 6 ++-- src/ConnStore/STM.hs | 6 ++-- src/Env/STM.hs | 5 ++-- src/Server.hs | 65 ++++++++++++++++++++++++++++++++------------ src/Transmission.hs | 8 +++--- 5 files changed, 61 insertions(+), 29 deletions(-) diff --git a/src/ConnStore.hs b/src/ConnStore.hs index 2c2016411..d1b302b8a 100644 --- a/src/ConnStore.hs +++ b/src/ConnStore.hs @@ -18,15 +18,15 @@ data Connection = Connection data ConnStatus = ConnActive | ConnOff class MonadConnStore s m where - createConn :: s -> RecipientKey -> m (Either ErrorType Connection) + addConn :: s -> RecipientKey -> m (Either ErrorType Connection) getConn :: s -> Sing (a :: Party) -> ConnId -> m (Either ErrorType Connection) secureConn :: s -> RecipientId -> SenderKey -> m (Either ErrorType ()) suspendConn :: s -> RecipientId -> m (Either ErrorType ()) deleteConn :: s -> RecipientId -> m (Either ErrorType ()) -- TODO stub -newConnection :: RecipientKey -> Connection -newConnection rKey = +mkConnection :: RecipientKey -> Connection +mkConnection rKey = Connection { recipientId = "1", recipientKey = rKey, diff --git a/src/ConnStore/STM.hs b/src/ConnStore/STM.hs index 0dd9ab135..7bfa41ec9 100644 --- a/src/ConnStore/STM.hs +++ b/src/ConnStore/STM.hs @@ -29,10 +29,10 @@ newConnStore :: STM STMConnStore newConnStore = newTVar ConnStoreData {connections = M.empty, senders = M.empty} instance MonadUnliftIO m => MonadConnStore STMConnStore m where - createConn :: STMConnStore -> RecipientKey -> m (Either ErrorType Connection) - createConn store rKey = atomically $ do + addConn :: STMConnStore -> RecipientKey -> m (Either ErrorType Connection) + addConn store rKey = atomically $ do db <- readTVar store - let c@Connection {recipientId = rId, senderId = sId} = newConnection rKey + let c@Connection {recipientId = rId, senderId = sId} = mkConnection rKey db' = db { connections = M.insert rId c (connections db), diff --git a/src/Env/STM.hs b/src/Env/STM.hs index 488e2b000..a59685790 100644 --- a/src/Env/STM.hs +++ b/src/Env/STM.hs @@ -23,7 +23,7 @@ data Env = Env data Server = Server { subscribedQ :: TBQueue (RecipientId, Client), - connections :: Map RecipientId Client + connections :: TVar (Map RecipientId Client) } data Client = Client @@ -35,7 +35,8 @@ data Client = Client newServer :: Natural -> STM Server newServer qSize = do subscribedQ <- newTBQueue qSize - return Server {subscribedQ, connections = M.empty} + connections <- newTVar M.empty + return Server {subscribedQ, connections} newClient :: Natural -> STM Client newClient qSize = do diff --git a/src/Server.hs b/src/Server.hs index 958d7a31e..4fbca73bf 100644 --- a/src/Server.hs +++ b/src/Server.hs @@ -32,7 +32,21 @@ import UnliftIO.STM runSMPServer :: MonadUnliftIO m => ServiceName -> Natural -> m () runSMPServer port queueSize = do env <- atomically $ newEnv port queueSize - runReaderT (runTCPServer port runClient) env + runReaderT smpServer env + where + smpServer :: (MonadUnliftIO m, MonadReader Env m) => m () + smpServer = do + s <- asks server + race_ (runTCPServer port runClient) (serverThread s) + + serverThread :: MonadUnliftIO m => Server -> m () + serverThread Server {subscribedQ, connections} = forever . atomically $ do + (rId, clnt) <- readTBQueue subscribedQ + cs <- readTVar connections + case M.lookup rId cs of + Just Client {rcvQ} -> writeTBQueue rcvQ (rId, Cmd SBroker END) + Nothing -> return () + writeTVar connections $ M.insert rId clnt cs runClient :: (MonadUnliftIO m, MonadReader Env m) => Handle -> m () runClient h = do @@ -97,14 +111,15 @@ client clnt@Client {connections, rcvQ, sndQ} Server {subscribedQ} = processCommand (connId, cmd) = do st <- asks connStore case cmd of + Cmd SBroker END -> unsubscribeConn >> return (connId, cmd) Cmd SBroker _ -> return (connId, cmd) - Cmd SSender (SEND msgBody) -> do + Cmd SSender (SEND msgBody) -> getConn st SSender connId >>= fmap (mkSigned connId) . either (return . ERR) (storeMessage msgBody) Cmd SRecipient command -> case command of - CONN rKey -> idsResponce <$> createConn st rKey - SUB -> subscribeConnection >> deliverMessage tryPeekMsg - ACK -> deliverMessage tryDelPeekMsg + CONN rKey -> createConn st rKey + SUB -> subscribeConn connId + ACK -> deliverMessage tryDelPeekMsg -- TODO? sending ACK without message loses the message KEY sKey -> okResponse <$> secureConn st connId sKey OFF -> okResponse <$> suspendConn st connId DEL -> okResponse <$> deleteConn st connId @@ -115,20 +130,35 @@ client clnt@Client {connections, rcvQ, sndQ} Server {subscribedQ} = mkSigned :: ConnId -> Command 'Broker -> Signed mkSigned cId command = (cId, Cmd SBroker command) - idsResponce :: Either ErrorType Connection -> Signed - idsResponce = either (mkSigned "" . ERR) $ - \Connection {recipientId = rId, senderId = sId} -> - mkSigned rId $ IDS rId sId - okResponse :: Either ErrorType () -> Signed okResponse = mkSigned connId . either ERR (const OK) - subscribeConnection :: m () - subscribeConnection = atomically $ do - cs <- readTVar connections - when (M.notMember connId cs) $ do - writeTBQueue subscribedQ (connId, clnt) - writeTVar connections $ M.insert connId (Left ()) cs + createConn :: MonadConnStore s m => s -> RecipientKey -> m Signed + createConn st rKey = + addConn st rKey >>= \case + Right Connection {recipientId = rId, senderId = sId} -> do + void $ subscribeConn rId + return . mkSigned rId $ IDS rId sId + Left e -> return . mkSigned "" $ ERR e + + subscribeConn :: RecipientId -> m Signed + subscribeConn rId = do + atomically $ do + cs <- readTVar connections + when (M.notMember rId cs) $ do + writeTBQueue subscribedQ (rId, clnt) + writeTVar connections $ M.insert rId (Left ()) cs + deliverMessage tryPeekMsg + + unsubscribeConn :: m () + unsubscribeConn = do + cs <- readTVarIO connections + case M.lookup connId cs of + Nothing -> return () + Just (Left ()) -> atomically $ writeTVar connections $ M.delete connId cs + Just (Right threadId) -> do + killThread threadId + atomically $ writeTVar connections $ M.delete connId cs storeMessage :: MsgBody -> Connection -> m (Command 'Broker) storeMessage msgBody c = case status c of @@ -153,7 +183,8 @@ client clnt@Client {connections, rcvQ, sndQ} Server {subscribedQ} = Nothing -> return ok Just (Right _) -> return ok Just (Left ()) -> do - void . forkIO $ subscriber q + threadId <- forkIO $ subscriber q + atomically . writeTVar connections $ M.insert connId (Right threadId) cs return ok subscriber :: MsgQueue -> m () diff --git a/src/Transmission.hs b/src/Transmission.hs index e361ae50e..6ae7786a1 100644 --- a/src/Transmission.hs +++ b/src/Transmission.hs @@ -49,8 +49,8 @@ data Command (a :: Party) where DEL :: Command Recipient SEND :: MsgBody -> Command Sender IDS :: RecipientId -> SenderId -> Command Broker - END :: RecipientId -> Command Broker MSG :: MsgId -> UTCTime -> MsgBody -> Command Broker + END :: Command Broker OK :: Command Broker ERR :: ErrorType -> Command Broker @@ -69,10 +69,10 @@ parseCommand command = case words command of ["SEND"] -> errParams "SEND" : msgBody -> Right . Cmd SSender . SEND . B.pack $ unwords msgBody ["IDS", rId, sId] -> bCmd $ IDS rId sId - ["END", rId] -> bCmd $ END rId ["MSG", msgId, ts, msgBody] -> case parseISO8601 ts of Just utc -> bCmd $ MSG msgId utc (B.pack msgBody) _ -> errParams + ["END"] -> bCmd END ["OK"] -> bCmd OK "ERR" : err -> case err of ["UNKNOWN"] -> bErr UNKNOWN @@ -90,6 +90,7 @@ parseCommand command = case words command of "DEL" : _ -> errParams "MSG" : _ -> errParams "IDS" : _ -> errParams + "END" : _ -> errParams "OK" : _ -> errParams _ -> Left UNKNOWN where @@ -107,9 +108,8 @@ serializeCommand = \case Cmd SBroker (MSG msgId ts msgBody) -> unwords ["MSG", msgId, formatISO8601Millis ts] ++ serializeMsg msgBody Cmd SBroker (IDS rId sId) -> unwords ["IDS", rId, sId] - Cmd SBroker (END rId) -> "END " ++ rId Cmd SBroker (ERR err) -> "ERR " ++ show err - Cmd SBroker OK -> "OK" + Cmd SBroker resp -> show resp where serializeMsg msgBody = " " ++ show (B.length msgBody) ++ "\n" ++ B.unpack msgBody