mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-03-31 20:36:22 +00:00
strict writes to STM, remove type class (#600)
This commit is contained in:
committed by
GitHub
parent
92a379e75c
commit
1f12697279
@@ -905,7 +905,7 @@ runCommandProcessing c@AgentClient {subQ} server_ = do
|
||||
atomically $ do
|
||||
srvs <- readTVar $ smpServers c
|
||||
let used' = if length used + 1 >= L.length srvs then initUsed else srv : used
|
||||
writeTVar usedSrvs used'
|
||||
writeTVar usedSrvs $! used'
|
||||
action srvAuth
|
||||
-- ^ ^ ^ async command processing /
|
||||
|
||||
|
||||
@@ -607,7 +607,7 @@ subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> m ()
|
||||
subscribeQueue c rq@RcvQueue {connId, server, rcvPrivateKey, rcvId} = do
|
||||
whenM (atomically . TM.member (server, rcvId) $ getMsgLocks c) . throwError $ CMD PROHIBITED
|
||||
atomically $ do
|
||||
modifyTVar (subscrConns c) $ S.insert connId
|
||||
modifyTVar' (subscrConns c) $ S.insert connId
|
||||
RQ.addQueue rq $ pendingSubs c
|
||||
withLogClient c server rcvId "SUB" $ \smp ->
|
||||
liftIO (runExceptT (subscribeSMPQueue smp rcvPrivateKey rcvId) >>= processSubResult c rq)
|
||||
@@ -644,7 +644,7 @@ subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> [RcvQueue] -> m (
|
||||
subscribeQueues c srv qs = do
|
||||
(errs, qs_) <- partitionEithers <$> mapM checkQueue qs
|
||||
forM_ qs_ $ \rq@RcvQueue {connId} -> atomically $ do
|
||||
modifyTVar (subscrConns c) $ S.insert connId
|
||||
modifyTVar' (subscrConns c) $ S.insert connId
|
||||
RQ.addQueue rq $ pendingSubs c
|
||||
case L.nonEmpty qs_ of
|
||||
Just qs' -> do
|
||||
@@ -671,7 +671,7 @@ subscribeQueues c srv qs = do
|
||||
|
||||
addSubscription :: MonadIO m => AgentClient -> RcvQueue -> m ()
|
||||
addSubscription c rq@RcvQueue {connId} = atomically $ do
|
||||
modifyTVar (subscrConns c) $ S.insert connId
|
||||
modifyTVar' (subscrConns c) $ S.insert connId
|
||||
RQ.addQueue rq $ activeSubs c
|
||||
RQ.deleteQueue rq $ pendingSubs c
|
||||
|
||||
@@ -680,7 +680,7 @@ hasActiveSubscription c connId = RQ.hasConn connId $ activeSubs c
|
||||
|
||||
removeSubscription :: AgentClient -> ConnId -> STM ()
|
||||
removeSubscription c connId = do
|
||||
modifyTVar (subscrConns c) $ S.delete connId
|
||||
modifyTVar' (subscrConns c) $ S.delete connId
|
||||
RQ.deleteConn connId $ activeSubs c
|
||||
RQ.deleteConn connId $ pendingSubs c
|
||||
|
||||
@@ -945,7 +945,7 @@ storeError = \case
|
||||
incStat :: AgentClient -> Int -> AgentStatsKey -> STM ()
|
||||
incStat AgentClient {agentStats} n k = do
|
||||
TM.lookup k agentStats >>= \case
|
||||
Just v -> modifyTVar v (+ n)
|
||||
Just v -> modifyTVar' v (+ n)
|
||||
_ -> newTVar n >>= \v -> TM.insert k v agentStats
|
||||
|
||||
incClientStat :: AgentClient -> ProtocolClient msg -> ByteString -> ByteString -> IO ()
|
||||
|
||||
@@ -964,11 +964,7 @@ pseudoRandomCbNonce :: TVar ChaChaDRG -> STM CbNonce
|
||||
pseudoRandomCbNonce gVar = CbNonce <$> pseudoRandomBytes 24 gVar
|
||||
|
||||
pseudoRandomBytes :: Int -> TVar ChaChaDRG -> STM ByteString
|
||||
pseudoRandomBytes n gVar = do
|
||||
g <- readTVar gVar
|
||||
let (bytes, g') = randomBytesGenerate n g
|
||||
writeTVar gVar g'
|
||||
return bytes
|
||||
pseudoRandomBytes n gVar = stateTVar gVar $ randomBytesGenerate n
|
||||
|
||||
instance Encoding CbNonce where
|
||||
smpEncode = unCbNonce
|
||||
|
||||
@@ -548,7 +548,7 @@ withNtfLog action = liftIO . mapM_ action =<< asks storeLog
|
||||
incNtfStat :: (NtfServerStats -> TVar Int) -> M ()
|
||||
incNtfStat statSel = do
|
||||
stats <- asks serverStats
|
||||
atomically $ modifyTVar (statSel stats) (+ 1)
|
||||
atomically $ modifyTVar' (statSel stats) (+ 1)
|
||||
|
||||
saveServerStats :: M ()
|
||||
saveServerStats =
|
||||
|
||||
@@ -70,14 +70,14 @@ getNtfServerStatsData s = do
|
||||
|
||||
setNtfServerStats :: NtfServerStats -> NtfServerStatsData -> STM ()
|
||||
setNtfServerStats s d = do
|
||||
writeTVar (fromTime (s :: NtfServerStats)) (_fromTime (d :: NtfServerStatsData))
|
||||
writeTVar (tknCreated s) (_tknCreated d)
|
||||
writeTVar (tknVerified s) (_tknVerified d)
|
||||
writeTVar (tknDeleted s) (_tknDeleted d)
|
||||
writeTVar (subCreated s) (_subCreated d)
|
||||
writeTVar (subDeleted s) (_subDeleted d)
|
||||
writeTVar (ntfReceived s) (_ntfReceived d)
|
||||
writeTVar (ntfDelivered s) (_ntfDelivered d)
|
||||
writeTVar (fromTime (s :: NtfServerStats)) $! _fromTime (d :: NtfServerStatsData)
|
||||
writeTVar (tknCreated s) $! _tknCreated d
|
||||
writeTVar (tknVerified s) $! _tknVerified d
|
||||
writeTVar (tknDeleted s) $! _tknDeleted d
|
||||
writeTVar (subCreated s) $! _subCreated d
|
||||
writeTVar (subDeleted s) $! _subDeleted d
|
||||
writeTVar (ntfReceived s) $! _ntfReceived d
|
||||
writeTVar (ntfDelivered s) $! _ntfDelivered d
|
||||
setPeriodStats (activeTokens s) (_activeTokens d)
|
||||
setPeriodStats (activeSubs s) (_activeSubs d)
|
||||
|
||||
|
||||
@@ -65,9 +65,9 @@ import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Server.Env.STM
|
||||
import Simplex.Messaging.Server.Expiration
|
||||
import Simplex.Messaging.Server.MsgStore
|
||||
import Simplex.Messaging.Server.MsgStore.STM (MsgQueue)
|
||||
import Simplex.Messaging.Server.MsgStore.STM
|
||||
import Simplex.Messaging.Server.QueueStore
|
||||
import Simplex.Messaging.Server.QueueStore.STM (QueueStore)
|
||||
import Simplex.Messaging.Server.QueueStore.STM
|
||||
import Simplex.Messaging.Server.Stats
|
||||
import Simplex.Messaging.Server.StoreLog
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
@@ -386,7 +386,7 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
|
||||
Right _ -> do
|
||||
withLog (`logCreateById` rId)
|
||||
stats <- asks serverStats
|
||||
atomically $ modifyTVar (qCreated stats) (+ 1)
|
||||
atomically $ modifyTVar' (qCreated stats) (+ 1)
|
||||
subscribeQueue qr rId $> IDS (qik ids)
|
||||
|
||||
logCreateById :: StoreLog 'WriteMode -> RecipientId -> IO ()
|
||||
@@ -404,7 +404,7 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
|
||||
secureQueue_ st sKey = time "KEY" $ do
|
||||
withLog $ \s -> logSecureQueue s queueId sKey
|
||||
stats <- asks serverStats
|
||||
atomically $ modifyTVar (qSecured stats) (+ 1)
|
||||
atomically $ modifyTVar' (qSecured stats) (+ 1)
|
||||
atomically $ (corrId,queueId,) . either ERR (const OK) <$> secureQueue st queueId sKey
|
||||
|
||||
addQueueNotifier_ :: QueueStore -> NtfPublicVerifyKey -> RcvNtfPublicDhKey -> m (Transmission BrokerMsg)
|
||||
@@ -528,7 +528,7 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
|
||||
updateStats :: m ()
|
||||
updateStats = do
|
||||
stats <- asks serverStats
|
||||
atomically $ modifyTVar (msgRecv stats) (+ 1)
|
||||
atomically $ modifyTVar' (msgRecv stats) (+ 1)
|
||||
atomically $ updatePeriodStats (activeQueues stats) queueId
|
||||
|
||||
sendMessage :: QueueRec -> MsgFlags -> MsgBody -> m (Transmission BrokerMsg)
|
||||
@@ -550,7 +550,7 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
|
||||
when (notification msgFlags) $
|
||||
atomically . trySendNotification msg =<< asks idsDrg
|
||||
stats <- asks serverStats
|
||||
atomically $ modifyTVar (msgSent stats) (+ 1)
|
||||
atomically $ modifyTVar' (msgSent stats) (+ 1)
|
||||
atomically $ updatePeriodStats (activeQueues stats) (recipientId qr)
|
||||
pure ok
|
||||
where
|
||||
@@ -599,9 +599,9 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
|
||||
where
|
||||
forkSub :: m ()
|
||||
forkSub = do
|
||||
atomically . modifyTVar sub $ \s -> s {subThread = SubPending}
|
||||
atomically . modifyTVar' sub $ \s -> s {subThread = SubPending}
|
||||
t <- mkWeakThreadId =<< forkIO subscriber
|
||||
atomically . modifyTVar sub $ \case
|
||||
atomically . modifyTVar' sub $ \case
|
||||
s@Sub {subThread = SubPending} -> s {subThread = SubThread t}
|
||||
s -> s
|
||||
where
|
||||
@@ -612,7 +612,7 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
|
||||
writeTBQueue sndQ [(CorrId "", rId, MSG encMsg)]
|
||||
s <- readTVar sub
|
||||
void $ setDelivered s msg
|
||||
writeTVar sub s {subThread = NoSub}
|
||||
writeTVar sub $! s {subThread = NoSub}
|
||||
|
||||
time :: T.Text -> m a -> m a
|
||||
time name = timed name queueId
|
||||
@@ -646,7 +646,7 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
|
||||
withLog (`logDeleteQueue` queueId)
|
||||
ms <- asks msgStore
|
||||
stats <- asks serverStats
|
||||
atomically $ modifyTVar (qDeleted stats) (+ 1)
|
||||
atomically $ modifyTVar' (qDeleted stats) (+ 1)
|
||||
atomically $
|
||||
deleteQueue st queueId >>= \case
|
||||
Left e -> pure $ err e
|
||||
|
||||
@@ -164,8 +164,8 @@ newEnv config@ServerConfig {caCertificateFile, certificateFile, privateKeyFile,
|
||||
(qs, s') <- liftIO $ readWriteStoreLog s
|
||||
atomically $ do
|
||||
writeTVar queues =<< mapM newTVar qs
|
||||
writeTVar senders $ M.foldr' addSender M.empty qs
|
||||
writeTVar notifiers $ M.foldr' addNotifier M.empty qs
|
||||
writeTVar senders $! M.foldr' addSender M.empty qs
|
||||
writeTVar notifiers $! M.foldr' addNotifier M.empty qs
|
||||
pure s'
|
||||
addSender :: QueueRec -> Map SenderId RecipientId -> Map SenderId RecipientId
|
||||
addSender q = M.insert (senderId q) (recipientId q)
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
{-# LANGUAGE FunctionalDependencies #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module Simplex.Messaging.Server.MsgStore where
|
||||
|
||||
import Control.Applicative ((<|>))
|
||||
import Data.Int (Int64)
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Protocol (Message (..), MsgId, RcvMessage (..), RecipientId)
|
||||
import Simplex.Messaging.Protocol (Message (..), RcvMessage (..), RecipientId)
|
||||
|
||||
data MsgLogRecord = MLRv3 RecipientId Message | MLRv1 RecipientId RcvMessage
|
||||
|
||||
@@ -16,16 +14,3 @@ instance StrEncoding MsgLogRecord where
|
||||
MLRv3 rId msg -> strEncode (Str "v3", rId, msg)
|
||||
MLRv1 rId msg -> strEncode (rId, msg)
|
||||
strP = "v3 " *> (MLRv3 <$> strP_ <*> strP) <|> MLRv1 <$> strP_ <*> strP
|
||||
|
||||
class MonadMsgStore s q m | s -> q where
|
||||
getMsgQueue :: s -> RecipientId -> Int -> m q
|
||||
delMsgQueue :: s -> RecipientId -> m ()
|
||||
flushMsgQueue :: s -> RecipientId -> m [Message]
|
||||
|
||||
class MonadMsgQueue q m where
|
||||
writeMsg :: q -> Message -> m (Maybe Message) -- non blocking
|
||||
tryPeekMsg :: q -> m (Maybe Message) -- non blocking
|
||||
peekMsg :: q -> m Message -- blocking
|
||||
tryDelMsg :: q -> MsgId -> m Bool -- non blocking
|
||||
tryDelPeekMsg :: q -> MsgId -> m (Bool, Maybe Message) -- atomic delete (== read) last and peek next message, if available
|
||||
deleteExpiredMsgs :: q -> Int64 -> m ()
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
{-# LANGUAGE ConstraintKinds #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE InstanceSigs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
@@ -11,6 +10,15 @@ module Simplex.Messaging.Server.MsgStore.STM
|
||||
( STMMsgStore,
|
||||
MsgQueue,
|
||||
newMsgStore,
|
||||
getMsgQueue,
|
||||
delMsgQueue,
|
||||
flushMsgQueue,
|
||||
writeMsg,
|
||||
tryPeekMsg,
|
||||
peekMsg,
|
||||
tryDelMsg,
|
||||
tryDelPeekMsg,
|
||||
deleteExpiredMsgs,
|
||||
)
|
||||
where
|
||||
|
||||
@@ -21,7 +29,6 @@ import Data.Functor (($>))
|
||||
import Data.Int (Int64)
|
||||
import Data.Time.Clock.System (SystemTime (systemSeconds))
|
||||
import Simplex.Messaging.Protocol (Message (..), MsgId, RecipientId)
|
||||
import Simplex.Messaging.Server.MsgStore
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import UnliftIO.STM
|
||||
@@ -38,75 +45,73 @@ type STMMsgStore = TMap RecipientId MsgQueue
|
||||
newMsgStore :: STM STMMsgStore
|
||||
newMsgStore = TM.empty
|
||||
|
||||
instance MonadMsgStore STMMsgStore MsgQueue STM where
|
||||
getMsgQueue :: STMMsgStore -> RecipientId -> Int -> STM MsgQueue
|
||||
getMsgQueue st rId quota = maybe newQ pure =<< TM.lookup rId st
|
||||
where
|
||||
newQ = do
|
||||
msgQueue <- newTQueue
|
||||
canWrite <- newTVar True
|
||||
size <- newTVar 0
|
||||
let q = MsgQueue {msgQueue, quota, canWrite, size}
|
||||
TM.insert rId q st
|
||||
pure q
|
||||
getMsgQueue :: STMMsgStore -> RecipientId -> Int -> STM MsgQueue
|
||||
getMsgQueue st rId quota = maybe newQ pure =<< TM.lookup rId st
|
||||
where
|
||||
newQ = do
|
||||
msgQueue <- newTQueue
|
||||
canWrite <- newTVar True
|
||||
size <- newTVar 0
|
||||
let q = MsgQueue {msgQueue, quota, canWrite, size}
|
||||
TM.insert rId q st
|
||||
pure q
|
||||
|
||||
delMsgQueue :: STMMsgStore -> RecipientId -> STM ()
|
||||
delMsgQueue st rId = TM.delete rId st
|
||||
delMsgQueue :: STMMsgStore -> RecipientId -> STM ()
|
||||
delMsgQueue st rId = TM.delete rId st
|
||||
|
||||
flushMsgQueue :: STMMsgStore -> RecipientId -> STM [Message]
|
||||
flushMsgQueue st rId = TM.lookupDelete rId st >>= maybe (pure []) (flushTQueue . msgQueue)
|
||||
flushMsgQueue :: STMMsgStore -> RecipientId -> STM [Message]
|
||||
flushMsgQueue st rId = TM.lookupDelete rId st >>= maybe (pure []) (flushTQueue . msgQueue)
|
||||
|
||||
instance MonadMsgQueue MsgQueue STM where
|
||||
writeMsg :: MsgQueue -> Message -> STM (Maybe Message)
|
||||
writeMsg MsgQueue {msgQueue = q, quota, canWrite, size} msg = do
|
||||
canWrt <- readTVar canWrite
|
||||
empty <- isEmptyTQueue q
|
||||
if canWrt || empty
|
||||
then do
|
||||
canWrt' <- (quota >) <$> readTVar size
|
||||
writeTVar canWrite canWrt'
|
||||
modifyTVar' size (+ 1)
|
||||
if canWrt'
|
||||
then writeTQueue q msg $> Just msg
|
||||
else writeTQueue q msgQuota $> Nothing
|
||||
else pure Nothing
|
||||
where
|
||||
msgQuota = MessageQuota {msgId = msgId msg, msgTs = msgTs msg}
|
||||
writeMsg :: MsgQueue -> Message -> STM (Maybe Message)
|
||||
writeMsg MsgQueue {msgQueue = q, quota, canWrite, size} msg = do
|
||||
canWrt <- readTVar canWrite
|
||||
empty <- isEmptyTQueue q
|
||||
if canWrt || empty
|
||||
then do
|
||||
canWrt' <- (quota >) <$> readTVar size
|
||||
writeTVar canWrite $! canWrt'
|
||||
modifyTVar' size (+ 1)
|
||||
if canWrt'
|
||||
then writeTQueue q msg $> Just msg
|
||||
else writeTQueue q msgQuota $> Nothing
|
||||
else pure Nothing
|
||||
where
|
||||
msgQuota = MessageQuota {msgId = msgId msg, msgTs = msgTs msg}
|
||||
|
||||
tryPeekMsg :: MsgQueue -> STM (Maybe Message)
|
||||
tryPeekMsg = tryPeekTQueue . msgQueue
|
||||
{-# INLINE tryPeekMsg #-}
|
||||
tryPeekMsg :: MsgQueue -> STM (Maybe Message)
|
||||
tryPeekMsg = tryPeekTQueue . msgQueue
|
||||
{-# INLINE tryPeekMsg #-}
|
||||
|
||||
peekMsg :: MsgQueue -> STM Message
|
||||
peekMsg = peekTQueue . msgQueue
|
||||
{-# INLINE peekMsg #-}
|
||||
peekMsg :: MsgQueue -> STM Message
|
||||
peekMsg = peekTQueue . msgQueue
|
||||
{-# INLINE peekMsg #-}
|
||||
|
||||
tryDelMsg :: MsgQueue -> MsgId -> STM Bool
|
||||
tryDelMsg mq msgId' =
|
||||
tryPeekMsg mq >>= \case
|
||||
Just msg
|
||||
| msgId msg == msgId' || B.null msgId' -> tryDeleteMsg mq >> pure True
|
||||
| otherwise -> pure False
|
||||
_ -> pure False
|
||||
tryDelMsg :: MsgQueue -> MsgId -> STM Bool
|
||||
tryDelMsg mq msgId' =
|
||||
tryPeekMsg mq >>= \case
|
||||
Just msg
|
||||
| msgId msg == msgId' || B.null msgId' -> tryDeleteMsg mq >> pure True
|
||||
| otherwise -> pure False
|
||||
_ -> pure False
|
||||
|
||||
-- atomic delete (== read) last and peek next message if available
|
||||
tryDelPeekMsg :: MsgQueue -> MsgId -> STM (Bool, Maybe Message)
|
||||
tryDelPeekMsg mq msgId' =
|
||||
tryPeekMsg mq >>= \case
|
||||
msg_@(Just msg)
|
||||
| msgId msg == msgId' || B.null msgId' -> (True,) <$> (tryDeleteMsg mq >> tryPeekMsg mq)
|
||||
| otherwise -> pure (False, msg_)
|
||||
_ -> pure (False, Nothing)
|
||||
-- atomic delete (== read) last and peek next message if available
|
||||
tryDelPeekMsg :: MsgQueue -> MsgId -> STM (Bool, Maybe Message)
|
||||
tryDelPeekMsg mq msgId' =
|
||||
tryPeekMsg mq >>= \case
|
||||
msg_@(Just msg)
|
||||
| msgId msg == msgId' || B.null msgId' -> (True,) <$> (tryDeleteMsg mq >> tryPeekMsg mq)
|
||||
| otherwise -> pure (False, msg_)
|
||||
_ -> pure (False, Nothing)
|
||||
|
||||
deleteExpiredMsgs :: MsgQueue -> Int64 -> STM ()
|
||||
deleteExpiredMsgs mq old = loop
|
||||
where
|
||||
loop = tryPeekMsg mq >>= mapM_ delOldMsg
|
||||
delOldMsg = \case
|
||||
Message {msgTs} ->
|
||||
when (systemSeconds msgTs < old) $
|
||||
tryDeleteMsg mq >> loop
|
||||
_ -> pure ()
|
||||
deleteExpiredMsgs :: MsgQueue -> Int64 -> STM ()
|
||||
deleteExpiredMsgs mq old = loop
|
||||
where
|
||||
loop = tryPeekMsg mq >>= mapM_ delOldMsg
|
||||
delOldMsg = \case
|
||||
Message {msgTs} ->
|
||||
when (systemSeconds msgTs < old) $
|
||||
tryDeleteMsg mq >> loop
|
||||
_ -> pure ()
|
||||
|
||||
tryDeleteMsg :: MsgQueue -> STM ()
|
||||
tryDeleteMsg MsgQueue {msgQueue = q, size} =
|
||||
|
||||
@@ -9,20 +9,20 @@ import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Protocol
|
||||
|
||||
data QueueRec = QueueRec
|
||||
{ recipientId :: RecipientId,
|
||||
recipientKey :: RcvPublicVerifyKey,
|
||||
rcvDhSecret :: RcvDhSecret,
|
||||
senderId :: SenderId,
|
||||
senderKey :: Maybe SndPublicVerifyKey,
|
||||
notifier :: Maybe NtfCreds,
|
||||
status :: ServerQueueStatus
|
||||
{ recipientId :: !RecipientId,
|
||||
recipientKey :: !RcvPublicVerifyKey,
|
||||
rcvDhSecret :: !RcvDhSecret,
|
||||
senderId :: !SenderId,
|
||||
senderKey :: !(Maybe SndPublicVerifyKey),
|
||||
notifier :: !(Maybe NtfCreds),
|
||||
status :: !ServerQueueStatus
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
|
||||
data NtfCreds = NtfCreds
|
||||
{ notifierId :: NotifierId,
|
||||
notifierKey :: NtfPublicVerifyKey,
|
||||
rcvNtfDhSecret :: RcvNtfDhSecret
|
||||
{ notifierId :: !NotifierId,
|
||||
notifierKey :: !NtfPublicVerifyKey,
|
||||
rcvNtfDhSecret :: !RcvNtfDhSecret
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
|
||||
@@ -33,12 +33,3 @@ instance StrEncoding NtfCreds where
|
||||
pure NtfCreds {notifierId, notifierKey, rcvNtfDhSecret}
|
||||
|
||||
data ServerQueueStatus = QueueActive | QueueOff deriving (Eq, Show)
|
||||
|
||||
class MonadQueueStore s m where
|
||||
addQueue :: s -> QueueRec -> m (Either ErrorType ())
|
||||
getQueue :: s -> SParty p -> QueueId -> m (Either ErrorType QueueRec)
|
||||
secureQueue :: s -> RecipientId -> SndPublicVerifyKey -> m (Either ErrorType QueueRec)
|
||||
addQueueNotifier :: s -> RecipientId -> NtfCreds -> m (Either ErrorType QueueRec)
|
||||
deleteQueueNotifier :: s -> RecipientId -> m (Either ErrorType ())
|
||||
suspendQueue :: s -> RecipientId -> m (Either ErrorType ())
|
||||
deleteQueue :: s -> RecipientId -> m (Either ErrorType ())
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE InstanceSigs #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
@@ -10,7 +9,18 @@
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
|
||||
module Simplex.Messaging.Server.QueueStore.STM where
|
||||
module Simplex.Messaging.Server.QueueStore.STM
|
||||
( QueueStore (..),
|
||||
newQueueStore,
|
||||
addQueue,
|
||||
getQueue,
|
||||
secureQueue,
|
||||
addQueueNotifier,
|
||||
deleteQueueNotifier,
|
||||
suspendQueue,
|
||||
deleteQueue,
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Monad
|
||||
import Data.Functor (($>))
|
||||
@@ -34,66 +44,65 @@ newQueueStore = do
|
||||
notifiers <- TM.empty
|
||||
pure QueueStore {queues, senders, notifiers}
|
||||
|
||||
instance MonadQueueStore QueueStore STM where
|
||||
addQueue :: QueueStore -> QueueRec -> STM (Either ErrorType ())
|
||||
addQueue QueueStore {queues, senders} q@QueueRec {recipientId = rId, senderId = sId} = do
|
||||
ifM hasId (pure $ Left DUPLICATE_) $ do
|
||||
qVar <- newTVar q
|
||||
TM.insert rId qVar queues
|
||||
TM.insert sId rId senders
|
||||
pure $ Right ()
|
||||
where
|
||||
hasId = (||) <$> TM.member rId queues <*> TM.member sId senders
|
||||
addQueue :: QueueStore -> QueueRec -> STM (Either ErrorType ())
|
||||
addQueue QueueStore {queues, senders} q@QueueRec {recipientId = rId, senderId = sId} = do
|
||||
ifM hasId (pure $ Left DUPLICATE_) $ do
|
||||
qVar <- newTVar q
|
||||
TM.insert rId qVar queues
|
||||
TM.insert sId rId senders
|
||||
pure $ Right ()
|
||||
where
|
||||
hasId = (||) <$> TM.member rId queues <*> TM.member sId senders
|
||||
|
||||
getQueue :: QueueStore -> SParty p -> QueueId -> STM (Either ErrorType QueueRec)
|
||||
getQueue QueueStore {queues, senders, notifiers} party qId =
|
||||
toResult <$> (mapM readTVar =<< getVar)
|
||||
where
|
||||
getVar = case party of
|
||||
SRecipient -> TM.lookup qId queues
|
||||
SSender -> TM.lookup qId senders $>>= (`TM.lookup` queues)
|
||||
SNotifier -> TM.lookup qId notifiers $>>= (`TM.lookup` queues)
|
||||
getQueue :: QueueStore -> SParty p -> QueueId -> STM (Either ErrorType QueueRec)
|
||||
getQueue QueueStore {queues, senders, notifiers} party qId =
|
||||
toResult <$> (mapM readTVar =<< getVar)
|
||||
where
|
||||
getVar = case party of
|
||||
SRecipient -> TM.lookup qId queues
|
||||
SSender -> TM.lookup qId senders $>>= (`TM.lookup` queues)
|
||||
SNotifier -> TM.lookup qId notifiers $>>= (`TM.lookup` queues)
|
||||
|
||||
secureQueue :: QueueStore -> RecipientId -> SndPublicVerifyKey -> STM (Either ErrorType QueueRec)
|
||||
secureQueue QueueStore {queues} rId sKey =
|
||||
withQueue rId queues $ \qVar ->
|
||||
readTVar qVar >>= \q -> case senderKey q of
|
||||
Just k -> pure $ if sKey == k then Just q else Nothing
|
||||
_ ->
|
||||
let q' = q {senderKey = Just sKey}
|
||||
in writeTVar qVar q' $> Just q'
|
||||
secureQueue :: QueueStore -> RecipientId -> SndPublicVerifyKey -> STM (Either ErrorType QueueRec)
|
||||
secureQueue QueueStore {queues} rId sKey =
|
||||
withQueue rId queues $ \qVar ->
|
||||
readTVar qVar >>= \q -> case senderKey q of
|
||||
Just k -> pure $ if sKey == k then Just q else Nothing
|
||||
_ ->
|
||||
let q' = q {senderKey = Just sKey}
|
||||
in writeTVar qVar q' $> Just q'
|
||||
|
||||
addQueueNotifier :: QueueStore -> RecipientId -> NtfCreds -> STM (Either ErrorType QueueRec)
|
||||
addQueueNotifier QueueStore {queues, notifiers} rId ntfCreds@NtfCreds {notifierId = nId} = do
|
||||
ifM (TM.member nId notifiers) (pure $ Left DUPLICATE_) $
|
||||
withQueue rId queues $ \qVar -> do
|
||||
q <- readTVar qVar
|
||||
forM_ (notifier q) $ (`TM.delete` notifiers) . notifierId
|
||||
writeTVar qVar q {notifier = Just ntfCreds}
|
||||
TM.insert nId rId notifiers
|
||||
pure $ Just q
|
||||
|
||||
deleteQueueNotifier :: QueueStore -> RecipientId -> STM (Either ErrorType ())
|
||||
deleteQueueNotifier QueueStore {queues, notifiers} rId =
|
||||
addQueueNotifier :: QueueStore -> RecipientId -> NtfCreds -> STM (Either ErrorType QueueRec)
|
||||
addQueueNotifier QueueStore {queues, notifiers} rId ntfCreds@NtfCreds {notifierId = nId} = do
|
||||
ifM (TM.member nId notifiers) (pure $ Left DUPLICATE_) $
|
||||
withQueue rId queues $ \qVar -> do
|
||||
q <- readTVar qVar
|
||||
forM_ (notifier q) $ \NtfCreds {notifierId} -> TM.delete notifierId notifiers
|
||||
writeTVar qVar q {notifier = Nothing}
|
||||
pure $ Just ()
|
||||
forM_ (notifier q) $ (`TM.delete` notifiers) . notifierId
|
||||
writeTVar qVar $! q {notifier = Just ntfCreds}
|
||||
TM.insert nId rId notifiers
|
||||
pure $ Just q
|
||||
|
||||
suspendQueue :: QueueStore -> RecipientId -> STM (Either ErrorType ())
|
||||
suspendQueue QueueStore {queues} rId =
|
||||
withQueue rId queues $ \qVar -> modifyTVar' qVar (\q -> q {status = QueueOff}) $> Just ()
|
||||
deleteQueueNotifier :: QueueStore -> RecipientId -> STM (Either ErrorType ())
|
||||
deleteQueueNotifier QueueStore {queues, notifiers} rId =
|
||||
withQueue rId queues $ \qVar -> do
|
||||
q <- readTVar qVar
|
||||
forM_ (notifier q) $ \NtfCreds {notifierId} -> TM.delete notifierId notifiers
|
||||
writeTVar qVar $! q {notifier = Nothing}
|
||||
pure $ Just ()
|
||||
|
||||
deleteQueue :: QueueStore -> RecipientId -> STM (Either ErrorType ())
|
||||
deleteQueue QueueStore {queues, senders, notifiers} rId = do
|
||||
TM.lookupDelete rId queues >>= \case
|
||||
Just qVar ->
|
||||
readTVar qVar >>= \q -> do
|
||||
TM.delete (senderId q) senders
|
||||
forM_ (notifier q) $ \NtfCreds {notifierId} -> TM.delete notifierId notifiers
|
||||
pure $ Right ()
|
||||
_ -> pure $ Left AUTH
|
||||
suspendQueue :: QueueStore -> RecipientId -> STM (Either ErrorType ())
|
||||
suspendQueue QueueStore {queues} rId =
|
||||
withQueue rId queues $ \qVar -> modifyTVar' qVar (\q -> q {status = QueueOff}) $> Just ()
|
||||
|
||||
deleteQueue :: QueueStore -> RecipientId -> STM (Either ErrorType ())
|
||||
deleteQueue QueueStore {queues, senders, notifiers} rId = do
|
||||
TM.lookupDelete rId queues >>= \case
|
||||
Just qVar ->
|
||||
readTVar qVar >>= \q -> do
|
||||
TM.delete (senderId q) senders
|
||||
forM_ (notifier q) $ \NtfCreds {notifierId} -> TM.delete notifierId notifiers
|
||||
pure $ Right ()
|
||||
_ -> pure $ Left AUTH
|
||||
|
||||
toResult :: Maybe a -> Either ErrorType a
|
||||
toResult = maybe (Left AUTH) Right
|
||||
|
||||
@@ -61,12 +61,12 @@ getServerStatsData s = do
|
||||
|
||||
setServerStats :: ServerStats -> ServerStatsData -> STM ()
|
||||
setServerStats s d = do
|
||||
writeTVar (fromTime s) (_fromTime d)
|
||||
writeTVar (qCreated s) (_qCreated d)
|
||||
writeTVar (qSecured s) (_qSecured d)
|
||||
writeTVar (qDeleted s) (_qDeleted d)
|
||||
writeTVar (msgSent s) (_msgSent d)
|
||||
writeTVar (msgRecv s) (_msgRecv d)
|
||||
writeTVar (fromTime s) $! _fromTime d
|
||||
writeTVar (qCreated s) $! _qCreated d
|
||||
writeTVar (qSecured s) $! _qSecured d
|
||||
writeTVar (qDeleted s) $! _qDeleted d
|
||||
writeTVar (msgSent s) $! _msgSent d
|
||||
writeTVar (msgRecv s) $! _msgRecv d
|
||||
setPeriodStats (activeQueues s) (_activeQueues d)
|
||||
|
||||
instance StrEncoding ServerStatsData where
|
||||
@@ -126,9 +126,9 @@ getPeriodStatsData s = do
|
||||
|
||||
setPeriodStats :: PeriodStats a -> PeriodStatsData a -> STM ()
|
||||
setPeriodStats s d = do
|
||||
writeTVar (day s) (_day d)
|
||||
writeTVar (week s) (_week d)
|
||||
writeTVar (month s) (_month d)
|
||||
writeTVar (day s) $! _day d
|
||||
writeTVar (week s) $! _week d
|
||||
writeTVar (month s) $! _month d
|
||||
|
||||
instance (Ord a, StrEncoding a) => StrEncoding (PeriodStatsData a) where
|
||||
strEncode PeriodStatsData {_day, _week, _month} =
|
||||
@@ -165,4 +165,4 @@ updatePeriodStats stats pId = do
|
||||
updatePeriod week
|
||||
updatePeriod month
|
||||
where
|
||||
updatePeriod pSel = modifyTVar (pSel stats) (S.insert pId)
|
||||
updatePeriod pSel = modifyTVar' (pSel stats) (S.insert pId)
|
||||
|
||||
@@ -214,7 +214,7 @@ instance Transport TLS where
|
||||
$ do
|
||||
b <- readChunks =<< readTVarIO buffer
|
||||
let (s, b') = B.splitAt n b
|
||||
atomically $ writeTVar buffer b'
|
||||
atomically $ writeTVar buffer $! b'
|
||||
pure s
|
||||
where
|
||||
readChunks :: ByteString -> IO ByteString
|
||||
@@ -237,7 +237,7 @@ instance Transport TLS where
|
||||
$ do
|
||||
b <- readChunks =<< readTVarIO buffer
|
||||
let (s, b') = B.break (== '\n') b
|
||||
atomically $ writeTVar buffer (B.drop 1 b') -- drop '\n' we made a break at
|
||||
atomically $ writeTVar buffer $! B.drop 1 b' -- drop '\n' we made a break at
|
||||
pure $ trimCR s
|
||||
where
|
||||
readChunks :: ByteString -> IO ByteString
|
||||
|
||||
Reference in New Issue
Block a user