From 1f126972799fcd8aa86e0ebbd025bcab6f8367c2 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Thu, 12 Jan 2023 14:59:46 +0000 Subject: [PATCH] strict writes to STM, remove type class (#600) --- src/Simplex/Messaging/Agent.hs | 2 +- src/Simplex/Messaging/Agent/Client.hs | 10 +- src/Simplex/Messaging/Crypto.hs | 6 +- src/Simplex/Messaging/Notifications/Server.hs | 2 +- .../Messaging/Notifications/Server/Stats.hs | 16 +-- src/Simplex/Messaging/Server.hs | 20 +-- src/Simplex/Messaging/Server/Env/STM.hs | 4 +- src/Simplex/Messaging/Server/MsgStore.hs | 17 +-- src/Simplex/Messaging/Server/MsgStore/STM.hs | 131 +++++++++--------- src/Simplex/Messaging/Server/QueueStore.hs | 29 ++-- .../Messaging/Server/QueueStore/STM.hs | 119 ++++++++-------- src/Simplex/Messaging/Server/Stats.hs | 20 +-- src/Simplex/Messaging/Transport.hs | 4 +- 13 files changed, 183 insertions(+), 197 deletions(-) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 444a9a63b..632610690 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -905,7 +905,7 @@ runCommandProcessing c@AgentClient {subQ} server_ = do atomically $ do srvs <- readTVar $ smpServers c let used' = if length used + 1 >= L.length srvs then initUsed else srv : used - writeTVar usedSrvs used' + writeTVar usedSrvs $! used' action srvAuth -- ^ ^ ^ async command processing / diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 9a3642767..a94d866d0 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -607,7 +607,7 @@ subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> m () subscribeQueue c rq@RcvQueue {connId, server, rcvPrivateKey, rcvId} = do whenM (atomically . TM.member (server, rcvId) $ getMsgLocks c) . throwError $ CMD PROHIBITED atomically $ do - modifyTVar (subscrConns c) $ S.insert connId + modifyTVar' (subscrConns c) $ S.insert connId RQ.addQueue rq $ pendingSubs c withLogClient c server rcvId "SUB" $ \smp -> liftIO (runExceptT (subscribeSMPQueue smp rcvPrivateKey rcvId) >>= processSubResult c rq) @@ -644,7 +644,7 @@ subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> [RcvQueue] -> m ( subscribeQueues c srv qs = do (errs, qs_) <- partitionEithers <$> mapM checkQueue qs forM_ qs_ $ \rq@RcvQueue {connId} -> atomically $ do - modifyTVar (subscrConns c) $ S.insert connId + modifyTVar' (subscrConns c) $ S.insert connId RQ.addQueue rq $ pendingSubs c case L.nonEmpty qs_ of Just qs' -> do @@ -671,7 +671,7 @@ subscribeQueues c srv qs = do addSubscription :: MonadIO m => AgentClient -> RcvQueue -> m () addSubscription c rq@RcvQueue {connId} = atomically $ do - modifyTVar (subscrConns c) $ S.insert connId + modifyTVar' (subscrConns c) $ S.insert connId RQ.addQueue rq $ activeSubs c RQ.deleteQueue rq $ pendingSubs c @@ -680,7 +680,7 @@ hasActiveSubscription c connId = RQ.hasConn connId $ activeSubs c removeSubscription :: AgentClient -> ConnId -> STM () removeSubscription c connId = do - modifyTVar (subscrConns c) $ S.delete connId + modifyTVar' (subscrConns c) $ S.delete connId RQ.deleteConn connId $ activeSubs c RQ.deleteConn connId $ pendingSubs c @@ -945,7 +945,7 @@ storeError = \case incStat :: AgentClient -> Int -> AgentStatsKey -> STM () incStat AgentClient {agentStats} n k = do TM.lookup k agentStats >>= \case - Just v -> modifyTVar v (+ n) + Just v -> modifyTVar' v (+ n) _ -> newTVar n >>= \v -> TM.insert k v agentStats incClientStat :: AgentClient -> ProtocolClient msg -> ByteString -> ByteString -> IO () diff --git a/src/Simplex/Messaging/Crypto.hs b/src/Simplex/Messaging/Crypto.hs index bfbd6b584..6b56a7847 100644 --- a/src/Simplex/Messaging/Crypto.hs +++ b/src/Simplex/Messaging/Crypto.hs @@ -964,11 +964,7 @@ pseudoRandomCbNonce :: TVar ChaChaDRG -> STM CbNonce pseudoRandomCbNonce gVar = CbNonce <$> pseudoRandomBytes 24 gVar pseudoRandomBytes :: Int -> TVar ChaChaDRG -> STM ByteString -pseudoRandomBytes n gVar = do - g <- readTVar gVar - let (bytes, g') = randomBytesGenerate n g - writeTVar gVar g' - return bytes +pseudoRandomBytes n gVar = stateTVar gVar $ randomBytesGenerate n instance Encoding CbNonce where smpEncode = unCbNonce diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index 43b0d33de..80069a534 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -548,7 +548,7 @@ withNtfLog action = liftIO . mapM_ action =<< asks storeLog incNtfStat :: (NtfServerStats -> TVar Int) -> M () incNtfStat statSel = do stats <- asks serverStats - atomically $ modifyTVar (statSel stats) (+ 1) + atomically $ modifyTVar' (statSel stats) (+ 1) saveServerStats :: M () saveServerStats = diff --git a/src/Simplex/Messaging/Notifications/Server/Stats.hs b/src/Simplex/Messaging/Notifications/Server/Stats.hs index 6af4b0611..10703d284 100644 --- a/src/Simplex/Messaging/Notifications/Server/Stats.hs +++ b/src/Simplex/Messaging/Notifications/Server/Stats.hs @@ -70,14 +70,14 @@ getNtfServerStatsData s = do setNtfServerStats :: NtfServerStats -> NtfServerStatsData -> STM () setNtfServerStats s d = do - writeTVar (fromTime (s :: NtfServerStats)) (_fromTime (d :: NtfServerStatsData)) - writeTVar (tknCreated s) (_tknCreated d) - writeTVar (tknVerified s) (_tknVerified d) - writeTVar (tknDeleted s) (_tknDeleted d) - writeTVar (subCreated s) (_subCreated d) - writeTVar (subDeleted s) (_subDeleted d) - writeTVar (ntfReceived s) (_ntfReceived d) - writeTVar (ntfDelivered s) (_ntfDelivered d) + writeTVar (fromTime (s :: NtfServerStats)) $! _fromTime (d :: NtfServerStatsData) + writeTVar (tknCreated s) $! _tknCreated d + writeTVar (tknVerified s) $! _tknVerified d + writeTVar (tknDeleted s) $! _tknDeleted d + writeTVar (subCreated s) $! _subCreated d + writeTVar (subDeleted s) $! _subDeleted d + writeTVar (ntfReceived s) $! _ntfReceived d + writeTVar (ntfDelivered s) $! _ntfDelivered d setPeriodStats (activeTokens s) (_activeTokens d) setPeriodStats (activeSubs s) (_activeSubs d) diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 52c3558ed..9f5c24a9f 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -65,9 +65,9 @@ import Simplex.Messaging.Protocol import Simplex.Messaging.Server.Env.STM import Simplex.Messaging.Server.Expiration import Simplex.Messaging.Server.MsgStore -import Simplex.Messaging.Server.MsgStore.STM (MsgQueue) +import Simplex.Messaging.Server.MsgStore.STM import Simplex.Messaging.Server.QueueStore -import Simplex.Messaging.Server.QueueStore.STM (QueueStore) +import Simplex.Messaging.Server.QueueStore.STM import Simplex.Messaging.Server.Stats import Simplex.Messaging.Server.StoreLog import Simplex.Messaging.TMap (TMap) @@ -386,7 +386,7 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv Right _ -> do withLog (`logCreateById` rId) stats <- asks serverStats - atomically $ modifyTVar (qCreated stats) (+ 1) + atomically $ modifyTVar' (qCreated stats) (+ 1) subscribeQueue qr rId $> IDS (qik ids) logCreateById :: StoreLog 'WriteMode -> RecipientId -> IO () @@ -404,7 +404,7 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv secureQueue_ st sKey = time "KEY" $ do withLog $ \s -> logSecureQueue s queueId sKey stats <- asks serverStats - atomically $ modifyTVar (qSecured stats) (+ 1) + atomically $ modifyTVar' (qSecured stats) (+ 1) atomically $ (corrId,queueId,) . either ERR (const OK) <$> secureQueue st queueId sKey addQueueNotifier_ :: QueueStore -> NtfPublicVerifyKey -> RcvNtfPublicDhKey -> m (Transmission BrokerMsg) @@ -528,7 +528,7 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv updateStats :: m () updateStats = do stats <- asks serverStats - atomically $ modifyTVar (msgRecv stats) (+ 1) + atomically $ modifyTVar' (msgRecv stats) (+ 1) atomically $ updatePeriodStats (activeQueues stats) queueId sendMessage :: QueueRec -> MsgFlags -> MsgBody -> m (Transmission BrokerMsg) @@ -550,7 +550,7 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv when (notification msgFlags) $ atomically . trySendNotification msg =<< asks idsDrg stats <- asks serverStats - atomically $ modifyTVar (msgSent stats) (+ 1) + atomically $ modifyTVar' (msgSent stats) (+ 1) atomically $ updatePeriodStats (activeQueues stats) (recipientId qr) pure ok where @@ -599,9 +599,9 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv where forkSub :: m () forkSub = do - atomically . modifyTVar sub $ \s -> s {subThread = SubPending} + atomically . modifyTVar' sub $ \s -> s {subThread = SubPending} t <- mkWeakThreadId =<< forkIO subscriber - atomically . modifyTVar sub $ \case + atomically . modifyTVar' sub $ \case s@Sub {subThread = SubPending} -> s {subThread = SubThread t} s -> s where @@ -612,7 +612,7 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv writeTBQueue sndQ [(CorrId "", rId, MSG encMsg)] s <- readTVar sub void $ setDelivered s msg - writeTVar sub s {subThread = NoSub} + writeTVar sub $! s {subThread = NoSub} time :: T.Text -> m a -> m a time name = timed name queueId @@ -646,7 +646,7 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv withLog (`logDeleteQueue` queueId) ms <- asks msgStore stats <- asks serverStats - atomically $ modifyTVar (qDeleted stats) (+ 1) + atomically $ modifyTVar' (qDeleted stats) (+ 1) atomically $ deleteQueue st queueId >>= \case Left e -> pure $ err e diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index 644136571..15253fec8 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -164,8 +164,8 @@ newEnv config@ServerConfig {caCertificateFile, certificateFile, privateKeyFile, (qs, s') <- liftIO $ readWriteStoreLog s atomically $ do writeTVar queues =<< mapM newTVar qs - writeTVar senders $ M.foldr' addSender M.empty qs - writeTVar notifiers $ M.foldr' addNotifier M.empty qs + writeTVar senders $! M.foldr' addSender M.empty qs + writeTVar notifiers $! M.foldr' addNotifier M.empty qs pure s' addSender :: QueueRec -> Map SenderId RecipientId -> Map SenderId RecipientId addSender q = M.insert (senderId q) (recipientId q) diff --git a/src/Simplex/Messaging/Server/MsgStore.hs b/src/Simplex/Messaging/Server/MsgStore.hs index 37f5822d3..55a9c5499 100644 --- a/src/Simplex/Messaging/Server/MsgStore.hs +++ b/src/Simplex/Messaging/Server/MsgStore.hs @@ -1,13 +1,11 @@ -{-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} module Simplex.Messaging.Server.MsgStore where import Control.Applicative ((<|>)) -import Data.Int (Int64) import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Protocol (Message (..), MsgId, RcvMessage (..), RecipientId) +import Simplex.Messaging.Protocol (Message (..), RcvMessage (..), RecipientId) data MsgLogRecord = MLRv3 RecipientId Message | MLRv1 RecipientId RcvMessage @@ -16,16 +14,3 @@ instance StrEncoding MsgLogRecord where MLRv3 rId msg -> strEncode (Str "v3", rId, msg) MLRv1 rId msg -> strEncode (rId, msg) strP = "v3 " *> (MLRv3 <$> strP_ <*> strP) <|> MLRv1 <$> strP_ <*> strP - -class MonadMsgStore s q m | s -> q where - getMsgQueue :: s -> RecipientId -> Int -> m q - delMsgQueue :: s -> RecipientId -> m () - flushMsgQueue :: s -> RecipientId -> m [Message] - -class MonadMsgQueue q m where - writeMsg :: q -> Message -> m (Maybe Message) -- non blocking - tryPeekMsg :: q -> m (Maybe Message) -- non blocking - peekMsg :: q -> m Message -- blocking - tryDelMsg :: q -> MsgId -> m Bool -- non blocking - tryDelPeekMsg :: q -> MsgId -> m (Bool, Maybe Message) -- atomic delete (== read) last and peek next message, if available - deleteExpiredMsgs :: q -> Int64 -> m () diff --git a/src/Simplex/Messaging/Server/MsgStore/STM.hs b/src/Simplex/Messaging/Server/MsgStore/STM.hs index 4e27e599f..4ecf6d152 100644 --- a/src/Simplex/Messaging/Server/MsgStore/STM.hs +++ b/src/Simplex/Messaging/Server/MsgStore/STM.hs @@ -1,7 +1,6 @@ {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE InstanceSigs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE NamedFieldPuns #-} @@ -11,6 +10,15 @@ module Simplex.Messaging.Server.MsgStore.STM ( STMMsgStore, MsgQueue, newMsgStore, + getMsgQueue, + delMsgQueue, + flushMsgQueue, + writeMsg, + tryPeekMsg, + peekMsg, + tryDelMsg, + tryDelPeekMsg, + deleteExpiredMsgs, ) where @@ -21,7 +29,6 @@ import Data.Functor (($>)) import Data.Int (Int64) import Data.Time.Clock.System (SystemTime (systemSeconds)) import Simplex.Messaging.Protocol (Message (..), MsgId, RecipientId) -import Simplex.Messaging.Server.MsgStore import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import UnliftIO.STM @@ -38,75 +45,73 @@ type STMMsgStore = TMap RecipientId MsgQueue newMsgStore :: STM STMMsgStore newMsgStore = TM.empty -instance MonadMsgStore STMMsgStore MsgQueue STM where - getMsgQueue :: STMMsgStore -> RecipientId -> Int -> STM MsgQueue - getMsgQueue st rId quota = maybe newQ pure =<< TM.lookup rId st - where - newQ = do - msgQueue <- newTQueue - canWrite <- newTVar True - size <- newTVar 0 - let q = MsgQueue {msgQueue, quota, canWrite, size} - TM.insert rId q st - pure q +getMsgQueue :: STMMsgStore -> RecipientId -> Int -> STM MsgQueue +getMsgQueue st rId quota = maybe newQ pure =<< TM.lookup rId st + where + newQ = do + msgQueue <- newTQueue + canWrite <- newTVar True + size <- newTVar 0 + let q = MsgQueue {msgQueue, quota, canWrite, size} + TM.insert rId q st + pure q - delMsgQueue :: STMMsgStore -> RecipientId -> STM () - delMsgQueue st rId = TM.delete rId st +delMsgQueue :: STMMsgStore -> RecipientId -> STM () +delMsgQueue st rId = TM.delete rId st - flushMsgQueue :: STMMsgStore -> RecipientId -> STM [Message] - flushMsgQueue st rId = TM.lookupDelete rId st >>= maybe (pure []) (flushTQueue . msgQueue) +flushMsgQueue :: STMMsgStore -> RecipientId -> STM [Message] +flushMsgQueue st rId = TM.lookupDelete rId st >>= maybe (pure []) (flushTQueue . msgQueue) -instance MonadMsgQueue MsgQueue STM where - writeMsg :: MsgQueue -> Message -> STM (Maybe Message) - writeMsg MsgQueue {msgQueue = q, quota, canWrite, size} msg = do - canWrt <- readTVar canWrite - empty <- isEmptyTQueue q - if canWrt || empty - then do - canWrt' <- (quota >) <$> readTVar size - writeTVar canWrite canWrt' - modifyTVar' size (+ 1) - if canWrt' - then writeTQueue q msg $> Just msg - else writeTQueue q msgQuota $> Nothing - else pure Nothing - where - msgQuota = MessageQuota {msgId = msgId msg, msgTs = msgTs msg} +writeMsg :: MsgQueue -> Message -> STM (Maybe Message) +writeMsg MsgQueue {msgQueue = q, quota, canWrite, size} msg = do + canWrt <- readTVar canWrite + empty <- isEmptyTQueue q + if canWrt || empty + then do + canWrt' <- (quota >) <$> readTVar size + writeTVar canWrite $! canWrt' + modifyTVar' size (+ 1) + if canWrt' + then writeTQueue q msg $> Just msg + else writeTQueue q msgQuota $> Nothing + else pure Nothing + where + msgQuota = MessageQuota {msgId = msgId msg, msgTs = msgTs msg} - tryPeekMsg :: MsgQueue -> STM (Maybe Message) - tryPeekMsg = tryPeekTQueue . msgQueue - {-# INLINE tryPeekMsg #-} +tryPeekMsg :: MsgQueue -> STM (Maybe Message) +tryPeekMsg = tryPeekTQueue . msgQueue +{-# INLINE tryPeekMsg #-} - peekMsg :: MsgQueue -> STM Message - peekMsg = peekTQueue . msgQueue - {-# INLINE peekMsg #-} +peekMsg :: MsgQueue -> STM Message +peekMsg = peekTQueue . msgQueue +{-# INLINE peekMsg #-} - tryDelMsg :: MsgQueue -> MsgId -> STM Bool - tryDelMsg mq msgId' = - tryPeekMsg mq >>= \case - Just msg - | msgId msg == msgId' || B.null msgId' -> tryDeleteMsg mq >> pure True - | otherwise -> pure False - _ -> pure False +tryDelMsg :: MsgQueue -> MsgId -> STM Bool +tryDelMsg mq msgId' = + tryPeekMsg mq >>= \case + Just msg + | msgId msg == msgId' || B.null msgId' -> tryDeleteMsg mq >> pure True + | otherwise -> pure False + _ -> pure False - -- atomic delete (== read) last and peek next message if available - tryDelPeekMsg :: MsgQueue -> MsgId -> STM (Bool, Maybe Message) - tryDelPeekMsg mq msgId' = - tryPeekMsg mq >>= \case - msg_@(Just msg) - | msgId msg == msgId' || B.null msgId' -> (True,) <$> (tryDeleteMsg mq >> tryPeekMsg mq) - | otherwise -> pure (False, msg_) - _ -> pure (False, Nothing) +-- atomic delete (== read) last and peek next message if available +tryDelPeekMsg :: MsgQueue -> MsgId -> STM (Bool, Maybe Message) +tryDelPeekMsg mq msgId' = + tryPeekMsg mq >>= \case + msg_@(Just msg) + | msgId msg == msgId' || B.null msgId' -> (True,) <$> (tryDeleteMsg mq >> tryPeekMsg mq) + | otherwise -> pure (False, msg_) + _ -> pure (False, Nothing) - deleteExpiredMsgs :: MsgQueue -> Int64 -> STM () - deleteExpiredMsgs mq old = loop - where - loop = tryPeekMsg mq >>= mapM_ delOldMsg - delOldMsg = \case - Message {msgTs} -> - when (systemSeconds msgTs < old) $ - tryDeleteMsg mq >> loop - _ -> pure () +deleteExpiredMsgs :: MsgQueue -> Int64 -> STM () +deleteExpiredMsgs mq old = loop + where + loop = tryPeekMsg mq >>= mapM_ delOldMsg + delOldMsg = \case + Message {msgTs} -> + when (systemSeconds msgTs < old) $ + tryDeleteMsg mq >> loop + _ -> pure () tryDeleteMsg :: MsgQueue -> STM () tryDeleteMsg MsgQueue {msgQueue = q, size} = diff --git a/src/Simplex/Messaging/Server/QueueStore.hs b/src/Simplex/Messaging/Server/QueueStore.hs index c05a0d3a6..8a7856eb6 100644 --- a/src/Simplex/Messaging/Server/QueueStore.hs +++ b/src/Simplex/Messaging/Server/QueueStore.hs @@ -9,20 +9,20 @@ import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol data QueueRec = QueueRec - { recipientId :: RecipientId, - recipientKey :: RcvPublicVerifyKey, - rcvDhSecret :: RcvDhSecret, - senderId :: SenderId, - senderKey :: Maybe SndPublicVerifyKey, - notifier :: Maybe NtfCreds, - status :: ServerQueueStatus + { recipientId :: !RecipientId, + recipientKey :: !RcvPublicVerifyKey, + rcvDhSecret :: !RcvDhSecret, + senderId :: !SenderId, + senderKey :: !(Maybe SndPublicVerifyKey), + notifier :: !(Maybe NtfCreds), + status :: !ServerQueueStatus } deriving (Eq, Show) data NtfCreds = NtfCreds - { notifierId :: NotifierId, - notifierKey :: NtfPublicVerifyKey, - rcvNtfDhSecret :: RcvNtfDhSecret + { notifierId :: !NotifierId, + notifierKey :: !NtfPublicVerifyKey, + rcvNtfDhSecret :: !RcvNtfDhSecret } deriving (Eq, Show) @@ -33,12 +33,3 @@ instance StrEncoding NtfCreds where pure NtfCreds {notifierId, notifierKey, rcvNtfDhSecret} data ServerQueueStatus = QueueActive | QueueOff deriving (Eq, Show) - -class MonadQueueStore s m where - addQueue :: s -> QueueRec -> m (Either ErrorType ()) - getQueue :: s -> SParty p -> QueueId -> m (Either ErrorType QueueRec) - secureQueue :: s -> RecipientId -> SndPublicVerifyKey -> m (Either ErrorType QueueRec) - addQueueNotifier :: s -> RecipientId -> NtfCreds -> m (Either ErrorType QueueRec) - deleteQueueNotifier :: s -> RecipientId -> m (Either ErrorType ()) - suspendQueue :: s -> RecipientId -> m (Either ErrorType ()) - deleteQueue :: s -> RecipientId -> m (Either ErrorType ()) diff --git a/src/Simplex/Messaging/Server/QueueStore/STM.hs b/src/Simplex/Messaging/Server/QueueStore/STM.hs index 02375f168..b4c41c0de 100644 --- a/src/Simplex/Messaging/Server/QueueStore/STM.hs +++ b/src/Simplex/Messaging/Server/QueueStore/STM.hs @@ -1,7 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE InstanceSigs #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} @@ -10,7 +9,18 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE UndecidableInstances #-} -module Simplex.Messaging.Server.QueueStore.STM where +module Simplex.Messaging.Server.QueueStore.STM + ( QueueStore (..), + newQueueStore, + addQueue, + getQueue, + secureQueue, + addQueueNotifier, + deleteQueueNotifier, + suspendQueue, + deleteQueue, + ) +where import Control.Monad import Data.Functor (($>)) @@ -34,66 +44,65 @@ newQueueStore = do notifiers <- TM.empty pure QueueStore {queues, senders, notifiers} -instance MonadQueueStore QueueStore STM where - addQueue :: QueueStore -> QueueRec -> STM (Either ErrorType ()) - addQueue QueueStore {queues, senders} q@QueueRec {recipientId = rId, senderId = sId} = do - ifM hasId (pure $ Left DUPLICATE_) $ do - qVar <- newTVar q - TM.insert rId qVar queues - TM.insert sId rId senders - pure $ Right () - where - hasId = (||) <$> TM.member rId queues <*> TM.member sId senders +addQueue :: QueueStore -> QueueRec -> STM (Either ErrorType ()) +addQueue QueueStore {queues, senders} q@QueueRec {recipientId = rId, senderId = sId} = do + ifM hasId (pure $ Left DUPLICATE_) $ do + qVar <- newTVar q + TM.insert rId qVar queues + TM.insert sId rId senders + pure $ Right () + where + hasId = (||) <$> TM.member rId queues <*> TM.member sId senders - getQueue :: QueueStore -> SParty p -> QueueId -> STM (Either ErrorType QueueRec) - getQueue QueueStore {queues, senders, notifiers} party qId = - toResult <$> (mapM readTVar =<< getVar) - where - getVar = case party of - SRecipient -> TM.lookup qId queues - SSender -> TM.lookup qId senders $>>= (`TM.lookup` queues) - SNotifier -> TM.lookup qId notifiers $>>= (`TM.lookup` queues) +getQueue :: QueueStore -> SParty p -> QueueId -> STM (Either ErrorType QueueRec) +getQueue QueueStore {queues, senders, notifiers} party qId = + toResult <$> (mapM readTVar =<< getVar) + where + getVar = case party of + SRecipient -> TM.lookup qId queues + SSender -> TM.lookup qId senders $>>= (`TM.lookup` queues) + SNotifier -> TM.lookup qId notifiers $>>= (`TM.lookup` queues) - secureQueue :: QueueStore -> RecipientId -> SndPublicVerifyKey -> STM (Either ErrorType QueueRec) - secureQueue QueueStore {queues} rId sKey = - withQueue rId queues $ \qVar -> - readTVar qVar >>= \q -> case senderKey q of - Just k -> pure $ if sKey == k then Just q else Nothing - _ -> - let q' = q {senderKey = Just sKey} - in writeTVar qVar q' $> Just q' +secureQueue :: QueueStore -> RecipientId -> SndPublicVerifyKey -> STM (Either ErrorType QueueRec) +secureQueue QueueStore {queues} rId sKey = + withQueue rId queues $ \qVar -> + readTVar qVar >>= \q -> case senderKey q of + Just k -> pure $ if sKey == k then Just q else Nothing + _ -> + let q' = q {senderKey = Just sKey} + in writeTVar qVar q' $> Just q' - addQueueNotifier :: QueueStore -> RecipientId -> NtfCreds -> STM (Either ErrorType QueueRec) - addQueueNotifier QueueStore {queues, notifiers} rId ntfCreds@NtfCreds {notifierId = nId} = do - ifM (TM.member nId notifiers) (pure $ Left DUPLICATE_) $ - withQueue rId queues $ \qVar -> do - q <- readTVar qVar - forM_ (notifier q) $ (`TM.delete` notifiers) . notifierId - writeTVar qVar q {notifier = Just ntfCreds} - TM.insert nId rId notifiers - pure $ Just q - - deleteQueueNotifier :: QueueStore -> RecipientId -> STM (Either ErrorType ()) - deleteQueueNotifier QueueStore {queues, notifiers} rId = +addQueueNotifier :: QueueStore -> RecipientId -> NtfCreds -> STM (Either ErrorType QueueRec) +addQueueNotifier QueueStore {queues, notifiers} rId ntfCreds@NtfCreds {notifierId = nId} = do + ifM (TM.member nId notifiers) (pure $ Left DUPLICATE_) $ withQueue rId queues $ \qVar -> do q <- readTVar qVar - forM_ (notifier q) $ \NtfCreds {notifierId} -> TM.delete notifierId notifiers - writeTVar qVar q {notifier = Nothing} - pure $ Just () + forM_ (notifier q) $ (`TM.delete` notifiers) . notifierId + writeTVar qVar $! q {notifier = Just ntfCreds} + TM.insert nId rId notifiers + pure $ Just q - suspendQueue :: QueueStore -> RecipientId -> STM (Either ErrorType ()) - suspendQueue QueueStore {queues} rId = - withQueue rId queues $ \qVar -> modifyTVar' qVar (\q -> q {status = QueueOff}) $> Just () +deleteQueueNotifier :: QueueStore -> RecipientId -> STM (Either ErrorType ()) +deleteQueueNotifier QueueStore {queues, notifiers} rId = + withQueue rId queues $ \qVar -> do + q <- readTVar qVar + forM_ (notifier q) $ \NtfCreds {notifierId} -> TM.delete notifierId notifiers + writeTVar qVar $! q {notifier = Nothing} + pure $ Just () - deleteQueue :: QueueStore -> RecipientId -> STM (Either ErrorType ()) - deleteQueue QueueStore {queues, senders, notifiers} rId = do - TM.lookupDelete rId queues >>= \case - Just qVar -> - readTVar qVar >>= \q -> do - TM.delete (senderId q) senders - forM_ (notifier q) $ \NtfCreds {notifierId} -> TM.delete notifierId notifiers - pure $ Right () - _ -> pure $ Left AUTH +suspendQueue :: QueueStore -> RecipientId -> STM (Either ErrorType ()) +suspendQueue QueueStore {queues} rId = + withQueue rId queues $ \qVar -> modifyTVar' qVar (\q -> q {status = QueueOff}) $> Just () + +deleteQueue :: QueueStore -> RecipientId -> STM (Either ErrorType ()) +deleteQueue QueueStore {queues, senders, notifiers} rId = do + TM.lookupDelete rId queues >>= \case + Just qVar -> + readTVar qVar >>= \q -> do + TM.delete (senderId q) senders + forM_ (notifier q) $ \NtfCreds {notifierId} -> TM.delete notifierId notifiers + pure $ Right () + _ -> pure $ Left AUTH toResult :: Maybe a -> Either ErrorType a toResult = maybe (Left AUTH) Right diff --git a/src/Simplex/Messaging/Server/Stats.hs b/src/Simplex/Messaging/Server/Stats.hs index 44c1b97d7..3f08d5cb8 100644 --- a/src/Simplex/Messaging/Server/Stats.hs +++ b/src/Simplex/Messaging/Server/Stats.hs @@ -61,12 +61,12 @@ getServerStatsData s = do setServerStats :: ServerStats -> ServerStatsData -> STM () setServerStats s d = do - writeTVar (fromTime s) (_fromTime d) - writeTVar (qCreated s) (_qCreated d) - writeTVar (qSecured s) (_qSecured d) - writeTVar (qDeleted s) (_qDeleted d) - writeTVar (msgSent s) (_msgSent d) - writeTVar (msgRecv s) (_msgRecv d) + writeTVar (fromTime s) $! _fromTime d + writeTVar (qCreated s) $! _qCreated d + writeTVar (qSecured s) $! _qSecured d + writeTVar (qDeleted s) $! _qDeleted d + writeTVar (msgSent s) $! _msgSent d + writeTVar (msgRecv s) $! _msgRecv d setPeriodStats (activeQueues s) (_activeQueues d) instance StrEncoding ServerStatsData where @@ -126,9 +126,9 @@ getPeriodStatsData s = do setPeriodStats :: PeriodStats a -> PeriodStatsData a -> STM () setPeriodStats s d = do - writeTVar (day s) (_day d) - writeTVar (week s) (_week d) - writeTVar (month s) (_month d) + writeTVar (day s) $! _day d + writeTVar (week s) $! _week d + writeTVar (month s) $! _month d instance (Ord a, StrEncoding a) => StrEncoding (PeriodStatsData a) where strEncode PeriodStatsData {_day, _week, _month} = @@ -165,4 +165,4 @@ updatePeriodStats stats pId = do updatePeriod week updatePeriod month where - updatePeriod pSel = modifyTVar (pSel stats) (S.insert pId) + updatePeriod pSel = modifyTVar' (pSel stats) (S.insert pId) diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index 5f688f8f6..b33145324 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -214,7 +214,7 @@ instance Transport TLS where $ do b <- readChunks =<< readTVarIO buffer let (s, b') = B.splitAt n b - atomically $ writeTVar buffer b' + atomically $ writeTVar buffer $! b' pure s where readChunks :: ByteString -> IO ByteString @@ -237,7 +237,7 @@ instance Transport TLS where $ do b <- readChunks =<< readTVarIO buffer let (s, b') = B.break (== '\n') b - atomically $ writeTVar buffer (B.drop 1 b') -- drop '\n' we made a break at + atomically $ writeTVar buffer $! B.drop 1 b' -- drop '\n' we made a break at pure $ trimCR s where readChunks :: ByteString -> IO ByteString