mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-05-16 11:35:21 +00:00
SECURE command, tests
This commit is contained in:
+5
-7
@@ -7,8 +7,6 @@ module ConnStore where
|
||||
import Data.Singletons
|
||||
import Transmission
|
||||
|
||||
type SMPResult a = Either ErrorType a
|
||||
|
||||
data Connection = Connection
|
||||
{ recipientId :: ConnId,
|
||||
recipientKey :: PublicKey,
|
||||
@@ -18,12 +16,12 @@ data Connection = Connection
|
||||
}
|
||||
|
||||
class MonadConnStore s m where
|
||||
createConn :: s -> RecipientKey -> m (SMPResult Connection)
|
||||
getConn :: s -> Sing (a :: Party) -> ConnId -> m (SMPResult Connection)
|
||||
createConn :: s -> RecipientKey -> m (Either ErrorType Connection)
|
||||
getConn :: s -> Sing (a :: Party) -> ConnId -> m (Either ErrorType Connection)
|
||||
secureConn :: s -> RecipientId -> SenderKey -> m (Either ErrorType ())
|
||||
|
||||
-- secureConn :: RecipientId -> SenderKey -> m (SMPResult ())
|
||||
-- suspendConn :: RecipientId -> m (SMPResult ())
|
||||
-- deleteConn :: RecipientId -> m (SMPResult ())
|
||||
-- suspendConn :: RecipientId -> m (Either ErrorType ())
|
||||
-- deleteConn :: RecipientId -> m (Either ErrorType ())
|
||||
|
||||
newConnection :: RecipientKey -> Connection
|
||||
newConnection rKey =
|
||||
|
||||
+24
-13
@@ -1,7 +1,7 @@
|
||||
{-# LANGUAGE BlockArguments #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
@@ -26,27 +26,38 @@ newConnStore :: STM STMConnStore
|
||||
newConnStore = newTVar ConnStoreData {connections = M.empty, senders = M.empty}
|
||||
|
||||
instance MonadUnliftIO m => MonadConnStore STMConnStore m where
|
||||
createConn store rKey = atomically $ do
|
||||
createConn store rKey = atomically do
|
||||
db <- readTVar store
|
||||
let conn@Connection {senderId, recipientId} = newConnection rKey
|
||||
let c@Connection {recipientId = rId, senderId = sId} = newConnection rKey
|
||||
db' =
|
||||
ConnStoreData
|
||||
{ connections = M.insert recipientId conn (connections db),
|
||||
senders = M.insert senderId recipientId (senders db)
|
||||
{ connections = M.insert rId c (connections db),
|
||||
senders = M.insert sId rId (senders db)
|
||||
}
|
||||
writeTVar store db'
|
||||
return $ Right conn
|
||||
getConn store SRecipient rId = atomically $ do
|
||||
return $ Right c
|
||||
|
||||
-- TODO do not return suspended connections
|
||||
getConn store SRecipient rId = atomically do
|
||||
db <- readTVar store
|
||||
return $ getRcpConn db rId
|
||||
getConn store SSender sId = atomically $ do
|
||||
getConn store SSender sId = atomically do
|
||||
db <- readTVar store
|
||||
return $ maybeAuth (getRcpConn db) $ M.lookup sId $ senders db
|
||||
getConn _ SBroker _ = atomically $ do
|
||||
return $ maybe (Left AUTH) (getRcpConn db) $ M.lookup sId $ senders db
|
||||
getConn _ SBroker _ = atomically do
|
||||
return $ Left INTERNAL
|
||||
|
||||
maybeAuth :: (a -> Either ErrorType b) -> Maybe a -> Either ErrorType b
|
||||
maybeAuth = maybe (Left AUTH)
|
||||
secureConn store rId sKey = atomically do
|
||||
db <- readTVar store
|
||||
let conn = getRcpConn db rId
|
||||
either (return . Left) (updateConn db) conn
|
||||
where
|
||||
updateConn db c = case senderKey c of
|
||||
Just _ -> return $ Left AUTH
|
||||
Nothing -> do
|
||||
let db' = db {connections = M.insert rId c {senderKey = Just sKey} (connections db)}
|
||||
writeTVar store db'
|
||||
return $ Right ()
|
||||
|
||||
getRcpConn :: ConnStoreData -> RecipientId -> Either ErrorType Connection
|
||||
getRcpConn db rId = maybeAuth Right $ M.lookup rId $ connections db
|
||||
getRcpConn db rId = maybe (Left AUTH) Right . M.lookup rId $ connections db
|
||||
|
||||
+31
-14
@@ -26,12 +26,12 @@ import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.IO
|
||||
import UnliftIO.STM
|
||||
|
||||
runSMPServer :: ServiceName -> Natural -> IO ()
|
||||
runSMPServer :: MonadUnliftIO m => ServiceName -> Natural -> m ()
|
||||
runSMPServer port queueSize = do
|
||||
env <- atomically $ newEnv port queueSize
|
||||
runReaderT (runTCPServer runClient) env
|
||||
|
||||
runTCPServer :: (MonadReader Env m, MonadUnliftIO m) => (Handle -> m ()) -> m ()
|
||||
runTCPServer :: (MonadUnliftIO m, MonadReader Env m) => (Handle -> m ()) -> m ()
|
||||
runTCPServer server =
|
||||
E.bracket startTCPServer (liftIO . close) $ \sock -> forever $ do
|
||||
h <- acceptTCPConn sock
|
||||
@@ -58,7 +58,7 @@ receive h Client {queue} = forever $ do
|
||||
verifyTransmission :: forall m. (MonadUnliftIO m, MonadReader Env m) => Signature -> ConnId -> Cmd -> m Signed
|
||||
verifyTransmission signature connId cmd = do
|
||||
(connId,) <$> case cmd of
|
||||
Cmd SBroker _ -> return $ smpErr INTERNAL
|
||||
Cmd SBroker _ -> return $ smpErr INTERNAL -- it can only be client command, because `fromClient` was used
|
||||
Cmd SRecipient (CREATE _) -> return cmd
|
||||
Cmd SRecipient _ -> withConnection SRecipient $ verifySignature . recipientKey
|
||||
Cmd SSender (SEND _) -> withConnection SSender $ verifySend . senderKey
|
||||
@@ -80,18 +80,35 @@ verifyTransmission signature connId cmd = do
|
||||
smpErr e = Cmd SBroker $ ERROR e
|
||||
authErr = smpErr AUTH
|
||||
|
||||
client :: (MonadUnliftIO m, MonadReader Env m) => Handle -> Client -> m ()
|
||||
client :: forall m. (MonadUnliftIO m, MonadReader Env m) => Handle -> Client -> m ()
|
||||
client h Client {queue} = loop
|
||||
where
|
||||
loop = forever $ do
|
||||
(connId, cmd) <- atomically $ readTBQueue queue
|
||||
response <- case cmd of
|
||||
Cmd SRecipient (CREATE recipientKey) -> do
|
||||
store <- asks connStore
|
||||
conn <- createConn store recipientKey
|
||||
return . Cmd SBroker $ case conn of
|
||||
Right Connection {recipientId, senderId} -> CONN recipientId senderId
|
||||
Left e -> ERROR e
|
||||
Cmd SBroker _ -> return cmd
|
||||
Cmd _ _ -> return $ Cmd SBroker OK
|
||||
tPut h ("", (connId, response)) -- empty signature
|
||||
signed <- processCommand connId cmd
|
||||
tPut h ("", signed)
|
||||
|
||||
processCommand :: ConnId -> Cmd -> m Signed
|
||||
processCommand connId cmd = do
|
||||
st <- asks connStore
|
||||
case cmd of
|
||||
Cmd SRecipient (CREATE recipientKey) ->
|
||||
either (mkSigned "" . ERROR) connResponce
|
||||
<$> createConn st recipientKey
|
||||
Cmd SRecipient SUB -> do
|
||||
-- TODO message subscription
|
||||
return ok
|
||||
Cmd SRecipient (SECURE senderKey) -> do
|
||||
mkSigned connId . either ERROR (const OK)
|
||||
<$> secureConn st connId senderKey
|
||||
Cmd SBroker _ -> return (connId, cmd)
|
||||
Cmd _ _ -> return ok
|
||||
where
|
||||
ok :: Signed
|
||||
ok = (connId, Cmd SBroker OK)
|
||||
|
||||
mkSigned :: ConnId -> Command 'Broker -> Signed
|
||||
mkSigned cId command = (cId, Cmd SBroker command)
|
||||
|
||||
connResponce :: Connection -> Signed
|
||||
connResponce Connection {recipientId = rId, senderId = sId} = mkSigned rId $ CONN rId sId
|
||||
|
||||
+27
-15
@@ -2,24 +2,22 @@
|
||||
|
||||
module SMPClient where
|
||||
|
||||
import Control.Concurrent
|
||||
import qualified Control.Exception as E
|
||||
import Control.Monad
|
||||
import Control.Monad.IO.Unlift
|
||||
import Network.Socket
|
||||
import Numeric.Natural
|
||||
import Server
|
||||
import System.IO
|
||||
import Transmission
|
||||
import Transport
|
||||
import UnliftIO.Concurrent
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.IO
|
||||
|
||||
runSMPClient :: HostName -> ServiceName -> (Handle -> IO a) -> IO a
|
||||
runSMPClient host port client = withSocketsDo $ do
|
||||
startTCPClient :: MonadIO m => HostName -> ServiceName -> m Handle
|
||||
startTCPClient host port = liftIO . withSocketsDo $ do
|
||||
threadDelay 1 -- TODO hack: thread delay for SMP server to start
|
||||
addr <- resolve
|
||||
E.bracket (open addr) hClose $ \h -> do
|
||||
line <- getLn h
|
||||
if line == "Welcome to SMP"
|
||||
then client h
|
||||
else error "not connected"
|
||||
open addr
|
||||
where
|
||||
resolve = do
|
||||
let hints = defaultHints {addrSocketType = Stream}
|
||||
@@ -29,6 +27,18 @@ runSMPClient host port client = withSocketsDo $ do
|
||||
connect sock $ addrAddress addr
|
||||
getSocketHandle sock
|
||||
|
||||
runTCPClient :: MonadUnliftIO m => HostName -> ServiceName -> (Handle -> m a) -> m a
|
||||
runTCPClient host port client = do
|
||||
E.bracket (startTCPClient host port) (liftIO . hClose) client
|
||||
|
||||
runSMPClient :: MonadUnliftIO m => HostName -> ServiceName -> (Handle -> m a) -> m a
|
||||
runSMPClient host port client =
|
||||
runTCPClient host port $ \h -> do
|
||||
line <- getLn h
|
||||
if line == "Welcome to SMP"
|
||||
then client h
|
||||
else error "not connected"
|
||||
|
||||
testPort :: ServiceName
|
||||
testPort = "5000"
|
||||
|
||||
@@ -40,13 +50,15 @@ queueSize = 2
|
||||
|
||||
type TestTransmission = (Signature, ConnId, String)
|
||||
|
||||
smpServerTest :: [TestTransmission] -> IO [TestTransmission]
|
||||
smpServerTest commands =
|
||||
runSmpTest :: MonadUnliftIO m => (Handle -> m a) -> m a
|
||||
runSmpTest test =
|
||||
E.bracket
|
||||
(forkIO $ runSMPServer testPort queueSize)
|
||||
killThread
|
||||
\_ -> runSMPClient "localhost" testPort $
|
||||
\h -> mapM (sendReceive h) commands
|
||||
(liftIO . killThread)
|
||||
\_ -> runSMPClient "localhost" testPort test
|
||||
|
||||
smpServerTest :: [TestTransmission] -> IO [TestTransmission]
|
||||
smpServerTest commands = runSmpTest \h -> mapM (sendReceive h) commands
|
||||
where
|
||||
sendReceive :: Handle -> TestTransmission -> IO TestTransmission
|
||||
sendReceive h t = tPutRaw h t >> tGetRaw h
|
||||
|
||||
+78
-36
@@ -1,47 +1,89 @@
|
||||
{-# LANGUAGE BlockArguments #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
|
||||
import SMPClient
|
||||
import System.IO (Handle)
|
||||
import Test.Hspec
|
||||
import Transmission
|
||||
import Transport
|
||||
|
||||
(>#>) :: [TestTransmission] -> [TestTransmission] -> Expectation
|
||||
commands >#> responses = smpServerTest commands `shouldReturn` responses
|
||||
commands >#> responses = smpServerTest2 commands `shouldReturn` responses
|
||||
|
||||
main :: IO ()
|
||||
main = hspec do
|
||||
describe "SMP syntax" do
|
||||
it "unknown command" $ [("", "123", "HELLO")] >#> [("", "123", "ERROR UNKNOWN")]
|
||||
describe "CREATE" do
|
||||
it "no parameters" $ [("", "", "CREATE")] >#> [("", "", "ERROR SYNTAX 2")]
|
||||
it "many parameters" $ [("", "", "CREATE 1 2")] >#> [("", "", "ERROR SYNTAX 2")]
|
||||
it "has signature" $ [("123", "", "CREATE 123")] >#> [("", "", "ERROR SYNTAX 4")]
|
||||
it "connection ID" $ [("", "1", "CREATE 123")] >#> [("", "1", "ERROR SYNTAX 4")]
|
||||
noParamsSyntaxTest "SUB"
|
||||
oneParamSyntaxTest "SECURE"
|
||||
oneParamSyntaxTest "DELMSG"
|
||||
noParamsSyntaxTest "SUSPEND"
|
||||
noParamsSyntaxTest "DELETE"
|
||||
describe "SEND" do
|
||||
it "valid syntax 1" $ [("123", "1", "SEND :hello")] >#> [("", "1", "ERROR AUTH")]
|
||||
it "valid syntax 2" $ [("123", "1", "SEND 11\nhello there\n")] >#> [("", "1", "ERROR AUTH")]
|
||||
it "no parameters" $ [("123", "1", "SEND")] >#> [("", "1", "ERROR SYNTAX 2")]
|
||||
it "many parameters" $ [("123", "1", "SEND 11 hello")] >#> [("", "1", "ERROR SYNTAX 2")]
|
||||
it "no connection ID" $ [("123", "", "SEND :hello")] >#> [("", "", "ERROR SYNTAX 5")]
|
||||
it "bad message body" $ [("123", "1", "SEND hello")] >#> [("", "1", "ERROR SYNTAX 6")]
|
||||
it "bigger body" $ [("123", "1", "SEND 4\nhello\n")] >#> [("", "1", "ERROR SYNTAX 7")]
|
||||
describe "broker response not allowed" do
|
||||
it "OK" $ [("123", "1", "OK")] >#> [("", "1", "ERROR PROHIBITED")]
|
||||
describe "SMP syntax" syntaxTests
|
||||
fdescribe "SMP connections" connectionTests
|
||||
|
||||
noParamsSyntaxTest :: String -> SpecWith ()
|
||||
noParamsSyntaxTest cmd = describe cmd do
|
||||
it "valid syntax" $ [("123", "1", cmd)] >#> [("", "1", "ERROR AUTH")]
|
||||
it "parameters" $ [("123", "1", cmd ++ " 1")] >#> [("", "1", "ERROR SYNTAX 2")]
|
||||
it "no signature" $ [("", "1", cmd)] >#> [("", "1", "ERROR SYNTAX 3")]
|
||||
it "no connection ID" $ [("123", "", cmd)] >#> [("", "", "ERROR SYNTAX 3")]
|
||||
pattern Resp :: ConnId -> Command 'Broker -> TransmissionOrError
|
||||
pattern Resp connId command = ("", (connId, Right (Cmd SBroker command)))
|
||||
|
||||
oneParamSyntaxTest :: String -> SpecWith ()
|
||||
oneParamSyntaxTest cmd = describe cmd do
|
||||
it "valid syntax" $ [("123", "1", cmd ++ " 456")] >#> [("", "1", "ERROR AUTH")]
|
||||
it "no parameters" $ [("123", "1", cmd)] >#> [("", "1", "ERROR SYNTAX 2")]
|
||||
it "many parameters" $ [("123", "1", cmd ++ " 1 2")] >#> [("", "1", "ERROR SYNTAX 2")]
|
||||
it "no signature" $ [("", "1", cmd ++ " 456")] >#> [("", "1", "ERROR SYNTAX 3")]
|
||||
it "no connection ID" $ [("123", "", cmd ++ " 456")] >#> [("", "", "ERROR SYNTAX 3")]
|
||||
smpExpect :: (Show a, Eq a) => a -> (Handle -> IO a) -> Expectation
|
||||
smpExpect result test = runSmpTest test `shouldReturn` result
|
||||
|
||||
sendRecv :: Handle -> RawTransmission -> IO TransmissionOrError
|
||||
sendRecv h t = tPutRaw h t >> tGet fromServer h
|
||||
|
||||
connectionTests :: SpecWith ()
|
||||
connectionTests = do
|
||||
it "CREATE and SECURE connection, SEND messages (no delivery yet)" $
|
||||
smpExpect True \h -> do
|
||||
Resp rId (CONN rId' sId) <- sendRecv h ("", "", "CREATE 123")
|
||||
-- should allow unsigned
|
||||
Resp sId' OK <- sendRecv h ("", sId, "SEND :hello")
|
||||
-- should not allow signed
|
||||
Resp sId'' (ERROR AUTH) <- sendRecv h ("456", sId, "SEND :hello")
|
||||
-- shoud not secure with wrong signature (password atm)
|
||||
Resp _ (ERROR AUTH) <- sendRecv h ("1234", rId, "SECURE 456")
|
||||
-- shoud not secure with sender's ID
|
||||
Resp _ (ERROR AUTH) <- sendRecv h ("123", sId, "SECURE 456")
|
||||
-- secure connection
|
||||
Resp rId'' OK <- sendRecv h ("123", rId, "SECURE 456")
|
||||
-- should not secure if already secured
|
||||
Resp _ (ERROR AUTH) <- sendRecv h ("123", rId, "SECURE 456")
|
||||
-- should allow signed
|
||||
Resp _ OK <- sendRecv h ("456", sId, "SEND :hello")
|
||||
-- should not allow unsigned
|
||||
Resp _ (ERROR AUTH) <- sendRecv h ("", sId, "SEND :hello")
|
||||
return $ rId == rId' && rId == rId'' && sId == sId' && sId == sId''
|
||||
|
||||
syntaxTests :: SpecWith ()
|
||||
syntaxTests = do
|
||||
it "unknown command" $ [("", "123", "HELLO")] >#> [("", "123", "ERROR UNKNOWN")]
|
||||
describe "CREATE" do
|
||||
it "no parameters" $ [("", "", "CREATE")] >#> [("", "", "ERROR SYNTAX 2")]
|
||||
it "many parameters" $ [("", "", "CREATE 1 2")] >#> [("", "", "ERROR SYNTAX 2")]
|
||||
it "has signature" $ [("123", "", "CREATE 123")] >#> [("", "", "ERROR SYNTAX 4")]
|
||||
it "connection ID" $ [("", "1", "CREATE 123")] >#> [("", "1", "ERROR SYNTAX 4")]
|
||||
noParamsSyntaxTest "SUB"
|
||||
oneParamSyntaxTest "SECURE"
|
||||
oneParamSyntaxTest "DELMSG"
|
||||
noParamsSyntaxTest "SUSPEND"
|
||||
noParamsSyntaxTest "DELETE"
|
||||
describe "SEND" do
|
||||
it "valid syntax 1" $ [("123", "1", "SEND :hello")] >#> [("", "1", "ERROR AUTH")]
|
||||
it "valid syntax 2" $ [("123", "1", "SEND 11\nhello there\n")] >#> [("", "1", "ERROR AUTH")]
|
||||
it "no parameters" $ [("123", "1", "SEND")] >#> [("", "1", "ERROR SYNTAX 2")]
|
||||
it "many parameters" $ [("123", "1", "SEND 11 hello")] >#> [("", "1", "ERROR SYNTAX 2")]
|
||||
it "no connection ID" $ [("123", "", "SEND :hello")] >#> [("", "", "ERROR SYNTAX 5")]
|
||||
it "bad message body" $ [("123", "1", "SEND hello")] >#> [("", "1", "ERROR SYNTAX 6")]
|
||||
it "bigger body" $ [("123", "1", "SEND 4\nhello\n")] >#> [("", "1", "ERROR SYNTAX 7")]
|
||||
describe "broker response not allowed" do
|
||||
it "OK" $ [("123", "1", "OK")] >#> [("", "1", "ERROR PROHIBITED")]
|
||||
where
|
||||
noParamsSyntaxTest :: String -> SpecWith ()
|
||||
noParamsSyntaxTest cmd = describe cmd do
|
||||
it "valid syntax" $ [("123", "1", cmd)] >#> [("", "1", "ERROR AUTH")]
|
||||
it "parameters" $ [("123", "1", cmd ++ " 1")] >#> [("", "1", "ERROR SYNTAX 2")]
|
||||
it "no signature" $ [("", "1", cmd)] >#> [("", "1", "ERROR SYNTAX 3")]
|
||||
it "no connection ID" $ [("123", "", cmd)] >#> [("", "", "ERROR SYNTAX 3")]
|
||||
|
||||
oneParamSyntaxTest :: String -> SpecWith ()
|
||||
oneParamSyntaxTest cmd = describe cmd do
|
||||
it "valid syntax" $ [("123", "1", cmd ++ " 456")] >#> [("", "1", "ERROR AUTH")]
|
||||
it "no parameters" $ [("123", "1", cmd)] >#> [("", "1", "ERROR SYNTAX 2")]
|
||||
it "many parameters" $ [("123", "1", cmd ++ " 1 2")] >#> [("", "1", "ERROR SYNTAX 2")]
|
||||
it "no signature" $ [("", "1", cmd ++ " 456")] >#> [("", "1", "ERROR SYNTAX 3")]
|
||||
it "no connection ID" $ [("123", "", cmd ++ " 456")] >#> [("", "", "ERROR SYNTAX 3")]
|
||||
|
||||
Reference in New Issue
Block a user