From e75a3c44dfeb16a8387546d2d8d13236499b6540 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Tue, 13 Oct 2020 12:43:44 +0100 Subject: [PATCH] test client (WIP) --- package.yaml | 9 ++++ src/Main.hs | 116 ++----------------------------------------- src/Server.hs | 117 ++++++++++++++++++++++++++++++++++++++++++++ src/Transmission.hs | 11 +++-- src/Transport.hs | 4 ++ tests/SMPClient.hs | 46 +++++++++++++++++ tests/Test.hs | 2 + 7 files changed, 188 insertions(+), 117 deletions(-) create mode 100644 src/Server.hs create mode 100644 tests/SMPClient.hs create mode 100644 tests/Test.hs diff --git a/package.yaml b/package.yaml index de7cce1e7..0b66f2d39 100644 --- a/package.yaml +++ b/package.yaml @@ -29,6 +29,15 @@ executables: source-dirs: src main: Main.hs +library: + source-dirs: src + +tests: + smp-server-test: + source-dirs: tests + main: Test.hs + dependencies: simplex-messaging + ghc-options: - -haddock - -O2 diff --git a/src/Main.hs b/src/Main.hs index 10c0a1c21..0a7611ebb 100644 --- a/src/Main.hs +++ b/src/Main.hs @@ -1,122 +1,12 @@ -{-# LANGUAGE BlockArguments #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DuplicateRecordFields #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE NamedFieldPuns #-} -{-# LANGUAGE ScopedTypeVariables #-} - module Main where --- import Polysemy -import ConnStore -import Control.Monad -import Control.Monad.IO.Unlift -import Control.Monad.Reader -import qualified Data.ByteString.Char8 as B -import Env.STM -import Network.Socket -import Text.Read -import Transmission -import Transport -import UnliftIO.Async -import UnliftIO.Concurrent -import qualified UnliftIO.Exception as E -import UnliftIO.IO -import UnliftIO.STM +import Network.Socket (ServiceName) +import Server (runSMPServer) port :: ServiceName port = "5223" main :: IO () main = do - env <- atomically $ newEnv port putStrLn $ "Listening on port " ++ port - runReaderT (runTCPServer runClient) env - -runTCPServer :: (MonadReader Env m, MonadUnliftIO m) => (Handle -> m ()) -> m () -runTCPServer server = - E.bracket startTCPServer (liftIO . close) $ \sock -> forever $ do - h <- acceptTCPConn sock - putLn h "Welcome" - forkFinally (server h) (const $ hClose h) - -runClient :: (MonadUnliftIO m, MonadReader Env m) => Handle -> m () -runClient h = do - c <- atomically $ newClient h - void $ race (client c) (receive c) - -receive :: (MonadUnliftIO m, MonadReader Env m) => Client -> m () -receive Client {handle, channel} = forever $ do - signature <- getLn handle - connId <- getLn handle - command <- getLn handle - cmdOrError <- parseReadVerifyTransmission handle signature connId command - atomically $ writeTChan channel cmdOrError - -parseReadVerifyTransmission :: forall m. (MonadUnliftIO m, MonadReader Env m) => Handle -> String -> String -> String -> m SomeSigned -parseReadVerifyTransmission h signature connId command = do - let cmd = parseCommand command - cmd' <- case cmd of - Cmd SBroker _ -> return cmd - Cmd _ (CREATE _) -> signed False cmd errHasCredentials - Cmd _ (SEND msgBody) -> getSendMsgBody msgBody - Cmd _ _ -> verifyConnSignature cmd -- signed True cmd errNoCredentials - return (Just connId, cmd') - where - signed :: Bool -> Cmd -> Int -> m Cmd - signed isSigned cmd errCode = - return - if isSigned == (signature /= "") && isSigned == (connId /= "") - then cmd - else syntaxError errCode - getSendMsgBody :: MsgBody -> m Cmd - getSendMsgBody msgBody = - if connId == "" - then return $ syntaxError errNoConnectionId - else case B.unpack msgBody of - ':' : body -> return . smpSend $ B.pack body - sizeStr -> case readMaybe sizeStr :: Maybe Int of - Just size -> do - body <- getBytes h size - s <- getLn h - return if s == "" then smpSend body else syntaxError errMessageBodySize - Nothing -> return $ syntaxError errMessageBody - verifyConnSignature :: Cmd -> m Cmd - verifyConnSignature cmd@(Cmd party _) = - if null signature || null connId - then return $ syntaxError errNoCredentials - else do - store <- asks connStore - getConn store party connId >>= \case - Right Connection {recipientKey, senderKey} -> do - res <- case party of - SRecipient -> verifySignature recipientKey - SSender -> case senderKey of - Just key -> verifySignature key - Nothing -> return False - SBroker -> return False - if res then return cmd else return $ smpError AUTH - Left err -> return $ smpError err - verifySignature :: Encoded -> m Bool - verifySignature key = return $ signature == key - -client :: (MonadUnliftIO m, MonadReader Env m) => Client -> m () -client Client {handle, channel} = loop - where - loop = forever $ do - (_, cmdOrErr) <- atomically $ readTChan channel - response <- case cmdOrErr of - Cmd SRecipient (CREATE recipientKey) -> do - store <- asks connStore - conn <- createConn store recipientKey - case conn of - Right Connection {recipientId, senderId} -> return $ "CONN " ++ recipientId ++ " " ++ senderId - Left e -> return $ "ERROR " ++ show e - Cmd SRecipient _ -> return "OK" - Cmd SSender _ -> return "OK" - Cmd SBroker (ERROR e) -> return $ "ERROR " ++ show e - _ -> return "ERROR INTERNAL" - putLn handle response - liftIO $ print cmdOrErr + runSMPServer port diff --git a/src/Server.hs b/src/Server.hs new file mode 100644 index 000000000..6db116fe5 --- /dev/null +++ b/src/Server.hs @@ -0,0 +1,117 @@ +{-# LANGUAGE BlockArguments #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE ScopedTypeVariables #-} + +module Server (runSMPServer) where + +import ConnStore +import Control.Monad +import Control.Monad.IO.Unlift +import Control.Monad.Reader +import qualified Data.ByteString.Char8 as B +import Env.STM +import Network.Socket +import Text.Read +import Transmission +import Transport +import UnliftIO.Async +import UnliftIO.Concurrent +import qualified UnliftIO.Exception as E +import UnliftIO.IO +import UnliftIO.STM + +runSMPServer :: ServiceName -> IO () +runSMPServer port = do + env <- atomically $ newEnv port + runReaderT (runTCPServer runClient) env + +runTCPServer :: (MonadReader Env m, MonadUnliftIO m) => (Handle -> m ()) -> m () +runTCPServer server = + E.bracket startTCPServer (liftIO . close) $ \sock -> forever $ do + h <- acceptTCPConn sock + putLn h "Welcome to SMP" + forkFinally (server h) (const $ hClose h) + +runClient :: (MonadUnliftIO m, MonadReader Env m) => Handle -> m () +runClient h = do + c <- atomically $ newClient h + void $ race (client c) (receive c) + +receive :: (MonadUnliftIO m, MonadReader Env m) => Client -> m () +receive Client {handle, channel} = forever $ do + signature <- getLn handle + connId <- getLn handle + command <- getLn handle + cmdOrError <- parseReadVerifyTransmission handle signature connId command + atomically $ writeTChan channel cmdOrError + +parseReadVerifyTransmission :: forall m. (MonadUnliftIO m, MonadReader Env m) => Handle -> String -> String -> String -> m SomeSigned +parseReadVerifyTransmission h signature connId command = do + let cmd = parseCommand command + cmd' <- case cmd of + Cmd SBroker _ -> return cmd + Cmd _ (CREATE _) -> signed False cmd errHasCredentials + Cmd _ (SEND msgBody) -> getSendMsgBody msgBody + Cmd _ _ -> verifyConnSignature cmd -- signed True cmd errNoCredentials + return (connId, cmd') + where + signed :: Bool -> Cmd -> Int -> m Cmd + signed isSigned cmd errCode = + return + if isSigned == (signature /= "") && isSigned == (connId /= "") + then cmd + else syntaxError errCode + getSendMsgBody :: MsgBody -> m Cmd + getSendMsgBody msgBody = + if connId == "" + then return $ syntaxError errNoConnectionId + else case B.unpack msgBody of + ':' : body -> return . smpSend $ B.pack body + sizeStr -> case readMaybe sizeStr :: Maybe Int of + Just size -> do + body <- getBytes h size + s <- getLn h + return if s == "" then smpSend body else syntaxError errMessageBodySize + Nothing -> return $ syntaxError errMessageBody + verifyConnSignature :: Cmd -> m Cmd + verifyConnSignature cmd@(Cmd party _) = + if null signature || null connId + then return $ syntaxError errNoCredentials + else do + store <- asks connStore + getConn store party connId >>= \case + Right Connection {recipientKey, senderKey} -> do + res <- case party of + SRecipient -> verifySignature recipientKey + SSender -> case senderKey of + Just key -> verifySignature key + Nothing -> return False + SBroker -> return False + if res then return cmd else return $ smpError AUTH + Left err -> return $ smpError err + verifySignature :: Encoded -> m Bool + verifySignature key = return $ signature == key + +client :: (MonadUnliftIO m, MonadReader Env m) => Client -> m () +client Client {handle, channel} = loop + where + loop = forever $ do + (_, cmdOrErr) <- atomically $ readTChan channel + response <- case cmdOrErr of + Cmd SRecipient (CREATE recipientKey) -> do + store <- asks connStore + conn <- createConn store recipientKey + case conn of + Right Connection {recipientId, senderId} -> return $ "CONN " ++ recipientId ++ " " ++ senderId + Left e -> return $ "ERROR " ++ show e + Cmd SRecipient _ -> return "OK" + Cmd SSender _ -> return "OK" + Cmd SBroker (ERROR e) -> return $ "ERROR " ++ show e + _ -> return "ERROR INTERNAL" + putLn handle response + liftIO $ print cmdOrErr diff --git a/src/Transmission.hs b/src/Transmission.hs index 1e9fa925c..72986136f 100644 --- a/src/Transmission.hs +++ b/src/Transmission.hs @@ -21,16 +21,16 @@ $( singletons |] ) -type Transmission (a :: Party) = (Signed a, Maybe Signature) - -type Signed (a :: Party) = (Maybe ConnId, Command a) +type Signed (a :: Party) = (ConnId, Command a) data Cmd where Cmd :: Sing a -> Command a -> Cmd deriving instance Show Cmd -type SomeSigned = (Maybe ConnId, Cmd) +type SomeSigned = (ConnId, Cmd) + +type Transmission = (Signature, SomeSigned) data Command (a :: Party) where CREATE :: RecipientKey -> Command Recipient @@ -68,6 +68,9 @@ parseCommand command = case words command of err = syntaxError errBadParameters rCmd = Cmd SRecipient +serializeCommand :: Cmd -> String +serializeCommand _ = "TODO" + syntaxError :: Int -> Cmd syntaxError err = smpError $ SYNTAX err diff --git a/src/Transport.hs b/src/Transport.hs index c9230cb25..e2749b1ce 100644 --- a/src/Transport.hs +++ b/src/Transport.hs @@ -27,6 +27,10 @@ acceptTCPConn :: MonadIO m => Socket -> m Handle acceptTCPConn sock = liftIO $ do (conn, peer) <- accept sock putStrLn $ "Accepted connection from " ++ show peer + getSocketHandle conn + +getSocketHandle :: MonadIO m => Socket -> m Handle +getSocketHandle conn = liftIO $ do h <- socketToHandle conn ReadWriteMode hSetBinaryMode h True hSetNewlineMode h universalNewlineMode diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs new file mode 100644 index 000000000..82b02eb9d --- /dev/null +++ b/tests/SMPClient.hs @@ -0,0 +1,46 @@ +module SMPClient where + +import Control.Concurrent +import qualified Control.Exception as E +import Network.Socket +import Server +import System.IO +import Transmission +import Transport + +runSMPClient :: HostName -> ServiceName -> (Handle -> IO a) -> IO a +runSMPClient host port client = withSocketsDo $ do + addr <- resolve + E.bracket (open addr) hClose $ \h -> do + line <- getLn h + if line == "Welcome to SMP" + then client h + else error "not connected" + where + resolve = do + let hints = defaultHints {addrSocketType = Stream} + head <$> getAddrInfo (Just hints) (Just host) (Just port) + open addr = do + sock <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr) + connect sock $ addrAddress addr + getSocketHandle sock + +smpServerTest :: [Transmission] -> IO [Transmission] +smpServerTest toSend = + E.bracket + (forkIO $ runSMPServer "5000") + killThread + ( const $ + runSMPClient "localhost" "5000" $ \h -> + mapM (sendReceive h) toSend + ) + where + sendReceive :: Handle -> Transmission -> IO Transmission + sendReceive h (signature, (connId, cmd)) = do + putLn h signature + putLn h connId + putLn h $ serializeCommand cmd + signature' <- getLn h + connId' <- getLn h + cmd' <- parseCommand <$> getLn h + return (signature', (connId', cmd')) diff --git a/tests/Test.hs b/tests/Test.hs new file mode 100644 index 000000000..09048bdb7 --- /dev/null +++ b/tests/Test.hs @@ -0,0 +1,2 @@ +main :: IO () +main = print "hi"