diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index 5fe48baaf..5187c0d73 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -111,15 +111,12 @@ newEnv config@ServerConfig {caCertificateFile, certificateFile, privateKeyFile} return Env {config, server, serverIdentity, queueStore, msgStore, idsDrg, storeLog = s', tlsServerParams} where restoreQueues :: QueueStore -> StoreLog 'ReadMode -> m (StoreLog 'WriteMode) - restoreQueues queueStore s = do - (queues, s') <- liftIO $ readWriteStoreLog s - atomically $ - modifyTVar' queueStore $ \d -> - d - { queues, - senders = M.foldr' addSender M.empty queues, - notifiers = M.foldr' addNotifier M.empty queues - } + restoreQueues QueueStore {queues, senders, notifiers} s = do + (qs, s') <- liftIO $ readWriteStoreLog s + atomically $ do + writeTVar (TM.tVar queues) =<< mapM newTVar qs + writeTVar (TM.tVar senders) $ M.foldr' addSender M.empty qs + writeTVar (TM.tVar notifiers) $ M.foldr' addNotifier M.empty qs pure s' addSender :: QueueRec -> Map SenderId RecipientId -> Map SenderId RecipientId addSender q = M.insert (senderId q) (recipientId q) diff --git a/src/Simplex/Messaging/Server/MsgStore/STM.hs b/src/Simplex/Messaging/Server/MsgStore/STM.hs index 9ebe55bbf..86d6db996 100644 --- a/src/Simplex/Messaging/Server/MsgStore/STM.hs +++ b/src/Simplex/Messaging/Server/MsgStore/STM.hs @@ -6,36 +6,31 @@ module Simplex.Messaging.Server.MsgStore.STM where -import Data.Map.Strict (Map) -import qualified Data.Map.Strict as M import Numeric.Natural import Simplex.Messaging.Protocol (RecipientId) import Simplex.Messaging.Server.MsgStore +import Simplex.Messaging.TMap (TMap) +import qualified Simplex.Messaging.TMap as TM import UnliftIO.STM newtype MsgQueue = MsgQueue {msgQueue :: TBQueue Message} -newtype MsgStoreData = MsgStoreData {messages :: Map RecipientId MsgQueue} - -type STMMsgStore = TVar MsgStoreData +type STMMsgStore = TMap RecipientId MsgQueue newMsgStore :: STM STMMsgStore -newMsgStore = newTVar $ MsgStoreData M.empty +newMsgStore = TM.empty instance MonadMsgStore STMMsgStore MsgQueue STM where getMsgQueue :: STMMsgStore -> RecipientId -> Natural -> STM MsgQueue - getMsgQueue store rId quota = do - m <- messages <$> readTVar store - maybe (newQ m) return $ M.lookup rId m + getMsgQueue st rId quota = maybe newQ pure =<< TM.lookup rId st where - newQ m' = do + newQ = do q <- MsgQueue <$> newTBQueue quota - writeTVar store . MsgStoreData $ M.insert rId q m' + TM.insert rId q st return q delMsgQueue :: STMMsgStore -> RecipientId -> STM () - delMsgQueue store rId = - modifyTVar' store $ MsgStoreData . M.delete rId . messages + delMsgQueue st rId = TM.delete rId st instance MonadMsgQueue MsgQueue STM where isFull :: MsgQueue -> STM Bool diff --git a/src/Simplex/Messaging/Server/QueueStore.hs b/src/Simplex/Messaging/Server/QueueStore.hs index ed859422a..544bf35b9 100644 --- a/src/Simplex/Messaging/Server/QueueStore.hs +++ b/src/Simplex/Messaging/Server/QueueStore.hs @@ -15,6 +15,7 @@ data QueueRec = QueueRec notifier :: Maybe (NotifierId, NtfPublicVerifyKey), status :: QueueStatus } + deriving (Eq, Show) data QueueStatus = QueueActive | QueueOff deriving (Eq, Show) diff --git a/src/Simplex/Messaging/Server/QueueStore/STM.hs b/src/Simplex/Messaging/Server/QueueStore/STM.hs index b3424f6e8..401e7ee30 100644 --- a/src/Simplex/Messaging/Server/QueueStore/STM.hs +++ b/src/Simplex/Messaging/Server/QueueStore/STM.hs @@ -3,6 +3,7 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE KindSignatures #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE RankNTypes #-} @@ -11,107 +12,83 @@ module Simplex.Messaging.Server.QueueStore.STM where -import Data.Map.Strict (Map) -import qualified Data.Map.Strict as M +import Control.Monad +import Data.Functor (($>)) import Simplex.Messaging.Protocol import Simplex.Messaging.Server.QueueStore +import Simplex.Messaging.TMap (TMap) +import qualified Simplex.Messaging.TMap as TM +import Simplex.Messaging.Util (ifM) import UnliftIO.STM -data QueueStoreData = QueueStoreData - { queues :: Map RecipientId QueueRec, - senders :: Map SenderId RecipientId, - notifiers :: Map NotifierId RecipientId +data QueueStore = QueueStore + { queues :: TMap RecipientId (TVar QueueRec), + senders :: TMap SenderId RecipientId, + notifiers :: TMap NotifierId RecipientId } -type QueueStore = TVar QueueStoreData - newQueueStore :: STM QueueStore -newQueueStore = newTVar QueueStoreData {queues = M.empty, senders = M.empty, notifiers = M.empty} +newQueueStore = do + queues <- TM.empty + senders <- TM.empty + notifiers <- TM.empty + pure QueueStore {queues, senders, notifiers} instance MonadQueueStore QueueStore STM where addQueue :: QueueStore -> QueueRec -> STM (Either ErrorType ()) - addQueue store qRec@QueueRec {recipientId = rId, senderId = sId} = do - cs@QueueStoreData {queues, senders} <- readTVar store - if M.member rId queues || M.member sId senders - then return $ Left DUPLICATE_ - else do - writeTVar store $ - cs - { queues = M.insert rId qRec queues, - senders = M.insert sId rId senders - } - return $ Right () + addQueue QueueStore {queues, senders} q@QueueRec {recipientId = rId, senderId = sId} = do + ifM hasId (pure $ Left DUPLICATE_) $ do + qVar <- newTVar q + TM.insert rId qVar queues + TM.insert sId rId senders + pure $ Right () + where + hasId = (||) <$> TM.member rId queues <*> TM.member sId senders getQueue :: QueueStore -> SParty p -> QueueId -> STM (Either ErrorType QueueRec) - getQueue st party qId = do - cs <- readTVar st - pure $ case party of - SRecipient -> getRcpQueue cs qId - SSender -> getPartyQueue cs senders - SNotifier -> getPartyQueue cs notifiers + getQueue QueueStore {queues, senders, notifiers} party qId = + toResult <$> (mapM readTVar =<< getVar) where - getPartyQueue :: - QueueStoreData -> - (QueueStoreData -> Map QueueId RecipientId) -> - Either ErrorType QueueRec - getPartyQueue cs recipientIds = - case M.lookup qId $ recipientIds cs of - Just rId -> getRcpQueue cs rId - Nothing -> Left AUTH + getVar = case party of + SRecipient -> TM.lookup qId queues + SSender -> TM.lookup qId senders >>= get + SNotifier -> TM.lookup qId notifiers >>= get + get = fmap join . mapM (`TM.lookup` queues) secureQueue :: QueueStore -> RecipientId -> SndPublicVerifyKey -> STM (Either ErrorType QueueRec) - secureQueue store rId sKey = - updateQueues store rId $ \cs c -> - case senderKey c of - Just _ -> (Left AUTH, cs) - _ -> (Right c, cs {queues = M.insert rId c {senderKey = Just sKey} (queues cs)}) + secureQueue QueueStore {queues} rId sKey = + withQueue rId queues $ \qVar -> + readTVar qVar >>= \q -> case senderKey q of + Just _ -> pure Nothing + _ -> writeTVar qVar q {senderKey = Just sKey} $> Just q addQueueNotifier :: QueueStore -> RecipientId -> NotifierId -> NtfPublicVerifyKey -> STM (Either ErrorType QueueRec) - addQueueNotifier store rId nId nKey = do - cs@QueueStoreData {queues, notifiers} <- readTVar store - if M.member nId notifiers - then pure $ Left DUPLICATE_ - else case M.lookup rId queues of - Nothing -> pure $ Left AUTH - Just q -> case notifier q of - Just _ -> pure $ Left AUTH + addQueueNotifier QueueStore {queues, notifiers} rId nId nKey = do + ifM (TM.member nId notifiers) (pure $ Left DUPLICATE_) $ + withQueue rId queues $ \qVar -> + readTVar qVar >>= \q -> case notifier q of + Just _ -> pure Nothing _ -> do - writeTVar store $ - cs - { queues = M.insert rId q {notifier = Just (nId, nKey)} queues, - notifiers = M.insert nId rId notifiers - } - pure $ Right q + writeTVar qVar q {notifier = Just (nId, nKey)} + TM.insert nId rId notifiers + pure $ Just q suspendQueue :: QueueStore -> RecipientId -> STM (Either ErrorType ()) - suspendQueue store rId = - updateQueues store rId $ \cs c -> - (Right (), cs {queues = M.insert rId c {status = QueueOff} (queues cs)}) + suspendQueue QueueStore {queues} rId = + withQueue rId queues $ \qVar -> modifyTVar' qVar (\q -> q {status = QueueOff}) $> Just () deleteQueue :: QueueStore -> RecipientId -> STM (Either ErrorType ()) - deleteQueue store rId = - updateQueues store rId $ \cs c -> - ( Right (), - cs - { queues = M.delete rId (queues cs), - senders = M.delete (senderId c) (senders cs) - } - ) + deleteQueue QueueStore {queues, senders, notifiers} rId = do + TM.lookupDelete rId queues >>= \case + Just qVar -> + readTVar qVar >>= \q -> do + TM.delete (senderId q) senders + forM_ (notifier q) $ \(nId, _) -> TM.delete nId notifiers + pure $ Right () + _ -> pure $ Left AUTH -updateQueues :: - QueueStore -> - RecipientId -> - (QueueStoreData -> QueueRec -> (Either ErrorType a, QueueStoreData)) -> - STM (Either ErrorType a) -updateQueues store rId update = do - cs <- readTVar store - let conn = getRcpQueue cs rId - either (return . Left) (_update cs) conn - where - _update cs c = do - let (res, cs') = update cs c - writeTVar store cs' - return res +toResult :: Maybe a -> Either ErrorType a +toResult = maybe (Left AUTH) Right -getRcpQueue :: QueueStoreData -> RecipientId -> Either ErrorType QueueRec -getRcpQueue cs rId = maybe (Left AUTH) Right . M.lookup rId $ queues cs +withQueue :: RecipientId -> TMap RecipientId (TVar QueueRec) -> (TVar QueueRec -> STM (Maybe a)) -> STM (Either ErrorType a) +withQueue rId queues f = toResult <$> (TM.lookup rId queues >>= fmap join . mapM f) diff --git a/src/Simplex/Messaging/TMap.hs b/src/Simplex/Messaging/TMap.hs index de9b293f0..012adde4b 100644 --- a/src/Simplex/Messaging/TMap.hs +++ b/src/Simplex/Messaging/TMap.hs @@ -2,6 +2,7 @@ module Simplex.Messaging.TMap ( TMap (..), empty, Simplex.Messaging.TMap.lookup, + member, insert, delete, lookupInsert, @@ -26,6 +27,10 @@ lookup :: Ord k => k -> TMap k a -> STM (Maybe a) lookup k (TMap m) = M.lookup k <$> readTVar m {-# INLINE lookup #-} +member :: Ord k => k -> TMap k a -> STM Bool +member k (TMap m) = M.member k <$> readTVar m +{-# INLINE member #-} + insert :: Ord k => k -> a -> TMap k a -> STM () insert k v (TMap m) = modifyTVar' m $ M.insert k v {-# INLINE insert #-}