smp server: batch commands (#1559)

* protocol: refactor types and encoding

* clean

* smp server: batch commands (#1560)

* smp server: batch commands verification into one DB transaction

* ghc 8.10.7

* flatten transmission tuples

* diff

* only use batch logic if there is more than one transmission

* func

* reset NTF service when adding notifier

* version

* Revert "smp server: use separate database pool for reading queues and creating service records (#1561)"

This reverts commit 3df2425162.

* version

* Revert "version"

This reverts commit d80a6b74c5.
This commit is contained in:
Evgeny
2025-06-12 23:05:04 +01:00
committed by GitHub
parent 1658048c2c
commit da37384335
24 changed files with 556 additions and 377 deletions
@@ -37,18 +37,20 @@ import Control.Monad
import Control.Monad.Except
import Control.Monad.IO.Class
import Control.Monad.Trans.Except
import Data.Bifunctor (first)
import Data.ByteString.Builder (Builder)
import qualified Data.ByteString.Builder as BB
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Lazy as LB
import Data.Bitraversable (bimapM)
import Data.Either (fromRight)
import Data.Either (fromRight, lefts, rights)
import Data.Functor (($>))
import Data.Int (Int64)
import Data.List (foldl', intersperse, partition)
import Data.List.NonEmpty (NonEmpty)
import qualified Data.List.NonEmpty as L
import qualified Data.Map.Strict as M
import Data.Maybe (catMaybes, fromMaybe)
import Data.Maybe (catMaybes, fromMaybe, mapMaybe)
import qualified Data.Set as S
import Data.Text (Text)
import Data.Time.Clock.System (SystemTime (..), getSystemTime)
@@ -62,7 +64,7 @@ import Database.PostgreSQL.Simple.ToField (Action (..), ToField (..))
import Database.PostgreSQL.Simple.Errors (ConstraintViolation (..), constraintViolation)
import Database.PostgreSQL.Simple.SqlQQ (sql)
import GHC.IO (catchAny)
import Simplex.Messaging.Agent.Client (withLockMap)
import Simplex.Messaging.Agent.Client (withLockMap, withLocksMap)
import Simplex.Messaging.Agent.Lock (Lock)
import Simplex.Messaging.Agent.Store.AgentStore ()
import Simplex.Messaging.Agent.Store.Postgres (createDBStore, closeDBStore)
@@ -81,7 +83,7 @@ import Simplex.Messaging.Server.StoreLog
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport (SMPServiceRole (..))
import Simplex.Messaging.Util (eitherToMaybe, firstRow, ifM, maybeFirstRow, tshow, (<$$>))
import Simplex.Messaging.Util (eitherToMaybe, firstRow, ifM, maybeFirstRow, tshow, (<$$>), ($>>=))
import System.Exit (exitFailure)
import System.IO (IOMode (..), hFlush, stdout)
import UnliftIO.STM
@@ -180,18 +182,16 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where
-- hasId = anyM [TM.memberIO rId queues, TM.memberIO senderId senders, hasNotifier]
-- hasNotifier = maybe (pure False) (\NtfCreds {notifierId} -> TM.memberIO notifierId notifiers) notifier
getQueue_ :: DirectParty p => PostgresQueueStore q -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> QueueId -> IO (Either ErrorType q)
getQueue_ :: QueueParty p => PostgresQueueStore q -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> QueueId -> IO (Either ErrorType q)
getQueue_ st mkQ party qId = case party of
SRecipient -> getRcvQueue qId
SSender -> getSndQueue
SProxyService -> getSndQueue
SSender -> TM.lookupIO qId senders >>= maybe (mask loadSndQueue) getRcvQueue
SSenderLink -> TM.lookupIO qId links >>= maybe (mask loadLinkQueue) getRcvQueue
-- loaded queue is deleted from notifiers map to reduce cache size after queue was subscribed to by ntf server
SNotifier -> TM.lookupIO qId notifiers >>= maybe (mask loadNtfQueue) (getRcvQueue >=> (atomically (TM.delete qId notifiers) $>))
where
PostgresQueueStore {queues, senders, links, notifiers} = st
getRcvQueue rId = TM.lookupIO rId queues >>= maybe (mask loadRcvQueue) (pure . Right)
getSndQueue = TM.lookupIO qId senders >>= maybe (mask loadSndQueue) getRcvQueue
loadRcvQueue = do
(rId, qRec) <- loadQueue " WHERE recipient_id = ?"
liftIO $ cacheQueue rId qRec $ \_ -> pure () -- recipient map already checked, not caching sender ref
@@ -228,6 +228,47 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where
TM.insert rId sq queues
pure sq
getQueues_ :: forall p. BatchParty p => PostgresQueueStore q -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> [QueueId] -> IO [Either ErrorType q]
getQueues_ st mkQ party qIds = case party of
SRecipient -> do
qs <- readTVarIO queues
let qs' = map (\qId -> get qs qId qId) qIds
E.uninterruptibleMask_ $ loadQueues qs' " WHERE recipient_id IN ?" cacheRcvQueue
SNotifier -> do
ns <- readTVarIO notifiers
qs <- readTVarIO queues
let qs' = map (\qId -> get ns qId qId >>= get qs qId) qIds
E.uninterruptibleMask_ $ loadQueues qs' " WHERE notifier_id IN ?" $ \(rId, qRec) ->
forM (notifier qRec) $ \NtfCreds {notifierId = nId} -> -- it is always Just with this query
(nId,) <$> maybe (mkQ False rId qRec) pure (M.lookup rId qs)
where
PostgresQueueStore {queues, notifiers} = st
get :: M.Map QueueId a -> QueueId -> QueueId -> Either QueueId a
get m qId = maybe (Left qId) Right . (`M.lookup` m)
loadQueues :: [Either QueueId q] -> Query -> ((RecipientId, QueueRec) -> IO (Maybe (QueueId, q))) -> IO [Either ErrorType q]
loadQueues qs' cond mkCacheQueue = do
let qIds' = lefts qs'
if null qIds'
then pure $ map (first (const INTERNAL)) qs'
else do
qs_ <-
runExceptT $ fmap M.fromList $
withDB' "getQueues_" st (\db -> DB.query db (queueRecQuery <> cond <> " AND deleted_at IS NULL") (Only (In qIds')))
>>= liftIO . fmap catMaybes . mapM (mkCacheQueue . rowToQueueRec)
pure $ map (result qs_) qs'
where
result :: Either ErrorType (M.Map QueueId q) -> Either QueueId q -> Either ErrorType q
result _ (Right q) = Right q
result qs_ (Left qId) = maybe (Left AUTH) Right . M.lookup qId =<< qs_
cacheRcvQueue (rId, qRec) = do
sq <- mkQ True rId qRec
sq' <- withQueueLock sq "getQueue_" $ atomically $
-- checking the cache again for concurrent reads, use previously loaded queue if exists.
TM.lookup rId queues >>= \case
Just sq' -> pure sq'
Nothing -> sq <$ TM.insert rId sq queues
pure $ Just (rId, sq')
getQueueLinkData :: PostgresQueueStore q -> q -> LinkId -> IO (Either ErrorType QueueLinkData)
getQueueLinkData st sq lnkId = runExceptT $ do
qr <- ExceptT $ readQueueRecIO $ queueRec sq
@@ -311,7 +352,7 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where
db
[sql|
UPDATE msg_queues
SET notifier_id = ?, notifier_key = ?, rcv_ntf_dh_secret = ?
SET notifier_id = ?, notifier_key = ?, rcv_ntf_dh_secret = ?, ntf_service_id = NULL
WHERE recipient_id = ? AND deleted_at IS NULL
|]
(nId, notifierKey, rcvNtfDhSecret, rId)
@@ -333,7 +374,7 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where
db
[sql|
UPDATE msg_queues
SET notifier_id = NULL, notifier_key = NULL, rcv_ntf_dh_secret = NULL
SET notifier_id = NULL, notifier_key = NULL, rcv_ntf_dh_secret = NULL, ntf_service_id = NULL
WHERE recipient_id = ? AND deleted_at IS NULL
|]
(Only rId)
@@ -402,15 +443,15 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where
when new $ withLog "getCreateService" st (`logNewService` sr)
pure serviceId
setQueueService :: (PartyI p, SubscriberParty p) => PostgresQueueStore q -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ())
setQueueService :: (PartyI p, ServiceParty p) => PostgresQueueStore q -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ())
setQueueService st sq party serviceId = withQueueRec sq "setQueueService" $ \q -> case party of
SRecipient
SRecipientService
| rcvServiceId q == serviceId -> pure ()
| otherwise -> do
assertUpdated $ withDB' "setQueueService" st $ \db ->
DB.execute db "UPDATE msg_queues SET rcv_service_id = ? WHERE recipient_id = ? AND deleted_at IS NULL" (serviceId, rId)
updateQueueRec q {rcvServiceId = serviceId}
SNotifier -> case notifier q of
SNotifierService -> case notifier q of
Nothing -> throwE AUTH
Just nc@NtfCreds {ntfServiceId = prevSrvId}
| prevSrvId == serviceId -> pure ()
+19 -7
View File
@@ -128,17 +128,29 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where
hasNotifier = maybe (pure False) (\NtfCreds {notifierId} -> TM.member notifierId notifiers) notifier
hasLink = maybe (pure False) (\(lnkId, _) -> TM.member lnkId links) queueData
getQueue_ :: DirectParty p => STMQueueStore q -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> QueueId -> IO (Either ErrorType q)
getQueue_ :: QueueParty p => STMQueueStore q -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> QueueId -> IO (Either ErrorType q)
getQueue_ st _ party qId =
maybe (Left AUTH) Right <$> case party of
SRecipient -> TM.lookupIO qId queues
SSender -> getSndQueue
SProxyService -> getSndQueue
SSender -> TM.lookupIO qId senders $>>= (`TM.lookupIO` queues)
SNotifier -> TM.lookupIO qId notifiers $>>= (`TM.lookupIO` queues)
SSenderLink -> TM.lookupIO qId links $>>= (`TM.lookupIO` queues)
where
STMQueueStore {queues, senders, notifiers, links} = st
getSndQueue = TM.lookupIO qId senders $>>= (`TM.lookupIO` queues)
getQueues_ :: BatchParty p => STMQueueStore q -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> [QueueId] -> IO [Either ErrorType q]
getQueues_ st _ party qIds = case party of
SRecipient -> do
qs <- readTVarIO queues
pure $ map (get qs) qIds
SNotifier -> do
ns <- readTVarIO notifiers
qs <- readTVarIO queues
pure $ map (get qs <=< get ns) qIds
where
STMQueueStore {queues, notifiers} = st
get :: M.Map QueueId a -> QueueId -> Either ErrorType a
get m = maybe (Left AUTH) Right . (`M.lookup` m)
getQueueLinkData :: STMQueueStore q -> q -> LinkId -> IO (Either ErrorType QueueLinkData)
getQueueLinkData _ q lnkId = atomically $ readQueueRec (queueRec q) $>>= pure . getData
@@ -292,7 +304,7 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where
serviceNtfQueues <- newTVar S.empty
pure STMService {serviceRec = sr, serviceRcvQueues, serviceNtfQueues}
setQueueService :: (PartyI p, SubscriberParty p) => STMQueueStore q -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ())
setQueueService :: (PartyI p, ServiceParty p) => STMQueueStore q -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ())
setQueueService st sq party serviceId =
atomically (readQueueRec qr $>>= setService)
$>> withLog "setQueueService" st (\sl -> logQueueService sl rId party serviceId)
@@ -301,13 +313,13 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where
rId = recipientId sq
setService :: QueueRec -> STM (Either ErrorType ())
setService q@QueueRec {rcvServiceId = prevSrvId} = case party of
SRecipient
SRecipientService
| prevSrvId == serviceId -> pure $ Right ()
| otherwise -> do
updateServiceQueues serviceRcvQueues rId prevSrvId
let !q' = Just q {rcvServiceId = serviceId}
writeTVar qr q' $> Right ()
SNotifier -> case notifier q of
SNotifierService -> case notifier q of
Nothing -> pure $ Left AUTH
Just nc@NtfCreds {notifierId = nId, ntfServiceId = prevNtfSrvId}
| prevNtfSrvId == serviceId -> pure $ Right ()
@@ -31,7 +31,8 @@ class StoreQueueClass q => QueueStoreClass q s where
loadedQueues :: s -> TMap RecipientId q
compactQueues :: s -> IO Int64
addQueue_ :: s -> (RecipientId -> QueueRec -> IO q) -> RecipientId -> QueueRec -> IO (Either ErrorType q)
getQueue_ :: DirectParty p => s -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> QueueId -> IO (Either ErrorType q)
getQueue_ :: QueueParty p => s -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> QueueId -> IO (Either ErrorType q)
getQueues_ :: BatchParty p => s -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> [QueueId] -> IO [Either ErrorType q]
getQueueLinkData :: s -> q -> LinkId -> IO (Either ErrorType QueueLinkData)
addQueueLinkData :: s -> q -> LinkId -> QueueLinkData -> IO (Either ErrorType ())
deleteQueueLinkData :: s -> q -> IO (Either ErrorType ())
@@ -45,7 +46,7 @@ class StoreQueueClass q => QueueStoreClass q s where
updateQueueTime :: s -> q -> RoundedSystemTime -> IO (Either ErrorType QueueRec)
deleteStoreQueue :: s -> q -> IO (Either ErrorType (QueueRec, Maybe (MsgQueue q)))
getCreateService :: s -> ServiceRec -> IO (Either ErrorType ServiceId)
setQueueService :: (PartyI p, SubscriberParty p) => s -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ())
setQueueService :: (PartyI p, ServiceParty p) => s -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ())
getQueueNtfServices :: s -> [(NotifierId, a)] -> IO (Either ErrorType ([(Maybe ServiceId, [(NotifierId, a)])], [(NotifierId, a)]))
getNtfServiceQueueCount :: s -> ServiceId -> IO (Either ErrorType Int64)