SMP client library (#9)

* functions to send SMP commands and receive responses

* refactor agent: use SMPClient

* fix tests, remove ServerClient.hs

* refactor processCommand

* fix Agent.hs

* fix SMPClient, tests

* "forever" to SMPClient process
This commit is contained in:
Evgeny Poberezkin
2021-01-13 19:32:21 +00:00
committed by Efim Poberezkin
parent b02bd42f84
commit 2e6ba85308
11 changed files with 224 additions and 231 deletions
+2 -7
View File
@@ -4,7 +4,7 @@ module Main where
import Simplex.Messaging.Agent (runSMPAgent)
import Simplex.Messaging.Agent.Env.SQLite
import Simplex.Messaging.Agent.ServerClient
import Simplex.Messaging.Client (smpDefaultConfig)
cfg :: AgentConfig
cfg =
@@ -13,12 +13,7 @@ cfg =
tbqSize = 16,
connIdBytes = 12,
dbFile = "smp-agent.db",
smpTcpPort = "5223",
smpConfig =
ServerClientConfig
{ tbqSize = 16,
corrIdBytes = 4
}
smpCfg = smpDefaultConfig
}
main :: IO ()
+1
View File
@@ -12,6 +12,7 @@ extra-source-files:
- README.md
dependencies:
- async == 2.2.*
- base >= 4.7 && < 5
- base64-bytestring >= 1.0 && < 1.3
- bytestring == 0.10.*
+70 -127
View File
@@ -15,16 +15,14 @@ import Control.Monad.IO.Unlift
import Control.Monad.Reader
import Crypto.Random
import qualified Data.Map as M
import Data.Maybe
import Network.Socket
import Simplex.Messaging.Agent.Env.SQLite
import Simplex.Messaging.Agent.ServerClient (ServerClient (..), newServerClient)
import Simplex.Messaging.Agent.Store
import Simplex.Messaging.Agent.Store.SQLite
import Simplex.Messaging.Agent.Store.Types
import Simplex.Messaging.Agent.Transmission
import Simplex.Messaging.Client
import Simplex.Messaging.Server (randomBytes)
import Simplex.Messaging.Server.Transmission (Cmd (..), CorrId (..), PublicKey, SParty (..))
import Simplex.Messaging.Server.Transmission (PublicKey)
import qualified Simplex.Messaging.Server.Transmission as SMP
import Simplex.Messaging.Transport
import UnliftIO.Async
@@ -87,61 +85,43 @@ processCommand ::
AgentClient ->
ATransmission 'Client ->
m ()
processCommand AgentClient {respQ, servers, commands} t@(_, connAlias, cmd) =
processCommand AgentClient {sndQ, smpClients} (corrId, connAlias, cmd) =
case cmd of
NEW smpServer -> do
srv <- getSMPServer smpServer
smpT <- mkSmpNEW smpServer
atomically $ writeTBQueue (smpSndQ srv) smpT
return ()
JOIN (SMPQueueInfo smpServer senderId encKey) _ -> do
srv <- getSMPServer smpServer
smpT <- mkConfSEND smpServer senderId encKey
atomically $ writeTBQueue (smpSndQ srv) smpT
return ()
NEW smpServer -> createNewConnection smpServer
JOIN smpQueueInfo replyMode -> joinConnection smpQueueInfo replyMode
_ -> throwError PROHIBITED
where
replyError :: ErrorType -> SomeException -> m a
replyError err e = do
liftIO . putStrLn $ "Exception: " ++ show e -- TODO remove
throwError err
getSMPServer :: SMPServer -> m ServerClient
getSMPServer SMPServer {host, port} = do
defPort <- asks $ smpTcpPort . config
let p = fromMaybe defPort port
atomically (M.lookup (host, p) <$> readTVar servers)
>>= maybe (newSMPServer host p) return
newSMPServer :: HostName -> ServiceName -> m ServerClient
newSMPServer host port = do
cfg <- asks $ smpConfig . config
-- store <- asks db
-- _serverId <- withStore (addServer store s) `E.catch` replyError INTERNAL
srv <- newServerClient cfg respQ host port `E.catch` replyError (BROKER smpErrTCPConnection)
atomically . modifyTVar servers $ M.insert (host, port) srv
return srv
mkSmpNEW :: SMPServer -> m SMP.Transmission
mkSmpNEW smpServer = do
createNewConnection :: SMPServer -> m ()
createNewConnection smpServer = do
c <- getSMPServerClient smpServer
g <- asks idsDrg
smpCorrId <- atomically $ CorrId <$> randomBytes 4 g
recipientKey <- atomically $ randomBytes 16 g -- TODO replace with cryptographic key pair
let rcvPrivateKey = recipientKey
toSMP = ("", (smpCorrId, "", Cmd SRecipient $ SMP.NEW recipientKey))
req =
Request
{ fromClient = t,
toSMP,
state = NEWRequestState {connAlias, smpServer, rcvPrivateKey}
}
atomically . modifyTVar commands $ M.insert smpCorrId req -- TODO check ID collision
return toSMP
(recipientId, senderId) <-
liftIO (createSMPQueue c recipientKey)
`E.catch` smpClientError
`E.catch` replyError INTERNAL
encryptKey <- atomically $ randomBytes 16 g -- TODO replace with cryptographic key pair
let decryptKey = encryptKey
withStore $ \st ->
createRcvConn st connAlias $
ReceiveQueue
{ server = smpServer,
rcvId = recipientId,
rcvPrivateKey,
sndId = Just senderId,
sndKey = Nothing,
decryptKey,
verifyKey = Nothing,
status = New,
ackMode = AckMode On
}
respond . INV $ SMPQueueInfo smpServer senderId encryptKey
mkConfSEND :: SMPServer -> SMP.SenderId -> PublicKey -> m SMP.Transmission
mkConfSEND smpServer senderId encryptKey = do
joinConnection :: SMPQueueInfo -> ReplyMode -> m ()
joinConnection (SMPQueueInfo smpServer senderId encryptKey) _ = do
c <- getSMPServerClient smpServer
g <- asks idsDrg
smpCorrId <- atomically $ CorrId <$> randomBytes 4 g
senderKey <- atomically $ randomBytes 16 g -- TODO replace with cryptographic key pair
verifyKey <- atomically $ randomBytes 16 g -- TODO replace with cryptographic key pair
-- TODO create connection with NEW status, it will be upgraded to CONFIRMED status once SMP server replies OK to SEND
@@ -160,15 +140,38 @@ processCommand AgentClient {respQ, servers, commands} t@(_, connAlias, cmd) =
status = New,
ackMode = AckMode On
}
let toSMP = ("", (smpCorrId, senderId, Cmd SSender $ SMP.SEND msg))
req =
Request
{ fromClient = t,
toSMP,
state = ConfSENDRequestState {connAlias, smpServer, senderId, sndPrivateKey, encryptKey}
}
atomically . modifyTVar commands $ M.insert smpCorrId req -- TODO check ID collision
return toSMP
liftIO (sendSMPMessage c "" senderId msg)
`E.catch` smpClientError
-- `E.catch` replyError INTERNAL
-- TODO the problem here is that while the intention of the 2nd catch was to catch
-- all other exceptions, because smpClientError "throwError" via left channel
-- and of how ExceptT instance of UnliftIO is implemented, the second `catch` catches
-- Left channel... The only solution is to use runtime exceptions and not ExceptT
withStore $ \st -> updateQueueStatus st connAlias SND Confirmed
respond OK
smpClientError :: SMPClientError -> m a
smpClientError = \case
SMPServerError e -> throwError $ SMP e
_ -> throwError INTERNAL
-- TODO
replyError :: ErrorType -> SomeException -> m a
replyError err e = do
liftIO . putStrLn $ "Exception: " ++ show e -- TODO remove
throwError err
getSMPServerClient :: SMPServer -> m SMPClient
getSMPServerClient srv =
atomically (M.lookup srv <$> readTVar smpClients)
>>= maybe newSMPClient return
where
newSMPClient :: m SMPClient
newSMPClient = do
cfg <- asks $ smpCfg . config
c <- liftIO (getSMPClient srv cfg) `E.catch` replyError (BROKER smpErrTCPConnection)
atomically . modifyTVar smpClients $ M.insert srv c
return c
mkConfirmation :: PublicKey -> PublicKey -> m SMP.MsgBody
mkConfirmation _encKey senderKey = do
@@ -176,71 +179,11 @@ processCommand AgentClient {respQ, servers, commands} t@(_, connAlias, cmd) =
-- TODO encryption
return msg
processSmp :: forall m. (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m ()
processSmp c@AgentClient {respQ, sndQ, commands} = forever $ do
(_, (smpCorrId, qId, cmdOrErr)) <- atomically $ readTBQueue respQ
liftIO $ putStrLn "received from server" -- TODO remove
liftIO $ print (smpCorrId, qId, cmdOrErr)
req <- atomically $ M.lookup smpCorrId <$> readTVar commands
case req of -- TODO empty correlation ID is ok - it can be a message
Nothing -> atomically $ writeTBQueue sndQ ("", "", ERR $ BROKER smpErrCorrelationId)
Just r@Request {fromClient = (corrId, cAlias, _)} ->
-- TODO remove matched correlation ID
runExceptT (processResponse c r cmdOrErr) >>= \case
Left e -> atomically $ writeTBQueue sndQ (corrId, cAlias, ERR e)
Right _ -> return ()
respond :: ACommand 'Agent -> m ()
respond c = atomically $ writeTBQueue sndQ (corrId, connAlias, c)
processResponse ::
forall m.
(MonadUnliftIO m, MonadReader Env m, MonadError ErrorType m) =>
AgentClient ->
Request ->
Either SMP.ErrorType SMP.Cmd ->
m ()
processResponse
AgentClient {sndQ}
Request {fromClient = (corrId, cAlias, cmd), toSMP = (_, (_, _, smpCmd)), state}
cmdOrErr = do
case cmdOrErr of
Left e -> throwError $ SMP e
Right resp -> case resp of
Cmd SBroker (SMP.IDS recipientId senderId) -> case smpCmd of
Cmd SRecipient (SMP.NEW _) -> case (cmd, state) of
(NEW _, NEWRequestState {connAlias, smpServer, rcvPrivateKey}) -> do
g <- asks idsDrg
encryptKey <- atomically $ randomBytes 16 g -- TODO replace with cryptographic key pair
let decryptKey = encryptKey
withStore $ \st ->
createRcvConn st connAlias $
ReceiveQueue
{ server = smpServer,
rcvId = recipientId,
rcvPrivateKey,
sndId = Just senderId,
sndKey = Nothing,
decryptKey,
verifyKey = Nothing,
status = New,
ackMode = AckMode On
}
respond . INV $ SMPQueueInfo smpServer senderId encryptKey
_ -> throwError INTERNAL
_ -> throwError $ BROKER smpUnexpectedResponse
Cmd SBroker (SMP.OK) -> case smpCmd of
Cmd SSender (SMP.SEND _) -> case (cmd, state) of
(JOIN _ _, ConfSENDRequestState {connAlias}) -> do
withStore $ \st -> updateQueueStatus st connAlias SND Confirmed
respond OK
_ -> throwError INTERNAL
_ -> throwError $ BROKER smpUnexpectedResponse
Cmd SBroker (SMP.ERR e) -> case smpCmd of
Cmd SSender (SMP.SEND _) -> case (cmd, state) of
(JOIN _ _, ConfSENDRequestState {connAlias}) -> do
withStore $ \st -> deleteConn st connAlias
respond . ERR $ SMP e
_ -> throwError INTERNAL
_ -> throwError $ BROKER smpUnexpectedResponse
_ -> throwError UNSUPPORTED
where
respond :: ACommand 'Agent -> m ()
respond c = atomically $ writeTBQueue sndQ (corrId, cAlias, c)
processSmp :: MonadUnliftIO m => AgentClient -> m ()
processSmp AgentClient {respQ} = forever $ do
-- TODO this will only process messages and notifications
(_, (_smpCorrId, _qId, _cmdOrErr)) <- atomically $ readTBQueue respQ
return ()
+7 -31
View File
@@ -9,13 +9,11 @@ import Control.Monad.IO.Unlift
import Crypto.Random
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Network.Socket (HostName, ServiceName)
import Network.Socket
import Numeric.Natural
import Simplex.Messaging.Agent.ServerClient
import Simplex.Messaging.Agent.Store
import Simplex.Messaging.Agent.Store.SQLite
import Simplex.Messaging.Agent.Transmission
import Simplex.Messaging.Server.Transmission (PublicKey, SenderId)
import Simplex.Messaging.Client
import qualified Simplex.Messaging.Server.Transmission as SMP
import UnliftIO.STM
@@ -24,8 +22,7 @@ data AgentConfig = AgentConfig
tbqSize :: Natural,
connIdBytes :: Int,
dbFile :: String,
smpTcpPort :: ServiceName,
smpConfig :: ServerClientConfig
smpCfg :: SMPClientConfig
}
data Env = Env
@@ -37,39 +34,18 @@ data Env = Env
data AgentClient = AgentClient
{ rcvQ :: TBQueue (ATransmission Client),
sndQ :: TBQueue (ATransmission Agent),
-- TODO rename, respQ is only for messages and notifications, not for responses
respQ :: TBQueue SMP.TransmissionOrError,
servers :: TVar (Map (HostName, ServiceName) ServerClient),
commands :: TVar (Map SMP.CorrId Request)
smpClients :: TVar (Map SMPServer SMPClient)
}
data Request = Request
{ fromClient :: ATransmission Client,
toSMP :: SMP.Transmission,
state :: RequestState
}
data RequestState
= NEWRequestState
{ connAlias :: ConnAlias,
smpServer :: SMPServer,
rcvPrivateKey :: PrivateKey
}
| ConfSENDRequestState
{ connAlias :: ConnAlias,
smpServer :: SMPServer,
senderId :: SenderId,
sndPrivateKey :: PrivateKey,
encryptKey :: PublicKey
}
newAgentClient :: Natural -> STM AgentClient
newAgentClient qSize = do
rcvQ <- newTBQueue qSize
sndQ <- newTBQueue qSize
respQ <- newTBQueue qSize
servers <- newTVar M.empty
commands <- newTVar M.empty
return AgentClient {rcvQ, sndQ, respQ, servers, commands}
smpClients <- newTVar M.empty
return AgentClient {rcvQ, sndQ, respQ, smpClients}
newEnv :: (MonadUnliftIO m, MonadRandom m) => AgentConfig -> m Env
newEnv config = do
@@ -1,55 +0,0 @@
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Simplex.Messaging.Agent.ServerClient where
import Control.Monad
import Control.Monad.IO.Unlift
import Network.Socket (HostName, ServiceName)
import Numeric.Natural
import qualified Simplex.Messaging.Server.Transmission as SMP
import Simplex.Messaging.Transport
import UnliftIO.Async
import UnliftIO.IO
import UnliftIO.STM
data ServerClientConfig = ServerClientConfig
{ tbqSize :: Natural,
corrIdBytes :: Natural
}
data ServerClient = ServerClient
{ smpSndQ :: TBQueue SMP.Transmission,
smpRcvQ :: TBQueue SMP.TransmissionOrError
-- srvA :: Async ()
}
newServerClient ::
forall m.
MonadUnliftIO m =>
ServerClientConfig ->
TBQueue SMP.TransmissionOrError ->
HostName ->
ServiceName ->
m ServerClient
newServerClient cfg smpRcvQ host port = do
smpSndQ <- atomically . newTBQueue $ tbqSize cfg
let c = ServerClient {smpSndQ, smpRcvQ}
_srvA <- async $ runTCPClient host port (client c)
-- TODO because exception can be thrown inside async it is not caught by newSMPServer
-- there possibly needs to be another channel to communicate with ServerClient if it fails
-- alternatively, there may be just timeout on sent commands -
-- in this case late responses should be just ignored rather than result in smpErrCorrelationId
return c
where
client :: ServerClient -> Handle -> m ()
client c h = do
_line <- getLn h -- "Welcome to SMP"
-- TODO test connection failure
send c h `race_` receive h
send :: ServerClient -> Handle -> m ()
send ServerClient {smpSndQ} h = forever $ atomically (readTBQueue smpSndQ) >>= SMP.tPut h
receive :: Handle -> m ()
receive h = forever $ SMP.tGet SMP.fromServer h >>= atomically . writeTBQueue smpRcvQ
+1 -3
View File
@@ -15,7 +15,7 @@ import Data.Time.Clock (UTCTime)
import Data.Type.Equality
import Simplex.Messaging.Agent.Store.Types
import Simplex.Messaging.Agent.Transmission
import Simplex.Messaging.Server.Transmission (Encoded, PublicKey, QueueId)
import Simplex.Messaging.Server.Transmission (PrivateKey, PublicKey, QueueId)
data ReceiveQueue = ReceiveQueue
{ server :: SMPServer,
@@ -85,8 +85,6 @@ data MessageDelivery = MessageDelivery
msgStatus :: DeliveryStatus
}
type PrivateKey = Encoded
data DeliveryStatus
= MDTransmitted -- SMP: SEND sent / MSG received
| MDConfirmed -- SMP: OK received / ACK sent
+1 -1
View File
@@ -104,7 +104,7 @@ data SMPServer = SMPServer
port :: Maybe ServiceName,
keyHash :: Maybe KeyHash
}
deriving (Eq, Show)
deriving (Eq, Ord, Show)
type KeyHash = Encoded
+134
View File
@@ -0,0 +1,134 @@
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Simplex.Messaging.Client
( SMPClient,
getSMPClient,
createSMPQueue,
sendSMPMessage,
sendSMPCommand,
SMPClientError (..),
SMPClientConfig (..),
smpDefaultConfig,
)
where
import Control.Concurrent.Async
import Control.Concurrent.STM
import Control.Exception
import Control.Monad
import qualified Data.ByteString.Char8 as B
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.Maybe
import Network.Socket (ServiceName)
import Numeric.Natural
import Simplex.Messaging.Agent.Transmission (SMPServer (..))
import Simplex.Messaging.Server.Transmission
import Simplex.Messaging.Transport
import Simplex.Messaging.Util
import System.IO
data SMPClient = SMPClient
{ action :: Async (),
clientCorrId :: TVar Natural,
sentCommands :: TVar (Map CorrId Request),
sndQ :: TBQueue Transmission,
rcvQ :: TBQueue TransmissionOrError
}
data SMPClientConfig = SMPClientConfig
{ qSize :: Natural,
defaultPort :: ServiceName
}
smpDefaultConfig :: SMPClientConfig
smpDefaultConfig = SMPClientConfig 16 "5223"
data Request = Request
{ queueId :: QueueId,
responseVar :: TMVar (Either SMPClientError Cmd)
}
getSMPClient :: SMPServer -> SMPClientConfig -> IO SMPClient
getSMPClient SMPServer {host, port} SMPClientConfig {qSize, defaultPort} = do
c <-
atomically $
SMPClient undefined <$> newTVar 0 <*> newTVar M.empty <*> newTBQueue qSize <*> newTBQueue qSize
action <- async $ runTCPClient host (fromMaybe defaultPort port) (client c)
return c {action}
where
client :: SMPClient -> Handle -> IO ()
client c h = do
_line <- getLn h -- "Welcome to SMP"
-- TODO test connection failure
raceAny_ [send c h, process c, receive c h]
send :: SMPClient -> Handle -> IO ()
send SMPClient {sndQ} h = forever $ atomically (readTBQueue sndQ) >>= tPut h
receive :: SMPClient -> Handle -> IO ()
receive SMPClient {rcvQ} h = forever $ tGet fromServer h >>= atomically . writeTBQueue rcvQ
process :: SMPClient -> IO ()
process SMPClient {rcvQ, sentCommands} = forever . atomically $ do
(_, (corrId, qId, respOrErr)) <- readTBQueue rcvQ
cs <- readTVar sentCommands
case M.lookup corrId cs of
Nothing -> return () -- TODO send to message channel or error channel
Just Request {queueId, responseVar} -> do
modifyTVar sentCommands $ M.delete corrId
putTMVar responseVar $
if queueId == qId
then either (Left . SMPResponseError) Right respOrErr
else Left SMPQueueIdError
data SMPClientError
= SMPServerError ErrorType
| SMPResponseError ErrorType
| SMPQueueIdError
| SMPUnexpectedResponse
| SMPResponseTimeout
| SMPClientError
deriving (Eq, Show, Exception)
createSMPQueue :: SMPClient -> RecipientKey -> IO (RecipientId, SenderId)
createSMPQueue c rKey = do
sendSMPCommand c "" "" (Cmd SRecipient $ NEW rKey) >>= \case
Cmd _ (IDS rId sId) -> return (rId, sId)
Cmd _ (ERR e) -> throwIO $ SMPServerError e
_ -> throwIO SMPUnexpectedResponse
sendSMPMessage :: SMPClient -> SenderKey -> QueueId -> MsgBody -> IO ()
sendSMPMessage c sKey qId msg = do
sendSMPCommand c sKey qId (Cmd SSender $ SEND msg) >>= \case
Cmd _ OK -> return ()
Cmd _ (ERR e) -> throwIO $ SMPServerError e
_ -> throwIO SMPUnexpectedResponse
sendSMPCommand :: SMPClient -> PrivateKey -> QueueId -> Cmd -> IO Cmd
sendSMPCommand SMPClient {sndQ, sentCommands, clientCorrId} pKey qId cmd = do
corrId <- atomically getNextCorrId
t <- signTransmission (corrId, qId, cmd)
atomically (send corrId t) >>= atomically . takeTMVar >>= either throwIO return
where
getNextCorrId :: STM CorrId
getNextCorrId = do
i <- (+ 1) <$> readTVar clientCorrId
writeTVar clientCorrId i
return . CorrId . B.pack $ show i
-- TODO this is a stub - to replace with cryptographic signature
signTransmission :: Signed -> IO Transmission
signTransmission signed = return (pKey, signed)
send :: CorrId -> Transmission -> STM (TMVar (Either SMPClientError Cmd))
send corrId t = do
r <- newEmptyTMVar
modifyTVar sentCommands . M.insert corrId $ Request qId r
writeTBQueue sndQ t
return r
@@ -153,6 +153,8 @@ instance IsString CorrId where
type PublicKey = Encoded
type PrivateKey = Encoded
type Signature = Encoded
type RecipientKey = PublicKey
+1 -1
View File
@@ -14,7 +14,7 @@ import Test.Hspec
agentTests :: Spec
agentTests = do
describe "SQLite store" storeTests
fdescribe "SMP agent protocol syntax" syntaxTests
describe "SMP agent protocol syntax" syntaxTests
(>#>) :: ARawTransmission -> ARawTransmission -> Expectation
command >#> response = smpAgentTest command `shouldReturn` response
+5 -6
View File
@@ -10,8 +10,8 @@ import Network.Socket
import SMPClient (testPort, withSmpServer)
import Simplex.Messaging.Agent
import Simplex.Messaging.Agent.Env.SQLite
import Simplex.Messaging.Agent.ServerClient
import Simplex.Messaging.Agent.Transmission
import Simplex.Messaging.Client (SMPClientConfig (..))
import Simplex.Messaging.Transport
import UnliftIO.Concurrent
import UnliftIO.Directory
@@ -40,11 +40,10 @@ cfg =
tbqSize = 1,
connIdBytes = 12,
dbFile = testDB,
smpTcpPort = testPort,
smpConfig =
ServerClientConfig
{ tbqSize = 1,
corrIdBytes = 4
smpCfg =
SMPClientConfig
{ qSize = 1,
defaultPort = testPort
}
}