From d0b56034a7ab3df909201021eaddfdeb268bfada Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Sun, 24 Jan 2021 12:01:44 +0000 Subject: [PATCH] subscriptions (#27) * subscribe connection and track subscriptions * notify client when subscription ENDs * tcp connection timeout * move types --- src/Simplex/Messaging/Agent.hs | 25 +++-- src/Simplex/Messaging/Agent/Client.hs | 61 +++++++---- src/Simplex/Messaging/Agent/Store.hs | 9 +- src/Simplex/Messaging/Agent/Store/SQLite.hs | 4 +- .../Messaging/Agent/Store/SQLite/Util.hs | 8 +- src/Simplex/Messaging/Agent/Transmission.hs | 19 ++-- src/Simplex/Messaging/Client.hs | 103 +++++++++++------- src/Simplex/Messaging/Protocol.hs | 8 +- src/Simplex/Messaging/Server/Env/STM.hs | 1 - src/Simplex/Messaging/Server/MsgStore.hs | 1 + src/Simplex/Messaging/Server/MsgStore/STM.hs | 2 +- src/Simplex/Messaging/Transport.hs | 5 +- src/Simplex/Messaging/Types.hs | 6 - tests/SMPAgentClient.hs | 3 +- tests/ServerTests.hs | 2 +- 15 files changed, 158 insertions(+), 99 deletions(-) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index c536d89b3..e68a3c12d 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -25,10 +25,10 @@ import Simplex.Messaging.Agent.Store.SQLite.Types import Simplex.Messaging.Agent.Store.Types import Simplex.Messaging.Agent.Transmission import Simplex.Messaging.Client (SMPServerTransmission) -import Simplex.Messaging.Types (CorrId (..), MsgBody, PrivateKey, SenderKey) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Server (randomBytes) import Simplex.Messaging.Transport +import Simplex.Messaging.Types (CorrId (..), MsgBody, PrivateKey, SenderKey) import UnliftIO.Async import UnliftIO.Exception (SomeException) import qualified UnliftIO.Exception as E @@ -104,15 +104,15 @@ processCommand c@AgentClient {sndQ} (corrId, connAlias, cmd) = case cmd of NEW smpServer -> createNewConnection smpServer JOIN smpQueueInfo replyMode -> joinConnection smpQueueInfo replyMode + SUB -> subscribeConnection SEND msgBody -> sendMessage msgBody ACK aMsgId -> ackMessage aMsgId - _ -> throwError PROHIBITED where createNewConnection :: SMPServer -> m () createNewConnection server = do -- TODO create connection alias if not passed -- make connAlias Maybe? - (rq, qInfo) <- newReceiveQueue c server + (rq, qInfo) <- newReceiveQueue c server connAlias withStore $ \st -> createRcvConn st connAlias rq respond $ INV qInfo @@ -129,6 +129,16 @@ processCommand c@AgentClient {sndQ} (corrId, connAlias, cmd) = ReplyOff -> return () respond CON + subscribeConnection :: m () + subscribeConnection = + withStore (`getConn` connAlias) >>= \case + SomeConn _ (DuplexConnection _ rq _) -> subscribe rq + SomeConn _ (ReceiveConnection _ rq) -> subscribe rq + -- TODO possibly there should be a separate error type trying to send the message to the connection without ReceiveQueue + _ -> throwError PROHIBITED + where + subscribe rq = subscribeQueue c rq connAlias >> respond OK + sendMessage :: MsgBody -> m () sendMessage msgBody = withStore (`getConn` connAlias) >>= \case @@ -147,7 +157,7 @@ processCommand c@AgentClient {sndQ} (corrId, connAlias, cmd) = withStore (`getConn` connAlias) >>= \case SomeConn _ (DuplexConnection _ rq _) -> ackMsg rq SomeConn _ (ReceiveConnection _ rq) -> ackMsg rq - -- TODO possibly there should be a separate error type trying to send the message to the connection without SendQueue + -- TODO possibly there should be a separate error type trying to send the message to the connection without ReceiveQueue -- NOT_READY ? _ -> throwError PROHIBITED where @@ -155,7 +165,7 @@ processCommand c@AgentClient {sndQ} (corrId, connAlias, cmd) = sendReplyQInfo :: SMPServer -> SendQueue -> m () sendReplyQInfo srv sq = do - (rq, qInfo) <- newReceiveQueue c srv + (rq, qInfo) <- newReceiveQueue c srv connAlias withStore $ \st -> addRcvQueue st connAlias rq sendAgentMessage c sq $ REPLY qInfo @@ -172,10 +182,10 @@ subscriber c@AgentClient {msgQ} = forever $ do processSMPTransmission :: forall m. AgentMonad m => AgentClient -> SMPServerTransmission -> m () processSMPTransmission c@AgentClient {sndQ} (srv, rId, cmd) = do + (connAlias, rq@ReceiveQueue {decryptKey, status}) <- withStore $ \st -> getReceiveQueue st srv rId case cmd of SMP.MSG _ srvTs msgBody -> do -- TODO deduplicate with previously received - (connAlias, rq@ReceiveQueue {decryptKey, status}) <- withStore $ \st -> getReceiveQueue st srv rId agentMsg <- liftEither . parseSMPMessage =<< decryptMessage decryptKey msgBody case agentMsg of SMPConfirmation senderKey -> do @@ -215,8 +225,9 @@ processSMPTransmission c@AgentClient {sndQ} (srv, rId, cmd) = do notify connAlias $ MSG agentMsgId agentTimestamp srvTs MsgOk body return () SMP.END -> do + removeSubscription c connAlias logServer "<--" c srv rId "END" - return () + notify connAlias END _ -> logServer "<--" c srv rId $ "unexpected:" <> (B.pack . show) cmd where notify :: ConnAlias -> ACommand 'Agent -> m () diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index ad73f3d6b..b29ae8d6a 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -13,12 +13,14 @@ module Simplex.Messaging.Agent.Client AgentMonad, getSMPServerClient, newReceiveQueue, + subscribeQueue, sendConfirmation, sendHello, secureQueue, sendAgentMessage, sendAck, logServer, + removeSubscription, ) where @@ -40,10 +42,11 @@ import Simplex.Messaging.Agent.Env.SQLite import Simplex.Messaging.Agent.Store import Simplex.Messaging.Agent.Transmission import Simplex.Messaging.Client +import Simplex.Messaging.Protocol (QueueId, RecipientId) import Simplex.Messaging.Server (randomBytes) -import Simplex.Messaging.Types (ErrorType (AUTH), MsgBody, PrivateKey, PublicKey, QueueId, SenderKey) +import Simplex.Messaging.Types (ErrorType (AUTH), MsgBody, PrivateKey, PublicKey, SenderKey) import UnliftIO.Concurrent -import UnliftIO.Exception (SomeException) +import UnliftIO.Exception (IOException) import qualified UnliftIO.Exception as E import UnliftIO.STM @@ -52,6 +55,7 @@ data AgentClient = AgentClient sndQ :: TBQueue (ATransmission 'Agent), msgQ :: TBQueue SMPServerTransmission, smpClients :: TVar (Map SMPServer SMPClient), + subscribed :: TVar (Map ConnAlias (SMPServer, RecipientId)), clientId :: Int } @@ -61,28 +65,31 @@ newAgentClient cc qSize = do sndQ <- newTBQueue qSize msgQ <- newTBQueue qSize smpClients <- newTVar M.empty + subscribed <- newTVar M.empty clientId <- (+ 1) <$> readTVar cc writeTVar cc clientId - return AgentClient {rcvQ, sndQ, msgQ, smpClients, clientId} + return AgentClient {rcvQ, sndQ, msgQ, smpClients, subscribed, clientId} type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m) getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPServer -> m SMPClient getSMPServerClient AgentClient {smpClients, msgQ} srv = - atomically (M.lookup srv <$> readTVar smpClients) - >>= maybe newSMPClient return + readTVarIO smpClients + >>= maybe newSMPClient return . M.lookup srv where newSMPClient :: m SMPClient newSMPClient = do - cfg <- asks $ smpCfg . config - c <- liftIO (getSMPClient srv cfg msgQ) `E.catch` throwErr (BROKER smpErrTCPConnection) + c <- connectClient + logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv + -- TODO how can agent know client lost the connection? atomically . modifyTVar smpClients $ M.insert srv c return c - throwErr :: AgentErrorType -> SomeException -> m a - throwErr err e = do - liftIO . putStrLn $ "Exception: " ++ show e -- TODO remove - throwError err + connectClient :: m SMPClient + connectClient = do + cfg <- asks $ smpCfg . config + liftIO (getSMPClient srv cfg msgQ) + `E.catch` \(_ :: IOException) -> throwError (BROKER smpErrTCPConnection) withSMP :: forall a m. AgentMonad m => AgentClient -> SMPServer -> (SMPClient -> ExceptT SMPClientError IO a) -> m a withSMP c srv action = @@ -111,8 +118,8 @@ withLogSMP c srv qId cmdStr action = do logServer "<--" c srv qId "OK" return res -newReceiveQueue :: AgentMonad m => AgentClient -> SMPServer -> m (ReceiveQueue, SMPQueueInfo) -newReceiveQueue c srv = do +newReceiveQueue :: AgentMonad m => AgentClient -> SMPServer -> ConnAlias -> m (ReceiveQueue, SMPQueueInfo) +newReceiveQueue c srv connAlias = do g <- asks idsDrg recipientKey <- atomically $ randomBytes 16 g -- TODO replace with cryptographic key pair let rcvPrivateKey = recipientKey @@ -121,7 +128,7 @@ newReceiveQueue c srv = do logServer "<--" c srv "" $ B.unwords ["IDS", logSecret rcvId, logSecret sId] encryptKey <- atomically $ randomBytes 16 g -- TODO replace with cryptographic key pair let decryptKey = encryptKey - rcvQueue = + rq = ReceiveQueue { server = srv, rcvId, @@ -133,13 +140,29 @@ newReceiveQueue c srv = do status = New, ackMode = AckMode On } - return (rcvQueue, SMPQueueInfo srv sId encryptKey) + addSubscription c rq connAlias + return (rq, SMPQueueInfo srv sId encryptKey) + +subscribeQueue :: AgentMonad m => AgentClient -> ReceiveQueue -> ConnAlias -> m () +subscribeQueue c rq@ReceiveQueue {server, rcvPrivateKey, rcvId} connAlias = do + withLogSMP c server rcvId "SUB" $ \smp -> + subscribeSMPQueue smp rcvPrivateKey rcvId + addSubscription c rq connAlias + +addSubscription :: MonadUnliftIO m => AgentClient -> ReceiveQueue -> ConnAlias -> m () +addSubscription c ReceiveQueue {server, rcvId} connAlias = + atomically . modifyTVar (subscribed c) $ M.insert connAlias (server, rcvId) + +removeSubscription :: AgentMonad m => AgentClient -> ConnAlias -> m () +removeSubscription c connAlias = + atomically . modifyTVar (subscribed c) $ M.delete connAlias logServer :: AgentMonad m => ByteString -> AgentClient -> SMPServer -> QueueId -> ByteString -> m () -logServer dir AgentClient {clientId} SMPServer {host, port} qId cmdStr = - logInfo . decodeUtf8 $ B.unwords ["A", "(" <> (B.pack . show) clientId <> ")", dir, server, ":", logSecret qId, cmdStr] - where - server = B.pack $ host <> maybe "" (":" <>) port +logServer dir AgentClient {clientId} srv qId cmdStr = + logInfo . decodeUtf8 $ B.unwords ["A", "(" <> (B.pack . show) clientId <> ")", dir, showServer srv, ":", logSecret qId, cmdStr] + +showServer :: SMPServer -> ByteString +showServer srv = B.pack $ host srv <> maybe "" (":" <>) (port srv) logSecret :: ByteString -> ByteString logSecret bs = encode $ B.take 3 bs diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index acfa45341..5b23c5d68 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -15,13 +15,14 @@ import Data.Time.Clock (UTCTime) import Data.Type.Equality import Simplex.Messaging.Agent.Store.Types import Simplex.Messaging.Agent.Transmission +import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Types data ReceiveQueue = ReceiveQueue { server :: SMPServer, - rcvId :: RecipientId, + rcvId :: SMP.RecipientId, rcvPrivateKey :: PrivateKey, - sndId :: Maybe SenderId, + sndId :: Maybe SMP.SenderId, sndKey :: Maybe PublicKey, decryptKey :: PrivateKey, verifyKey :: Maybe PublicKey, @@ -32,7 +33,7 @@ data ReceiveQueue = ReceiveQueue data SendQueue = SendQueue { server :: SMPServer, - sndId :: SenderId, + sndId :: SMP.SenderId, sndPrivateKey :: PrivateKey, encryptKey :: PublicKey, signKey :: PrivateKey, @@ -98,7 +99,7 @@ class Monad m => MonadAgentStore s m where createRcvConn :: s -> ConnAlias -> ReceiveQueue -> m () createSndConn :: s -> ConnAlias -> SendQueue -> m () getConn :: s -> ConnAlias -> m SomeConn - getReceiveQueue :: s -> SMPServer -> RecipientId -> m (ConnAlias, ReceiveQueue) + getReceiveQueue :: s -> SMPServer -> SMP.RecipientId -> m (ConnAlias, ReceiveQueue) deleteConn :: s -> ConnAlias -> m () addSndQueue :: s -> ConnAlias -> SendQueue -> m () addRcvQueue :: s -> ConnAlias -> ReceiveQueue -> m () diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index aa90a99e9..721b0a333 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -25,7 +25,7 @@ import Simplex.Messaging.Agent.Store.SQLite.Types import Simplex.Messaging.Agent.Store.SQLite.Util import Simplex.Messaging.Agent.Store.Types import Simplex.Messaging.Agent.Transmission -import Simplex.Messaging.Types +import qualified Simplex.Messaging.Protocol as SMP import UnliftIO.STM newSQLiteStore :: MonadUnliftIO m => String -> m SQLiteStore @@ -81,7 +81,7 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto return $ SomeConn SCSend (SendConnection connAlias sndQ) _ -> throwError SEBadConn - getReceiveQueue :: SQLiteStore -> SMPServer -> RecipientId -> m (ConnAlias, ReceiveQueue) + getReceiveQueue :: SQLiteStore -> SMPServer -> SMP.RecipientId -> m (ConnAlias, ReceiveQueue) getReceiveQueue st SMPServer {host, port} recipientId = do rcvQueue <- getRcvQueueByRecipientId st recipientId host port connAlias <- getConnAliasByRcvQueue st recipientId diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Util.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Util.hs index eca0b3672..a0fb3aef2 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Util.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Util.hs @@ -34,7 +34,7 @@ import Simplex.Messaging.Agent.Store import Simplex.Messaging.Agent.Store.SQLite.Types import Simplex.Messaging.Agent.Store.Types import Simplex.Messaging.Agent.Transmission -import Simplex.Messaging.Types +import Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Util import Text.Read import qualified UnliftIO.Exception as E @@ -335,7 +335,7 @@ updateReceiveQueueStatus store rcvQueueId host port status = |] (Only status :. Only rcvQueueId :. Only host :. Only port) -updateSendQueueStatus :: MonadUnliftIO m => SQLiteStore -> SenderId -> HostName -> Maybe ServiceName -> QueueStatus -> m () +updateSendQueueStatus :: MonadUnliftIO m => SQLiteStore -> SMP.SenderId -> HostName -> Maybe ServiceName -> QueueStatus -> m () updateSendQueueStatus store sndQueueId host port status = executeWithLock store @@ -357,7 +357,7 @@ instance ToField QueueDirection where toField = toField . show -- TODO add parser and serializer for DeliveryStatus? Pass DeliveryStatus? insertMsg :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> QueueDirection -> AgentMsgId -> Message -> m () insertMsg store connAlias qDirection agentMsgId msg = do - tstamp <- liftIO getCurrentTime + ts <- liftIO getCurrentTime void $ insertWithLock store @@ -366,4 +366,4 @@ insertMsg store connAlias qDirection agentMsgId msg = do INSERT INTO messages (conn_alias, agent_msg_id, timestamp, message, direction, msg_status) VALUES (?,?,?,?,?,"MDTransmitted"); |] - (Only connAlias :. Only agentMsgId :. Only tstamp :. Only qDirection :. Only msg) + (Only connAlias :. Only agentMsgId :. Only ts :. Only qDirection :. Only msg) diff --git a/src/Simplex/Messaging/Agent/Transmission.hs b/src/Simplex/Messaging/Agent/Transmission.hs index 195b04345..c215d2de7 100644 --- a/src/Simplex/Messaging/Agent/Transmission.hs +++ b/src/Simplex/Messaging/Agent/Transmission.hs @@ -30,8 +30,9 @@ import Data.Typeable () import Network.Socket import Numeric.Natural import Simplex.Messaging.Agent.Store.Types +import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Transport -import Simplex.Messaging.Types (CorrId (..), Encoded, ErrorType, MsgBody, PublicKey, SenderId, errBadParameters, errMessageBody) +import Simplex.Messaging.Types (CorrId (..), Encoded, ErrorType, MsgBody, PublicKey, errBadParameters, errMessageBody) import Simplex.Messaging.Util import System.IO import Text.Read @@ -73,7 +74,7 @@ data ACommand (p :: AParty) where READY :: ACommand Agent -- CONF :: OtherPartyId -> ACommand Agent -- LET :: OtherPartyId -> ACommand Client - SUB :: SubMode -> ACommand Client + SUB :: ACommand Client END :: ACommand Agent -- QST :: QueueDirection -> ACommand Client -- STAT :: QueueDirection -> Maybe QueueStatus -> Maybe SubMode -> ACommand Agent @@ -208,9 +209,7 @@ data Mode = On | Off deriving (Eq, Show, Read) newtype AckMode = AckMode Mode deriving (Eq, Show) -newtype SubMode = SubMode Mode deriving (Show) - -data SMPQueueInfo = SMPQueueInfo SMPServer SenderId EncryptionKey +data SMPQueueInfo = SMPQueueInfo SMPServer SMP.SenderId EncryptionKey deriving (Show) data ReplyMode = ReplyOff | ReplyOn | ReplyVia SMPServer deriving (Show) @@ -281,6 +280,8 @@ parseCommandP = "NEW " *> newCmd <|> "INV " *> invResp <|> "JOIN " *> joinCmd + <|> "SUB" $> ACmd SClient SUB + <|> "END" $> ACmd SAgent END <|> "SEND " *> sendCmd <|> "MSG " *> message <|> "ACK " *> acknowledge @@ -314,6 +315,8 @@ serializeCommand = \case NEW srv -> "NEW " <> serializeServer srv INV qInfo -> "INV " <> serializeSmpQueueInfo qInfo JOIN qInfo rMode -> "JOIN " <> serializeSmpQueueInfo qInfo <> replyMode rMode + SUB -> "SUB" + END -> "END" SEND msgBody -> "SEND " <> serializeMsg msgBody MSG aMsgId aTs ts st body -> B.unwords ["MSG", B.pack $ show aMsgId, B.pack $ formatISO8601Millis aTs, B.pack $ formatISO8601Millis ts, msgStatus st, serializeMsg body] @@ -349,11 +352,7 @@ tPutRaw h (corrId, connAlias, command) = do putLn h command tGetRaw :: Handle -> IO ARawTransmission -tGetRaw h = do - corrId <- getLn h - connAlias <- getLn h - command <- getLn h - return (corrId, connAlias, command) +tGetRaw h = (,,) <$> getLn h <*> getLn h <*> getLn h tPut :: MonadIO m => Handle -> ATransmission p -> m () tPut h (CorrId corrId, connAlias, command) = diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index ff885a500..7abd54aa2 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -3,6 +3,7 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE NumericUnderscores #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -34,6 +35,7 @@ import qualified Data.ByteString.Char8 as B import Data.Map.Strict (Map) import qualified Data.Map.Strict as M import Data.Maybe +import GHC.IO.Exception (IOErrorType (..)) import Network.Socket (ServiceName) import Numeric.Natural import Simplex.Messaging.Agent.Transmission (SMPServer (..)) @@ -42,6 +44,8 @@ import Simplex.Messaging.Transport import Simplex.Messaging.Types import Simplex.Messaging.Util import System.IO +import System.IO.Error +import System.Timeout data SMPClient = SMPClient { action :: Async (), @@ -57,11 +61,17 @@ type SMPServerTransmission = (SMPServer, RecipientId, Command 'Broker) data SMPClientConfig = SMPClientConfig { qSize :: Natural, - defaultPort :: ServiceName + defaultPort :: ServiceName, + tcpTimeout :: Int } smpDefaultConfig :: SMPClientConfig -smpDefaultConfig = SMPClientConfig 16 "5223" +smpDefaultConfig = + SMPClientConfig + { qSize = 16, + defaultPort = "5223", + tcpTimeout = 2_000_000 + } data Request = Request { queueId :: QueueId, @@ -69,49 +79,60 @@ data Request = Request } getSMPClient :: SMPServer -> SMPClientConfig -> TBQueue SMPServerTransmission -> IO SMPClient -getSMPClient smpServer@SMPServer {host, port} SMPClientConfig {qSize, defaultPort} msgQ = do - c <- atomically mkSMPClient - action <- async $ runTCPClient host (fromMaybe defaultPort port) (client c) - return c {action} - where - mkSMPClient :: STM SMPClient - mkSMPClient = do - clientCorrId <- newTVar 0 - sentCommands <- newTVar M.empty - sndQ <- newTBQueue qSize - rcvQ <- newTBQueue qSize - return SMPClient {action = undefined, smpServer, clientCorrId, sentCommands, sndQ, rcvQ, msgQ} +getSMPClient + smpServer@SMPServer {host, port} + SMPClientConfig {qSize, defaultPort, tcpTimeout} + msgQ = do + c <- atomically mkSMPClient + started <- newEmptyTMVarIO + action <- async $ runTCPClient host (fromMaybe defaultPort port) (client c started) + tcpTimeout `timeout` atomically (takeTMVar started) >>= \case + Just _ -> return c {action} + _ -> throwIO err + where + err :: IOException + err = mkIOError TimeExpired "connection timeout" Nothing Nothing - client :: SMPClient -> Handle -> IO () - client c h = do - _line <- getLn h -- "Welcome to SMP" - -- TODO test connection failure - raceAny_ [send c h, process c, receive c h] + mkSMPClient :: STM SMPClient + mkSMPClient = do + clientCorrId <- newTVar 0 + sentCommands <- newTVar M.empty + sndQ <- newTBQueue qSize + rcvQ <- newTBQueue qSize + return SMPClient {action = undefined, smpServer, clientCorrId, sentCommands, sndQ, rcvQ, msgQ} - send :: SMPClient -> Handle -> IO () - send SMPClient {sndQ} h = forever $ atomically (readTBQueue sndQ) >>= tPut h + client :: SMPClient -> TMVar () -> Handle -> IO () + client c started h = do + _ <- getLn h -- "Welcome to SMP" + atomically $ putTMVar started () + -- TODO call continuation on disconnection after raceAny_ exits + raceAny_ [send c h, process c, receive c h] - receive :: SMPClient -> Handle -> IO () - receive SMPClient {rcvQ} h = forever $ tGet fromServer h >>= atomically . writeTBQueue rcvQ + send :: SMPClient -> Handle -> IO () + send SMPClient {sndQ} h = forever $ atomically (readTBQueue sndQ) >>= tPut h - process :: SMPClient -> IO () - process SMPClient {rcvQ, sentCommands} = forever $ do - (_, (corrId, qId, respOrErr)) <- atomically $ readTBQueue rcvQ - cs <- readTVarIO sentCommands - case M.lookup corrId cs of - Nothing -> do - case respOrErr of - Right (Cmd SBroker cmd) -> atomically $ writeTBQueue msgQ (smpServer, qId, cmd) - _ -> return () - Just Request {queueId, responseVar} -> atomically $ do - modifyTVar sentCommands $ M.delete corrId - putTMVar responseVar $ - if queueId == qId - then case respOrErr of - Left e -> Left $ SMPResponseError e - Right (Cmd _ (ERR e)) -> Left $ SMPServerError e - Right r -> Right r - else Left SMPQueueIdError + receive :: SMPClient -> Handle -> IO () + receive SMPClient {rcvQ} h = forever $ tGet fromServer h >>= atomically . writeTBQueue rcvQ + + process :: SMPClient -> IO () + process SMPClient {rcvQ, sentCommands} = forever $ do + (_, (corrId, qId, respOrErr)) <- atomically $ readTBQueue rcvQ + cs <- readTVarIO sentCommands + case M.lookup corrId cs of + Nothing -> do + case respOrErr of + Right (Cmd SBroker cmd) -> atomically $ writeTBQueue msgQ (smpServer, qId, cmd) + -- TODO send everything else to errQ and log in agent + _ -> return () + Just Request {queueId, responseVar} -> atomically $ do + modifyTVar sentCommands $ M.delete corrId + putTMVar responseVar $ + if queueId == qId + then case respOrErr of + Left e -> Left $ SMPResponseError e + Right (Cmd _ (ERR e)) -> Left $ SMPServerError e + Right r -> Right r + else Left SMPQueueIdError data SMPClientError = SMPServerError ErrorType diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 391ac3d40..8c5e23fca 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -20,8 +20,8 @@ import Data.Char (ord) import Data.Kind import Data.Time.Clock import Data.Time.ISO8601 -import Simplex.Messaging.Types import Simplex.Messaging.Transport +import Simplex.Messaging.Types import Simplex.Messaging.Util import System.IO import Text.Read @@ -51,6 +51,12 @@ type TransmissionOrError = (Signature, SignedOrError) type RawTransmission = (ByteString, ByteString, ByteString, ByteString) +type RecipientId = QueueId + +type SenderId = QueueId + +type QueueId = Encoded + data Command (a :: Party) where NEW :: RecipientKey -> Command Recipient SUB :: Command Recipient diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index 182542678..39d703abe 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -10,7 +10,6 @@ import Data.Map.Strict (Map) import qualified Data.Map.Strict as M import Network.Socket (ServiceName) import Numeric.Natural -import Simplex.Messaging.Types import Simplex.Messaging.Protocol import Simplex.Messaging.Server.MsgStore.STM import Simplex.Messaging.Server.QueueStore.STM diff --git a/src/Simplex/Messaging/Server/MsgStore.hs b/src/Simplex/Messaging/Server/MsgStore.hs index bf2e2d1d6..c3f93b507 100644 --- a/src/Simplex/Messaging/Server/MsgStore.hs +++ b/src/Simplex/Messaging/Server/MsgStore.hs @@ -3,6 +3,7 @@ module Simplex.Messaging.Server.MsgStore where import Data.Time.Clock +import Simplex.Messaging.Protocol (RecipientId) import Simplex.Messaging.Types data Message = Message diff --git a/src/Simplex/Messaging/Server/MsgStore/STM.hs b/src/Simplex/Messaging/Server/MsgStore/STM.hs index 1c32c7755..f5b0e670f 100644 --- a/src/Simplex/Messaging/Server/MsgStore/STM.hs +++ b/src/Simplex/Messaging/Server/MsgStore/STM.hs @@ -8,7 +8,7 @@ module Simplex.Messaging.Server.MsgStore.STM where import Data.Map.Strict (Map) import qualified Data.Map.Strict as M -import Simplex.Messaging.Types +import Simplex.Messaging.Protocol (RecipientId) import Simplex.Messaging.Server.MsgStore import UnliftIO.STM diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index 48764164f..a235d9dbc 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -86,4 +86,7 @@ putLn :: Handle -> ByteString -> IO () putLn h = B.hPut h . (<> "\r\n") getLn :: Handle -> IO ByteString -getLn h = B.pack <$> hGetLine h +getLn h = trim_cr <$> B.hGetLine h + where + trim_cr "" = "" + trim_cr s = if B.last s == '\r' then B.init s else s diff --git a/src/Simplex/Messaging/Types.hs b/src/Simplex/Messaging/Types.hs index 829992537..b0b19b8a5 100644 --- a/src/Simplex/Messaging/Types.hs +++ b/src/Simplex/Messaging/Types.hs @@ -29,12 +29,6 @@ type RecipientKey = PublicKey type SenderKey = PublicKey -type RecipientId = QueueId - -type SenderId = QueueId - -type QueueId = Encoded - type MsgId = Encoded type MsgBody = ByteString diff --git a/tests/SMPAgentClient.hs b/tests/SMPAgentClient.hs index dbb7d8d28..f4ea8c6cd 100644 --- a/tests/SMPAgentClient.hs +++ b/tests/SMPAgentClient.hs @@ -85,7 +85,8 @@ cfg = smpCfg = SMPClientConfig { qSize = 1, - defaultPort = testPort + defaultPort = testPort, + tcpTimeout = 500_000 } } diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index b0f6f6d3d..3911f5add 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -233,7 +233,7 @@ testSwitchSub = Resp "bcda" _ ok3 <- sendRecv rh2 ("1234", "bcda", rId, "ACK") (ok3, OK) #== "accepts ACK from the 2nd TCP connection" - timeout 1000 (tGet fromServer rh1) >>= \case + 1000 `timeout` tGet fromServer rh1 >>= \case Nothing -> return () Just _ -> error "nothing else is delivered to the 1st TCP connection"