mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-04-25 20:42:15 +00:00
subscriptions (#27)
* subscribe connection and track subscriptions * notify client when subscription ENDs * tcp connection timeout * move types
This commit is contained in:
committed by
GitHub
parent
2372abf2e9
commit
d0b56034a7
@@ -25,10 +25,10 @@ import Simplex.Messaging.Agent.Store.SQLite.Types
|
||||
import Simplex.Messaging.Agent.Store.Types
|
||||
import Simplex.Messaging.Agent.Transmission
|
||||
import Simplex.Messaging.Client (SMPServerTransmission)
|
||||
import Simplex.Messaging.Types (CorrId (..), MsgBody, PrivateKey, SenderKey)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Server (randomBytes)
|
||||
import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Types (CorrId (..), MsgBody, PrivateKey, SenderKey)
|
||||
import UnliftIO.Async
|
||||
import UnliftIO.Exception (SomeException)
|
||||
import qualified UnliftIO.Exception as E
|
||||
@@ -104,15 +104,15 @@ processCommand c@AgentClient {sndQ} (corrId, connAlias, cmd) =
|
||||
case cmd of
|
||||
NEW smpServer -> createNewConnection smpServer
|
||||
JOIN smpQueueInfo replyMode -> joinConnection smpQueueInfo replyMode
|
||||
SUB -> subscribeConnection
|
||||
SEND msgBody -> sendMessage msgBody
|
||||
ACK aMsgId -> ackMessage aMsgId
|
||||
_ -> throwError PROHIBITED
|
||||
where
|
||||
createNewConnection :: SMPServer -> m ()
|
||||
createNewConnection server = do
|
||||
-- TODO create connection alias if not passed
|
||||
-- make connAlias Maybe?
|
||||
(rq, qInfo) <- newReceiveQueue c server
|
||||
(rq, qInfo) <- newReceiveQueue c server connAlias
|
||||
withStore $ \st -> createRcvConn st connAlias rq
|
||||
respond $ INV qInfo
|
||||
|
||||
@@ -129,6 +129,16 @@ processCommand c@AgentClient {sndQ} (corrId, connAlias, cmd) =
|
||||
ReplyOff -> return ()
|
||||
respond CON
|
||||
|
||||
subscribeConnection :: m ()
|
||||
subscribeConnection =
|
||||
withStore (`getConn` connAlias) >>= \case
|
||||
SomeConn _ (DuplexConnection _ rq _) -> subscribe rq
|
||||
SomeConn _ (ReceiveConnection _ rq) -> subscribe rq
|
||||
-- TODO possibly there should be a separate error type trying to send the message to the connection without ReceiveQueue
|
||||
_ -> throwError PROHIBITED
|
||||
where
|
||||
subscribe rq = subscribeQueue c rq connAlias >> respond OK
|
||||
|
||||
sendMessage :: MsgBody -> m ()
|
||||
sendMessage msgBody =
|
||||
withStore (`getConn` connAlias) >>= \case
|
||||
@@ -147,7 +157,7 @@ processCommand c@AgentClient {sndQ} (corrId, connAlias, cmd) =
|
||||
withStore (`getConn` connAlias) >>= \case
|
||||
SomeConn _ (DuplexConnection _ rq _) -> ackMsg rq
|
||||
SomeConn _ (ReceiveConnection _ rq) -> ackMsg rq
|
||||
-- TODO possibly there should be a separate error type trying to send the message to the connection without SendQueue
|
||||
-- TODO possibly there should be a separate error type trying to send the message to the connection without ReceiveQueue
|
||||
-- NOT_READY ?
|
||||
_ -> throwError PROHIBITED
|
||||
where
|
||||
@@ -155,7 +165,7 @@ processCommand c@AgentClient {sndQ} (corrId, connAlias, cmd) =
|
||||
|
||||
sendReplyQInfo :: SMPServer -> SendQueue -> m ()
|
||||
sendReplyQInfo srv sq = do
|
||||
(rq, qInfo) <- newReceiveQueue c srv
|
||||
(rq, qInfo) <- newReceiveQueue c srv connAlias
|
||||
withStore $ \st -> addRcvQueue st connAlias rq
|
||||
sendAgentMessage c sq $ REPLY qInfo
|
||||
|
||||
@@ -172,10 +182,10 @@ subscriber c@AgentClient {msgQ} = forever $ do
|
||||
|
||||
processSMPTransmission :: forall m. AgentMonad m => AgentClient -> SMPServerTransmission -> m ()
|
||||
processSMPTransmission c@AgentClient {sndQ} (srv, rId, cmd) = do
|
||||
(connAlias, rq@ReceiveQueue {decryptKey, status}) <- withStore $ \st -> getReceiveQueue st srv rId
|
||||
case cmd of
|
||||
SMP.MSG _ srvTs msgBody -> do
|
||||
-- TODO deduplicate with previously received
|
||||
(connAlias, rq@ReceiveQueue {decryptKey, status}) <- withStore $ \st -> getReceiveQueue st srv rId
|
||||
agentMsg <- liftEither . parseSMPMessage =<< decryptMessage decryptKey msgBody
|
||||
case agentMsg of
|
||||
SMPConfirmation senderKey -> do
|
||||
@@ -215,8 +225,9 @@ processSMPTransmission c@AgentClient {sndQ} (srv, rId, cmd) = do
|
||||
notify connAlias $ MSG agentMsgId agentTimestamp srvTs MsgOk body
|
||||
return ()
|
||||
SMP.END -> do
|
||||
removeSubscription c connAlias
|
||||
logServer "<--" c srv rId "END"
|
||||
return ()
|
||||
notify connAlias END
|
||||
_ -> logServer "<--" c srv rId $ "unexpected:" <> (B.pack . show) cmd
|
||||
where
|
||||
notify :: ConnAlias -> ACommand 'Agent -> m ()
|
||||
|
||||
@@ -13,12 +13,14 @@ module Simplex.Messaging.Agent.Client
|
||||
AgentMonad,
|
||||
getSMPServerClient,
|
||||
newReceiveQueue,
|
||||
subscribeQueue,
|
||||
sendConfirmation,
|
||||
sendHello,
|
||||
secureQueue,
|
||||
sendAgentMessage,
|
||||
sendAck,
|
||||
logServer,
|
||||
removeSubscription,
|
||||
)
|
||||
where
|
||||
|
||||
@@ -40,10 +42,11 @@ import Simplex.Messaging.Agent.Env.SQLite
|
||||
import Simplex.Messaging.Agent.Store
|
||||
import Simplex.Messaging.Agent.Transmission
|
||||
import Simplex.Messaging.Client
|
||||
import Simplex.Messaging.Protocol (QueueId, RecipientId)
|
||||
import Simplex.Messaging.Server (randomBytes)
|
||||
import Simplex.Messaging.Types (ErrorType (AUTH), MsgBody, PrivateKey, PublicKey, QueueId, SenderKey)
|
||||
import Simplex.Messaging.Types (ErrorType (AUTH), MsgBody, PrivateKey, PublicKey, SenderKey)
|
||||
import UnliftIO.Concurrent
|
||||
import UnliftIO.Exception (SomeException)
|
||||
import UnliftIO.Exception (IOException)
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.STM
|
||||
|
||||
@@ -52,6 +55,7 @@ data AgentClient = AgentClient
|
||||
sndQ :: TBQueue (ATransmission 'Agent),
|
||||
msgQ :: TBQueue SMPServerTransmission,
|
||||
smpClients :: TVar (Map SMPServer SMPClient),
|
||||
subscribed :: TVar (Map ConnAlias (SMPServer, RecipientId)),
|
||||
clientId :: Int
|
||||
}
|
||||
|
||||
@@ -61,28 +65,31 @@ newAgentClient cc qSize = do
|
||||
sndQ <- newTBQueue qSize
|
||||
msgQ <- newTBQueue qSize
|
||||
smpClients <- newTVar M.empty
|
||||
subscribed <- newTVar M.empty
|
||||
clientId <- (+ 1) <$> readTVar cc
|
||||
writeTVar cc clientId
|
||||
return AgentClient {rcvQ, sndQ, msgQ, smpClients, clientId}
|
||||
return AgentClient {rcvQ, sndQ, msgQ, smpClients, subscribed, clientId}
|
||||
|
||||
type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m)
|
||||
|
||||
getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPServer -> m SMPClient
|
||||
getSMPServerClient AgentClient {smpClients, msgQ} srv =
|
||||
atomically (M.lookup srv <$> readTVar smpClients)
|
||||
>>= maybe newSMPClient return
|
||||
readTVarIO smpClients
|
||||
>>= maybe newSMPClient return . M.lookup srv
|
||||
where
|
||||
newSMPClient :: m SMPClient
|
||||
newSMPClient = do
|
||||
cfg <- asks $ smpCfg . config
|
||||
c <- liftIO (getSMPClient srv cfg msgQ) `E.catch` throwErr (BROKER smpErrTCPConnection)
|
||||
c <- connectClient
|
||||
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv
|
||||
-- TODO how can agent know client lost the connection?
|
||||
atomically . modifyTVar smpClients $ M.insert srv c
|
||||
return c
|
||||
|
||||
throwErr :: AgentErrorType -> SomeException -> m a
|
||||
throwErr err e = do
|
||||
liftIO . putStrLn $ "Exception: " ++ show e -- TODO remove
|
||||
throwError err
|
||||
connectClient :: m SMPClient
|
||||
connectClient = do
|
||||
cfg <- asks $ smpCfg . config
|
||||
liftIO (getSMPClient srv cfg msgQ)
|
||||
`E.catch` \(_ :: IOException) -> throwError (BROKER smpErrTCPConnection)
|
||||
|
||||
withSMP :: forall a m. AgentMonad m => AgentClient -> SMPServer -> (SMPClient -> ExceptT SMPClientError IO a) -> m a
|
||||
withSMP c srv action =
|
||||
@@ -111,8 +118,8 @@ withLogSMP c srv qId cmdStr action = do
|
||||
logServer "<--" c srv qId "OK"
|
||||
return res
|
||||
|
||||
newReceiveQueue :: AgentMonad m => AgentClient -> SMPServer -> m (ReceiveQueue, SMPQueueInfo)
|
||||
newReceiveQueue c srv = do
|
||||
newReceiveQueue :: AgentMonad m => AgentClient -> SMPServer -> ConnAlias -> m (ReceiveQueue, SMPQueueInfo)
|
||||
newReceiveQueue c srv connAlias = do
|
||||
g <- asks idsDrg
|
||||
recipientKey <- atomically $ randomBytes 16 g -- TODO replace with cryptographic key pair
|
||||
let rcvPrivateKey = recipientKey
|
||||
@@ -121,7 +128,7 @@ newReceiveQueue c srv = do
|
||||
logServer "<--" c srv "" $ B.unwords ["IDS", logSecret rcvId, logSecret sId]
|
||||
encryptKey <- atomically $ randomBytes 16 g -- TODO replace with cryptographic key pair
|
||||
let decryptKey = encryptKey
|
||||
rcvQueue =
|
||||
rq =
|
||||
ReceiveQueue
|
||||
{ server = srv,
|
||||
rcvId,
|
||||
@@ -133,13 +140,29 @@ newReceiveQueue c srv = do
|
||||
status = New,
|
||||
ackMode = AckMode On
|
||||
}
|
||||
return (rcvQueue, SMPQueueInfo srv sId encryptKey)
|
||||
addSubscription c rq connAlias
|
||||
return (rq, SMPQueueInfo srv sId encryptKey)
|
||||
|
||||
subscribeQueue :: AgentMonad m => AgentClient -> ReceiveQueue -> ConnAlias -> m ()
|
||||
subscribeQueue c rq@ReceiveQueue {server, rcvPrivateKey, rcvId} connAlias = do
|
||||
withLogSMP c server rcvId "SUB" $ \smp ->
|
||||
subscribeSMPQueue smp rcvPrivateKey rcvId
|
||||
addSubscription c rq connAlias
|
||||
|
||||
addSubscription :: MonadUnliftIO m => AgentClient -> ReceiveQueue -> ConnAlias -> m ()
|
||||
addSubscription c ReceiveQueue {server, rcvId} connAlias =
|
||||
atomically . modifyTVar (subscribed c) $ M.insert connAlias (server, rcvId)
|
||||
|
||||
removeSubscription :: AgentMonad m => AgentClient -> ConnAlias -> m ()
|
||||
removeSubscription c connAlias =
|
||||
atomically . modifyTVar (subscribed c) $ M.delete connAlias
|
||||
|
||||
logServer :: AgentMonad m => ByteString -> AgentClient -> SMPServer -> QueueId -> ByteString -> m ()
|
||||
logServer dir AgentClient {clientId} SMPServer {host, port} qId cmdStr =
|
||||
logInfo . decodeUtf8 $ B.unwords ["A", "(" <> (B.pack . show) clientId <> ")", dir, server, ":", logSecret qId, cmdStr]
|
||||
where
|
||||
server = B.pack $ host <> maybe "" (":" <>) port
|
||||
logServer dir AgentClient {clientId} srv qId cmdStr =
|
||||
logInfo . decodeUtf8 $ B.unwords ["A", "(" <> (B.pack . show) clientId <> ")", dir, showServer srv, ":", logSecret qId, cmdStr]
|
||||
|
||||
showServer :: SMPServer -> ByteString
|
||||
showServer srv = B.pack $ host srv <> maybe "" (":" <>) (port srv)
|
||||
|
||||
logSecret :: ByteString -> ByteString
|
||||
logSecret bs = encode $ B.take 3 bs
|
||||
|
||||
@@ -15,13 +15,14 @@ import Data.Time.Clock (UTCTime)
|
||||
import Data.Type.Equality
|
||||
import Simplex.Messaging.Agent.Store.Types
|
||||
import Simplex.Messaging.Agent.Transmission
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Types
|
||||
|
||||
data ReceiveQueue = ReceiveQueue
|
||||
{ server :: SMPServer,
|
||||
rcvId :: RecipientId,
|
||||
rcvId :: SMP.RecipientId,
|
||||
rcvPrivateKey :: PrivateKey,
|
||||
sndId :: Maybe SenderId,
|
||||
sndId :: Maybe SMP.SenderId,
|
||||
sndKey :: Maybe PublicKey,
|
||||
decryptKey :: PrivateKey,
|
||||
verifyKey :: Maybe PublicKey,
|
||||
@@ -32,7 +33,7 @@ data ReceiveQueue = ReceiveQueue
|
||||
|
||||
data SendQueue = SendQueue
|
||||
{ server :: SMPServer,
|
||||
sndId :: SenderId,
|
||||
sndId :: SMP.SenderId,
|
||||
sndPrivateKey :: PrivateKey,
|
||||
encryptKey :: PublicKey,
|
||||
signKey :: PrivateKey,
|
||||
@@ -98,7 +99,7 @@ class Monad m => MonadAgentStore s m where
|
||||
createRcvConn :: s -> ConnAlias -> ReceiveQueue -> m ()
|
||||
createSndConn :: s -> ConnAlias -> SendQueue -> m ()
|
||||
getConn :: s -> ConnAlias -> m SomeConn
|
||||
getReceiveQueue :: s -> SMPServer -> RecipientId -> m (ConnAlias, ReceiveQueue)
|
||||
getReceiveQueue :: s -> SMPServer -> SMP.RecipientId -> m (ConnAlias, ReceiveQueue)
|
||||
deleteConn :: s -> ConnAlias -> m ()
|
||||
addSndQueue :: s -> ConnAlias -> SendQueue -> m ()
|
||||
addRcvQueue :: s -> ConnAlias -> ReceiveQueue -> m ()
|
||||
|
||||
@@ -25,7 +25,7 @@ import Simplex.Messaging.Agent.Store.SQLite.Types
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Util
|
||||
import Simplex.Messaging.Agent.Store.Types
|
||||
import Simplex.Messaging.Agent.Transmission
|
||||
import Simplex.Messaging.Types
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import UnliftIO.STM
|
||||
|
||||
newSQLiteStore :: MonadUnliftIO m => String -> m SQLiteStore
|
||||
@@ -81,7 +81,7 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto
|
||||
return $ SomeConn SCSend (SendConnection connAlias sndQ)
|
||||
_ -> throwError SEBadConn
|
||||
|
||||
getReceiveQueue :: SQLiteStore -> SMPServer -> RecipientId -> m (ConnAlias, ReceiveQueue)
|
||||
getReceiveQueue :: SQLiteStore -> SMPServer -> SMP.RecipientId -> m (ConnAlias, ReceiveQueue)
|
||||
getReceiveQueue st SMPServer {host, port} recipientId = do
|
||||
rcvQueue <- getRcvQueueByRecipientId st recipientId host port
|
||||
connAlias <- getConnAliasByRcvQueue st recipientId
|
||||
|
||||
@@ -34,7 +34,7 @@ import Simplex.Messaging.Agent.Store
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Types
|
||||
import Simplex.Messaging.Agent.Store.Types
|
||||
import Simplex.Messaging.Agent.Transmission
|
||||
import Simplex.Messaging.Types
|
||||
import Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Util
|
||||
import Text.Read
|
||||
import qualified UnliftIO.Exception as E
|
||||
@@ -335,7 +335,7 @@ updateReceiveQueueStatus store rcvQueueId host port status =
|
||||
|]
|
||||
(Only status :. Only rcvQueueId :. Only host :. Only port)
|
||||
|
||||
updateSendQueueStatus :: MonadUnliftIO m => SQLiteStore -> SenderId -> HostName -> Maybe ServiceName -> QueueStatus -> m ()
|
||||
updateSendQueueStatus :: MonadUnliftIO m => SQLiteStore -> SMP.SenderId -> HostName -> Maybe ServiceName -> QueueStatus -> m ()
|
||||
updateSendQueueStatus store sndQueueId host port status =
|
||||
executeWithLock
|
||||
store
|
||||
@@ -357,7 +357,7 @@ instance ToField QueueDirection where toField = toField . show
|
||||
-- TODO add parser and serializer for DeliveryStatus? Pass DeliveryStatus?
|
||||
insertMsg :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> QueueDirection -> AgentMsgId -> Message -> m ()
|
||||
insertMsg store connAlias qDirection agentMsgId msg = do
|
||||
tstamp <- liftIO getCurrentTime
|
||||
ts <- liftIO getCurrentTime
|
||||
void $
|
||||
insertWithLock
|
||||
store
|
||||
@@ -366,4 +366,4 @@ insertMsg store connAlias qDirection agentMsgId msg = do
|
||||
INSERT INTO messages (conn_alias, agent_msg_id, timestamp, message, direction, msg_status)
|
||||
VALUES (?,?,?,?,?,"MDTransmitted");
|
||||
|]
|
||||
(Only connAlias :. Only agentMsgId :. Only tstamp :. Only qDirection :. Only msg)
|
||||
(Only connAlias :. Only agentMsgId :. Only ts :. Only qDirection :. Only msg)
|
||||
|
||||
@@ -30,8 +30,9 @@ import Data.Typeable ()
|
||||
import Network.Socket
|
||||
import Numeric.Natural
|
||||
import Simplex.Messaging.Agent.Store.Types
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Types (CorrId (..), Encoded, ErrorType, MsgBody, PublicKey, SenderId, errBadParameters, errMessageBody)
|
||||
import Simplex.Messaging.Types (CorrId (..), Encoded, ErrorType, MsgBody, PublicKey, errBadParameters, errMessageBody)
|
||||
import Simplex.Messaging.Util
|
||||
import System.IO
|
||||
import Text.Read
|
||||
@@ -73,7 +74,7 @@ data ACommand (p :: AParty) where
|
||||
READY :: ACommand Agent
|
||||
-- CONF :: OtherPartyId -> ACommand Agent
|
||||
-- LET :: OtherPartyId -> ACommand Client
|
||||
SUB :: SubMode -> ACommand Client
|
||||
SUB :: ACommand Client
|
||||
END :: ACommand Agent
|
||||
-- QST :: QueueDirection -> ACommand Client
|
||||
-- STAT :: QueueDirection -> Maybe QueueStatus -> Maybe SubMode -> ACommand Agent
|
||||
@@ -208,9 +209,7 @@ data Mode = On | Off deriving (Eq, Show, Read)
|
||||
|
||||
newtype AckMode = AckMode Mode deriving (Eq, Show)
|
||||
|
||||
newtype SubMode = SubMode Mode deriving (Show)
|
||||
|
||||
data SMPQueueInfo = SMPQueueInfo SMPServer SenderId EncryptionKey
|
||||
data SMPQueueInfo = SMPQueueInfo SMPServer SMP.SenderId EncryptionKey
|
||||
deriving (Show)
|
||||
|
||||
data ReplyMode = ReplyOff | ReplyOn | ReplyVia SMPServer deriving (Show)
|
||||
@@ -281,6 +280,8 @@ parseCommandP =
|
||||
"NEW " *> newCmd
|
||||
<|> "INV " *> invResp
|
||||
<|> "JOIN " *> joinCmd
|
||||
<|> "SUB" $> ACmd SClient SUB
|
||||
<|> "END" $> ACmd SAgent END
|
||||
<|> "SEND " *> sendCmd
|
||||
<|> "MSG " *> message
|
||||
<|> "ACK " *> acknowledge
|
||||
@@ -314,6 +315,8 @@ serializeCommand = \case
|
||||
NEW srv -> "NEW " <> serializeServer srv
|
||||
INV qInfo -> "INV " <> serializeSmpQueueInfo qInfo
|
||||
JOIN qInfo rMode -> "JOIN " <> serializeSmpQueueInfo qInfo <> replyMode rMode
|
||||
SUB -> "SUB"
|
||||
END -> "END"
|
||||
SEND msgBody -> "SEND " <> serializeMsg msgBody
|
||||
MSG aMsgId aTs ts st body ->
|
||||
B.unwords ["MSG", B.pack $ show aMsgId, B.pack $ formatISO8601Millis aTs, B.pack $ formatISO8601Millis ts, msgStatus st, serializeMsg body]
|
||||
@@ -349,11 +352,7 @@ tPutRaw h (corrId, connAlias, command) = do
|
||||
putLn h command
|
||||
|
||||
tGetRaw :: Handle -> IO ARawTransmission
|
||||
tGetRaw h = do
|
||||
corrId <- getLn h
|
||||
connAlias <- getLn h
|
||||
command <- getLn h
|
||||
return (corrId, connAlias, command)
|
||||
tGetRaw h = (,,) <$> getLn h <*> getLn h <*> getLn h
|
||||
|
||||
tPut :: MonadIO m => Handle -> ATransmission p -> m ()
|
||||
tPut h (CorrId corrId, connAlias, command) =
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE NumericUnderscores #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
@@ -34,6 +35,7 @@ import qualified Data.ByteString.Char8 as B
|
||||
import Data.Map.Strict (Map)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Maybe
|
||||
import GHC.IO.Exception (IOErrorType (..))
|
||||
import Network.Socket (ServiceName)
|
||||
import Numeric.Natural
|
||||
import Simplex.Messaging.Agent.Transmission (SMPServer (..))
|
||||
@@ -42,6 +44,8 @@ import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Types
|
||||
import Simplex.Messaging.Util
|
||||
import System.IO
|
||||
import System.IO.Error
|
||||
import System.Timeout
|
||||
|
||||
data SMPClient = SMPClient
|
||||
{ action :: Async (),
|
||||
@@ -57,11 +61,17 @@ type SMPServerTransmission = (SMPServer, RecipientId, Command 'Broker)
|
||||
|
||||
data SMPClientConfig = SMPClientConfig
|
||||
{ qSize :: Natural,
|
||||
defaultPort :: ServiceName
|
||||
defaultPort :: ServiceName,
|
||||
tcpTimeout :: Int
|
||||
}
|
||||
|
||||
smpDefaultConfig :: SMPClientConfig
|
||||
smpDefaultConfig = SMPClientConfig 16 "5223"
|
||||
smpDefaultConfig =
|
||||
SMPClientConfig
|
||||
{ qSize = 16,
|
||||
defaultPort = "5223",
|
||||
tcpTimeout = 2_000_000
|
||||
}
|
||||
|
||||
data Request = Request
|
||||
{ queueId :: QueueId,
|
||||
@@ -69,49 +79,60 @@ data Request = Request
|
||||
}
|
||||
|
||||
getSMPClient :: SMPServer -> SMPClientConfig -> TBQueue SMPServerTransmission -> IO SMPClient
|
||||
getSMPClient smpServer@SMPServer {host, port} SMPClientConfig {qSize, defaultPort} msgQ = do
|
||||
c <- atomically mkSMPClient
|
||||
action <- async $ runTCPClient host (fromMaybe defaultPort port) (client c)
|
||||
return c {action}
|
||||
where
|
||||
mkSMPClient :: STM SMPClient
|
||||
mkSMPClient = do
|
||||
clientCorrId <- newTVar 0
|
||||
sentCommands <- newTVar M.empty
|
||||
sndQ <- newTBQueue qSize
|
||||
rcvQ <- newTBQueue qSize
|
||||
return SMPClient {action = undefined, smpServer, clientCorrId, sentCommands, sndQ, rcvQ, msgQ}
|
||||
getSMPClient
|
||||
smpServer@SMPServer {host, port}
|
||||
SMPClientConfig {qSize, defaultPort, tcpTimeout}
|
||||
msgQ = do
|
||||
c <- atomically mkSMPClient
|
||||
started <- newEmptyTMVarIO
|
||||
action <- async $ runTCPClient host (fromMaybe defaultPort port) (client c started)
|
||||
tcpTimeout `timeout` atomically (takeTMVar started) >>= \case
|
||||
Just _ -> return c {action}
|
||||
_ -> throwIO err
|
||||
where
|
||||
err :: IOException
|
||||
err = mkIOError TimeExpired "connection timeout" Nothing Nothing
|
||||
|
||||
client :: SMPClient -> Handle -> IO ()
|
||||
client c h = do
|
||||
_line <- getLn h -- "Welcome to SMP"
|
||||
-- TODO test connection failure
|
||||
raceAny_ [send c h, process c, receive c h]
|
||||
mkSMPClient :: STM SMPClient
|
||||
mkSMPClient = do
|
||||
clientCorrId <- newTVar 0
|
||||
sentCommands <- newTVar M.empty
|
||||
sndQ <- newTBQueue qSize
|
||||
rcvQ <- newTBQueue qSize
|
||||
return SMPClient {action = undefined, smpServer, clientCorrId, sentCommands, sndQ, rcvQ, msgQ}
|
||||
|
||||
send :: SMPClient -> Handle -> IO ()
|
||||
send SMPClient {sndQ} h = forever $ atomically (readTBQueue sndQ) >>= tPut h
|
||||
client :: SMPClient -> TMVar () -> Handle -> IO ()
|
||||
client c started h = do
|
||||
_ <- getLn h -- "Welcome to SMP"
|
||||
atomically $ putTMVar started ()
|
||||
-- TODO call continuation on disconnection after raceAny_ exits
|
||||
raceAny_ [send c h, process c, receive c h]
|
||||
|
||||
receive :: SMPClient -> Handle -> IO ()
|
||||
receive SMPClient {rcvQ} h = forever $ tGet fromServer h >>= atomically . writeTBQueue rcvQ
|
||||
send :: SMPClient -> Handle -> IO ()
|
||||
send SMPClient {sndQ} h = forever $ atomically (readTBQueue sndQ) >>= tPut h
|
||||
|
||||
process :: SMPClient -> IO ()
|
||||
process SMPClient {rcvQ, sentCommands} = forever $ do
|
||||
(_, (corrId, qId, respOrErr)) <- atomically $ readTBQueue rcvQ
|
||||
cs <- readTVarIO sentCommands
|
||||
case M.lookup corrId cs of
|
||||
Nothing -> do
|
||||
case respOrErr of
|
||||
Right (Cmd SBroker cmd) -> atomically $ writeTBQueue msgQ (smpServer, qId, cmd)
|
||||
_ -> return ()
|
||||
Just Request {queueId, responseVar} -> atomically $ do
|
||||
modifyTVar sentCommands $ M.delete corrId
|
||||
putTMVar responseVar $
|
||||
if queueId == qId
|
||||
then case respOrErr of
|
||||
Left e -> Left $ SMPResponseError e
|
||||
Right (Cmd _ (ERR e)) -> Left $ SMPServerError e
|
||||
Right r -> Right r
|
||||
else Left SMPQueueIdError
|
||||
receive :: SMPClient -> Handle -> IO ()
|
||||
receive SMPClient {rcvQ} h = forever $ tGet fromServer h >>= atomically . writeTBQueue rcvQ
|
||||
|
||||
process :: SMPClient -> IO ()
|
||||
process SMPClient {rcvQ, sentCommands} = forever $ do
|
||||
(_, (corrId, qId, respOrErr)) <- atomically $ readTBQueue rcvQ
|
||||
cs <- readTVarIO sentCommands
|
||||
case M.lookup corrId cs of
|
||||
Nothing -> do
|
||||
case respOrErr of
|
||||
Right (Cmd SBroker cmd) -> atomically $ writeTBQueue msgQ (smpServer, qId, cmd)
|
||||
-- TODO send everything else to errQ and log in agent
|
||||
_ -> return ()
|
||||
Just Request {queueId, responseVar} -> atomically $ do
|
||||
modifyTVar sentCommands $ M.delete corrId
|
||||
putTMVar responseVar $
|
||||
if queueId == qId
|
||||
then case respOrErr of
|
||||
Left e -> Left $ SMPResponseError e
|
||||
Right (Cmd _ (ERR e)) -> Left $ SMPServerError e
|
||||
Right r -> Right r
|
||||
else Left SMPQueueIdError
|
||||
|
||||
data SMPClientError
|
||||
= SMPServerError ErrorType
|
||||
|
||||
@@ -20,8 +20,8 @@ import Data.Char (ord)
|
||||
import Data.Kind
|
||||
import Data.Time.Clock
|
||||
import Data.Time.ISO8601
|
||||
import Simplex.Messaging.Types
|
||||
import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Types
|
||||
import Simplex.Messaging.Util
|
||||
import System.IO
|
||||
import Text.Read
|
||||
@@ -51,6 +51,12 @@ type TransmissionOrError = (Signature, SignedOrError)
|
||||
|
||||
type RawTransmission = (ByteString, ByteString, ByteString, ByteString)
|
||||
|
||||
type RecipientId = QueueId
|
||||
|
||||
type SenderId = QueueId
|
||||
|
||||
type QueueId = Encoded
|
||||
|
||||
data Command (a :: Party) where
|
||||
NEW :: RecipientKey -> Command Recipient
|
||||
SUB :: Command Recipient
|
||||
|
||||
@@ -10,7 +10,6 @@ import Data.Map.Strict (Map)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Network.Socket (ServiceName)
|
||||
import Numeric.Natural
|
||||
import Simplex.Messaging.Types
|
||||
import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Server.MsgStore.STM
|
||||
import Simplex.Messaging.Server.QueueStore.STM
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
module Simplex.Messaging.Server.MsgStore where
|
||||
|
||||
import Data.Time.Clock
|
||||
import Simplex.Messaging.Protocol (RecipientId)
|
||||
import Simplex.Messaging.Types
|
||||
|
||||
data Message = Message
|
||||
|
||||
@@ -8,7 +8,7 @@ module Simplex.Messaging.Server.MsgStore.STM where
|
||||
|
||||
import Data.Map.Strict (Map)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Simplex.Messaging.Types
|
||||
import Simplex.Messaging.Protocol (RecipientId)
|
||||
import Simplex.Messaging.Server.MsgStore
|
||||
import UnliftIO.STM
|
||||
|
||||
|
||||
@@ -86,4 +86,7 @@ putLn :: Handle -> ByteString -> IO ()
|
||||
putLn h = B.hPut h . (<> "\r\n")
|
||||
|
||||
getLn :: Handle -> IO ByteString
|
||||
getLn h = B.pack <$> hGetLine h
|
||||
getLn h = trim_cr <$> B.hGetLine h
|
||||
where
|
||||
trim_cr "" = ""
|
||||
trim_cr s = if B.last s == '\r' then B.init s else s
|
||||
|
||||
@@ -29,12 +29,6 @@ type RecipientKey = PublicKey
|
||||
|
||||
type SenderKey = PublicKey
|
||||
|
||||
type RecipientId = QueueId
|
||||
|
||||
type SenderId = QueueId
|
||||
|
||||
type QueueId = Encoded
|
||||
|
||||
type MsgId = Encoded
|
||||
|
||||
type MsgBody = ByteString
|
||||
|
||||
@@ -85,7 +85,8 @@ cfg =
|
||||
smpCfg =
|
||||
SMPClientConfig
|
||||
{ qSize = 1,
|
||||
defaultPort = testPort
|
||||
defaultPort = testPort,
|
||||
tcpTimeout = 500_000
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -233,7 +233,7 @@ testSwitchSub =
|
||||
Resp "bcda" _ ok3 <- sendRecv rh2 ("1234", "bcda", rId, "ACK")
|
||||
(ok3, OK) #== "accepts ACK from the 2nd TCP connection"
|
||||
|
||||
timeout 1000 (tGet fromServer rh1) >>= \case
|
||||
1000 `timeout` tGet fromServer rh1 >>= \case
|
||||
Nothing -> return ()
|
||||
Just _ -> error "nothing else is delivered to the 1st TCP connection"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user