mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-05-15 16:05:24 +00:00
test client (WIP)
This commit is contained in:
@@ -29,6 +29,15 @@ executables:
|
||||
source-dirs: src
|
||||
main: Main.hs
|
||||
|
||||
library:
|
||||
source-dirs: src
|
||||
|
||||
tests:
|
||||
smp-server-test:
|
||||
source-dirs: tests
|
||||
main: Test.hs
|
||||
dependencies: simplex-messaging
|
||||
|
||||
ghc-options:
|
||||
- -haddock
|
||||
- -O2
|
||||
|
||||
+3
-113
@@ -1,122 +1,12 @@
|
||||
{-# LANGUAGE BlockArguments #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module Main where
|
||||
|
||||
-- import Polysemy
|
||||
import ConnStore
|
||||
import Control.Monad
|
||||
import Control.Monad.IO.Unlift
|
||||
import Control.Monad.Reader
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Env.STM
|
||||
import Network.Socket
|
||||
import Text.Read
|
||||
import Transmission
|
||||
import Transport
|
||||
import UnliftIO.Async
|
||||
import UnliftIO.Concurrent
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.IO
|
||||
import UnliftIO.STM
|
||||
import Network.Socket (ServiceName)
|
||||
import Server (runSMPServer)
|
||||
|
||||
port :: ServiceName
|
||||
port = "5223"
|
||||
|
||||
main :: IO ()
|
||||
main = do
|
||||
env <- atomically $ newEnv port
|
||||
putStrLn $ "Listening on port " ++ port
|
||||
runReaderT (runTCPServer runClient) env
|
||||
|
||||
runTCPServer :: (MonadReader Env m, MonadUnliftIO m) => (Handle -> m ()) -> m ()
|
||||
runTCPServer server =
|
||||
E.bracket startTCPServer (liftIO . close) $ \sock -> forever $ do
|
||||
h <- acceptTCPConn sock
|
||||
putLn h "Welcome"
|
||||
forkFinally (server h) (const $ hClose h)
|
||||
|
||||
runClient :: (MonadUnliftIO m, MonadReader Env m) => Handle -> m ()
|
||||
runClient h = do
|
||||
c <- atomically $ newClient h
|
||||
void $ race (client c) (receive c)
|
||||
|
||||
receive :: (MonadUnliftIO m, MonadReader Env m) => Client -> m ()
|
||||
receive Client {handle, channel} = forever $ do
|
||||
signature <- getLn handle
|
||||
connId <- getLn handle
|
||||
command <- getLn handle
|
||||
cmdOrError <- parseReadVerifyTransmission handle signature connId command
|
||||
atomically $ writeTChan channel cmdOrError
|
||||
|
||||
parseReadVerifyTransmission :: forall m. (MonadUnliftIO m, MonadReader Env m) => Handle -> String -> String -> String -> m SomeSigned
|
||||
parseReadVerifyTransmission h signature connId command = do
|
||||
let cmd = parseCommand command
|
||||
cmd' <- case cmd of
|
||||
Cmd SBroker _ -> return cmd
|
||||
Cmd _ (CREATE _) -> signed False cmd errHasCredentials
|
||||
Cmd _ (SEND msgBody) -> getSendMsgBody msgBody
|
||||
Cmd _ _ -> verifyConnSignature cmd -- signed True cmd errNoCredentials
|
||||
return (Just connId, cmd')
|
||||
where
|
||||
signed :: Bool -> Cmd -> Int -> m Cmd
|
||||
signed isSigned cmd errCode =
|
||||
return
|
||||
if isSigned == (signature /= "") && isSigned == (connId /= "")
|
||||
then cmd
|
||||
else syntaxError errCode
|
||||
getSendMsgBody :: MsgBody -> m Cmd
|
||||
getSendMsgBody msgBody =
|
||||
if connId == ""
|
||||
then return $ syntaxError errNoConnectionId
|
||||
else case B.unpack msgBody of
|
||||
':' : body -> return . smpSend $ B.pack body
|
||||
sizeStr -> case readMaybe sizeStr :: Maybe Int of
|
||||
Just size -> do
|
||||
body <- getBytes h size
|
||||
s <- getLn h
|
||||
return if s == "" then smpSend body else syntaxError errMessageBodySize
|
||||
Nothing -> return $ syntaxError errMessageBody
|
||||
verifyConnSignature :: Cmd -> m Cmd
|
||||
verifyConnSignature cmd@(Cmd party _) =
|
||||
if null signature || null connId
|
||||
then return $ syntaxError errNoCredentials
|
||||
else do
|
||||
store <- asks connStore
|
||||
getConn store party connId >>= \case
|
||||
Right Connection {recipientKey, senderKey} -> do
|
||||
res <- case party of
|
||||
SRecipient -> verifySignature recipientKey
|
||||
SSender -> case senderKey of
|
||||
Just key -> verifySignature key
|
||||
Nothing -> return False
|
||||
SBroker -> return False
|
||||
if res then return cmd else return $ smpError AUTH
|
||||
Left err -> return $ smpError err
|
||||
verifySignature :: Encoded -> m Bool
|
||||
verifySignature key = return $ signature == key
|
||||
|
||||
client :: (MonadUnliftIO m, MonadReader Env m) => Client -> m ()
|
||||
client Client {handle, channel} = loop
|
||||
where
|
||||
loop = forever $ do
|
||||
(_, cmdOrErr) <- atomically $ readTChan channel
|
||||
response <- case cmdOrErr of
|
||||
Cmd SRecipient (CREATE recipientKey) -> do
|
||||
store <- asks connStore
|
||||
conn <- createConn store recipientKey
|
||||
case conn of
|
||||
Right Connection {recipientId, senderId} -> return $ "CONN " ++ recipientId ++ " " ++ senderId
|
||||
Left e -> return $ "ERROR " ++ show e
|
||||
Cmd SRecipient _ -> return "OK"
|
||||
Cmd SSender _ -> return "OK"
|
||||
Cmd SBroker (ERROR e) -> return $ "ERROR " ++ show e
|
||||
_ -> return "ERROR INTERNAL"
|
||||
putLn handle response
|
||||
liftIO $ print cmdOrErr
|
||||
runSMPServer port
|
||||
|
||||
+117
@@ -0,0 +1,117 @@
|
||||
{-# LANGUAGE BlockArguments #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module Server (runSMPServer) where
|
||||
|
||||
import ConnStore
|
||||
import Control.Monad
|
||||
import Control.Monad.IO.Unlift
|
||||
import Control.Monad.Reader
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Env.STM
|
||||
import Network.Socket
|
||||
import Text.Read
|
||||
import Transmission
|
||||
import Transport
|
||||
import UnliftIO.Async
|
||||
import UnliftIO.Concurrent
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.IO
|
||||
import UnliftIO.STM
|
||||
|
||||
runSMPServer :: ServiceName -> IO ()
|
||||
runSMPServer port = do
|
||||
env <- atomically $ newEnv port
|
||||
runReaderT (runTCPServer runClient) env
|
||||
|
||||
runTCPServer :: (MonadReader Env m, MonadUnliftIO m) => (Handle -> m ()) -> m ()
|
||||
runTCPServer server =
|
||||
E.bracket startTCPServer (liftIO . close) $ \sock -> forever $ do
|
||||
h <- acceptTCPConn sock
|
||||
putLn h "Welcome to SMP"
|
||||
forkFinally (server h) (const $ hClose h)
|
||||
|
||||
runClient :: (MonadUnliftIO m, MonadReader Env m) => Handle -> m ()
|
||||
runClient h = do
|
||||
c <- atomically $ newClient h
|
||||
void $ race (client c) (receive c)
|
||||
|
||||
receive :: (MonadUnliftIO m, MonadReader Env m) => Client -> m ()
|
||||
receive Client {handle, channel} = forever $ do
|
||||
signature <- getLn handle
|
||||
connId <- getLn handle
|
||||
command <- getLn handle
|
||||
cmdOrError <- parseReadVerifyTransmission handle signature connId command
|
||||
atomically $ writeTChan channel cmdOrError
|
||||
|
||||
parseReadVerifyTransmission :: forall m. (MonadUnliftIO m, MonadReader Env m) => Handle -> String -> String -> String -> m SomeSigned
|
||||
parseReadVerifyTransmission h signature connId command = do
|
||||
let cmd = parseCommand command
|
||||
cmd' <- case cmd of
|
||||
Cmd SBroker _ -> return cmd
|
||||
Cmd _ (CREATE _) -> signed False cmd errHasCredentials
|
||||
Cmd _ (SEND msgBody) -> getSendMsgBody msgBody
|
||||
Cmd _ _ -> verifyConnSignature cmd -- signed True cmd errNoCredentials
|
||||
return (connId, cmd')
|
||||
where
|
||||
signed :: Bool -> Cmd -> Int -> m Cmd
|
||||
signed isSigned cmd errCode =
|
||||
return
|
||||
if isSigned == (signature /= "") && isSigned == (connId /= "")
|
||||
then cmd
|
||||
else syntaxError errCode
|
||||
getSendMsgBody :: MsgBody -> m Cmd
|
||||
getSendMsgBody msgBody =
|
||||
if connId == ""
|
||||
then return $ syntaxError errNoConnectionId
|
||||
else case B.unpack msgBody of
|
||||
':' : body -> return . smpSend $ B.pack body
|
||||
sizeStr -> case readMaybe sizeStr :: Maybe Int of
|
||||
Just size -> do
|
||||
body <- getBytes h size
|
||||
s <- getLn h
|
||||
return if s == "" then smpSend body else syntaxError errMessageBodySize
|
||||
Nothing -> return $ syntaxError errMessageBody
|
||||
verifyConnSignature :: Cmd -> m Cmd
|
||||
verifyConnSignature cmd@(Cmd party _) =
|
||||
if null signature || null connId
|
||||
then return $ syntaxError errNoCredentials
|
||||
else do
|
||||
store <- asks connStore
|
||||
getConn store party connId >>= \case
|
||||
Right Connection {recipientKey, senderKey} -> do
|
||||
res <- case party of
|
||||
SRecipient -> verifySignature recipientKey
|
||||
SSender -> case senderKey of
|
||||
Just key -> verifySignature key
|
||||
Nothing -> return False
|
||||
SBroker -> return False
|
||||
if res then return cmd else return $ smpError AUTH
|
||||
Left err -> return $ smpError err
|
||||
verifySignature :: Encoded -> m Bool
|
||||
verifySignature key = return $ signature == key
|
||||
|
||||
client :: (MonadUnliftIO m, MonadReader Env m) => Client -> m ()
|
||||
client Client {handle, channel} = loop
|
||||
where
|
||||
loop = forever $ do
|
||||
(_, cmdOrErr) <- atomically $ readTChan channel
|
||||
response <- case cmdOrErr of
|
||||
Cmd SRecipient (CREATE recipientKey) -> do
|
||||
store <- asks connStore
|
||||
conn <- createConn store recipientKey
|
||||
case conn of
|
||||
Right Connection {recipientId, senderId} -> return $ "CONN " ++ recipientId ++ " " ++ senderId
|
||||
Left e -> return $ "ERROR " ++ show e
|
||||
Cmd SRecipient _ -> return "OK"
|
||||
Cmd SSender _ -> return "OK"
|
||||
Cmd SBroker (ERROR e) -> return $ "ERROR " ++ show e
|
||||
_ -> return "ERROR INTERNAL"
|
||||
putLn handle response
|
||||
liftIO $ print cmdOrErr
|
||||
+7
-4
@@ -21,16 +21,16 @@ $( singletons
|
||||
|]
|
||||
)
|
||||
|
||||
type Transmission (a :: Party) = (Signed a, Maybe Signature)
|
||||
|
||||
type Signed (a :: Party) = (Maybe ConnId, Command a)
|
||||
type Signed (a :: Party) = (ConnId, Command a)
|
||||
|
||||
data Cmd where
|
||||
Cmd :: Sing a -> Command a -> Cmd
|
||||
|
||||
deriving instance Show Cmd
|
||||
|
||||
type SomeSigned = (Maybe ConnId, Cmd)
|
||||
type SomeSigned = (ConnId, Cmd)
|
||||
|
||||
type Transmission = (Signature, SomeSigned)
|
||||
|
||||
data Command (a :: Party) where
|
||||
CREATE :: RecipientKey -> Command Recipient
|
||||
@@ -68,6 +68,9 @@ parseCommand command = case words command of
|
||||
err = syntaxError errBadParameters
|
||||
rCmd = Cmd SRecipient
|
||||
|
||||
serializeCommand :: Cmd -> String
|
||||
serializeCommand _ = "TODO"
|
||||
|
||||
syntaxError :: Int -> Cmd
|
||||
syntaxError err = smpError $ SYNTAX err
|
||||
|
||||
|
||||
@@ -27,6 +27,10 @@ acceptTCPConn :: MonadIO m => Socket -> m Handle
|
||||
acceptTCPConn sock = liftIO $ do
|
||||
(conn, peer) <- accept sock
|
||||
putStrLn $ "Accepted connection from " ++ show peer
|
||||
getSocketHandle conn
|
||||
|
||||
getSocketHandle :: MonadIO m => Socket -> m Handle
|
||||
getSocketHandle conn = liftIO $ do
|
||||
h <- socketToHandle conn ReadWriteMode
|
||||
hSetBinaryMode h True
|
||||
hSetNewlineMode h universalNewlineMode
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
module SMPClient where
|
||||
|
||||
import Control.Concurrent
|
||||
import qualified Control.Exception as E
|
||||
import Network.Socket
|
||||
import Server
|
||||
import System.IO
|
||||
import Transmission
|
||||
import Transport
|
||||
|
||||
runSMPClient :: HostName -> ServiceName -> (Handle -> IO a) -> IO a
|
||||
runSMPClient host port client = withSocketsDo $ do
|
||||
addr <- resolve
|
||||
E.bracket (open addr) hClose $ \h -> do
|
||||
line <- getLn h
|
||||
if line == "Welcome to SMP"
|
||||
then client h
|
||||
else error "not connected"
|
||||
where
|
||||
resolve = do
|
||||
let hints = defaultHints {addrSocketType = Stream}
|
||||
head <$> getAddrInfo (Just hints) (Just host) (Just port)
|
||||
open addr = do
|
||||
sock <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr)
|
||||
connect sock $ addrAddress addr
|
||||
getSocketHandle sock
|
||||
|
||||
smpServerTest :: [Transmission] -> IO [Transmission]
|
||||
smpServerTest toSend =
|
||||
E.bracket
|
||||
(forkIO $ runSMPServer "5000")
|
||||
killThread
|
||||
( const $
|
||||
runSMPClient "localhost" "5000" $ \h ->
|
||||
mapM (sendReceive h) toSend
|
||||
)
|
||||
where
|
||||
sendReceive :: Handle -> Transmission -> IO Transmission
|
||||
sendReceive h (signature, (connId, cmd)) = do
|
||||
putLn h signature
|
||||
putLn h connId
|
||||
putLn h $ serializeCommand cmd
|
||||
signature' <- getLn h
|
||||
connId' <- getLn h
|
||||
cmd' <- parseCommand <$> getLn h
|
||||
return (signature', (connId', cmd'))
|
||||
@@ -0,0 +1,2 @@
|
||||
main :: IO ()
|
||||
main = print "hi"
|
||||
Reference in New Issue
Block a user