mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-05-11 08:16:57 +00:00
more syntax validation, read full SEND msgBody
This commit is contained in:
@@ -10,7 +10,7 @@
|
||||
-- {-# LANGUAGE TypeApplications #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
|
||||
module Store where
|
||||
module ConnStore where
|
||||
|
||||
import Control.Concurrent.STM
|
||||
import Data.Map (Map)
|
||||
@@ -74,12 +74,12 @@ runConnStoreSTM = reinterpret $ \case
|
||||
return $ Right conn
|
||||
GetConn Recipient rId -> do
|
||||
db <- input >>= embed . readTVar
|
||||
return $ getConn db rId
|
||||
return $ getRcpConn db rId
|
||||
GetConn Sender sId -> do
|
||||
db <- input >>= embed . readTVar
|
||||
return $ maybeError (getConn db) $ M.lookup sId $ senders db
|
||||
return $ maybeError (getRcpConn db) $ M.lookup sId $ senders db
|
||||
GetConn Broker _ -> do
|
||||
return $ Left InternalError
|
||||
where
|
||||
maybeError = maybe (Left AuthError)
|
||||
getConn db rId = maybeError Right $ M.lookup rId $ connections db
|
||||
getRcpConn db rId = maybeError Right $ M.lookup rId $ connections db
|
||||
+1
-1
@@ -3,11 +3,11 @@
|
||||
|
||||
module EnvSTM where
|
||||
|
||||
import ConnStore
|
||||
import Control.Concurrent.STM
|
||||
import qualified Data.Map as M
|
||||
import qualified Data.Set as S
|
||||
import Network.Socket (ServiceName)
|
||||
import Store
|
||||
import System.IO
|
||||
import Transmission
|
||||
|
||||
|
||||
+37
-31
@@ -8,13 +8,15 @@
|
||||
|
||||
module Main where
|
||||
|
||||
-- import Polysemy
|
||||
import ConnStore
|
||||
import Control.Monad
|
||||
import Control.Monad.IO.Unlift
|
||||
import Control.Monad.Reader
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import EnvSTM
|
||||
import Network.Socket
|
||||
-- import Polysemy
|
||||
import Store
|
||||
import Text.Read
|
||||
import Transmission
|
||||
import Transport
|
||||
import UnliftIO.Async
|
||||
@@ -44,38 +46,41 @@ runClient h = do
|
||||
c <- atomically $ newClient h
|
||||
void $ race (client c) (receive c)
|
||||
|
||||
receive :: MonadIO m => Client -> m ()
|
||||
receive :: MonadUnliftIO m => Client -> m ()
|
||||
receive Client {handle, channel} = forever $ do
|
||||
signature <- getLn handle
|
||||
connId <- getLn handle
|
||||
command <- getLn handle
|
||||
cmdOrError <- parseVerifyTransmission signature connId command
|
||||
cmdOrError <- parseReadVerifyTransmission handle signature connId command
|
||||
atomically $ writeTChan channel cmdOrError
|
||||
|
||||
parseVerifyTransmission :: Monad m => String -> String -> String -> m SomeSigned
|
||||
parseVerifyTransmission _ connId command = do
|
||||
return (Just connId, parseCommand command)
|
||||
|
||||
parseCommand :: String -> SomeCom
|
||||
parseCommand command = case words command of
|
||||
["CREATE", recipientKey] -> rCmd $ CREATE recipientKey
|
||||
["SUB"] -> rCmd SUB
|
||||
["SECURE", senderKey] -> rCmd $ SECURE senderKey
|
||||
["DELMSG", msgId] -> rCmd $ DELMSG msgId
|
||||
["SUSPEND"] -> rCmd SUSPEND
|
||||
["DELETE"] -> rCmd DELETE
|
||||
["SEND", msgBody] -> SomeCom SSender $ SEND msgBody
|
||||
"CREATE" : _ -> err SYNTAX
|
||||
"SUB" : _ -> err SYNTAX
|
||||
"SECURE" : _ -> err SYNTAX
|
||||
"DELMSG" : _ -> err SYNTAX
|
||||
"SUSPEND" : _ -> err SYNTAX
|
||||
"DELETE" : _ -> err SYNTAX
|
||||
"SEND" : _ -> err SYNTAX
|
||||
_ -> err CMD
|
||||
where
|
||||
rCmd = SomeCom SRecipient
|
||||
err t = SomeCom SBroker $ ERROR t
|
||||
parseReadVerifyTransmission :: MonadUnliftIO m => Handle -> String -> String -> String -> m SomeSigned
|
||||
parseReadVerifyTransmission h signature connId command = do
|
||||
let cmd = parseCommand command
|
||||
cmd' <- case cmd of
|
||||
Cmd SBroker _ -> return cmd
|
||||
Cmd _ (CREATE _) ->
|
||||
return
|
||||
if signature == "" && connId == ""
|
||||
then cmd
|
||||
else smpError SYNTAX
|
||||
Cmd _ (SEND msgBody) ->
|
||||
if connId == ""
|
||||
then return $ smpError SYNTAX
|
||||
else case B.unpack msgBody of
|
||||
':' : body -> return . smpSend $ B.pack body
|
||||
sizeStr -> case readMaybe sizeStr :: Maybe Int of
|
||||
Just size -> do
|
||||
body <- getBytes h size
|
||||
s <- getLn h
|
||||
return if s == "" then smpSend body else smpError SYNTAX
|
||||
Nothing -> return $ smpError SYNTAX
|
||||
Cmd _ _ ->
|
||||
return
|
||||
if signature == "" || connId == ""
|
||||
then smpError SYNTAX
|
||||
else cmd
|
||||
return (Just connId, cmd')
|
||||
|
||||
client :: MonadIO m => Client -> m ()
|
||||
client Client {handle, channel} = loop
|
||||
@@ -83,8 +88,9 @@ client Client {handle, channel} = loop
|
||||
loop = forever $ do
|
||||
(_, cmdOrErr) <- atomically $ readTChan channel
|
||||
let response = case cmdOrErr of
|
||||
SomeCom SRecipient _ -> "OK"
|
||||
SomeCom SSender _ -> "OK"
|
||||
SomeCom SBroker (ERROR t) -> "ERROR " ++ show t
|
||||
Cmd SRecipient _ -> "OK"
|
||||
Cmd SSender _ -> "OK"
|
||||
Cmd SBroker (ERROR t) -> "ERROR " ++ show t
|
||||
_ -> "ERROR INTERNAL"
|
||||
putLn handle response
|
||||
liftIO $ print cmdOrErr
|
||||
|
||||
+54
-17
@@ -1,40 +1,77 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE InstanceSigs #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
{-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-}
|
||||
|
||||
module Transmission where
|
||||
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Singletons.TH
|
||||
|
||||
$( singletons
|
||||
[d|
|
||||
data Party = Broker | Recipient | Sender
|
||||
deriving (Show)
|
||||
|]
|
||||
)
|
||||
|
||||
type Transmission (a :: Party) = (Signed a, Maybe Signature)
|
||||
|
||||
type Signed (a :: Party) = (Maybe ConnId, Com a)
|
||||
type Signed (a :: Party) = (Maybe ConnId, Command a)
|
||||
|
||||
data SomeCom where
|
||||
SomeCom :: Sing a -> Com a -> SomeCom
|
||||
data Cmd where
|
||||
Cmd :: Sing a -> Command a -> Cmd
|
||||
|
||||
type SomeSigned = (Maybe ConnId, SomeCom)
|
||||
deriving instance Show Cmd
|
||||
|
||||
data Com (a :: Party) where
|
||||
CREATE :: RecipientKey -> Com Recipient
|
||||
SECURE :: SenderKey -> Com Recipient
|
||||
DELMSG :: MsgId -> Com Recipient
|
||||
SUB :: Com Recipient
|
||||
SUSPEND :: Com Recipient
|
||||
DELETE :: Com Recipient
|
||||
SEND :: MsgBody -> Com Sender
|
||||
MSG :: MsgId -> Timestamp -> MsgBody -> Com Broker
|
||||
CONN :: SenderId -> RecipientId -> Com Broker
|
||||
ERROR :: ErrorType -> Com Broker
|
||||
OK :: Com Broker
|
||||
type SomeSigned = (Maybe ConnId, Cmd)
|
||||
|
||||
data Command (a :: Party) where
|
||||
CREATE :: RecipientKey -> Command Recipient
|
||||
SECURE :: SenderKey -> Command Recipient
|
||||
DELMSG :: MsgId -> Command Recipient
|
||||
SUB :: Command Recipient
|
||||
SUSPEND :: Command Recipient
|
||||
DELETE :: Command Recipient
|
||||
SEND :: MsgBody -> Command Sender
|
||||
MSG :: MsgId -> Timestamp -> MsgBody -> Command Broker
|
||||
CONN :: SenderId -> RecipientId -> Command Broker
|
||||
ERROR :: ErrorType -> Command Broker
|
||||
OK :: Command Broker
|
||||
|
||||
deriving instance Show (Command a)
|
||||
|
||||
parseCommand :: String -> Cmd
|
||||
parseCommand command = case words command of
|
||||
["CREATE", recipientKey] -> rCmd $ CREATE recipientKey
|
||||
["SUB"] -> rCmd SUB
|
||||
["SECURE", senderKey] -> rCmd $ SECURE senderKey
|
||||
["DELMSG", msgId] -> rCmd $ DELMSG msgId
|
||||
["SUSPEND"] -> rCmd SUSPEND
|
||||
["DELETE"] -> rCmd DELETE
|
||||
["SEND", msgBody] -> smpSend $ B.pack msgBody
|
||||
"CREATE" : _ -> smpError SYNTAX
|
||||
"SUB" : _ -> smpError SYNTAX
|
||||
"SECURE" : _ -> smpError SYNTAX
|
||||
"DELMSG" : _ -> smpError SYNTAX
|
||||
"SUSPEND" : _ -> smpError SYNTAX
|
||||
"DELETE" : _ -> smpError SYNTAX
|
||||
"SEND" : _ -> smpError SYNTAX
|
||||
_ -> smpError CMD
|
||||
where
|
||||
rCmd = Cmd SRecipient
|
||||
|
||||
smpError :: ErrorType -> Cmd
|
||||
smpError = Cmd SBroker . ERROR
|
||||
|
||||
smpSend :: MsgBody -> Cmd
|
||||
smpSend = Cmd SSender . SEND
|
||||
|
||||
type Encoded = String
|
||||
|
||||
@@ -56,6 +93,6 @@ type MsgId = Encoded
|
||||
|
||||
type Timestamp = Encoded
|
||||
|
||||
type MsgBody = Encoded
|
||||
type MsgBody = B.ByteString
|
||||
|
||||
data ErrorType = CMD | SYNTAX | AUTH | INTERNAL deriving (Show)
|
||||
|
||||
@@ -3,7 +3,9 @@
|
||||
module Transport where
|
||||
|
||||
import Control.Monad.IO.Class
|
||||
import Control.Monad.IO.Unlift
|
||||
import Control.Monad.Reader
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import EnvSTM
|
||||
import Network.Socket
|
||||
import System.IO
|
||||
@@ -36,3 +38,6 @@ putLn h = liftIO . hPutStrLn h
|
||||
|
||||
getLn :: MonadIO m => Handle -> m String
|
||||
getLn = liftIO . hGetLine
|
||||
|
||||
getBytes :: MonadUnliftIO m => Handle -> Int -> m B.ByteString
|
||||
getBytes h = liftIO . B.hGet h
|
||||
|
||||
Reference in New Issue
Block a user