From caaa18a95a126f79ef33f1909f4b58503da1cf9f Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Wed, 21 Oct 2020 11:22:00 +0100 Subject: [PATCH] move connection store to STM --- src/ConnStore.hs | 6 ++--- src/ConnStore/STM.hs | 58 ++++++++++++++++++-------------------------- src/Env/STM.hs | 2 +- src/Server.hs | 40 ++++++++++++++++++++---------- src/Transmission.hs | 2 +- 5 files changed, 55 insertions(+), 53 deletions(-) diff --git a/src/ConnStore.hs b/src/ConnStore.hs index b180fadd5..8205c7030 100644 --- a/src/ConnStore.hs +++ b/src/ConnStore.hs @@ -18,15 +18,15 @@ data Connection = Connection data ConnStatus = ConnActive | ConnOff class MonadConnStore s m where - addConn :: s -> m (RecipientId, SenderId) -> RecipientKey -> m (Either ErrorType Connection) + addConn :: s -> RecipientKey -> (RecipientId, SenderId) -> m (Either ErrorType ()) getConn :: s -> SParty (a :: Party) -> ConnId -> m (Either ErrorType Connection) secureConn :: s -> RecipientId -> SenderKey -> m (Either ErrorType ()) suspendConn :: s -> RecipientId -> m (Either ErrorType ()) deleteConn :: s -> RecipientId -> m (Either ErrorType ()) -- TODO stub -mkConnection :: (RecipientId, SenderId) -> RecipientKey -> Connection -mkConnection (recipientId, senderId) recipientKey = +mkConnection :: RecipientKey -> (RecipientId, SenderId) -> Connection +mkConnection recipientKey (recipientId, senderId) = Connection { recipientId, senderId, diff --git a/src/ConnStore/STM.hs b/src/ConnStore/STM.hs index e87988965..0b6ad9c87 100644 --- a/src/ConnStore/STM.hs +++ b/src/ConnStore/STM.hs @@ -3,7 +3,6 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE KindSignatures #-} -{-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE RankNTypes #-} @@ -13,7 +12,6 @@ module ConnStore.STM where import ConnStore -import Control.Monad.IO.Unlift import Data.Map.Strict (Map) import qualified Data.Map.Strict as M import Transmission @@ -24,39 +22,30 @@ data ConnStoreData = ConnStoreData senders :: Map SenderId RecipientId } -type STMConnStore = TVar ConnStoreData +type ConnStore = TVar ConnStoreData -newConnStore :: STM STMConnStore +newConnStore :: STM ConnStore newConnStore = newTVar ConnStoreData {connections = M.empty, senders = M.empty} -instance MonadUnliftIO m => MonadConnStore STMConnStore m where - addConn :: STMConnStore -> m (RecipientId, SenderId) -> RecipientKey -> m (Either ErrorType Connection) - addConn = _addConn (3 :: Int) - where - _addConn 0 _ _ _ = return $ Left INTERNAL - _addConn retry store getIds rKey = do - getIds >>= atomically . insertConn >>= \case - Nothing -> _addConn (retry - 1) store getIds rKey - Just c -> return $ Right c - where - insertConn ids@(rId, sId) = do - cs@ConnStoreData {connections, senders} <- readTVar store - if M.member rId connections || M.member sId senders - then return Nothing - else do - let c = mkConnection ids rKey - writeTVar store $ - cs - { connections = M.insert rId c connections, - senders = M.insert sId rId senders - } - return $ Just c +instance MonadConnStore ConnStore STM where + addConn :: ConnStore -> RecipientKey -> (RecipientId, SenderId) -> STM (Either ErrorType ()) + addConn store rKey ids@(rId, sId) = do + cs@ConnStoreData {connections, senders} <- readTVar store + if M.member rId connections || M.member sId senders + then return $ Left DUPLICATE + else do + writeTVar store $ + cs + { connections = M.insert rId (mkConnection rKey ids) connections, + senders = M.insert sId rId senders + } + return $ Right () - getConn :: STMConnStore -> SParty (p :: Party) -> ConnId -> m (Either ErrorType Connection) - getConn store SRecipient rId = atomically $ do + getConn :: ConnStore -> SParty (p :: Party) -> ConnId -> STM (Either ErrorType Connection) + getConn store SRecipient rId = do cs <- readTVar store return $ getRcpConn cs rId - getConn store SSender sId = atomically $ do + getConn store SSender sId = do cs <- readTVar store let rId = M.lookup sId $ senders cs return $ maybe (Left AUTH) (getRcpConn cs) rId @@ -69,12 +58,12 @@ instance MonadUnliftIO m => MonadConnStore STMConnStore m where Just _ -> (Left AUTH, cs) _ -> (Right (), cs {connections = M.insert rId c {senderKey = Just sKey} (connections cs)}) - suspendConn :: STMConnStore -> RecipientId -> m (Either ErrorType ()) + suspendConn :: ConnStore -> RecipientId -> STM (Either ErrorType ()) suspendConn store rId = updateConnections store rId $ \cs c -> (Right (), cs {connections = M.insert rId c {status = ConnOff} (connections cs)}) - deleteConn :: STMConnStore -> RecipientId -> m (Either ErrorType ()) + deleteConn :: ConnStore -> RecipientId -> STM (Either ErrorType ()) deleteConn store rId = updateConnections store rId $ \cs c -> ( Right (), @@ -85,12 +74,11 @@ instance MonadUnliftIO m => MonadConnStore STMConnStore m where ) updateConnections :: - MonadUnliftIO m => - STMConnStore -> + ConnStore -> RecipientId -> (ConnStoreData -> Connection -> (Either ErrorType (), ConnStoreData)) -> - m (Either ErrorType ()) -updateConnections store rId update = atomically $ do + STM (Either ErrorType ()) +updateConnections store rId update = do cs <- readTVar store let conn = getRcpConn cs rId either (return . Left) (_update cs) conn diff --git a/src/Env/STM.hs b/src/Env/STM.hs index 378df6232..6307add20 100644 --- a/src/Env/STM.hs +++ b/src/Env/STM.hs @@ -25,7 +25,7 @@ data Config = Config data Env = Env { config :: Config, server :: Server, - connStore :: STMConnStore, + connStore :: ConnStore, msgStore :: STMMsgStore, idsDrg :: TVar ChaChaDRG } diff --git a/src/Server.hs b/src/Server.hs index 0642ed391..dd1ad76ed 100644 --- a/src/Server.hs +++ b/src/Server.hs @@ -13,6 +13,7 @@ module Server (runSMPServer) where import ConnStore +import ConnStore.STM (ConnStore) import Control.Concurrent.STM (stateTVar) import Control.Monad import Control.Monad.IO.Unlift @@ -102,7 +103,7 @@ verifyTransmission signature connId cmd = do withConnection :: SParty (p :: Party) -> (Connection -> m Cmd) -> m Cmd withConnection party f = do store <- asks connStore - conn <- getConn store party connId + conn <- atomically $ getConn store party connId either (return . smpErr) f conn verifySend :: Maybe PublicKey -> m Cmd verifySend @@ -133,9 +134,9 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} = CONN rKey -> createConn st rKey SUB -> subscribeConn connId ACK -> acknowledgeMsg - KEY sKey -> okResponse <$> secureConn st connId sKey - OFF -> okResponse <$> suspendConn st connId - DEL -> okResponse <$> deleteConn st connId + KEY sKey -> okResponse <$> atomically (secureConn st connId sKey) + OFF -> okResponse <$> atomically (suspendConn st connId) + DEL -> okResponse <$> atomically (deleteConn st connId) where ok :: Signed ok = (connId, Cmd SBroker OK) @@ -143,15 +144,26 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} = okResponse :: Either ErrorType () -> Signed okResponse = mkSigned connId . either ERR (const OK) - createConn :: MonadConnStore s m => s -> RecipientKey -> m Signed + createConn :: ConnStore -> RecipientKey -> m Signed createConn st rKey = mkSigned B.empty <$> addSubscribe where addSubscribe = do - addConn st getIds rKey >>= \case - Right Connection {recipientId = rId, senderId = sId} -> do + addConnRetry 3 >>= \case + Left e -> return $ ERR e + Right (rId, sId) -> do void $ subscribeConn rId return $ IDS rId sId - Left e -> return $ ERR e + + addConnRetry :: Int -> m (Either ErrorType (RecipientId, SenderId)) + addConnRetry 0 = return $ Left INTERNAL + addConnRetry n = do + ids <- getIds + atomically (addConn st rKey ids) >>= \case + Left DUPLICATE -> addConnRetry $ n - 1 + Left e -> return $ Left e + Right _ -> return $ Right ids + + getIds :: m (RecipientId, SenderId) getIds = do n <- asks $ connIdBytes . config liftM2 (,) (randomId n) (randomId n) @@ -185,9 +197,9 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} = Just () -> deliverMessage tryDelPeekMsg connId Nothing -> return . mkSigned connId $ ERR PROHIBITED - sendMessage :: MonadConnStore s m => s -> MsgBody -> m Signed + sendMessage :: ConnStore -> MsgBody -> m Signed sendMessage st msgBody = - getConn st SSender connId + atomically (getConn st SSender connId) >>= fmap (mkSigned connId) . either (return . ERR) storeMessage where mkMessage :: m Message @@ -200,9 +212,11 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} = storeMessage c = case status c of ConnActive -> do ms <- asks msgStore - q <- atomically $ getMsgQueue ms (recipientId c) - mkMessage >>= atomically . writeMsg q - return OK + msg <- mkMessage + atomically $ do + q <- getMsgQueue ms (recipientId c) + writeMsg q msg + return OK ConnOff -> return $ ERR AUTH deliverMessage :: (MsgQueue -> STM (Maybe Message)) -> RecipientId -> m Signed diff --git a/src/Transmission.hs b/src/Transmission.hs index 672cd9a46..e7ef75833 100644 --- a/src/Transmission.hs +++ b/src/Transmission.hs @@ -155,7 +155,7 @@ type MsgId = Encoded type MsgBody = ByteString -data ErrorType = UNKNOWN | PROHIBITED | SYNTAX Int | SIZE | AUTH | INTERNAL deriving (Show, Eq) +data ErrorType = UNKNOWN | PROHIBITED | SYNTAX Int | SIZE | AUTH | INTERNAL | DUPLICATE deriving (Show, Eq) errBadTransmission :: Int errBadTransmission = 1