From ee4092750669f525f90da89752a1eb577b48fa20 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Sat, 17 Oct 2020 11:03:38 +0100 Subject: [PATCH] random connection and message IDs --- package.yaml | 2 ++ src/ConnStore.hs | 15 ++++----- src/ConnStore/STM.hs | 74 ++++++++++++++++++++++++++------------------ src/Env/STM.hs | 20 +++++++----- src/MsgStore.hs | 12 ------- src/Server.hs | 31 +++++++++++++++---- tests/SMPClient.hs | 5 +-- tests/Test.hs | 4 +-- 8 files changed, 96 insertions(+), 67 deletions(-) diff --git a/package.yaml b/package.yaml index d6d75deed..88fa27f50 100644 --- a/package.yaml +++ b/package.yaml @@ -13,8 +13,10 @@ extra-source-files: dependencies: - base >= 4.7 && < 5 + - base64-bytestring >= 1.0 && < 1.3 - bytestring - containers + - cryptonite == 0.26.* - iso8601-time == 0.1.* - mtl - network diff --git a/src/ConnStore.hs b/src/ConnStore.hs index d1b302b8a..f8cfa3fd7 100644 --- a/src/ConnStore.hs +++ b/src/ConnStore.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE NamedFieldPuns #-} module ConnStore where @@ -9,8 +10,8 @@ import Transmission data Connection = Connection { recipientId :: ConnId, - recipientKey :: PublicKey, senderId :: ConnId, + recipientKey :: PublicKey, senderKey :: Maybe PublicKey, status :: ConnStatus } @@ -18,19 +19,19 @@ data Connection = Connection data ConnStatus = ConnActive | ConnOff class MonadConnStore s m where - addConn :: s -> RecipientKey -> m (Either ErrorType Connection) + addConn :: s -> m (RecipientId, SenderId) -> RecipientKey -> m (Either ErrorType Connection) getConn :: s -> Sing (a :: Party) -> ConnId -> m (Either ErrorType Connection) secureConn :: s -> RecipientId -> SenderKey -> m (Either ErrorType ()) suspendConn :: s -> RecipientId -> m (Either ErrorType ()) deleteConn :: s -> RecipientId -> m (Either ErrorType ()) -- TODO stub -mkConnection :: RecipientKey -> Connection -mkConnection rKey = +mkConnection :: (RecipientId, SenderId) -> RecipientKey -> Connection +mkConnection (recipientId, senderId) recipientKey = Connection - { recipientId = "1", - recipientKey = rKey, - senderId = "2", + { recipientId, + senderId, + recipientKey, senderKey = Nothing, status = ConnActive } diff --git a/src/ConnStore/STM.hs b/src/ConnStore/STM.hs index e313267e6..5328c5ca7 100644 --- a/src/ConnStore/STM.hs +++ b/src/ConnStore/STM.hs @@ -3,7 +3,9 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE KindSignatures #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE UndecidableInstances #-} @@ -29,45 +31,57 @@ newConnStore :: STM STMConnStore newConnStore = newTVar ConnStoreData {connections = M.empty, senders = M.empty} instance MonadUnliftIO m => MonadConnStore STMConnStore m where - addConn :: STMConnStore -> RecipientKey -> m (Either ErrorType Connection) - addConn store rKey = atomically $ do - let c@Connection {recipientId = rId, senderId = sId} = mkConnection rKey - modifyTVar store $ \db -> - db - { connections = M.insert rId c (connections db), - senders = M.insert sId rId (senders db) - } - return $ Right c + addConn :: STMConnStore -> m (RecipientId, SenderId) -> RecipientKey -> m (Either ErrorType Connection) + addConn = _addConn (3 :: Int) + where + _addConn 0 _ _ _ = return $ Left INTERNAL + _addConn retry store getIds rKey = do + getIds >>= atomically . insertConn >>= \case + Nothing -> _addConn (retry - 1) store getIds rKey + Just c -> return $ Right c + where + insertConn ids@(rId, sId) = do + cs@ConnStoreData {connections, senders} <- readTVar store + if M.member rId connections || M.member sId senders + then return Nothing + else do + let c = mkConnection ids rKey + writeTVar store $ + cs + { connections = M.insert rId c connections, + senders = M.insert sId rId senders + } + return $ Just c getConn :: STMConnStore -> Sing (p :: Party) -> ConnId -> m (Either ErrorType Connection) getConn store SRecipient rId = atomically $ do - db <- readTVar store - return $ getRcpConn db rId + cs <- readTVar store + return $ getRcpConn cs rId getConn store SSender sId = atomically $ do - db <- readTVar store - let rId = M.lookup sId $ senders db - return $ maybe (Left AUTH) (getRcpConn db) rId + cs <- readTVar store + let rId = M.lookup sId $ senders cs + return $ maybe (Left AUTH) (getRcpConn cs) rId getConn _ SBroker _ = return $ Left INTERNAL secureConn store rId sKey = - updateConnections store rId $ \db c -> + updateConnections store rId $ \cs c -> case senderKey c of - Just _ -> (Left AUTH, db) - _ -> (Right (), db {connections = M.insert rId c {senderKey = Just sKey} (connections db)}) + Just _ -> (Left AUTH, cs) + _ -> (Right (), cs {connections = M.insert rId c {senderKey = Just sKey} (connections cs)}) suspendConn :: STMConnStore -> RecipientId -> m (Either ErrorType ()) suspendConn store rId = - updateConnections store rId $ \db c -> - (Right (), db {connections = M.insert rId c {status = ConnOff} (connections db)}) + updateConnections store rId $ \cs c -> + (Right (), cs {connections = M.insert rId c {status = ConnOff} (connections cs)}) deleteConn :: STMConnStore -> RecipientId -> m (Either ErrorType ()) deleteConn store rId = - updateConnections store rId $ \db c -> + updateConnections store rId $ \cs c -> ( Right (), - db - { connections = M.delete rId (connections db), - senders = M.delete (senderId c) (senders db) + cs + { connections = M.delete rId (connections cs), + senders = M.delete (senderId c) (senders cs) } ) @@ -78,14 +92,14 @@ updateConnections :: (ConnStoreData -> Connection -> (Either ErrorType (), ConnStoreData)) -> m (Either ErrorType ()) updateConnections store rId update = atomically $ do - db <- readTVar store - let conn = getRcpConn db rId - either (return . Left) (_update db) conn + cs <- readTVar store + let conn = getRcpConn cs rId + either (return . Left) (_update cs) conn where - _update db c = do - let (res, db') = update db c - writeTVar store db' + _update cs c = do + let (res, cs') = update cs c + writeTVar store cs' return res getRcpConn :: ConnStoreData -> RecipientId -> Either ErrorType Connection -getRcpConn db rId = maybe (Left AUTH) Right . M.lookup rId $ connections db +getRcpConn cs rId = maybe (Left AUTH) Right . M.lookup rId $ connections cs diff --git a/src/Env/STM.hs b/src/Env/STM.hs index a59685790..3d3ed522f 100644 --- a/src/Env/STM.hs +++ b/src/Env/STM.hs @@ -4,21 +4,24 @@ module Env.STM where import ConnStore.STM -import Control.Concurrent -import Control.Concurrent.STM +import Control.Concurrent (ThreadId) +import Control.Monad.IO.Unlift +import Crypto.Random import Data.Map.Strict (Map) import qualified Data.Map.Strict as M import MsgStore.STM import Network.Socket (ServiceName) import Numeric.Natural import Transmission +import UnliftIO.STM data Env = Env { tcpPort :: ServiceName, queueSize :: Natural, server :: Server, connStore :: STMConnStore, - msgStore :: STMMsgStore + msgStore :: STMMsgStore, + idsDrg :: TVar ChaChaDRG } data Server = Server @@ -45,9 +48,10 @@ newClient qSize = do sndQ <- newTBQueue qSize return Client {connections, rcvQ, sndQ} -newEnv :: String -> Natural -> STM Env +newEnv :: (MonadUnliftIO m, MonadRandom m) => String -> Natural -> m Env newEnv tcpPort queueSize = do - server <- newServer queueSize - connStore <- newConnStore - msgStore <- newMsgStore - return Env {tcpPort, queueSize, server, connStore, msgStore} + server <- atomically $ newServer queueSize + connStore <- atomically newConnStore + msgStore <- atomically newMsgStore + idsDrg <- drgNew >>= newTVarIO + return Env {tcpPort, queueSize, server, connStore, msgStore, idsDrg} diff --git a/src/MsgStore.hs b/src/MsgStore.hs index 3c2186447..18d2860f5 100644 --- a/src/MsgStore.hs +++ b/src/MsgStore.hs @@ -1,9 +1,7 @@ {-# LANGUAGE FunctionalDependencies #-} -{-# LANGUAGE NamedFieldPuns #-} module MsgStore where -import Control.Monad.IO.Class import Data.Time.Clock import Transmission @@ -22,13 +20,3 @@ class MonadMsgQueue q m where tryPeekMsg :: q -> m (Maybe Message) -- non blocking peekMsg :: q -> m Message -- blocking tryDelPeekMsg :: q -> m (Maybe Message) -- atomic delete (== read) last and peek next message, if available - -newMessage :: MonadIO m => MsgBody -> m Message -newMessage msgBody = do - ts <- liftIO getCurrentTime - return - Message - { msgId = "1", - ts, - msgBody - } diff --git a/src/Server.hs b/src/Server.hs index b86e0d339..404635363 100644 --- a/src/Server.hs +++ b/src/Server.hs @@ -15,8 +15,13 @@ import ConnStore import Control.Monad import Control.Monad.IO.Unlift import Control.Monad.Reader +import Crypto.Random +import Data.ByteString.Base64 +import Data.ByteString.Char8 (ByteString) +import qualified Data.ByteString.Char8 as B import qualified Data.Map.Strict as M import Data.Singletons +import Data.Time.Clock import Env.STM import MsgStore import MsgStore.STM (MsgQueue) @@ -29,9 +34,9 @@ import UnliftIO.Concurrent import UnliftIO.IO import UnliftIO.STM -runSMPServer :: MonadUnliftIO m => ServiceName -> Natural -> m () +runSMPServer :: (MonadRandom m, MonadUnliftIO m) => ServiceName -> Natural -> m () runSMPServer port queueSize = do - env <- atomically $ newEnv port queueSize + env <- newEnv port queueSize runReaderT smpServer env where smpServer :: (MonadUnliftIO m, MonadReader Env m) => m () @@ -131,12 +136,13 @@ client clnt@Client {connections, rcvQ, sndQ} Server {subscribedQ} = createConn :: MonadConnStore s m => s -> RecipientKey -> m Signed createConn st rKey = mkSigned "" <$> addSubscribe where - addSubscribe = - addConn st rKey >>= \case + addSubscribe = do + addConn st getIds rKey >>= \case Right Connection {recipientId = rId, senderId = sId} -> do void $ subscribeConn rId return $ IDS rId sId Left e -> return $ ERR e + getIds = liftM2 (,) (randomId 16) (randomId 16) subscribeConn :: RecipientId -> m Signed subscribeConn rId = do @@ -165,8 +171,9 @@ client clnt@Client {connections, rcvQ, sndQ} Server {subscribedQ} = ConnActive -> do ms <- asks msgStore q <- getMsgQueue ms (recipientId c) - msg <- newMessage msgBody - writeMsg q msg + msgId <- randomId 8 + ts <- liftIO getCurrentTime + writeMsg q $ Message {msgId, ts, msgBody} return OK ConnOff -> return $ ERR AUTH @@ -199,3 +206,15 @@ client clnt@Client {connections, rcvQ, sndQ} Server {subscribedQ} = msgResponse :: RecipientId -> Message -> Signed msgResponse rId Message {msgId, ts, msgBody} = mkSigned rId $ MSG msgId ts msgBody + +randomId :: (MonadUnliftIO m, MonadReader Env m) => Int -> m Encoded +randomId n = do + gVar <- asks idsDrg + B.unpack . encode <$> atomically (randomBytes n gVar) + +randomBytes :: Int -> TVar ChaChaDRG -> STM ByteString +randomBytes n gVar = do + g <- readTVar gVar + let (bytes, g') = randomBytesGenerate n g + writeTVar gVar g' + return bytes diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index 726be399f..9c18f719f 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -3,6 +3,7 @@ module SMPClient where import Control.Monad.IO.Unlift +import Crypto.Random import Network.Socket import Numeric.Natural import Server @@ -14,7 +15,7 @@ import UnliftIO.IO testSMPClient :: MonadUnliftIO m => HostName -> ServiceName -> (Handle -> m a) -> m a testSMPClient host port client = do - threadDelay 1 -- TODO hack: thread delay for SMP server to start + threadDelay 100 -- TODO hack: thread delay for SMP server to start runTCPClient host port $ \h -> do line <- getLn h if line == "Welcome to SMP" @@ -32,7 +33,7 @@ queueSize = 2 type TestTransmission = (Signature, ConnId, String) -runSmpTest :: MonadUnliftIO m => (Handle -> m a) -> m a +runSmpTest :: (MonadUnliftIO m, MonadRandom m) => (Handle -> m a) -> m a runSmpTest test = E.bracket (forkIO $ runSMPServer testPort queueSize) diff --git a/tests/Test.hs b/tests/Test.hs index 6e7cdd407..9244fc46c 100644 --- a/tests/Test.hs +++ b/tests/Test.hs @@ -47,7 +47,7 @@ testCreateSecure = do Resp _ (MSG _ _ msg1) <- tGet fromServer h (msg1, "hello") #== "delivers message" - Resp _ ok4 <- sendRecv h ("123", "1", "ACK") + Resp _ ok4 <- sendRecv h ("123", rId, "ACK") (ok4, OK) #== "replies OK when message acknowledged if no more messages" Resp sId2 err1 <- sendRecv h ("456", sId, "SEND :hello") @@ -73,7 +73,7 @@ testCreateSecure = do Resp _ (MSG _ _ msg) <- tGet fromServer h (msg, "hello again") #== "delivers message 2" - Resp _ ok5 <- sendRecv h ("123", "1", "ACK") + Resp _ ok5 <- sendRecv h ("123", rId, "ACK") (ok5, OK) #== "replies OK when message acknowledged 2" Resp _ err5 <- sendRecv h ("", sId, "SEND :hello")