move connection store to STM

This commit is contained in:
Evgeny Poberezkin
2020-10-21 11:22:00 +01:00
parent 0c17422fa1
commit caaa18a95a
5 changed files with 55 additions and 53 deletions

View File

@@ -18,15 +18,15 @@ data Connection = Connection
data ConnStatus = ConnActive | ConnOff
class MonadConnStore s m where
addConn :: s -> m (RecipientId, SenderId) -> RecipientKey -> m (Either ErrorType Connection)
addConn :: s -> RecipientKey -> (RecipientId, SenderId) -> m (Either ErrorType ())
getConn :: s -> SParty (a :: Party) -> ConnId -> m (Either ErrorType Connection)
secureConn :: s -> RecipientId -> SenderKey -> m (Either ErrorType ())
suspendConn :: s -> RecipientId -> m (Either ErrorType ())
deleteConn :: s -> RecipientId -> m (Either ErrorType ())
-- TODO stub
mkConnection :: (RecipientId, SenderId) -> RecipientKey -> Connection
mkConnection (recipientId, senderId) recipientKey =
mkConnection :: RecipientKey -> (RecipientId, SenderId) -> Connection
mkConnection recipientKey (recipientId, senderId) =
Connection
{ recipientId,
senderId,

View File

@@ -3,7 +3,6 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RankNTypes #-}
@@ -13,7 +12,6 @@
module ConnStore.STM where
import ConnStore
import Control.Monad.IO.Unlift
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Transmission
@@ -24,39 +22,30 @@ data ConnStoreData = ConnStoreData
senders :: Map SenderId RecipientId
}
type STMConnStore = TVar ConnStoreData
type ConnStore = TVar ConnStoreData
newConnStore :: STM STMConnStore
newConnStore :: STM ConnStore
newConnStore = newTVar ConnStoreData {connections = M.empty, senders = M.empty}
instance MonadUnliftIO m => MonadConnStore STMConnStore m where
addConn :: STMConnStore -> m (RecipientId, SenderId) -> RecipientKey -> m (Either ErrorType Connection)
addConn = _addConn (3 :: Int)
where
_addConn 0 _ _ _ = return $ Left INTERNAL
_addConn retry store getIds rKey = do
getIds >>= atomically . insertConn >>= \case
Nothing -> _addConn (retry - 1) store getIds rKey
Just c -> return $ Right c
where
insertConn ids@(rId, sId) = do
cs@ConnStoreData {connections, senders} <- readTVar store
if M.member rId connections || M.member sId senders
then return Nothing
else do
let c = mkConnection ids rKey
writeTVar store $
cs
{ connections = M.insert rId c connections,
senders = M.insert sId rId senders
}
return $ Just c
instance MonadConnStore ConnStore STM where
addConn :: ConnStore -> RecipientKey -> (RecipientId, SenderId) -> STM (Either ErrorType ())
addConn store rKey ids@(rId, sId) = do
cs@ConnStoreData {connections, senders} <- readTVar store
if M.member rId connections || M.member sId senders
then return $ Left DUPLICATE
else do
writeTVar store $
cs
{ connections = M.insert rId (mkConnection rKey ids) connections,
senders = M.insert sId rId senders
}
return $ Right ()
getConn :: STMConnStore -> SParty (p :: Party) -> ConnId -> m (Either ErrorType Connection)
getConn store SRecipient rId = atomically $ do
getConn :: ConnStore -> SParty (p :: Party) -> ConnId -> STM (Either ErrorType Connection)
getConn store SRecipient rId = do
cs <- readTVar store
return $ getRcpConn cs rId
getConn store SSender sId = atomically $ do
getConn store SSender sId = do
cs <- readTVar store
let rId = M.lookup sId $ senders cs
return $ maybe (Left AUTH) (getRcpConn cs) rId
@@ -69,12 +58,12 @@ instance MonadUnliftIO m => MonadConnStore STMConnStore m where
Just _ -> (Left AUTH, cs)
_ -> (Right (), cs {connections = M.insert rId c {senderKey = Just sKey} (connections cs)})
suspendConn :: STMConnStore -> RecipientId -> m (Either ErrorType ())
suspendConn :: ConnStore -> RecipientId -> STM (Either ErrorType ())
suspendConn store rId =
updateConnections store rId $ \cs c ->
(Right (), cs {connections = M.insert rId c {status = ConnOff} (connections cs)})
deleteConn :: STMConnStore -> RecipientId -> m (Either ErrorType ())
deleteConn :: ConnStore -> RecipientId -> STM (Either ErrorType ())
deleteConn store rId =
updateConnections store rId $ \cs c ->
( Right (),
@@ -85,12 +74,11 @@ instance MonadUnliftIO m => MonadConnStore STMConnStore m where
)
updateConnections ::
MonadUnliftIO m =>
STMConnStore ->
ConnStore ->
RecipientId ->
(ConnStoreData -> Connection -> (Either ErrorType (), ConnStoreData)) ->
m (Either ErrorType ())
updateConnections store rId update = atomically $ do
STM (Either ErrorType ())
updateConnections store rId update = do
cs <- readTVar store
let conn = getRcpConn cs rId
either (return . Left) (_update cs) conn

View File

@@ -25,7 +25,7 @@ data Config = Config
data Env = Env
{ config :: Config,
server :: Server,
connStore :: STMConnStore,
connStore :: ConnStore,
msgStore :: STMMsgStore,
idsDrg :: TVar ChaChaDRG
}

View File

@@ -13,6 +13,7 @@
module Server (runSMPServer) where
import ConnStore
import ConnStore.STM (ConnStore)
import Control.Concurrent.STM (stateTVar)
import Control.Monad
import Control.Monad.IO.Unlift
@@ -102,7 +103,7 @@ verifyTransmission signature connId cmd = do
withConnection :: SParty (p :: Party) -> (Connection -> m Cmd) -> m Cmd
withConnection party f = do
store <- asks connStore
conn <- getConn store party connId
conn <- atomically $ getConn store party connId
either (return . smpErr) f conn
verifySend :: Maybe PublicKey -> m Cmd
verifySend
@@ -133,9 +134,9 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} =
CONN rKey -> createConn st rKey
SUB -> subscribeConn connId
ACK -> acknowledgeMsg
KEY sKey -> okResponse <$> secureConn st connId sKey
OFF -> okResponse <$> suspendConn st connId
DEL -> okResponse <$> deleteConn st connId
KEY sKey -> okResponse <$> atomically (secureConn st connId sKey)
OFF -> okResponse <$> atomically (suspendConn st connId)
DEL -> okResponse <$> atomically (deleteConn st connId)
where
ok :: Signed
ok = (connId, Cmd SBroker OK)
@@ -143,15 +144,26 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} =
okResponse :: Either ErrorType () -> Signed
okResponse = mkSigned connId . either ERR (const OK)
createConn :: MonadConnStore s m => s -> RecipientKey -> m Signed
createConn :: ConnStore -> RecipientKey -> m Signed
createConn st rKey = mkSigned B.empty <$> addSubscribe
where
addSubscribe = do
addConn st getIds rKey >>= \case
Right Connection {recipientId = rId, senderId = sId} -> do
addConnRetry 3 >>= \case
Left e -> return $ ERR e
Right (rId, sId) -> do
void $ subscribeConn rId
return $ IDS rId sId
Left e -> return $ ERR e
addConnRetry :: Int -> m (Either ErrorType (RecipientId, SenderId))
addConnRetry 0 = return $ Left INTERNAL
addConnRetry n = do
ids <- getIds
atomically (addConn st rKey ids) >>= \case
Left DUPLICATE -> addConnRetry $ n - 1
Left e -> return $ Left e
Right _ -> return $ Right ids
getIds :: m (RecipientId, SenderId)
getIds = do
n <- asks $ connIdBytes . config
liftM2 (,) (randomId n) (randomId n)
@@ -185,9 +197,9 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} =
Just () -> deliverMessage tryDelPeekMsg connId
Nothing -> return . mkSigned connId $ ERR PROHIBITED
sendMessage :: MonadConnStore s m => s -> MsgBody -> m Signed
sendMessage :: ConnStore -> MsgBody -> m Signed
sendMessage st msgBody =
getConn st SSender connId
atomically (getConn st SSender connId)
>>= fmap (mkSigned connId) . either (return . ERR) storeMessage
where
mkMessage :: m Message
@@ -200,9 +212,11 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} =
storeMessage c = case status c of
ConnActive -> do
ms <- asks msgStore
q <- atomically $ getMsgQueue ms (recipientId c)
mkMessage >>= atomically . writeMsg q
return OK
msg <- mkMessage
atomically $ do
q <- getMsgQueue ms (recipientId c)
writeMsg q msg
return OK
ConnOff -> return $ ERR AUTH
deliverMessage :: (MsgQueue -> STM (Maybe Message)) -> RecipientId -> m Signed

View File

@@ -155,7 +155,7 @@ type MsgId = Encoded
type MsgBody = ByteString
data ErrorType = UNKNOWN | PROHIBITED | SYNTAX Int | SIZE | AUTH | INTERNAL deriving (Show, Eq)
data ErrorType = UNKNOWN | PROHIBITED | SYNTAX Int | SIZE | AUTH | INTERNAL | DUPLICATE deriving (Show, Eq)
errBadTransmission :: Int
errBadTransmission = 1