mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-03-30 16:26:02 +00:00
move connection store to STM
This commit is contained in:
@@ -18,15 +18,15 @@ data Connection = Connection
|
||||
data ConnStatus = ConnActive | ConnOff
|
||||
|
||||
class MonadConnStore s m where
|
||||
addConn :: s -> m (RecipientId, SenderId) -> RecipientKey -> m (Either ErrorType Connection)
|
||||
addConn :: s -> RecipientKey -> (RecipientId, SenderId) -> m (Either ErrorType ())
|
||||
getConn :: s -> SParty (a :: Party) -> ConnId -> m (Either ErrorType Connection)
|
||||
secureConn :: s -> RecipientId -> SenderKey -> m (Either ErrorType ())
|
||||
suspendConn :: s -> RecipientId -> m (Either ErrorType ())
|
||||
deleteConn :: s -> RecipientId -> m (Either ErrorType ())
|
||||
|
||||
-- TODO stub
|
||||
mkConnection :: (RecipientId, SenderId) -> RecipientKey -> Connection
|
||||
mkConnection (recipientId, senderId) recipientKey =
|
||||
mkConnection :: RecipientKey -> (RecipientId, SenderId) -> Connection
|
||||
mkConnection recipientKey (recipientId, senderId) =
|
||||
Connection
|
||||
{ recipientId,
|
||||
senderId,
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE InstanceSigs #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
@@ -13,7 +12,6 @@
|
||||
module ConnStore.STM where
|
||||
|
||||
import ConnStore
|
||||
import Control.Monad.IO.Unlift
|
||||
import Data.Map.Strict (Map)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Transmission
|
||||
@@ -24,39 +22,30 @@ data ConnStoreData = ConnStoreData
|
||||
senders :: Map SenderId RecipientId
|
||||
}
|
||||
|
||||
type STMConnStore = TVar ConnStoreData
|
||||
type ConnStore = TVar ConnStoreData
|
||||
|
||||
newConnStore :: STM STMConnStore
|
||||
newConnStore :: STM ConnStore
|
||||
newConnStore = newTVar ConnStoreData {connections = M.empty, senders = M.empty}
|
||||
|
||||
instance MonadUnliftIO m => MonadConnStore STMConnStore m where
|
||||
addConn :: STMConnStore -> m (RecipientId, SenderId) -> RecipientKey -> m (Either ErrorType Connection)
|
||||
addConn = _addConn (3 :: Int)
|
||||
where
|
||||
_addConn 0 _ _ _ = return $ Left INTERNAL
|
||||
_addConn retry store getIds rKey = do
|
||||
getIds >>= atomically . insertConn >>= \case
|
||||
Nothing -> _addConn (retry - 1) store getIds rKey
|
||||
Just c -> return $ Right c
|
||||
where
|
||||
insertConn ids@(rId, sId) = do
|
||||
cs@ConnStoreData {connections, senders} <- readTVar store
|
||||
if M.member rId connections || M.member sId senders
|
||||
then return Nothing
|
||||
else do
|
||||
let c = mkConnection ids rKey
|
||||
writeTVar store $
|
||||
cs
|
||||
{ connections = M.insert rId c connections,
|
||||
senders = M.insert sId rId senders
|
||||
}
|
||||
return $ Just c
|
||||
instance MonadConnStore ConnStore STM where
|
||||
addConn :: ConnStore -> RecipientKey -> (RecipientId, SenderId) -> STM (Either ErrorType ())
|
||||
addConn store rKey ids@(rId, sId) = do
|
||||
cs@ConnStoreData {connections, senders} <- readTVar store
|
||||
if M.member rId connections || M.member sId senders
|
||||
then return $ Left DUPLICATE
|
||||
else do
|
||||
writeTVar store $
|
||||
cs
|
||||
{ connections = M.insert rId (mkConnection rKey ids) connections,
|
||||
senders = M.insert sId rId senders
|
||||
}
|
||||
return $ Right ()
|
||||
|
||||
getConn :: STMConnStore -> SParty (p :: Party) -> ConnId -> m (Either ErrorType Connection)
|
||||
getConn store SRecipient rId = atomically $ do
|
||||
getConn :: ConnStore -> SParty (p :: Party) -> ConnId -> STM (Either ErrorType Connection)
|
||||
getConn store SRecipient rId = do
|
||||
cs <- readTVar store
|
||||
return $ getRcpConn cs rId
|
||||
getConn store SSender sId = atomically $ do
|
||||
getConn store SSender sId = do
|
||||
cs <- readTVar store
|
||||
let rId = M.lookup sId $ senders cs
|
||||
return $ maybe (Left AUTH) (getRcpConn cs) rId
|
||||
@@ -69,12 +58,12 @@ instance MonadUnliftIO m => MonadConnStore STMConnStore m where
|
||||
Just _ -> (Left AUTH, cs)
|
||||
_ -> (Right (), cs {connections = M.insert rId c {senderKey = Just sKey} (connections cs)})
|
||||
|
||||
suspendConn :: STMConnStore -> RecipientId -> m (Either ErrorType ())
|
||||
suspendConn :: ConnStore -> RecipientId -> STM (Either ErrorType ())
|
||||
suspendConn store rId =
|
||||
updateConnections store rId $ \cs c ->
|
||||
(Right (), cs {connections = M.insert rId c {status = ConnOff} (connections cs)})
|
||||
|
||||
deleteConn :: STMConnStore -> RecipientId -> m (Either ErrorType ())
|
||||
deleteConn :: ConnStore -> RecipientId -> STM (Either ErrorType ())
|
||||
deleteConn store rId =
|
||||
updateConnections store rId $ \cs c ->
|
||||
( Right (),
|
||||
@@ -85,12 +74,11 @@ instance MonadUnliftIO m => MonadConnStore STMConnStore m where
|
||||
)
|
||||
|
||||
updateConnections ::
|
||||
MonadUnliftIO m =>
|
||||
STMConnStore ->
|
||||
ConnStore ->
|
||||
RecipientId ->
|
||||
(ConnStoreData -> Connection -> (Either ErrorType (), ConnStoreData)) ->
|
||||
m (Either ErrorType ())
|
||||
updateConnections store rId update = atomically $ do
|
||||
STM (Either ErrorType ())
|
||||
updateConnections store rId update = do
|
||||
cs <- readTVar store
|
||||
let conn = getRcpConn cs rId
|
||||
either (return . Left) (_update cs) conn
|
||||
|
||||
@@ -25,7 +25,7 @@ data Config = Config
|
||||
data Env = Env
|
||||
{ config :: Config,
|
||||
server :: Server,
|
||||
connStore :: STMConnStore,
|
||||
connStore :: ConnStore,
|
||||
msgStore :: STMMsgStore,
|
||||
idsDrg :: TVar ChaChaDRG
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
module Server (runSMPServer) where
|
||||
|
||||
import ConnStore
|
||||
import ConnStore.STM (ConnStore)
|
||||
import Control.Concurrent.STM (stateTVar)
|
||||
import Control.Monad
|
||||
import Control.Monad.IO.Unlift
|
||||
@@ -102,7 +103,7 @@ verifyTransmission signature connId cmd = do
|
||||
withConnection :: SParty (p :: Party) -> (Connection -> m Cmd) -> m Cmd
|
||||
withConnection party f = do
|
||||
store <- asks connStore
|
||||
conn <- getConn store party connId
|
||||
conn <- atomically $ getConn store party connId
|
||||
either (return . smpErr) f conn
|
||||
verifySend :: Maybe PublicKey -> m Cmd
|
||||
verifySend
|
||||
@@ -133,9 +134,9 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} =
|
||||
CONN rKey -> createConn st rKey
|
||||
SUB -> subscribeConn connId
|
||||
ACK -> acknowledgeMsg
|
||||
KEY sKey -> okResponse <$> secureConn st connId sKey
|
||||
OFF -> okResponse <$> suspendConn st connId
|
||||
DEL -> okResponse <$> deleteConn st connId
|
||||
KEY sKey -> okResponse <$> atomically (secureConn st connId sKey)
|
||||
OFF -> okResponse <$> atomically (suspendConn st connId)
|
||||
DEL -> okResponse <$> atomically (deleteConn st connId)
|
||||
where
|
||||
ok :: Signed
|
||||
ok = (connId, Cmd SBroker OK)
|
||||
@@ -143,15 +144,26 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} =
|
||||
okResponse :: Either ErrorType () -> Signed
|
||||
okResponse = mkSigned connId . either ERR (const OK)
|
||||
|
||||
createConn :: MonadConnStore s m => s -> RecipientKey -> m Signed
|
||||
createConn :: ConnStore -> RecipientKey -> m Signed
|
||||
createConn st rKey = mkSigned B.empty <$> addSubscribe
|
||||
where
|
||||
addSubscribe = do
|
||||
addConn st getIds rKey >>= \case
|
||||
Right Connection {recipientId = rId, senderId = sId} -> do
|
||||
addConnRetry 3 >>= \case
|
||||
Left e -> return $ ERR e
|
||||
Right (rId, sId) -> do
|
||||
void $ subscribeConn rId
|
||||
return $ IDS rId sId
|
||||
Left e -> return $ ERR e
|
||||
|
||||
addConnRetry :: Int -> m (Either ErrorType (RecipientId, SenderId))
|
||||
addConnRetry 0 = return $ Left INTERNAL
|
||||
addConnRetry n = do
|
||||
ids <- getIds
|
||||
atomically (addConn st rKey ids) >>= \case
|
||||
Left DUPLICATE -> addConnRetry $ n - 1
|
||||
Left e -> return $ Left e
|
||||
Right _ -> return $ Right ids
|
||||
|
||||
getIds :: m (RecipientId, SenderId)
|
||||
getIds = do
|
||||
n <- asks $ connIdBytes . config
|
||||
liftM2 (,) (randomId n) (randomId n)
|
||||
@@ -185,9 +197,9 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} =
|
||||
Just () -> deliverMessage tryDelPeekMsg connId
|
||||
Nothing -> return . mkSigned connId $ ERR PROHIBITED
|
||||
|
||||
sendMessage :: MonadConnStore s m => s -> MsgBody -> m Signed
|
||||
sendMessage :: ConnStore -> MsgBody -> m Signed
|
||||
sendMessage st msgBody =
|
||||
getConn st SSender connId
|
||||
atomically (getConn st SSender connId)
|
||||
>>= fmap (mkSigned connId) . either (return . ERR) storeMessage
|
||||
where
|
||||
mkMessage :: m Message
|
||||
@@ -200,9 +212,11 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} =
|
||||
storeMessage c = case status c of
|
||||
ConnActive -> do
|
||||
ms <- asks msgStore
|
||||
q <- atomically $ getMsgQueue ms (recipientId c)
|
||||
mkMessage >>= atomically . writeMsg q
|
||||
return OK
|
||||
msg <- mkMessage
|
||||
atomically $ do
|
||||
q <- getMsgQueue ms (recipientId c)
|
||||
writeMsg q msg
|
||||
return OK
|
||||
ConnOff -> return $ ERR AUTH
|
||||
|
||||
deliverMessage :: (MsgQueue -> STM (Maybe Message)) -> RecipientId -> m Signed
|
||||
|
||||
@@ -155,7 +155,7 @@ type MsgId = Encoded
|
||||
|
||||
type MsgBody = ByteString
|
||||
|
||||
data ErrorType = UNKNOWN | PROHIBITED | SYNTAX Int | SIZE | AUTH | INTERNAL deriving (Show, Eq)
|
||||
data ErrorType = UNKNOWN | PROHIBITED | SYNTAX Int | SIZE | AUTH | INTERNAL | DUPLICATE deriving (Show, Eq)
|
||||
|
||||
errBadTransmission :: Int
|
||||
errBadTransmission = 1
|
||||
|
||||
Reference in New Issue
Block a user