diff --git a/package.yaml b/package.yaml index 68e68ac4c..a469396cb 100644 --- a/package.yaml +++ b/package.yaml @@ -13,10 +13,26 @@ extra-source-files: dependencies: - base >= 4.7 && < 5 + - async - bytestring + - containers - network + - polysemy + - singletons + - stm executables: simplex-messaging: source-dirs: src main: Main.hs + +ghc-options: + - -haddock + - -O2 + - -Wall + - -Wcompat + - -Werror=incomplete-patterns + - -Wredundant-constraints + - -Wincomplete-record-updates + - -Wincomplete-uni-patterns + - -Wunused-type-patterns diff --git a/src/EnvStm.hs b/src/EnvStm.hs new file mode 100644 index 000000000..5e5ef7012 --- /dev/null +++ b/src/EnvStm.hs @@ -0,0 +1,37 @@ +{-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE NamedFieldPuns #-} + +module EnvStm where + +import Control.Concurrent.STM +import qualified Data.Map as M +import qualified Data.Set as S +import Store +import System.IO +import Transmission + +data Env = Env + { port :: String, + server :: TVar Server, + connStore :: TVar ConnStoreData + } + +data Server = Server + { clients :: S.Set Client, + connections :: M.Map RecipientId Client + } + +data Client = Client + { handle :: Handle, + connections :: S.Set RecipientId, + channel :: TChan SomeSigned + } + +newServer :: STM (TVar Server) +newServer = newTVar $ Server {clients = S.empty, connections = M.empty} + +newEnv :: String -> STM Env +newEnv port = do + srv <- newServer + st <- newConnStore + return Env {port, server = srv, connStore = st} diff --git a/src/Main.hs b/src/Main.hs index ae9e7228c..232e3764f 100644 --- a/src/Main.hs +++ b/src/Main.hs @@ -1,45 +1,102 @@ +{-# LANGUAGE BlockArguments #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} + module Main where import Control.Concurrent +import Control.Concurrent.Async +import Control.Concurrent.STM import qualified Control.Exception as E import Control.Monad +import Data.Function ((&)) +import Data.Map (Map) +import qualified Data.Map as M +import Data.Set (Set) +import qualified Data.Set as S +import EnvStm import Network.Socket +import Polysemy +import Polysemy.Embed +import Polysemy.Resource +import Store import System.IO +import Transmission +import Transport +newClient :: Handle -> IO Client +newClient h = do + c <- newTChanIO @SomeSigned + return Client {handle = h, connections = S.empty, channel = c} + +main :: IO () main = do - putStrLn $ "Listening on port " ++ port - runTCPServer Nothing port talk + server <- atomically newServer + putStrLn $ "Listening on port " ++ port' + runTCPServer port' $ runClient server -port :: String -port = "5223" +port' :: String +port' = "5223" -runTCPServer :: Maybe HostName -> ServiceName -> (Handle -> IO ()) -> IO () -runTCPServer mhost port server = withSocketsDo $ do - let hints = defaultHints {addrFlags = [AI_PASSIVE], addrSocketType = Stream} - addr : _ <- getAddrInfo (Just hints) mhost (Just port) - E.bracket (open addr) close loop +runTCPServer :: ServiceName -> (Handle -> IO ()) -> IO () +runTCPServer port server = + E.bracket (startTCPServer port) close $ \sock -> forever $ do + h <- acceptTCPConn sock + hPutStrLn h "Welcome\r" + forkFinally (server h) (const $ hClose h) + +runClient :: TVar Server -> Handle -> IO () +runClient server h = do + c <- newClient h + void $ race (client server c) (receive c) + +receive :: Client -> IO () +receive Client {handle, channel} = forever $ do + signature <- hGetLine handle + connId <- hGetLine handle + command <- hGetLine handle + cmdOrError <- parseVerifyTransmission signature connId command + atomically $ writeTChan channel cmdOrError + +parseVerifyTransmission :: String -> String -> String -> IO SomeSigned +parseVerifyTransmission _ connId command = do + return (Just connId, parseCommand command) + +parseCommand :: String -> SomeCom +parseCommand command = case words command of + ["CREATE", recipientKey] -> rCmd $ CREATE recipientKey + ["SUB"] -> rCmd SUB + ["SECURE", senderKey] -> rCmd $ SECURE senderKey + ["DELMSG", msgId] -> rCmd $ DELMSG msgId + ["SUSPEND"] -> rCmd SUSPEND + ["DELETE"] -> rCmd DELETE + ["SEND", msgBody] -> SomeCom SSender $ SEND msgBody + "CREATE" : _ -> error SYNTAX + "SUB" : _ -> error SYNTAX + "SECURE" : _ -> error SYNTAX + "DELMSG" : _ -> error SYNTAX + "SUSPEND" : _ -> error SYNTAX + "DELETE" : _ -> error SYNTAX + "SEND" : _ -> error SYNTAX + _ -> error CMD where - open addr = do - sock <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr) - setSocketOption sock ReuseAddr 1 - withFdSocket sock $ setCloseOnExecIfNeeded - bind sock $ addrAddress addr - listen sock 1024 - return sock - loop sock = forever $ do - (conn, peer) <- accept sock - putStrLn $ "Accepted connection from " ++ show peer - h <- socketToHandle conn ReadWriteMode - hSetBinaryMode h True - hSetBuffering h LineBuffering - hPutStrLn h "Welcome\r" - forkFinally (server h) (const $ hClose h) + rCmd = SomeCom SRecipient + error t = SomeCom SBroker $ ERROR t -talk :: Handle -> IO () -talk h = do - line <- hGetLine h - if line == "end" - then hPutStrLn h "Bye\r" - else do - hPutStrLn h (show (2 * (read line :: Integer)) ++ "\r") - talk h +client :: TVar Server -> Client -> IO () +client server Client {handle, channel} = loop + where + loop = do + (_, cmdOrErr) <- atomically $ readTChan channel + let response = case cmdOrErr of + SomeCom SRecipient _ -> "OK" + SomeCom SSender _ -> "OK" + SomeCom SBroker (ERROR t) -> "ERROR " ++ show t + _ -> "ERROR INTERNAL" + hPutStrLn handle response + loop diff --git a/src/Store.hs b/src/Store.hs new file mode 100644 index 000000000..52d6af751 --- /dev/null +++ b/src/Store.hs @@ -0,0 +1,85 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +-- {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} + +module Store where + +import Control.Concurrent.STM +import Data.Map (Map) +import qualified Data.Map as M +import Polysemy +import Polysemy.Input +import Transmission + +type SMPResult a = Either SMPError a + +data SMPError = CmdError | SyntaxError | AuthError | InternalError + +data Connection = Connection + { recipientId :: ConnId, + recipientKey :: PublicKey, + senderId :: ConnId, + senderKey :: Maybe PublicKey, + active :: Bool + } + +data ConnStore m a where + CreateConn :: RecipientKey -> ConnStore m (SMPResult Connection) + GetConn :: Party -> ConnId -> ConnStore m (SMPResult Connection) + +-- SecureConn :: RecipientId -> SenderKey -> ConnStore m (SMPResult ()) +-- SuspendConn :: RecipientId -> ConnStore m (SMPResult ()) +-- DeleteConn :: RecipientId -> ConnStore m (SMPResult ()) + +makeSem ''ConnStore + +data ConnStoreData = ConnStoreData + { connections :: Map RecipientId Connection, + senders :: Map SenderId RecipientId + } + +newConnStore :: STM (TVar ConnStoreData) +newConnStore = newTVar ConnStoreData {connections = M.empty, senders = M.empty} + +newConnection :: RecipientKey -> Connection +newConnection rKey = + Connection + { recipientId = "1", + recipientKey = rKey, + senderId = "2", + senderKey = Nothing, + active = True + } + +runConnStoreSTM :: Member (Embed STM) r => Sem (ConnStore ': r) a -> Sem (Input (TVar ConnStoreData) ': r) a +runConnStoreSTM = reinterpret $ \case + CreateConn rKey -> do + store <- input + db <- embed $ readTVar store + let conn@Connection {senderId, recipientId} = newConnection rKey + db' = + ConnStoreData + { connections = M.insert recipientId conn (connections db), + senders = M.insert senderId recipientId (senders db) + } + embed $ writeTVar store db' + return $ Right conn + GetConn Recipient rId -> do + db <- input >>= embed . readTVar + return $ getConn db rId + GetConn Sender sId -> do + db <- input >>= embed . readTVar + return $ maybeError (getConn db) $ M.lookup sId $ senders db + GetConn Broker _ -> do + return $ Left InternalError + where + maybeError = maybe (Left AuthError) + getConn db rId = maybeError Right $ M.lookup rId $ connections db diff --git a/src/Transmission.hs b/src/Transmission.hs new file mode 100644 index 000000000..f3ae76f10 --- /dev/null +++ b/src/Transmission.hs @@ -0,0 +1,61 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeFamilies #-} +{-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-} + +module Transmission where + +import Data.Singletons.TH + +$( singletons + [d| + data Party = Broker | Recipient | Sender + |] + ) + +type Transmission (a :: Party) = (Signed a, Maybe Signature) + +type Signed (a :: Party) = (Maybe ConnId, Com a) + +data SomeCom where + SomeCom :: Sing a -> Com a -> SomeCom + +type SomeSigned = (Maybe ConnId, SomeCom) + +data Com (a :: Party) where + CREATE :: RecipientKey -> Com Recipient + SECURE :: SenderKey -> Com Recipient + DELMSG :: MsgId -> Com Recipient + SUB :: Com Recipient + SUSPEND :: Com Recipient + DELETE :: Com Recipient + SEND :: MsgBody -> Com Sender + MSG :: MsgId -> Timestamp -> MsgBody -> Com Broker + CONN :: SenderId -> RecipientId -> Com Broker + ERROR :: ErrorType -> Com Broker + OK :: Com Broker + +type Encoded = String + +type PublicKey = Encoded + +type Signature = Encoded + +type RecipientKey = PublicKey + +type SenderKey = PublicKey + +type RecipientId = ConnId + +type SenderId = ConnId + +type ConnId = Encoded + +type MsgId = Encoded + +type Timestamp = Encoded + +type MsgBody = Encoded + +data ErrorType = CMD | SYNTAX | AUTH | INTERNAL deriving (Show) diff --git a/src/Transport.hs b/src/Transport.hs index 9c0a0ca0d..eabd9ab8d 100644 --- a/src/Transport.hs +++ b/src/Transport.hs @@ -1,46 +1,51 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeOperators #-} module Transport where -data Party = Broker | Recipient | Sender +import Network.Socket +import Polysemy +import Polysemy.Reader +import Polysemy.Resource +import System.IO -type Transmission (a :: Party) = (Signed a, Signature) +data Transport h m a where + TReadLn :: h -> Transport h m String + TWriteLn :: h -> String -> Transport h m () -type Signed (a :: Party) = (ConnId, Com a) +makeSem ''Transport -data Com (a :: Party) where - CREATE :: RecipientKey -> Com Recipient - SECURE :: SenderKey -> Com Recipient - DELMSG :: MsgId -> Com Recipient - SUB :: Com Recipient - SUSPEND :: Com Recipient - DELETE :: Com Recipient - SEND :: MsgBody -> Com Sender - MSG :: MsgId -> Timestamp -> MsgBody -> Com Broker - CONN :: SenderId -> RecipientId -> Com Broker - ERROR :: ErrorType -> Com Broker - OK :: Com Broker +type TCPTransport = Transport Handle -type Encoded = String +startTCPServer :: String -> IO Socket +startTCPServer port = withSocketsDo $ do + let hints = defaultHints {addrFlags = [AI_PASSIVE], addrSocketType = Stream} + addr <- head <$> getAddrInfo (Just hints) Nothing (Just port) + sock <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr) + setSocketOption sock ReuseAddr 1 + withFdSocket sock setCloseOnExecIfNeeded + bind sock $ addrAddress addr + listen sock 1024 + return sock -type Signature = Encoded +acceptTCPConn :: Socket -> IO Handle +acceptTCPConn sock = do + (conn, peer) <- accept sock + putStrLn $ "Accepted connection from " ++ show peer + h <- socketToHandle conn ReadWriteMode + hSetBinaryMode h True + hSetNewlineMode h universalNewlineMode + hSetBuffering h LineBuffering + return h -type RecipientKey = Encoded - -type SenderKey = Encoded - -type ConnId = Encoded - -type SenderId = Encoded - -type RecipientId = Encoded - -type MsgId = Encoded - -type Timestamp = Encoded - -type MsgBody = Encoded - -data ErrorType = CMD | SYNTAX | AUTH | INTERNAL +runTCPTransportIO :: Member (Embed IO) r => Sem (TCPTransport ': r) a -> Sem r a +runTCPTransportIO = interpret $ \case + TReadLn h -> embed $ hGetLine h + TWriteLn h str -> embed $ hPutStr h $ str ++ "\r\n"