strict writes to STM, remove type class (#600)

This commit is contained in:
Evgeny Poberezkin
2023-01-12 14:59:46 +00:00
committed by GitHub
parent 92a379e75c
commit 1f12697279
13 changed files with 183 additions and 197 deletions
+2 -2
View File
@@ -164,8 +164,8 @@ newEnv config@ServerConfig {caCertificateFile, certificateFile, privateKeyFile,
(qs, s') <- liftIO $ readWriteStoreLog s
atomically $ do
writeTVar queues =<< mapM newTVar qs
writeTVar senders $ M.foldr' addSender M.empty qs
writeTVar notifiers $ M.foldr' addNotifier M.empty qs
writeTVar senders $! M.foldr' addSender M.empty qs
writeTVar notifiers $! M.foldr' addNotifier M.empty qs
pure s'
addSender :: QueueRec -> Map SenderId RecipientId -> Map SenderId RecipientId
addSender q = M.insert (senderId q) (recipientId q)
+1 -16
View File
@@ -1,13 +1,11 @@
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module Simplex.Messaging.Server.MsgStore where
import Control.Applicative ((<|>))
import Data.Int (Int64)
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Protocol (Message (..), MsgId, RcvMessage (..), RecipientId)
import Simplex.Messaging.Protocol (Message (..), RcvMessage (..), RecipientId)
data MsgLogRecord = MLRv3 RecipientId Message | MLRv1 RecipientId RcvMessage
@@ -16,16 +14,3 @@ instance StrEncoding MsgLogRecord where
MLRv3 rId msg -> strEncode (Str "v3", rId, msg)
MLRv1 rId msg -> strEncode (rId, msg)
strP = "v3 " *> (MLRv3 <$> strP_ <*> strP) <|> MLRv1 <$> strP_ <*> strP
class MonadMsgStore s q m | s -> q where
getMsgQueue :: s -> RecipientId -> Int -> m q
delMsgQueue :: s -> RecipientId -> m ()
flushMsgQueue :: s -> RecipientId -> m [Message]
class MonadMsgQueue q m where
writeMsg :: q -> Message -> m (Maybe Message) -- non blocking
tryPeekMsg :: q -> m (Maybe Message) -- non blocking
peekMsg :: q -> m Message -- blocking
tryDelMsg :: q -> MsgId -> m Bool -- non blocking
tryDelPeekMsg :: q -> MsgId -> m (Bool, Maybe Message) -- atomic delete (== read) last and peek next message, if available
deleteExpiredMsgs :: q -> Int64 -> m ()
+68 -63
View File
@@ -1,7 +1,6 @@
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NamedFieldPuns #-}
@@ -11,6 +10,15 @@ module Simplex.Messaging.Server.MsgStore.STM
( STMMsgStore,
MsgQueue,
newMsgStore,
getMsgQueue,
delMsgQueue,
flushMsgQueue,
writeMsg,
tryPeekMsg,
peekMsg,
tryDelMsg,
tryDelPeekMsg,
deleteExpiredMsgs,
)
where
@@ -21,7 +29,6 @@ import Data.Functor (($>))
import Data.Int (Int64)
import Data.Time.Clock.System (SystemTime (systemSeconds))
import Simplex.Messaging.Protocol (Message (..), MsgId, RecipientId)
import Simplex.Messaging.Server.MsgStore
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import UnliftIO.STM
@@ -38,75 +45,73 @@ type STMMsgStore = TMap RecipientId MsgQueue
newMsgStore :: STM STMMsgStore
newMsgStore = TM.empty
instance MonadMsgStore STMMsgStore MsgQueue STM where
getMsgQueue :: STMMsgStore -> RecipientId -> Int -> STM MsgQueue
getMsgQueue st rId quota = maybe newQ pure =<< TM.lookup rId st
where
newQ = do
msgQueue <- newTQueue
canWrite <- newTVar True
size <- newTVar 0
let q = MsgQueue {msgQueue, quota, canWrite, size}
TM.insert rId q st
pure q
getMsgQueue :: STMMsgStore -> RecipientId -> Int -> STM MsgQueue
getMsgQueue st rId quota = maybe newQ pure =<< TM.lookup rId st
where
newQ = do
msgQueue <- newTQueue
canWrite <- newTVar True
size <- newTVar 0
let q = MsgQueue {msgQueue, quota, canWrite, size}
TM.insert rId q st
pure q
delMsgQueue :: STMMsgStore -> RecipientId -> STM ()
delMsgQueue st rId = TM.delete rId st
delMsgQueue :: STMMsgStore -> RecipientId -> STM ()
delMsgQueue st rId = TM.delete rId st
flushMsgQueue :: STMMsgStore -> RecipientId -> STM [Message]
flushMsgQueue st rId = TM.lookupDelete rId st >>= maybe (pure []) (flushTQueue . msgQueue)
flushMsgQueue :: STMMsgStore -> RecipientId -> STM [Message]
flushMsgQueue st rId = TM.lookupDelete rId st >>= maybe (pure []) (flushTQueue . msgQueue)
instance MonadMsgQueue MsgQueue STM where
writeMsg :: MsgQueue -> Message -> STM (Maybe Message)
writeMsg MsgQueue {msgQueue = q, quota, canWrite, size} msg = do
canWrt <- readTVar canWrite
empty <- isEmptyTQueue q
if canWrt || empty
then do
canWrt' <- (quota >) <$> readTVar size
writeTVar canWrite canWrt'
modifyTVar' size (+ 1)
if canWrt'
then writeTQueue q msg $> Just msg
else writeTQueue q msgQuota $> Nothing
else pure Nothing
where
msgQuota = MessageQuota {msgId = msgId msg, msgTs = msgTs msg}
writeMsg :: MsgQueue -> Message -> STM (Maybe Message)
writeMsg MsgQueue {msgQueue = q, quota, canWrite, size} msg = do
canWrt <- readTVar canWrite
empty <- isEmptyTQueue q
if canWrt || empty
then do
canWrt' <- (quota >) <$> readTVar size
writeTVar canWrite $! canWrt'
modifyTVar' size (+ 1)
if canWrt'
then writeTQueue q msg $> Just msg
else writeTQueue q msgQuota $> Nothing
else pure Nothing
where
msgQuota = MessageQuota {msgId = msgId msg, msgTs = msgTs msg}
tryPeekMsg :: MsgQueue -> STM (Maybe Message)
tryPeekMsg = tryPeekTQueue . msgQueue
{-# INLINE tryPeekMsg #-}
tryPeekMsg :: MsgQueue -> STM (Maybe Message)
tryPeekMsg = tryPeekTQueue . msgQueue
{-# INLINE tryPeekMsg #-}
peekMsg :: MsgQueue -> STM Message
peekMsg = peekTQueue . msgQueue
{-# INLINE peekMsg #-}
peekMsg :: MsgQueue -> STM Message
peekMsg = peekTQueue . msgQueue
{-# INLINE peekMsg #-}
tryDelMsg :: MsgQueue -> MsgId -> STM Bool
tryDelMsg mq msgId' =
tryPeekMsg mq >>= \case
Just msg
| msgId msg == msgId' || B.null msgId' -> tryDeleteMsg mq >> pure True
| otherwise -> pure False
_ -> pure False
tryDelMsg :: MsgQueue -> MsgId -> STM Bool
tryDelMsg mq msgId' =
tryPeekMsg mq >>= \case
Just msg
| msgId msg == msgId' || B.null msgId' -> tryDeleteMsg mq >> pure True
| otherwise -> pure False
_ -> pure False
-- atomic delete (== read) last and peek next message if available
tryDelPeekMsg :: MsgQueue -> MsgId -> STM (Bool, Maybe Message)
tryDelPeekMsg mq msgId' =
tryPeekMsg mq >>= \case
msg_@(Just msg)
| msgId msg == msgId' || B.null msgId' -> (True,) <$> (tryDeleteMsg mq >> tryPeekMsg mq)
| otherwise -> pure (False, msg_)
_ -> pure (False, Nothing)
-- atomic delete (== read) last and peek next message if available
tryDelPeekMsg :: MsgQueue -> MsgId -> STM (Bool, Maybe Message)
tryDelPeekMsg mq msgId' =
tryPeekMsg mq >>= \case
msg_@(Just msg)
| msgId msg == msgId' || B.null msgId' -> (True,) <$> (tryDeleteMsg mq >> tryPeekMsg mq)
| otherwise -> pure (False, msg_)
_ -> pure (False, Nothing)
deleteExpiredMsgs :: MsgQueue -> Int64 -> STM ()
deleteExpiredMsgs mq old = loop
where
loop = tryPeekMsg mq >>= mapM_ delOldMsg
delOldMsg = \case
Message {msgTs} ->
when (systemSeconds msgTs < old) $
tryDeleteMsg mq >> loop
_ -> pure ()
deleteExpiredMsgs :: MsgQueue -> Int64 -> STM ()
deleteExpiredMsgs mq old = loop
where
loop = tryPeekMsg mq >>= mapM_ delOldMsg
delOldMsg = \case
Message {msgTs} ->
when (systemSeconds msgTs < old) $
tryDeleteMsg mq >> loop
_ -> pure ()
tryDeleteMsg :: MsgQueue -> STM ()
tryDeleteMsg MsgQueue {msgQueue = q, size} =
+10 -19
View File
@@ -9,20 +9,20 @@ import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Protocol
data QueueRec = QueueRec
{ recipientId :: RecipientId,
recipientKey :: RcvPublicVerifyKey,
rcvDhSecret :: RcvDhSecret,
senderId :: SenderId,
senderKey :: Maybe SndPublicVerifyKey,
notifier :: Maybe NtfCreds,
status :: ServerQueueStatus
{ recipientId :: !RecipientId,
recipientKey :: !RcvPublicVerifyKey,
rcvDhSecret :: !RcvDhSecret,
senderId :: !SenderId,
senderKey :: !(Maybe SndPublicVerifyKey),
notifier :: !(Maybe NtfCreds),
status :: !ServerQueueStatus
}
deriving (Eq, Show)
data NtfCreds = NtfCreds
{ notifierId :: NotifierId,
notifierKey :: NtfPublicVerifyKey,
rcvNtfDhSecret :: RcvNtfDhSecret
{ notifierId :: !NotifierId,
notifierKey :: !NtfPublicVerifyKey,
rcvNtfDhSecret :: !RcvNtfDhSecret
}
deriving (Eq, Show)
@@ -33,12 +33,3 @@ instance StrEncoding NtfCreds where
pure NtfCreds {notifierId, notifierKey, rcvNtfDhSecret}
data ServerQueueStatus = QueueActive | QueueOff deriving (Eq, Show)
class MonadQueueStore s m where
addQueue :: s -> QueueRec -> m (Either ErrorType ())
getQueue :: s -> SParty p -> QueueId -> m (Either ErrorType QueueRec)
secureQueue :: s -> RecipientId -> SndPublicVerifyKey -> m (Either ErrorType QueueRec)
addQueueNotifier :: s -> RecipientId -> NtfCreds -> m (Either ErrorType QueueRec)
deleteQueueNotifier :: s -> RecipientId -> m (Either ErrorType ())
suspendQueue :: s -> RecipientId -> m (Either ErrorType ())
deleteQueue :: s -> RecipientId -> m (Either ErrorType ())
+64 -55
View File
@@ -1,7 +1,6 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
@@ -10,7 +9,18 @@
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UndecidableInstances #-}
module Simplex.Messaging.Server.QueueStore.STM where
module Simplex.Messaging.Server.QueueStore.STM
( QueueStore (..),
newQueueStore,
addQueue,
getQueue,
secureQueue,
addQueueNotifier,
deleteQueueNotifier,
suspendQueue,
deleteQueue,
)
where
import Control.Monad
import Data.Functor (($>))
@@ -34,66 +44,65 @@ newQueueStore = do
notifiers <- TM.empty
pure QueueStore {queues, senders, notifiers}
instance MonadQueueStore QueueStore STM where
addQueue :: QueueStore -> QueueRec -> STM (Either ErrorType ())
addQueue QueueStore {queues, senders} q@QueueRec {recipientId = rId, senderId = sId} = do
ifM hasId (pure $ Left DUPLICATE_) $ do
qVar <- newTVar q
TM.insert rId qVar queues
TM.insert sId rId senders
pure $ Right ()
where
hasId = (||) <$> TM.member rId queues <*> TM.member sId senders
addQueue :: QueueStore -> QueueRec -> STM (Either ErrorType ())
addQueue QueueStore {queues, senders} q@QueueRec {recipientId = rId, senderId = sId} = do
ifM hasId (pure $ Left DUPLICATE_) $ do
qVar <- newTVar q
TM.insert rId qVar queues
TM.insert sId rId senders
pure $ Right ()
where
hasId = (||) <$> TM.member rId queues <*> TM.member sId senders
getQueue :: QueueStore -> SParty p -> QueueId -> STM (Either ErrorType QueueRec)
getQueue QueueStore {queues, senders, notifiers} party qId =
toResult <$> (mapM readTVar =<< getVar)
where
getVar = case party of
SRecipient -> TM.lookup qId queues
SSender -> TM.lookup qId senders $>>= (`TM.lookup` queues)
SNotifier -> TM.lookup qId notifiers $>>= (`TM.lookup` queues)
getQueue :: QueueStore -> SParty p -> QueueId -> STM (Either ErrorType QueueRec)
getQueue QueueStore {queues, senders, notifiers} party qId =
toResult <$> (mapM readTVar =<< getVar)
where
getVar = case party of
SRecipient -> TM.lookup qId queues
SSender -> TM.lookup qId senders $>>= (`TM.lookup` queues)
SNotifier -> TM.lookup qId notifiers $>>= (`TM.lookup` queues)
secureQueue :: QueueStore -> RecipientId -> SndPublicVerifyKey -> STM (Either ErrorType QueueRec)
secureQueue QueueStore {queues} rId sKey =
withQueue rId queues $ \qVar ->
readTVar qVar >>= \q -> case senderKey q of
Just k -> pure $ if sKey == k then Just q else Nothing
_ ->
let q' = q {senderKey = Just sKey}
in writeTVar qVar q' $> Just q'
secureQueue :: QueueStore -> RecipientId -> SndPublicVerifyKey -> STM (Either ErrorType QueueRec)
secureQueue QueueStore {queues} rId sKey =
withQueue rId queues $ \qVar ->
readTVar qVar >>= \q -> case senderKey q of
Just k -> pure $ if sKey == k then Just q else Nothing
_ ->
let q' = q {senderKey = Just sKey}
in writeTVar qVar q' $> Just q'
addQueueNotifier :: QueueStore -> RecipientId -> NtfCreds -> STM (Either ErrorType QueueRec)
addQueueNotifier QueueStore {queues, notifiers} rId ntfCreds@NtfCreds {notifierId = nId} = do
ifM (TM.member nId notifiers) (pure $ Left DUPLICATE_) $
withQueue rId queues $ \qVar -> do
q <- readTVar qVar
forM_ (notifier q) $ (`TM.delete` notifiers) . notifierId
writeTVar qVar q {notifier = Just ntfCreds}
TM.insert nId rId notifiers
pure $ Just q
deleteQueueNotifier :: QueueStore -> RecipientId -> STM (Either ErrorType ())
deleteQueueNotifier QueueStore {queues, notifiers} rId =
addQueueNotifier :: QueueStore -> RecipientId -> NtfCreds -> STM (Either ErrorType QueueRec)
addQueueNotifier QueueStore {queues, notifiers} rId ntfCreds@NtfCreds {notifierId = nId} = do
ifM (TM.member nId notifiers) (pure $ Left DUPLICATE_) $
withQueue rId queues $ \qVar -> do
q <- readTVar qVar
forM_ (notifier q) $ \NtfCreds {notifierId} -> TM.delete notifierId notifiers
writeTVar qVar q {notifier = Nothing}
pure $ Just ()
forM_ (notifier q) $ (`TM.delete` notifiers) . notifierId
writeTVar qVar $! q {notifier = Just ntfCreds}
TM.insert nId rId notifiers
pure $ Just q
suspendQueue :: QueueStore -> RecipientId -> STM (Either ErrorType ())
suspendQueue QueueStore {queues} rId =
withQueue rId queues $ \qVar -> modifyTVar' qVar (\q -> q {status = QueueOff}) $> Just ()
deleteQueueNotifier :: QueueStore -> RecipientId -> STM (Either ErrorType ())
deleteQueueNotifier QueueStore {queues, notifiers} rId =
withQueue rId queues $ \qVar -> do
q <- readTVar qVar
forM_ (notifier q) $ \NtfCreds {notifierId} -> TM.delete notifierId notifiers
writeTVar qVar $! q {notifier = Nothing}
pure $ Just ()
deleteQueue :: QueueStore -> RecipientId -> STM (Either ErrorType ())
deleteQueue QueueStore {queues, senders, notifiers} rId = do
TM.lookupDelete rId queues >>= \case
Just qVar ->
readTVar qVar >>= \q -> do
TM.delete (senderId q) senders
forM_ (notifier q) $ \NtfCreds {notifierId} -> TM.delete notifierId notifiers
pure $ Right ()
_ -> pure $ Left AUTH
suspendQueue :: QueueStore -> RecipientId -> STM (Either ErrorType ())
suspendQueue QueueStore {queues} rId =
withQueue rId queues $ \qVar -> modifyTVar' qVar (\q -> q {status = QueueOff}) $> Just ()
deleteQueue :: QueueStore -> RecipientId -> STM (Either ErrorType ())
deleteQueue QueueStore {queues, senders, notifiers} rId = do
TM.lookupDelete rId queues >>= \case
Just qVar ->
readTVar qVar >>= \q -> do
TM.delete (senderId q) senders
forM_ (notifier q) $ \NtfCreds {notifierId} -> TM.delete notifierId notifiers
pure $ Right ()
_ -> pure $ Left AUTH
toResult :: Maybe a -> Either ErrorType a
toResult = maybe (Left AUTH) Right
+10 -10
View File
@@ -61,12 +61,12 @@ getServerStatsData s = do
setServerStats :: ServerStats -> ServerStatsData -> STM ()
setServerStats s d = do
writeTVar (fromTime s) (_fromTime d)
writeTVar (qCreated s) (_qCreated d)
writeTVar (qSecured s) (_qSecured d)
writeTVar (qDeleted s) (_qDeleted d)
writeTVar (msgSent s) (_msgSent d)
writeTVar (msgRecv s) (_msgRecv d)
writeTVar (fromTime s) $! _fromTime d
writeTVar (qCreated s) $! _qCreated d
writeTVar (qSecured s) $! _qSecured d
writeTVar (qDeleted s) $! _qDeleted d
writeTVar (msgSent s) $! _msgSent d
writeTVar (msgRecv s) $! _msgRecv d
setPeriodStats (activeQueues s) (_activeQueues d)
instance StrEncoding ServerStatsData where
@@ -126,9 +126,9 @@ getPeriodStatsData s = do
setPeriodStats :: PeriodStats a -> PeriodStatsData a -> STM ()
setPeriodStats s d = do
writeTVar (day s) (_day d)
writeTVar (week s) (_week d)
writeTVar (month s) (_month d)
writeTVar (day s) $! _day d
writeTVar (week s) $! _week d
writeTVar (month s) $! _month d
instance (Ord a, StrEncoding a) => StrEncoding (PeriodStatsData a) where
strEncode PeriodStatsData {_day, _week, _month} =
@@ -165,4 +165,4 @@ updatePeriodStats stats pId = do
updatePeriod week
updatePeriod month
where
updatePeriod pSel = modifyTVar (pSel stats) (S.insert pId)
updatePeriod pSel = modifyTVar' (pSel stats) (S.insert pId)