servers: maintain xor-hash of all associated queue IDs in PostgreSQL (#1668)

* servers: maintain xor-hash of all associated queue IDs in PostgreSQL (#1615)

* ntf server: maintain xor-hash of all associated queue IDs via PostgreSQL triggers

* smp server: xor hash with triggers

* fix sql and using pgcrypto extension in tests

* track counts and hashes in smp/ntf servers via triggers, smp server stats for service subscription, update SMP protocol to pass expected count and hash in SSUB/NSSUB commands

* agent migrations with functions/triggers

* remove agent triggers

* try tracking service subs in the agent (WIP, does not compile)

* Revert "try tracking service subs in the agent (WIP, does not compile)"

This reverts commit 59e908100d.

* comment

* agent database triggers

* service subscriptions in the client

* test / fix client services

* update schema

* fix postgres migration

* update schema

* move schema test to the end

* use static function with SQLite to avoid dynamic wrapper
This commit is contained in:
Evgeny
2025-11-25 16:55:59 +00:00
committed by GitHub
parent 1ca4677b28
commit 3ccf854865
44 changed files with 2969 additions and 331 deletions
+15 -12
View File
@@ -211,7 +211,6 @@ import Simplex.Messaging.Protocol
ErrorType (AUTH),
MsgBody,
MsgFlags (..),
IdsHash,
NtfServer,
ProtoServerWithAuth (..),
ProtocolServer (..),
@@ -222,6 +221,7 @@ import Simplex.Messaging.Protocol
SMPMsgMeta,
SParty (..),
SProtocolType (..),
ServiceSub (..),
SndPublicAuthKey,
SubscriptionMode (..),
UserProtocol,
@@ -500,7 +500,7 @@ resubscribeConnections :: AgentClient -> [ConnId] -> AE (Map ConnId (Either Agen
resubscribeConnections c = withAgentEnv c . resubscribeConnections' c
{-# INLINE resubscribeConnections #-}
subscribeClientServices :: AgentClient -> UserId -> AE (Map SMPServer (Either AgentErrorType (Int64, IdsHash)))
subscribeClientServices :: AgentClient -> UserId -> AE (Map SMPServer (Either AgentErrorType ServiceSub))
subscribeClientServices c = withAgentEnv c . subscribeClientServices' c
{-# INLINE subscribeClientServices #-}
@@ -594,6 +594,7 @@ testProtocolServer c nm userId srv = withAgentEnv' c $ case protocolTypeI @p of
SPNTF -> runNTFServerTest c nm userId srv
-- | set SOCKS5 proxy on/off and optionally set TCP timeouts for fast network
-- TODO [certs rcv] should fail if any user is enabled to use services and per-connection isolation is chosen
setNetworkConfig :: AgentClient -> NetworkConfig -> IO ()
setNetworkConfig c@AgentClient {useNetworkConfig, proxySessTs} cfg' = do
ts <- getCurrentTime
@@ -771,6 +772,7 @@ deleteUser' c@AgentClient {smpServersStats, xftpServersStats} userId delSMPQueue
whenM (withStore' c (`deleteUserWithoutConns` userId)) . atomically $
writeTBQueue (subQ c) ("", "", AEvt SAENone $ DEL_USER userId)
-- TODO [certs rcv] should fail enabling if per-connection isolation is set
setUserService' :: AgentClient -> UserId -> Bool -> AM ()
setUserService' c userId enable = do
wasEnabled <- liftIO $ fromMaybe False <$> TM.lookupIO userId (useClientServices c)
@@ -1507,15 +1509,15 @@ resubscribeConnections' c connIds = do
[] -> pure True
rqs' -> anyM $ map (atomically . hasActiveSubscription c) rqs'
-- TODO [certs rcv] compare hash with lock
subscribeClientServices' :: AgentClient -> UserId -> AM (Map SMPServer (Either AgentErrorType (Int64, IdsHash)))
-- TODO [certs rcv] compare hash. possibly, it should return both expected and returned counts
subscribeClientServices' :: AgentClient -> UserId -> AM (Map SMPServer (Either AgentErrorType ServiceSub))
subscribeClientServices' c userId =
ifM useService subscribe $ throwError $ CMD PROHIBITED "no user service allowed"
where
useService = liftIO $ (Just True ==) <$> TM.lookupIO userId (useClientServices c)
subscribe = do
srvs <- withStore' c (`getClientServiceServers` userId)
lift $ M.fromList . zip srvs <$> mapConcurrently (tryAllErrors' . subscribeClientService c userId) srvs
lift $ M.fromList <$> mapConcurrently (\(srv, ServiceSub _ n idsHash) -> fmap (srv,) $ tryAllErrors' $ subscribeClientService c userId srv n idsHash) srvs
-- requesting messages sequentially, to reduce memory usage
getConnectionMessages' :: AgentClient -> NonEmpty ConnMsgReq -> AM' (NonEmpty (Either AgentErrorType (Maybe SMPMsgMeta)))
@@ -2829,12 +2831,13 @@ processSMPTransmissions :: AgentClient -> ServerTransmissionBatch SMPVersion Err
processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId, ts) = do
upConnIds <- newTVarIO []
forM_ ts $ \(entId, t) -> case t of
STEvent msgOrErr ->
withRcvConn entId $ \rq@RcvQueue {connId} conn -> case msgOrErr of
Right msg -> runProcessSMP rq conn (toConnData conn) msg
Left e -> lift $ do
processClientNotice rq e
notifyErr connId e
STEvent msgOrErr
| entId == SMP.NoEntity -> pure () -- TODO [certs rcv] process SALL
| otherwise -> withRcvConn entId $ \rq@RcvQueue {connId} conn -> case msgOrErr of
Right msg -> runProcessSMP rq conn (toConnData conn) msg
Left e -> lift $ do
processClientNotice rq e
notifyErr connId e
STResponse (Cmd SRecipient cmd) respOrErr ->
withRcvConn entId $ \rq conn -> case cmd of
SMP.SUB -> case respOrErr of
@@ -2870,7 +2873,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId
processSubOk :: RcvQueue -> TVar [ConnId] -> IO ()
processSubOk rq@RcvQueue {connId} upConnIds =
atomically . whenM (isPendingSub rq) $ do
SS.addActiveSub tSess sessId (rcvQueueSub rq) $ currentSubs c
SS.addActiveSub tSess sessId rq $ currentSubs c
modifyTVar' upConnIds (connId :)
processSubErr :: RcvQueue -> SMPClientError -> AM' ()
processSubErr rq@RcvQueue {connId} e = do
+86 -38
View File
@@ -241,7 +241,7 @@ import Simplex.Messaging.Agent.RetryInterval
import Simplex.Messaging.Agent.Stats
import Simplex.Messaging.Agent.Store
import Simplex.Messaging.Agent.Store.AgentStore
import Simplex.Messaging.Agent.Store.Common (DBStore, withTransaction)
import Simplex.Messaging.Agent.Store.Common (DBStore)
import qualified Simplex.Messaging.Agent.Store.DB as DB
import Simplex.Messaging.Agent.Store.Entity
import Simplex.Messaging.Agent.TSessionSubs (TSessionSubs)
@@ -279,6 +279,7 @@ import Simplex.Messaging.Protocol
RcvNtfPublicDhKey,
SMPMsgMeta (..),
SProtocolType (..),
ServiceSub (..),
SndPublicAuthKey,
SubscriptionMode (..),
NewNtfCreds (..),
@@ -499,6 +500,7 @@ data UserNetworkType = UNNone | UNCellular | UNWifi | UNEthernet | UNOther
deriving (Eq, Show)
-- | Creates an SMP agent client instance that receives commands and sends responses via 'TBQueue's.
-- TODO [certs rcv] should fail if both per-connection isolation is set and any users use services
newAgentClient :: Int -> InitialAgentServers -> UTCTime -> Map (Maybe SMPServer) (Maybe SystemSeconds) -> Env -> IO AgentClient
newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg, useServices, presetDomains, presetServers} currentTs notices agentEnv = do
let cfg = config agentEnv
@@ -622,9 +624,8 @@ getServiceCredentials c userId srv =
let tlsCreds = tlsCredentials [cred]
createClientService db userId srv tlsCreds
pure (tlsCreds, Nothing)
(_, pk) <- atomically $ C.generateKeyPair g
let serviceSignKey = C.APrivateSignKey C.SEd25519 pk
creds = ServiceCredentials {serviceRole = SRMessaging, serviceCreds, serviceCertHash = XV.Fingerprint kh, serviceSignKey}
serviceSignKey <- liftEitherWith INTERNAL $ C.x509ToPrivate' $ snd serviceCreds
let creds = ServiceCredentials {serviceRole = SRMessaging, serviceCreds, serviceCertHash = XV.Fingerprint kh, serviceSignKey}
pure (creds, serviceId_)
class (Encoding err, Show err) => ProtocolServerClient v err msg | msg -> v, msg -> err where
@@ -744,9 +745,11 @@ smpConnectClient c@AgentClient {smpClients, msgQ, proxySessTs, presetDomains} nm
smp <- liftError (protocolClientError SMP $ B.unpack $ strEncode srv) $ do
ts <- readTVarIO proxySessTs
ExceptT $ getProtocolClient g nm tSess cfg' presetDomains (Just msgQ) ts $ smpClientDisconnected c tSess env v' prs
-- TODO [certs rcv] add service to SS, possibly combine with SS.setSessionId
atomically $ SS.setSessionId tSess (sessionId $ thParams smp) $ currentSubs c
updateClientService service smp
pure SMPConnectedClient {connectedClient = smp, proxiedRelays = prs}
-- TODO [certs rcv] this should differentiate between service ID just set and service ID changed, and in the latter case disassociate the queue
updateClientService service smp = case (service, smpClientService smp) of
(Just (_, serviceId_), Just THClientService {serviceId})
| serviceId_ /= Just serviceId -> withStore' c $ \db -> setClientServiceId db userId srv serviceId
@@ -763,32 +766,34 @@ smpClientDisconnected c@AgentClient {active, smpClients, smpProxiedRelays} tSess
-- we make active subscriptions pending only if the client for tSess was current (in the map) and active,
-- because we can have a race condition when a new current client could have already
-- made subscriptions active, and the old client would be processing diconnection later.
removeClientAndSubs :: IO ([RcvQueueSub], [ConnId])
removeClientAndSubs :: IO ([RcvQueueSub], [ConnId], Maybe ServiceSub)
removeClientAndSubs = atomically $ do
removeSessVar v tSess smpClients
ifM (readTVar active) removeSubs (pure ([], []))
ifM (readTVar active) removeSubs (pure ([], [], Nothing))
where
sessId = sessionId $ thParams client
removeSubs = do
mode <- getSessionMode c
subs <- SS.setSubsPending mode tSess sessId $ currentSubs c
(subs, serviceSub_) <- SS.setSubsPending mode tSess sessId $ currentSubs c
let qs = M.elems subs
cs = nubOrd $ map qConnId qs
-- this removes proxied relays that this client created sessions to
destSrvs <- M.keys <$> readTVar prs
forM_ destSrvs $ \destSrv -> TM.delete (userId, destSrv, cId) smpProxiedRelays
pure (qs, cs)
pure (qs, cs, serviceSub_)
serverDown :: ([RcvQueueSub], [ConnId]) -> IO ()
serverDown (qs, conns) = whenM (readTVarIO active) $ do
serverDown :: ([RcvQueueSub], [ConnId], Maybe ServiceSub) -> IO ()
serverDown (qs, conns, serviceSub_) = whenM (readTVarIO active) $ do
notifySub c $ hostEvent' DISCONNECT client
unless (null conns) $ notifySub c $ DOWN srv conns
unless (null qs) $ do
unless (null qs && isNothing serviceSub_) $ do
releaseGetLocksIO c qs
mode <- getSessionModeIO c
let resubscribe
| (mode == TSMEntity) == isJust cId = resubscribeSMPSession c tSess
| otherwise = void $ subscribeQueues c True qs
| otherwise = do
mapM_ (runExceptT . resubscribeClientService c tSess) serviceSub_
unless (null qs) $ void $ subscribeQueues c True qs
runReaderT resubscribe env
resubscribeSMPSession :: AgentClient -> SMPTransportSession -> AM' ()
@@ -807,11 +812,12 @@ resubscribeSMPSession c@AgentClient {smpSubWorkers, workerSeq} tSess = do
runSubWorker = do
ri <- asks $ reconnectInterval . config
withRetryForeground ri isForeground (isNetworkOnline c) $ \_ loop -> do
pending <- atomically $ SS.getPendingSubs tSess $ currentSubs c
unless (M.null pending) $ do
(pendingSubs, pendingSS) <- atomically $ SS.getPendingSubs tSess $ currentSubs c
unless (M.null pendingSubs && isNothing pendingSS) $ do
liftIO $ waitUntilForeground c
liftIO $ waitForUserNetwork c
handleNotify $ resubscribeSessQueues c tSess $ M.elems pending
mapM_ (handleNotify . void . runExceptT . resubscribeClientService c tSess) pendingSS
unless (M.null pendingSubs) $ handleNotify $ resubscribeSessQueues c tSess $ M.elems pendingSubs
loop
isForeground = (ASForeground ==) <$> readTVar (agentState c)
cleanup :: SessionVar (Async ()) -> STM ()
@@ -1508,25 +1514,25 @@ newRcvQueue_ c nm userId connId (ProtoServerWithAuth srv auth) vRange cqrd enabl
newErr :: String -> AM (Maybe ShortLinkCreds)
newErr = throwE . BROKER (B.unpack $ strEncode srv) . UNEXPECTED . ("Create queue: " <>)
processSubResults :: AgentClient -> SMPTransportSession -> SessionId -> NonEmpty (RcvQueueSub, Either SMPClientError (Maybe ServiceId)) -> STM [(RcvQueueSub, Maybe ClientNotice)]
processSubResults c tSess@(userId, srv, _) sessId rs = do
pendingSubs <- SS.getPendingSubs tSess $ currentSubs c
let (failed, subscribed, notices, ignored) = foldr (partitionResults pendingSubs) (M.empty, [], [], 0) rs
processSubResults :: AgentClient -> SMPTransportSession -> SessionId -> Maybe ServiceId -> NonEmpty (RcvQueueSub, Either SMPClientError (Maybe ServiceId)) -> STM ([RcvQueueSub], [(RcvQueueSub, Maybe ClientNotice)])
processSubResults c tSess@(userId, srv, _) sessId smpServiceId rs = do
pending <- SS.getPendingSubs tSess $ currentSubs c
let (failed, subscribed@(qs, sQs), notices, ignored) = foldr (partitionResults pending) (M.empty, ([], []), [], 0) rs
unless (M.null failed) $ do
incSMPServerStat' c userId srv connSubErrs $ M.size failed
failSubscriptions c tSess failed
unless (null subscribed) $ do
incSMPServerStat' c userId srv connSubscribed $ length subscribed
unless (null qs && null sQs) $ do
incSMPServerStat' c userId srv connSubscribed $ length qs + length sQs
SS.batchAddActiveSubs tSess sessId subscribed $ currentSubs c
unless (ignored == 0) $ incSMPServerStat' c userId srv connSubIgnored ignored
pure notices
pure (sQs, notices)
where
partitionResults ::
Map SMP.RecipientId RcvQueueSub ->
(Map SMP.RecipientId RcvQueueSub, Maybe ServiceSub) ->
(RcvQueueSub, Either SMPClientError (Maybe ServiceId)) ->
(Map SMP.RecipientId SMPClientError, [RcvQueueSub], [(RcvQueueSub, Maybe ClientNotice)], Int) ->
(Map SMP.RecipientId SMPClientError, [RcvQueueSub], [(RcvQueueSub, Maybe ClientNotice)], Int)
partitionResults pendingSubs (rq@RcvQueueSub {rcvId, clientNoticeId}, r) acc@(failed, subscribed, notices, ignored) = case r of
(Map SMP.RecipientId SMPClientError, ([RcvQueueSub], [RcvQueueSub]), [(RcvQueueSub, Maybe ClientNotice)], Int) ->
(Map SMP.RecipientId SMPClientError, ([RcvQueueSub], [RcvQueueSub]), [(RcvQueueSub, Maybe ClientNotice)], Int)
partitionResults (pendingSubs, pendingSS) (rq@RcvQueueSub {rcvId, clientNoticeId}, r) acc@(failed, subscribed@(qs, sQs), notices, ignored) = case r of
Left e -> case smpErrorClientNotice e of
Just notice_ -> (failed', subscribed, (rq, notice_) : notices, ignored)
where
@@ -1536,8 +1542,12 @@ processSubResults c tSess@(userId, srv, _) sessId rs = do
| otherwise -> (failed', subscribed, notices, ignored)
where
failed' = M.insert rcvId e failed
Right _serviceId -- TODO [certs rcv] store association with the service
| rcvId `M.member` pendingSubs -> (failed, rq : subscribed, notices', ignored)
Right serviceId_
| rcvId `M.member` pendingSubs ->
let subscribed' = case (smpServiceId, serviceId_, pendingSS) of
(Just sId, Just sId', Just ServiceSub {serviceId}) | sId == sId' && sId == serviceId -> (qs, rq : sQs)
_ -> (rq : qs, sQs)
in (failed, subscribed', notices', ignored)
| otherwise -> (failed, subscribed, notices', ignored + 1)
where
notices' = if isJust clientNoticeId then (rq, Nothing) : notices else notices
@@ -1576,6 +1586,7 @@ serverHostError = \case
-- | Batch by transport session and subscribe queues. The list of results can have a different order.
subscribeQueues :: AgentClient -> Bool -> [RcvQueueSub] -> AM' [(RcvQueueSub, Either AgentErrorType (Maybe ServiceId))]
subscribeQueues _ _ [] = pure []
subscribeQueues c withEvents qs = do
(errs, qs') <- checkQueues c qs
atomically $ modifyTVar' (subscrConns c) (`S.union` S.fromList (map qConnId qs'))
@@ -1632,6 +1643,7 @@ checkQueues c = fmap partitionEithers . mapM checkQueue
-- This function expects that all queues belong to one transport session,
-- and that they are already added to pending subscriptions.
resubscribeSessQueues :: AgentClient -> SMPTransportSession -> [RcvQueueSub] -> AM' ()
resubscribeSessQueues _ _ [] = pure ()
resubscribeSessQueues c tSess qs = do
(errs, qs_) <- checkQueues c qs
forM_ (L.nonEmpty qs_) $ \qs' -> void $ subscribeSessQueues_ c True (tSess, qs')
@@ -1650,13 +1662,15 @@ subscribeSessQueues_ c withEvents qs = sendClientBatch_ "SUB" False subscribe_ c
then Just . S.fromList . map qConnId . M.elems <$> atomically (SS.getActiveSubs tSess $ currentSubs c)
else pure Nothing
active <- E.uninterruptibleMask_ $ do
(active, notices) <- atomically $ do
r@(_, notices) <- ifM
(active, (serviceQs, notices)) <- atomically $ do
r@(_, (_, notices)) <- ifM
(activeClientSession c tSess sessId)
((True,) <$> processSubResults c tSess sessId rs)
((False, []) <$ incSMPServerStat' c userId srv connSubIgnored (length rs))
((True,) <$> processSubResults c tSess sessId smpServiceId rs)
((False, ([], [])) <$ incSMPServerStat' c userId srv connSubIgnored (length rs))
unless (null notices) $ takeTMVar $ clientNoticesLock c
pure r
unless (null serviceQs) $ void $
processRcvServiceAssocs c serviceQs `runReaderT` agentEnv c
unless (null notices) $ void $
(processClientNotices c tSess notices `runReaderT` agentEnv c)
`E.finally` atomically (putTMVar (clientNoticesLock c) ())
@@ -1677,6 +1691,13 @@ subscribeSessQueues_ c withEvents qs = sendClientBatch_ "SUB" False subscribe_ c
where
tSess = transportSession' smp
sessId = sessionId $ thParams smp
smpServiceId = (\THClientService {serviceId} -> serviceId) <$> smpClientService smp
processRcvServiceAssocs :: AgentClient -> [RcvQueueSub] -> AM' ()
processRcvServiceAssocs c serviceQs =
withStore' c (`setRcvServiceAssocs` serviceQs) `catchAllErrors'` \e -> do
logError $ "processClientNotices error: " <> tshow e
notifySub' c "" $ ERR e
processClientNotices :: AgentClient -> SMPTransportSession -> [(RcvQueueSub, Maybe ClientNotice)] -> AM' ()
processClientNotices c@AgentClient {presetServers} tSess notices = do
@@ -1689,10 +1710,35 @@ processClientNotices c@AgentClient {presetServers} tSess notices = do
logError $ "processClientNotices error: " <> tshow e
notifySub' c "" $ ERR e
subscribeClientService :: AgentClient -> UserId -> SMPServer -> AM (Int64, IdsHash)
subscribeClientService c userId srv =
withLogClient c NRMBackground (userId, srv, Nothing) B.empty "SUBS" $
(`subscribeService` SMP.SRecipientService) . connectedClient
resubscribeClientService :: AgentClient -> SMPTransportSession -> ServiceSub -> AM ServiceSub
resubscribeClientService c tSess (ServiceSub _ n idsHash) =
withServiceClient c tSess $ \smp _ -> do
subscribeClientService_ c tSess smp n idsHash
subscribeClientService :: AgentClient -> UserId -> SMPServer -> Int64 -> IdsHash -> AM ServiceSub
subscribeClientService c userId srv n idsHash =
withServiceClient c tSess $ \smp smpServiceId -> do
let serviceSub = ServiceSub smpServiceId n idsHash
atomically $ SS.setPendingServiceSub tSess serviceSub $ currentSubs c
subscribeClientService_ c tSess smp n idsHash
where
tSess = (userId, srv, Nothing)
withServiceClient :: AgentClient -> SMPTransportSession -> (SMPClient -> ServiceId -> ExceptT SMPClientError IO a) -> AM a
withServiceClient c tSess action =
withLogClient c NRMBackground tSess B.empty "SUBS" $ \(SMPConnectedClient smp _) ->
case (\THClientService {serviceId} -> serviceId) <$> smpClientService smp of
Just smpServiceId -> action smp smpServiceId
Nothing -> throwE PCEServiceUnavailable
subscribeClientService_ :: AgentClient -> SMPTransportSession -> SMPClient -> Int64 -> IdsHash -> ExceptT SMPClientError IO ServiceSub
subscribeClientService_ c tSess smp n idsHash = do
-- TODO [certs rcv] handle error
serviceSub' <- subscribeService smp SMP.SRecipientService n idsHash
let sessId = sessionId $ thParams smp
atomically $ whenM (activeClientSession c tSess sessId) $
SS.setActiveServiceSub tSess sessId serviceSub' $ currentSubs c
pure serviceSub'
activeClientSession :: AgentClient -> SMPTransportSession -> SessionId -> STM Bool
activeClientSession c tSess sessId = sameSess <$> tryReadSessVar tSess (smpClients c)
@@ -1762,7 +1808,7 @@ addNewQueueSubscription c rq' tSess sessId = do
modifyTVar' (subscrConns c) $ S.insert $ qConnId rq
active <- activeClientSession c tSess sessId
if active
then SS.addActiveSub tSess sessId rq $ currentSubs c
then SS.addActiveSub tSess sessId rq' $ currentSubs c
else SS.addPendingSub tSess rq $ currentSubs c
pure active
unless same $ resubscribeSMPSession c tSess
@@ -1951,6 +1997,7 @@ releaseGetLock c rq =
{-# INLINE releaseGetLock #-}
releaseGetLocksIO :: SomeRcvQueue q => AgentClient -> [q] -> IO ()
releaseGetLocksIO _ [] = pure ()
releaseGetLocksIO c rqs = do
locks <- readTVarIO $ getMsgLocks c
forM_ rqs $ \rq ->
@@ -2301,7 +2348,8 @@ withStore c action = do
[ E.Handler $ \(e :: SQL.SQLError) ->
let se = SQL.sqlError e
busy = se == SQL.ErrorBusy || se == SQL.ErrorLocked
in pure . Left . (if busy then SEDatabaseBusy else SEInternal) $ bshow se,
err = tshow se <> ": " <> SQL.sqlErrorDetails e <> ", " <> SQL.sqlErrorContext e
in pure . Left . (if busy then SEDatabaseBusy else SEInternal) $ encodeUtf8 err,
E.Handler $ \(E.SomeException e) -> pure . Left $ SEInternal $ bshow e
]
#endif
@@ -314,7 +314,7 @@ runNtfWorker c srv Worker {doWork} =
_ -> ((ntfSubConnId sub, INTERNAL "NSACheck - no subscription ID") : errs, subs, subIds)
updateSub :: DB.Connection -> NtfServer -> UTCTime -> UTCTime -> (NtfSubscription, NtfSubStatus) -> IO (Maybe SMPServer)
updateSub db ntfServer ts nextCheckTs (sub, status)
| ntfShouldSubscribe status =
| status `elem` subscribeNtfStatuses =
let sub' = sub {ntfSubStatus = NASCreated status}
in Nothing <$ updateNtfSubscription db sub' (NSANtf NSACheck) nextCheckTs
-- ntf server stopped subscribing to this queue
+27 -14
View File
@@ -53,6 +53,7 @@ module Simplex.Messaging.Agent.Store.AgentStore
getSubscriptionServers,
getUserServerRcvQueueSubs,
unsetQueuesToSubscribe,
setRcvServiceAssocs,
getConnIds,
getConn,
getDeletedConn,
@@ -401,29 +402,31 @@ deleteUsersWithoutConns db = do
pure userIds
createClientService :: DB.Connection -> UserId -> SMPServer -> (C.KeyHash, TLS.Credential) -> IO ()
createClientService db userId srv (kh, (cert, pk)) =
createClientService db userId srv (kh, (cert, pk)) = do
serverKeyHash_ <- createServer_ db srv
DB.execute
db
[sql|
INSERT INTO client_services
(user_id, host, port, service_cert_hash, service_cert, service_priv_key)
VALUES (?,?,?,?,?,?)
ON CONFLICT (user_id, host, port)
(user_id, host, port, server_key_hash, service_cert_hash, service_cert, service_priv_key)
VALUES (?,?,?,?,?,?,?)
ON CONFLICT (user_id, host, port, server_key_hash)
DO UPDATE SET
service_cert_hash = EXCLUDED.service_cert_hash,
service_cert = EXCLUDED.service_cert,
service_priv_key = EXCLUDED.service_priv_key,
rcv_service_id = NULL
service_id = NULL
|]
(userId, host srv, port srv, kh, cert, pk)
(userId, host srv, port srv, serverKeyHash_, kh, cert, pk)
-- TODO [certs rcv] get correct service based on key hash of the server
getClientService :: DB.Connection -> UserId -> SMPServer -> IO (Maybe ((C.KeyHash, TLS.Credential), Maybe ServiceId))
getClientService db userId srv =
maybeFirstRow toService $
DB.query
db
[sql|
SELECT service_cert_hash, service_cert, service_priv_key, rcv_service_id
SELECT service_cert_hash, service_cert, service_priv_key, service_id
FROM client_services
WHERE user_id = ? AND host = ? AND port = ?
|]
@@ -431,19 +434,21 @@ getClientService db userId srv =
where
toService (kh, cert, pk, serviceId_) = ((kh, (cert, pk)), serviceId_)
getClientServiceServers :: DB.Connection -> UserId -> IO [SMPServer]
getClientServiceServers :: DB.Connection -> UserId -> IO [(SMPServer, ServiceSub)]
getClientServiceServers db userId =
map toServer
<$> DB.query
db
[sql|
SELECT c.host, c.port, s.key_hash
SELECT c.host, c.port, s.key_hash, c.service_id, c.service_queue_count, c.service_queue_ids_hash
FROM client_services c
JOIN servers s ON s.host = c.host AND s.port = c.port
WHERE c.user_id = ?
|]
(Only userId)
where
toServer (host, port, kh) = SMPServer host port kh
toServer (host, port, kh, serviceId, n, Binary idsHash) =
(SMPServer host port kh, ServiceSub serviceId n (IdsHash idsHash))
setClientServiceId :: DB.Connection -> UserId -> SMPServer -> ServiceId -> IO ()
setClientServiceId db userId srv serviceId =
@@ -451,7 +456,7 @@ setClientServiceId db userId srv serviceId =
db
[sql|
UPDATE client_services
SET rcv_service_id = ?
SET service_id = ?
WHERE user_id = ? AND host = ? AND port = ?
|]
(serviceId, userId, host srv, port srv)
@@ -2099,7 +2104,7 @@ insertRcvQueue_ db connId' rq@RcvQueue {..} subMode serverKeyHash_ = do
ntf_public_key, ntf_private_key, ntf_id, rcv_ntf_dh_secret
) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?);
|]
( (host server, port server, rcvId, rcvServiceAssoc, connId', rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret)
( (host server, port server, rcvId, BI rcvServiceAssoc, connId', rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret)
:. (sndId, queueMode, status, BI toSubscribe, qId, BI primary, dbReplaceQueueId, smpClientVersion, serverKeyHash_)
:. (shortLinkId <$> shortLink, shortLinkKey <$> shortLink, linkPrivSigKey <$> shortLink, linkEncFixedData <$> shortLink)
:. ntfCredsFields
@@ -2248,6 +2253,14 @@ getUserServerRcvQueueSubs db userId srv onlyNeeded =
unsetQueuesToSubscribe :: DB.Connection -> IO ()
unsetQueuesToSubscribe db = DB.execute_ db "UPDATE rcv_queues SET to_subscribe = 0 WHERE to_subscribe = 1"
setRcvServiceAssocs :: DB.Connection -> [RcvQueueSub] -> IO ()
setRcvServiceAssocs db rqs =
#if defined(dbPostgres)
DB.execute db "UPDATE rcv_queues SET rcv_service_assoc = 1 WHERE rcv_id IN " $ Only $ In (map queueId rqs)
#else
DB.executeMany db "UPDATE rcv_queues SET rcv_service_assoc = 1 WHERE rcv_id = " $ map (Only . queueId) rqs
#endif
-- * getConn helpers
getConnIds :: DB.Connection -> IO [ConnId]
@@ -2468,13 +2481,13 @@ rcvQueueQuery =
toRcvQueue ::
(UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SMP.RecipientId, SMP.RcvPrivateAuthKey, SMP.RcvDhSecret, C.PrivateKeyX25519, Maybe C.DhSecretX25519, SMP.SenderId, Maybe QueueMode)
:. (QueueStatus, Maybe BoolInt, Maybe NoticeId, DBEntityId, BoolInt, Maybe Int64, Maybe RcvSwitchStatus, Maybe VersionSMPC, Int, ServiceAssoc)
:. (QueueStatus, Maybe BoolInt, Maybe NoticeId, DBEntityId, BoolInt, Maybe Int64, Maybe RcvSwitchStatus, Maybe VersionSMPC, Int, BoolInt)
:. (Maybe SMP.NtfPublicAuthKey, Maybe SMP.NtfPrivateAuthKey, Maybe SMP.NotifierId, Maybe RcvNtfDhSecret)
:. (Maybe SMP.LinkId, Maybe LinkKey, Maybe C.PrivateKeyEd25519, Maybe EncDataBytes) ->
RcvQueue
toRcvQueue
( (userId, keyHash, connId, host, port, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, queueMode)
:. (status, enableNtfs_, clientNoticeId, dbQueueId, BI primary, dbReplaceQueueId, rcvSwchStatus, smpClientVersion_, deleteErrors, rcvServiceAssoc)
:. (status, enableNtfs_, clientNoticeId, dbQueueId, BI primary, dbReplaceQueueId, rcvSwchStatus, smpClientVersion_, deleteErrors, BI rcvServiceAssoc)
:. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_)
:. (shortLinkId_, shortLinkKey_, linkPrivSigKey_, linkEncFixedData_)
) =
@@ -10,6 +10,7 @@ import Simplex.Messaging.Agent.Store.Postgres.Migrations.M20250322_short_links
import Simplex.Messaging.Agent.Store.Postgres.Migrations.M20250702_conn_invitations_remove_cascade_delete
import Simplex.Messaging.Agent.Store.Postgres.Migrations.M20251009_queue_to_subscribe
import Simplex.Messaging.Agent.Store.Postgres.Migrations.M20251010_client_notices
import Simplex.Messaging.Agent.Store.Postgres.Migrations.M20251020_service_certs
import Simplex.Messaging.Agent.Store.Shared (Migration (..))
schemaMigrations :: [(String, Text, Maybe Text)]
@@ -19,7 +20,8 @@ schemaMigrations =
("20250322_short_links", m20250322_short_links, Just down_m20250322_short_links),
("20250702_conn_invitations_remove_cascade_delete", m20250702_conn_invitations_remove_cascade_delete, Just down_m20250702_conn_invitations_remove_cascade_delete),
("20251009_queue_to_subscribe", m20251009_queue_to_subscribe, Just down_m20251009_queue_to_subscribe),
("20251010_client_notices", m20251010_client_notices, Just down_m20251010_client_notices)
("20251010_client_notices", m20251010_client_notices, Just down_m20251010_client_notices),
("20251020_service_certs", m20251020_service_certs, Just down_m20251020_service_certs)
]
-- | The list of migrations in ascending order by date
@@ -0,0 +1,114 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuasiQuotes #-}
module Simplex.Messaging.Agent.Store.Postgres.Migrations.M20251020_service_certs where
import Data.Text (Text)
import Simplex.Messaging.Agent.Store.Postgres.Migrations.Util
import Text.RawString.QQ (r)
m20251020_service_certs :: Text
m20251020_service_certs =
createXorHashFuncs <> [r|
CREATE TABLE client_services(
user_id BIGINT NOT NULL REFERENCES users ON UPDATE RESTRICT ON DELETE CASCADE,
host TEXT NOT NULL,
port TEXT NOT NULL,
server_key_hash BYTEA,
service_cert BYTEA NOT NULL,
service_cert_hash BYTEA NOT NULL,
service_priv_key BYTEA NOT NULL,
service_id BYTEA,
service_queue_count BIGINT NOT NULL DEFAULT 0,
service_queue_ids_hash BYTEA NOT NULL DEFAULT '\x00000000000000000000000000000000',
FOREIGN KEY(host, port) REFERENCES servers ON DELETE RESTRICT
);
CREATE UNIQUE INDEX idx_server_certs_user_id_host_port ON client_services(user_id, host, port, server_key_hash);
CREATE INDEX idx_server_certs_host_port ON client_services(host, port);
ALTER TABLE rcv_queues ADD COLUMN rcv_service_assoc SMALLINT NOT NULL DEFAULT 0;
CREATE FUNCTION update_aggregates(p_conn_id BYTEA, p_host TEXT, p_port TEXT, p_change BIGINT, p_rcv_id BYTEA) RETURNS VOID
LANGUAGE plpgsql
AS $$
DECLARE q_user_id BIGINT;
BEGIN
SELECT user_id INTO q_user_id FROM connections WHERE conn_id = p_conn_id;
UPDATE client_services
SET service_queue_count = service_queue_count + p_change,
service_queue_ids_hash = xor_combine(service_queue_ids_hash, public.digest(p_rcv_id, 'md5'))
WHERE user_id = q_user_id AND host = p_host AND port = p_port;
END;
$$;
CREATE FUNCTION on_rcv_queue_insert() RETURNS TRIGGER
LANGUAGE plpgsql
AS $$
BEGIN
IF NEW.rcv_service_assoc != 0 AND NEW.deleted = 0 THEN
PERFORM update_aggregates(NEW.conn_id, NEW.host, NEW.port, 1, NEW.rcv_id);
END IF;
RETURN NEW;
END;
$$;
CREATE FUNCTION on_rcv_queue_delete() RETURNS TRIGGER
LANGUAGE plpgsql
AS $$
BEGIN
IF OLD.rcv_service_assoc != 0 AND OLD.deleted = 0 THEN
PERFORM update_aggregates(OLD.conn_id, OLD.host, OLD.port, -1, OLD.rcv_id);
END IF;
RETURN OLD;
END;
$$;
CREATE FUNCTION on_rcv_queue_update() RETURNS TRIGGER
LANGUAGE plpgsql
AS $$
BEGIN
IF OLD.rcv_service_assoc != 0 AND OLD.deleted = 0 THEN
IF NOT (NEW.rcv_service_assoc != 0 AND NEW.deleted = 0) THEN
PERFORM update_aggregates(OLD.conn_id, OLD.host, OLD.port, -1, OLD.rcv_id);
END IF;
ELSIF NEW.rcv_service_assoc != 0 AND NEW.deleted = 0 THEN
PERFORM update_aggregates(NEW.conn_id, NEW.host, NEW.port, 1, NEW.rcv_id);
END IF;
RETURN NEW;
END;
$$;
CREATE TRIGGER tr_rcv_queue_insert
AFTER INSERT ON rcv_queues
FOR EACH ROW EXECUTE PROCEDURE on_rcv_queue_insert();
CREATE TRIGGER tr_rcv_queue_delete
AFTER DELETE ON rcv_queues
FOR EACH ROW EXECUTE PROCEDURE on_rcv_queue_delete();
CREATE TRIGGER tr_rcv_queue_update
AFTER UPDATE ON rcv_queues
FOR EACH ROW EXECUTE PROCEDURE on_rcv_queue_update();
|]
down_m20251020_service_certs :: Text
down_m20251020_service_certs =
[r|
DROP TRIGGER tr_rcv_queue_insert ON rcv_queues;
DROP TRIGGER tr_rcv_queue_delete ON rcv_queues;
DROP TRIGGER tr_rcv_queue_update ON rcv_queues;
DROP FUNCTION on_rcv_queue_insert;
DROP FUNCTION on_rcv_queue_delete;
DROP FUNCTION on_rcv_queue_update;
DROP FUNCTION update_aggregates;
ALTER TABLE rcv_queues DROP COLUMN rcv_service_assoc;
DROP INDEX idx_server_certs_host_port;
DROP INDEX idx_server_certs_user_id_host_port;
DROP TABLE client_services;
|]
<> dropXorHashFuncs
@@ -0,0 +1,46 @@
{-# LANGUAGE QuasiQuotes #-}
module Simplex.Messaging.Agent.Store.Postgres.Migrations.Util where
import Data.Text (Text)
import qualified Data.Text as T
import Text.RawString.QQ (r)
-- xor_combine is only applied to locally computed md5 hashes (128 bits/16 bytes),
-- so it is safe to require that all values are of the same length.
createXorHashFuncs :: Text
createXorHashFuncs =
T.pack
[r|
CREATE OR REPLACE FUNCTION xor_combine(state BYTEA, value BYTEA) RETURNS BYTEA
LANGUAGE plpgsql IMMUTABLE STRICT
AS $$
DECLARE
result BYTEA := state;
i INTEGER;
len INTEGER := octet_length(value);
BEGIN
IF octet_length(state) != len THEN
RAISE EXCEPTION 'Inputs must be equal length (% != %)', octet_length(state), len;
END IF;
FOR i IN 0..len-1 LOOP
result := set_byte(result, i, get_byte(state, i) # get_byte(value, i));
END LOOP;
RETURN result;
END;
$$;
CREATE OR REPLACE AGGREGATE xor_aggregate(BYTEA) (
SFUNC = xor_combine,
STYPE = BYTEA,
INITCOND = '\x00000000000000000000000000000000' -- 16 bytes
);
|]
dropXorHashFuncs :: Text
dropXorHashFuncs =
T.pack
[r|
DROP AGGREGATE xor_aggregate(BYTEA);
DROP FUNCTION xor_combine;
|]
File diff suppressed because it is too large Load Diff
@@ -21,30 +21,32 @@ import Database.PostgreSQL.Simple.SqlQQ (sql)
createDBAndUserIfNotExists :: ConnectInfo -> IO ()
createDBAndUserIfNotExists ConnectInfo {connectUser = user, connectDatabase = dbName} = do
-- connect to the default "postgres" maintenance database
bracket (PSQL.connect defaultConnectInfo {connectUser = "postgres", connectDatabase = "postgres"}) PSQL.close $
\postgresDB -> do
void $ PSQL.execute_ postgresDB "SET client_min_messages TO WARNING"
-- check if the user exists, create if not
[Only userExists] <-
PSQL.query
postgresDB
[sql|
SELECT EXISTS (
SELECT 1 FROM pg_catalog.pg_roles
WHERE rolname = ?
)
|]
(Only user)
unless userExists $ void $ PSQL.execute_ postgresDB (fromString $ "CREATE USER " <> user)
-- check if the database exists, create if not
dbExists <- checkDBExists postgresDB dbName
unless dbExists $ void $ PSQL.execute_ postgresDB (fromString $ "CREATE DATABASE " <> dbName <> " OWNER " <> user)
bracket (PSQL.connect defaultConnectInfo {connectUser = "postgres", connectDatabase = "postgres"}) PSQL.close $ \db -> do
execSQL db "SET client_min_messages TO WARNING"
-- check if the user exists, create if not
[Only userExists] <-
PSQL.query
db
[sql|
SELECT EXISTS (
SELECT 1 FROM pg_catalog.pg_roles
WHERE rolname = ?
)
|]
(Only user)
unless userExists $ execSQL db $ "CREATE USER " <> user
-- check if the database exists, create if not
dbExists <- checkDBExists db dbName
unless dbExists $ do
execSQL db $ "CREATE DATABASE " <> dbName <> " OWNER " <> user
bracket (PSQL.connect defaultConnectInfo {connectUser = "postgres", connectDatabase = dbName}) PSQL.close $
(`execSQL` "CREATE EXTENSION IF NOT EXISTS pgcrypto")
checkDBExists :: PSQL.Connection -> String -> IO Bool
checkDBExists postgresDB dbName = do
checkDBExists db dbName = do
[Only dbExists] <-
PSQL.query
postgresDB
db
[sql|
SELECT EXISTS (
SELECT 1 FROM pg_catalog.pg_database
@@ -56,45 +58,45 @@ checkDBExists postgresDB dbName = do
dropSchema :: ConnectInfo -> String -> IO ()
dropSchema connectInfo schema =
bracket (PSQL.connect connectInfo) PSQL.close $
\db -> do
void $ PSQL.execute_ db "SET client_min_messages TO WARNING"
void $ PSQL.execute_ db (fromString $ "DROP SCHEMA IF EXISTS " <> schema <> " CASCADE")
bracket (PSQL.connect connectInfo) PSQL.close $ \db -> do
execSQL db "SET client_min_messages TO WARNING"
execSQL db $ "DROP SCHEMA IF EXISTS " <> schema <> " CASCADE"
dropAllSchemasExceptSystem :: ConnectInfo -> IO ()
dropAllSchemasExceptSystem connectInfo =
bracket (PSQL.connect connectInfo) PSQL.close $
\db -> do
void $ PSQL.execute_ db "SET client_min_messages TO WARNING"
schemaNames :: [Only String] <-
PSQL.query_
db
[sql|
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('public', 'pg_catalog', 'information_schema')
|]
forM_ schemaNames $ \(Only schema) ->
PSQL.execute_ db (fromString $ "DROP SCHEMA " <> schema <> " CASCADE")
bracket (PSQL.connect connectInfo) PSQL.close $ \db -> do
execSQL db "SET client_min_messages TO WARNING"
schemaNames :: [Only String] <-
PSQL.query_
db
[sql|
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('public', 'pg_catalog', 'information_schema')
|]
forM_ schemaNames $ \(Only schema) ->
execSQL db $ "DROP SCHEMA " <> schema <> " CASCADE"
dropDatabaseAndUser :: ConnectInfo -> IO ()
dropDatabaseAndUser ConnectInfo {connectUser = user, connectDatabase = dbName} =
bracket (PSQL.connect defaultConnectInfo {connectUser = "postgres", connectDatabase = "postgres"}) PSQL.close $
\postgresDB -> do
void $ PSQL.execute_ postgresDB "SET client_min_messages TO WARNING"
dbExists <- checkDBExists postgresDB dbName
when dbExists $ do
void $ PSQL.execute_ postgresDB (fromString $ "ALTER DATABASE " <> dbName <> " WITH ALLOW_CONNECTIONS false")
-- terminate all connections to the database
_r :: [Only Bool] <-
PSQL.query
postgresDB
[sql|
SELECT pg_terminate_backend(pg_stat_activity.pid)
FROM pg_stat_activity
WHERE datname = ?
AND pid <> pg_backend_pid()
|]
(Only dbName)
void $ PSQL.execute_ postgresDB (fromString $ "DROP DATABASE " <> dbName)
void $ PSQL.execute_ postgresDB (fromString $ "DROP USER IF EXISTS " <> user)
bracket (PSQL.connect defaultConnectInfo {connectUser = "postgres", connectDatabase = "postgres"}) PSQL.close $ \db -> do
execSQL db "SET client_min_messages TO WARNING"
dbExists <- checkDBExists db dbName
when dbExists $ do
execSQL db $ "ALTER DATABASE " <> dbName <> " WITH ALLOW_CONNECTIONS false"
-- terminate all connections to the database
_r :: [Only Bool] <-
PSQL.query
db
[sql|
SELECT pg_terminate_backend(pg_stat_activity.pid)
FROM pg_stat_activity
WHERE datname = ?
AND pid <> pg_backend_pid()
|]
(Only dbName)
execSQL db $ "DROP DATABASE " <> dbName
execSQL db $ "DROP USER IF EXISTS " <> user
execSQL :: PSQL.Connection -> String -> IO ()
execSQL db = void . PSQL.execute_ db . fromString
+29 -6
View File
@@ -42,9 +42,15 @@ module Simplex.Messaging.Agent.Store.SQLite
)
where
import Control.Concurrent.MVar
import Control.Concurrent.STM
import Control.Exception (bracketOnError, onException, throwIO)
import Control.Monad
import Data.Bits (xor)
import Data.ByteArray (ScrubbedBytes)
import qualified Data.ByteArray as BA
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Data.Functor (($>))
import Data.IORef
import Data.Maybe (fromMaybe)
@@ -54,17 +60,19 @@ import Database.SQLite.Simple (Query (..))
import qualified Database.SQLite.Simple as SQL
import Database.SQLite.Simple.QQ (sql)
import qualified Database.SQLite3 as SQLite3
import Database.SQLite3.Bindings
import Foreign.C.Types
import Foreign.Ptr
import Simplex.Messaging.Agent.Store.Migrations (DBMigrate (..), sharedMigrateSchema)
import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations
import Simplex.Messaging.Agent.Store.SQLite.Common
import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB
import Simplex.Messaging.Agent.Store.Shared (Migration (..), MigrationConfig (..), MigrationError (..))
import Simplex.Messaging.Agent.Store.SQLite.Util (SQLiteFunc, createStaticFunction, mkSQLiteFunc)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Util (ifM, safeDecodeUtf8)
import System.Directory (copyFile, createDirectoryIfMissing, doesFileExist)
import System.FilePath (takeDirectory, takeFileName, (</>))
import UnliftIO.Exception (bracketOnError, onException)
import UnliftIO.MVar
import UnliftIO.STM
-- * SQLite Store implementation
@@ -109,9 +117,9 @@ connectDB path key track = do
pure db
where
prepare db = do
let exec = SQLite3.exec $ SQL.connectionHandle $ DB.conn db
unless (BA.null key) . exec $ "PRAGMA key = " <> keyString key <> ";"
exec . fromQuery $
let db' = SQL.connectionHandle $ DB.conn db
unless (BA.null key) . SQLite3.exec db' $ "PRAGMA key = " <> keyString key <> ";"
SQLite3.exec db' . fromQuery $
[sql|
PRAGMA busy_timeout = 100;
PRAGMA foreign_keys = ON;
@@ -119,6 +127,21 @@ connectDB path key track = do
PRAGMA secure_delete = ON;
PRAGMA auto_vacuum = FULL;
|]
createStaticFunction db' "simplex_xor_md5_combine" 2 True sqliteXorMd5CombinePtr
>>= either (throwIO . userError . show) pure
foreign export ccall "simplex_xor_md5_combine" sqliteXorMd5Combine :: SQLiteFunc
foreign import ccall "&simplex_xor_md5_combine" sqliteXorMd5CombinePtr :: FunPtr SQLiteFunc
sqliteXorMd5Combine :: SQLiteFunc
sqliteXorMd5Combine = mkSQLiteFunc $ \cxt args -> do
idsHash <- SQLite3.funcArgBlob args 0
rId <- SQLite3.funcArgBlob args 1
SQLite3.funcResultBlob cxt $ xorMd5Combine idsHash rId
xorMd5Combine :: ByteString -> ByteString -> ByteString
xorMd5Combine idsHash rId = B.packZipWith xor idsHash $ C.md5Hash rId
closeDBStore :: DBStore -> IO ()
closeDBStore st@DBStore {dbClosed} =
@@ -53,6 +53,12 @@ withConnectionPriority DBStore {dbSem, dbConnection} priority action
| priority = E.bracket_ signal release $ withMVar dbConnection action
| otherwise = lowPriority
where
-- To debug FK errors, set foreign_keys = OFF in Simplex.Messaging.Agent.Store.SQLite and use action' instead of action
-- action' conn = do
-- r <- action conn
-- violations <- DB.query_ conn "PRAGMA foreign_key_check" :: IO [ (String, Int, String, Int)]
-- unless (null violations) $ print violations
-- pure r
lowPriority = wait >> withMVar dbConnection (\db -> ifM free (Just <$> action db) (pure Nothing)) >>= maybe lowPriority pure
signal = atomically $ modifyTVar' dbSem (+ 1)
release = atomically $ modifyTVar' dbSem $ \sem -> if sem > 0 then sem - 1 else 0
@@ -5,7 +5,6 @@ module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20251020_service_certs w
import Database.SQLite.Simple (Query)
import Database.SQLite.Simple.QQ (sql)
-- TODO move date forward, create migration for postgres
m20251020_service_certs :: Query
m20251020_service_certs =
[sql|
@@ -13,27 +12,81 @@ CREATE TABLE client_services(
user_id INTEGER NOT NULL REFERENCES users ON DELETE CASCADE,
host TEXT NOT NULL,
port TEXT NOT NULL,
server_key_hash BLOB,
service_cert BLOB NOT NULL,
service_cert_hash BLOB NOT NULL,
service_priv_key BLOB NOT NULL,
rcv_service_id BLOB,
service_id BLOB,
service_queue_count INTEGER NOT NULL DEFAULT 0,
service_queue_ids_hash BLOB NOT NULL DEFAULT x'00000000000000000000000000000000',
FOREIGN KEY(host, port) REFERENCES servers ON UPDATE CASCADE ON DELETE RESTRICT
);
CREATE UNIQUE INDEX idx_server_certs_user_id_host_port ON client_services(user_id, host, port);
CREATE UNIQUE INDEX idx_server_certs_user_id_host_port ON client_services(user_id, host, port, server_key_hash);
CREATE INDEX idx_server_certs_host_port ON client_services(host, port);
ALTER TABLE rcv_queues ADD COLUMN rcv_service_assoc INTEGER NOT NULL DEFAULT 0;
CREATE TRIGGER tr_rcv_queue_insert
AFTER INSERT ON rcv_queues
FOR EACH ROW
WHEN NEW.rcv_service_assoc != 0 AND NEW.deleted = 0
BEGIN
UPDATE client_services
SET service_queue_count = service_queue_count + 1,
service_queue_ids_hash = simplex_xor_md5_combine(service_queue_ids_hash, NEW.rcv_id)
WHERE user_id = (SELECT user_id FROM connections WHERE conn_id = NEW.conn_id)
AND host = NEW.host AND port = NEW.port;
END;
CREATE TRIGGER tr_rcv_queue_delete
AFTER DELETE ON rcv_queues
FOR EACH ROW
WHEN OLD.rcv_service_assoc != 0 AND OLD.deleted = 0
BEGIN
UPDATE client_services
SET service_queue_count = service_queue_count - 1,
service_queue_ids_hash = simplex_xor_md5_combine(service_queue_ids_hash, OLD.rcv_id)
WHERE user_id = (SELECT user_id FROM connections WHERE conn_id = OLD.conn_id)
AND host = OLD.host AND port = OLD.port;
END;
CREATE TRIGGER tr_rcv_queue_update_remove
AFTER UPDATE ON rcv_queues
FOR EACH ROW
WHEN OLD.rcv_service_assoc != 0 AND OLD.deleted = 0 AND NOT (NEW.rcv_service_assoc != 0 AND NEW.deleted = 0)
BEGIN
UPDATE client_services
SET service_queue_count = service_queue_count - 1,
service_queue_ids_hash = simplex_xor_md5_combine(service_queue_ids_hash, OLD.rcv_id)
WHERE user_id = (SELECT user_id FROM connections WHERE conn_id = OLD.conn_id)
AND host = OLD.host AND port = OLD.port;
END;
CREATE TRIGGER tr_rcv_queue_update_add
AFTER UPDATE ON rcv_queues
FOR EACH ROW
WHEN NEW.rcv_service_assoc != 0 AND NEW.deleted = 0 AND NOT (OLD.rcv_service_assoc != 0 AND OLD.deleted = 0)
BEGIN
UPDATE client_services
SET service_queue_count = service_queue_count + 1,
service_queue_ids_hash = simplex_xor_md5_combine(service_queue_ids_hash, NEW.rcv_id)
WHERE user_id = (SELECT user_id FROM connections WHERE conn_id = NEW.conn_id)
AND host = NEW.host AND port = NEW.port;
END;
|]
down_m20251020_service_certs :: Query
down_m20251020_service_certs =
[sql|
DROP TRIGGER tr_rcv_queue_insert;
DROP TRIGGER tr_rcv_queue_delete;
DROP TRIGGER tr_rcv_queue_update_remove;
DROP TRIGGER tr_rcv_queue_update_add;
ALTER TABLE rcv_queues DROP COLUMN rcv_service_assoc;
DROP INDEX idx_server_certs_host_port;
DROP INDEX idx_server_certs_user_id_host_port;
DROP TABLE client_services;
@@ -455,10 +455,13 @@ CREATE TABLE client_services(
user_id INTEGER NOT NULL REFERENCES users ON DELETE CASCADE,
host TEXT NOT NULL,
port TEXT NOT NULL,
server_key_hash BLOB,
service_cert BLOB NOT NULL,
service_cert_hash BLOB NOT NULL,
service_priv_key BLOB NOT NULL,
rcv_service_id BLOB,
service_id BLOB,
service_queue_count INTEGER NOT NULL DEFAULT 0,
service_queue_ids_hash BLOB NOT NULL DEFAULT x'00000000000000000000000000000000',
FOREIGN KEY(host, port) REFERENCES servers ON UPDATE CASCADE ON DELETE RESTRICT
);
CREATE UNIQUE INDEX idx_rcv_queues_ntf ON rcv_queues(host, port, ntf_id);
@@ -607,6 +610,51 @@ CREATE INDEX idx_rcv_queues_client_notice_id ON rcv_queues(client_notice_id);
CREATE UNIQUE INDEX idx_server_certs_user_id_host_port ON client_services(
user_id,
host,
port
port,
server_key_hash
);
CREATE INDEX idx_server_certs_host_port ON client_services(host, port);
CREATE TRIGGER tr_rcv_queue_insert
AFTER INSERT ON rcv_queues
FOR EACH ROW
WHEN NEW.rcv_service_assoc != 0 AND NEW.deleted = 0
BEGIN
UPDATE client_services
SET service_queue_count = service_queue_count + 1,
service_queue_ids_hash = simplex_xor_md5_combine(service_queue_ids_hash, NEW.rcv_id)
WHERE user_id = (SELECT user_id FROM connections WHERE conn_id = NEW.conn_id)
AND host = NEW.host AND port = NEW.port;
END;
CREATE TRIGGER tr_rcv_queue_delete
AFTER DELETE ON rcv_queues
FOR EACH ROW
WHEN OLD.rcv_service_assoc != 0 AND OLD.deleted = 0
BEGIN
UPDATE client_services
SET service_queue_count = service_queue_count - 1,
service_queue_ids_hash = simplex_xor_md5_combine(service_queue_ids_hash, OLD.rcv_id)
WHERE user_id = (SELECT user_id FROM connections WHERE conn_id = OLD.conn_id)
AND host = OLD.host AND port = OLD.port;
END;
CREATE TRIGGER tr_rcv_queue_update_remove
AFTER UPDATE ON rcv_queues
FOR EACH ROW
WHEN OLD.rcv_service_assoc != 0 AND OLD.deleted = 0 AND NOT (NEW.rcv_service_assoc != 0 AND NEW.deleted = 0)
BEGIN
UPDATE client_services
SET service_queue_count = service_queue_count - 1,
service_queue_ids_hash = simplex_xor_md5_combine(service_queue_ids_hash, OLD.rcv_id)
WHERE user_id = (SELECT user_id FROM connections WHERE conn_id = OLD.conn_id)
AND host = OLD.host AND port = OLD.port;
END;
CREATE TRIGGER tr_rcv_queue_update_add
AFTER UPDATE ON rcv_queues
FOR EACH ROW
WHEN NEW.rcv_service_assoc != 0 AND NEW.deleted = 0 AND NOT (OLD.rcv_service_assoc != 0 AND OLD.deleted = 0)
BEGIN
UPDATE client_services
SET service_queue_count = service_queue_count + 1,
service_queue_ids_hash = simplex_xor_md5_combine(service_queue_ids_hash, NEW.rcv_id)
WHERE user_id = (SELECT user_id FROM connections WHERE conn_id = NEW.conn_id)
AND host = NEW.host AND port = NEW.port;
END;
@@ -0,0 +1,41 @@
module Simplex.Messaging.Agent.Store.SQLite.Util where
import Control.Exception (SomeException, catch, mask_)
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Database.SQLite3.Direct (Database (..), FuncArgs (..), FuncContext (..))
import Database.SQLite3.Bindings
import Foreign.C.String
import Foreign.Ptr
import Foreign.StablePtr
data CFuncPtrs = CFuncPtrs (FunPtr CFunc) (FunPtr CFunc) (FunPtr CFuncFinal)
type SQLiteFunc = Ptr CContext -> CArgCount -> Ptr (Ptr CValue) -> IO ()
mkSQLiteFunc :: (FuncContext -> FuncArgs -> IO ()) -> SQLiteFunc
mkSQLiteFunc f cxt nArgs cvals = catchAsResultError cxt $ f (FuncContext cxt) (FuncArgs nArgs cvals)
{-# INLINE mkSQLiteFunc #-}
-- Based on createFunction from Database.SQLite3.Direct, but uses static function pointer to avoid dynamic wrapper that triggers DCL.
createStaticFunction :: Database -> ByteString -> CArgCount -> Bool -> FunPtr SQLiteFunc -> IO (Either Error ())
createStaticFunction (Database db) name nArgs isDet funPtr = mask_ $ do
u <- newStablePtr $ CFuncPtrs funPtr nullFunPtr nullFunPtr
let flags = if isDet then c_SQLITE_DETERMINISTIC else 0
B.useAsCString name $ \namePtr ->
toResult () <$> c_sqlite3_create_function_v2 db namePtr nArgs flags (castStablePtrToPtr u) funPtr nullFunPtr nullFunPtr nullFunPtr
-- Convert a 'CError' to a 'Either Error', in the common case where
-- SQLITE_OK signals success and anything else signals an error.
--
-- Note that SQLITE_OK == 0.
toResult :: a -> CError -> Either Error a
toResult a (CError 0) = Right a
toResult _ code = Left $ decodeError code
-- call c_sqlite3_result_error in the event of an error
catchAsResultError :: Ptr CContext -> IO () -> IO ()
catchAsResultError ctx action = catch action $ \exn -> do
let msg = show (exn :: SomeException)
withCAStringLen msg $ \(ptr, len) ->
c_sqlite3_result_error ctx ptr (fromIntegral len)
+65 -19
View File
@@ -2,6 +2,7 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TupleSections #-}
module Simplex.Messaging.Agent.TSessionSubs
( TSessionSubs (sessionSubs),
@@ -12,7 +13,10 @@ module Simplex.Messaging.Agent.TSessionSubs
hasPendingSub,
addPendingSub,
setSessionId,
setPendingServiceSub,
setActiveServiceSub,
addActiveSub,
addActiveSub',
batchAddActiveSubs,
batchAddPendingSubs,
deletePendingSub,
@@ -38,13 +42,13 @@ import qualified Data.Map.Strict as M
import Data.Maybe (isJust)
import qualified Data.Set as S
import Simplex.Messaging.Agent.Protocol (SMPQueue (..))
import Simplex.Messaging.Agent.Store (RcvQueueSub (..), SomeRcvQueue)
import Simplex.Messaging.Agent.Store (RcvQueue, RcvQueueSub (..), SomeRcvQueue, StoredRcvQueue (rcvServiceAssoc), rcvQueueSub)
import Simplex.Messaging.Client (SMPTransportSession, TransportSessionMode (..))
import Simplex.Messaging.Protocol (RecipientId)
import Simplex.Messaging.Protocol (RecipientId, ServiceSub (..), queueIdHash)
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport
import Simplex.Messaging.Util (($>>=))
import Simplex.Messaging.Util (anyM, ($>>=))
data TSessionSubs = TSessionSubs
{ sessionSubs :: TMap SMPTransportSession SessSubs
@@ -53,7 +57,9 @@ data TSessionSubs = TSessionSubs
data SessSubs = SessSubs
{ subsSessId :: TVar (Maybe SessionId),
activeSubs :: TMap RecipientId RcvQueueSub,
pendingSubs :: TMap RecipientId RcvQueueSub
pendingSubs :: TMap RecipientId RcvQueueSub,
activeServiceSub :: TVar (Maybe ServiceSub),
pendingServiceSub :: TVar (Maybe ServiceSub)
}
emptyIO :: IO TSessionSubs
@@ -72,7 +78,7 @@ getSessSubs :: SMPTransportSession -> TSessionSubs -> STM SessSubs
getSessSubs tSess ss = lookupSubs tSess ss >>= maybe new pure
where
new = do
s <- SessSubs <$> newTVar Nothing <*> newTVar M.empty <*> newTVar M.empty
s <- SessSubs <$> newTVar Nothing <*> newTVar M.empty <*> newTVar M.empty <*> newTVar Nothing <*> newTVar Nothing
TM.insert tSess s $ sessionSubs ss
pure s
@@ -98,8 +104,27 @@ setSessionId tSess sessId ss = do
Nothing -> writeTVar (subsSessId s) (Just sessId)
Just sessId' -> unless (sessId == sessId') $ void $ setSubsPending_ s $ Just sessId
addActiveSub :: SMPTransportSession -> SessionId -> RcvQueueSub -> TSessionSubs -> STM ()
addActiveSub tSess sessId rq ss = do
setPendingServiceSub :: SMPTransportSession -> ServiceSub -> TSessionSubs -> STM ()
setPendingServiceSub tSess serviceSub ss = do
s <- getSessSubs tSess ss
writeTVar (pendingServiceSub s) $ Just serviceSub
setActiveServiceSub :: SMPTransportSession -> SessionId -> ServiceSub -> TSessionSubs -> STM ()
setActiveServiceSub tSess sessId serviceSub ss = do
s <- getSessSubs tSess ss
sessId' <- readTVar $ subsSessId s
if Just sessId == sessId'
then do
writeTVar (activeServiceSub s) $ Just serviceSub
writeTVar (pendingServiceSub s) Nothing
else writeTVar (pendingServiceSub s) $ Just serviceSub
addActiveSub :: SMPTransportSession -> SessionId -> RcvQueue -> TSessionSubs -> STM ()
addActiveSub tSess sessId rq = addActiveSub' tSess sessId (rcvQueueSub rq) (rcvServiceAssoc rq)
{-# INLINE addActiveSub #-}
addActiveSub' :: SMPTransportSession -> SessionId -> RcvQueueSub -> Bool -> TSessionSubs -> STM ()
addActiveSub' tSess sessId rq serviceAssoc ss = do
s <- getSessSubs tSess ss
sessId' <- readTVar $ subsSessId s
let rId = rcvId rq
@@ -107,10 +132,13 @@ addActiveSub tSess sessId rq ss = do
then do
TM.insert rId rq $ activeSubs s
TM.delete rId $ pendingSubs s
when serviceAssoc $
let updateServiceSub (ServiceSub serviceId n idsHash) = ServiceSub serviceId (n + 1) (idsHash <> queueIdHash rId)
in modifyTVar' (activeServiceSub s) (updateServiceSub <$>)
else TM.insert rId rq $ pendingSubs s
batchAddActiveSubs :: SMPTransportSession -> SessionId -> [RcvQueueSub] -> TSessionSubs -> STM ()
batchAddActiveSubs tSess sessId rqs ss = do
batchAddActiveSubs :: SMPTransportSession -> SessionId -> ([RcvQueueSub], [RcvQueueSub]) -> TSessionSubs -> STM ()
batchAddActiveSubs tSess sessId (rqs, serviceRQs) ss = do
s <- getSessSubs tSess ss
sessId' <- readTVar $ subsSessId s
let qs = M.fromList $ map (\rq -> (rcvId rq, rq)) rqs
@@ -118,6 +146,12 @@ batchAddActiveSubs tSess sessId rqs ss = do
then do
TM.union qs $ activeSubs s
modifyTVar' (pendingSubs s) (`M.difference` qs)
serviceSub_ <- readTVar $ activeServiceSub s
forM_ serviceSub_ $ \(ServiceSub serviceId n idsHash) -> do
unless (null serviceRQs) $ do
let idsHash' = idsHash <> mconcat (map (queueIdHash . rcvId) serviceRQs)
n' = n + fromIntegral (length serviceRQs)
writeTVar (activeServiceSub s) $ Just $ ServiceSub serviceId n' idsHash'
else TM.union qs $ pendingSubs s
batchAddPendingSubs :: SMPTransportSession -> [RcvQueueSub] -> TSessionSubs -> STM ()
@@ -143,11 +177,15 @@ batchDeleteSubs tSess rqs = lookupSubs tSess >=> mapM_ (\s -> delete (activeSubs
delete = (`modifyTVar'` (`M.withoutKeys` rIds))
hasPendingSubs :: SMPTransportSession -> TSessionSubs -> STM Bool
hasPendingSubs tSess = lookupSubs tSess >=> maybe (pure False) (fmap (not . null) . readTVar . pendingSubs)
hasPendingSubs tSess = lookupSubs tSess >=> maybe (pure False) (\s -> anyM [hasSubs s, hasServiceSub s])
where
hasSubs = fmap (not . null) . readTVar . pendingSubs
hasServiceSub = fmap isJust . readTVar . pendingServiceSub
getPendingSubs :: SMPTransportSession -> TSessionSubs -> STM (Map RecipientId RcvQueueSub)
getPendingSubs = getSubs_ pendingSubs
{-# INLINE getPendingSubs #-}
getPendingSubs :: SMPTransportSession -> TSessionSubs -> STM (Map RecipientId RcvQueueSub, Maybe ServiceSub)
getPendingSubs tSess = lookupSubs tSess >=> maybe (pure (M.empty, Nothing)) get
where
get s = liftM2 (,) (readTVar $ pendingSubs s) (readTVar $ pendingServiceSub s)
getActiveSubs :: SMPTransportSession -> TSessionSubs -> STM (Map RecipientId RcvQueueSub)
getActiveSubs = getSubs_ activeSubs
@@ -156,7 +194,7 @@ getActiveSubs = getSubs_ activeSubs
getSubs_ :: (SessSubs -> TMap RecipientId RcvQueueSub) -> SMPTransportSession -> TSessionSubs -> STM (Map RecipientId RcvQueueSub)
getSubs_ subs tSess = lookupSubs tSess >=> maybe (pure M.empty) (readTVar . subs)
setSubsPending :: TransportSessionMode -> SMPTransportSession -> SessionId -> TSessionSubs -> STM (Map RecipientId RcvQueueSub)
setSubsPending :: TransportSessionMode -> SMPTransportSession -> SessionId -> TSessionSubs -> STM (Map RecipientId RcvQueueSub, Maybe ServiceSub)
setSubsPending mode tSess@(uId, srv, connId_) sessId tss@(TSessionSubs ss)
| entitySession == isJust connId_ =
TM.lookup tSess ss >>= withSessSubs (`setSubsPending_` Nothing)
@@ -166,17 +204,17 @@ setSubsPending mode tSess@(uId, srv, connId_) sessId tss@(TSessionSubs ss)
entitySession = mode == TSMEntity
sessEntId = if entitySession then Just else const Nothing
withSessSubs run = \case
Nothing -> pure M.empty
Nothing -> pure (M.empty, Nothing)
Just s -> do
sessId' <- readTVar $ subsSessId s
if Just sessId == sessId' then run s else pure M.empty
if Just sessId == sessId' then run s else pure (M.empty, Nothing)
setPendingChangeMode s = do
subs <- M.union <$> readTVar (activeSubs s) <*> readTVar (pendingSubs s)
unless (null subs) $
forM_ subs $ \rq -> addPendingSub (uId, srv, sessEntId (connId rq)) rq tss
pure subs
(subs,) <$> setServiceSubPending_ s
setSubsPending_ :: SessSubs -> Maybe SessionId -> STM (Map RecipientId RcvQueueSub)
setSubsPending_ :: SessSubs -> Maybe SessionId -> STM (Map RecipientId RcvQueueSub, Maybe ServiceSub)
setSubsPending_ s sessId_ = do
writeTVar (subsSessId s) sessId_
let as = activeSubs s
@@ -184,7 +222,15 @@ setSubsPending_ s sessId_ = do
unless (null subs) $ do
writeTVar as M.empty
modifyTVar' (pendingSubs s) $ M.union subs
pure subs
(subs,) <$> setServiceSubPending_ s
setServiceSubPending_ :: SessSubs -> STM (Maybe ServiceSub)
setServiceSubPending_ s = do
serviceSub_ <- readTVar $ activeServiceSub s
forM_ serviceSub_ $ \serviceSub -> do
writeTVar (activeServiceSub s) Nothing
writeTVar (pendingServiceSub s) $ Just serviceSub
pure serviceSub_
updateClientNotices :: SMPTransportSession -> [(RecipientId, Maybe Int64)] -> TSessionSubs -> STM ()
updateClientNotices tSess noticeIds ss = do
+5 -5
View File
@@ -909,18 +909,18 @@ nsubResponse_ = \case
{-# INLINE nsubResponse_ #-}
-- This command is always sent in background request mode
subscribeService :: forall p. (PartyI p, ServiceParty p) => SMPClient -> SParty p -> ExceptT SMPClientError IO (Int64, IdsHash)
subscribeService c party = case smpClientService c of
subscribeService :: forall p. (PartyI p, ServiceParty p) => SMPClient -> SParty p -> Int64 -> IdsHash -> ExceptT SMPClientError IO ServiceSub
subscribeService c party n idsHash = case smpClientService c of
Just THClientService {serviceId, serviceKey} -> do
liftIO $ enablePings c
sendSMPCommand c NRMBackground (Just (C.APrivateAuthKey C.SEd25519 serviceKey)) serviceId subCmd >>= \case
SOKS n idsHash -> pure (n, idsHash)
SOKS n' idsHash' -> pure $ ServiceSub serviceId n' idsHash'
r -> throwE $ unexpectedResponse r
where
subCmd :: Command p
subCmd = case party of
SRecipientService -> SUBS
SNotifierService -> NSUBS
SRecipientService -> SUBS n idsHash
SNotifierService -> NSUBS n idsHash
Nothing -> throwE PCEServiceUnavailable
smpClientService :: SMPClient -> Maybe THClientService
+29 -28
View File
@@ -45,7 +45,6 @@ import Crypto.Random (ChaChaDRG)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Constraint (Dict (..))
import Data.Int (Int64)
import Data.List.NonEmpty (NonEmpty)
import qualified Data.List.NonEmpty as L
import Data.Map.Strict (Map)
@@ -69,10 +68,12 @@ import Simplex.Messaging.Protocol
ProtocolServer (..),
QueueId,
SMPServer,
ServiceSub (..),
SParty (..),
ServiceParty,
serviceParty,
partyServiceRole
partyServiceRole,
queueIdsHash,
)
import Simplex.Messaging.Session
import Simplex.Messaging.TMap (TMap)
@@ -91,14 +92,14 @@ data SMPClientAgentEvent
| CADisconnected SMPServer (NonEmpty QueueId)
| CASubscribed SMPServer (Maybe ServiceId) (NonEmpty QueueId)
| CASubError SMPServer (NonEmpty (QueueId, SMPClientError))
| CAServiceDisconnected SMPServer (ServiceId, Int64)
| CAServiceSubscribed SMPServer (ServiceId, Int64) Int64
| CAServiceSubError SMPServer (ServiceId, Int64) SMPClientError
| CAServiceDisconnected SMPServer ServiceSub
| CAServiceSubscribed {subServer :: SMPServer, expected :: ServiceSub, subscribed :: ServiceSub}
| CAServiceSubError SMPServer ServiceSub SMPClientError
-- CAServiceUnavailable is used when service ID in pending subscription is different from the current service in connection.
-- This will require resubscribing to all queues associated with this service ID individually, creating new associations.
-- It may happen if, for example, SMP server deletes service information (e.g. via downgrade and upgrade)
-- and assigns different service ID to the service certificate.
| CAServiceUnavailable SMPServer (ServiceId, Int64)
| CAServiceUnavailable SMPServer ServiceSub
data SMPClientAgentConfig = SMPClientAgentConfig
{ smpCfg :: ProtocolClientConfig SMPVersion,
@@ -142,11 +143,11 @@ data SMPClientAgent p = SMPClientAgent
-- Only one service subscription can exist per server with this agent.
-- With correctly functioning SMP server, queue and service subscriptions can't be
-- active at the same time.
activeServiceSubs :: TMap SMPServer (TVar (Maybe ((ServiceId, Int64), SessionId))),
activeServiceSubs :: TMap SMPServer (TVar (Maybe (ServiceSub, SessionId))),
activeQueueSubs :: TMap SMPServer (TMap QueueId (SessionId, C.APrivateAuthKey)),
-- Pending service subscriptions can co-exist with pending queue subscriptions
-- on the same SMP server during subscriptions being transitioned from per-queue to service.
pendingServiceSubs :: TMap SMPServer (TVar (Maybe (ServiceId, Int64))),
pendingServiceSubs :: TMap SMPServer (TVar (Maybe ServiceSub)),
pendingQueueSubs :: TMap SMPServer (TMap QueueId C.APrivateAuthKey),
smpSubWorkers :: TMap SMPServer (SessionVar (Async ())),
workerSeq :: TVar Int
@@ -256,7 +257,7 @@ connectClient ca@SMPClientAgent {agentCfg, smpClients, smpSessions, msgQ, random
removeClientAndSubs smp >>= serverDown
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
removeClientAndSubs :: SMPClient -> IO (Maybe (ServiceId, Int64), Maybe (Map QueueId C.APrivateAuthKey))
removeClientAndSubs :: SMPClient -> IO (Maybe ServiceSub, Maybe (Map QueueId C.APrivateAuthKey))
removeClientAndSubs smp = do
-- Looking up subscription vars outside of STM transaction to reduce re-evaluation.
-- It is possible because these vars are never removed, they are only added.
@@ -287,7 +288,7 @@ connectClient ca@SMPClientAgent {agentCfg, smpClients, smpSessions, msgQ, random
then pure Nothing
else Just subs <$ addSubs_ (pendingQueueSubs ca) srv subs
serverDown :: (Maybe (ServiceId, Int64), Maybe (Map QueueId C.APrivateAuthKey)) -> IO ()
serverDown :: (Maybe ServiceSub, Maybe (Map QueueId C.APrivateAuthKey)) -> IO ()
serverDown (sSub, qSubs) = do
mapM_ (notify ca . CAServiceDisconnected srv) sSub
let qIds = L.nonEmpty . M.keys =<< qSubs
@@ -317,7 +318,7 @@ reconnectClient ca@SMPClientAgent {active, agentCfg, smpSubWorkers, workerSeq} s
loop
ProtocolClientConfig {networkConfig = NetworkConfig {tcpConnectTimeout}} = smpCfg agentCfg
noPending (sSub, qSubs) = isNothing sSub && maybe True M.null qSubs
getPending :: Monad m => (forall a. SMPServer -> TMap SMPServer a -> m (Maybe a)) -> (forall a. TVar a -> m a) -> m (Maybe (ServiceId, Int64), Maybe (Map QueueId C.APrivateAuthKey))
getPending :: Monad m => (forall a. SMPServer -> TMap SMPServer a -> m (Maybe a)) -> (forall a. TVar a -> m a) -> m (Maybe ServiceSub, Maybe (Map QueueId C.APrivateAuthKey))
getPending lkup rd = do
sSub <- lkup srv (pendingServiceSubs ca) $>>= rd
qSubs <- lkup srv (pendingQueueSubs ca) >>= mapM rd
@@ -329,7 +330,7 @@ reconnectClient ca@SMPClientAgent {active, agentCfg, smpSubWorkers, workerSeq} s
whenM (isEmptyTMVar $ sessionVar v) retry
removeSessVar v srv smpSubWorkers
reconnectSMPClient :: forall p. SMPClientAgent p -> SMPServer -> (Maybe (ServiceId, Int64), Maybe (Map QueueId C.APrivateAuthKey)) -> ExceptT SMPClientError IO ()
reconnectSMPClient :: forall p. SMPClientAgent p -> SMPServer -> (Maybe ServiceSub, Maybe (Map QueueId C.APrivateAuthKey)) -> ExceptT SMPClientError IO ()
reconnectSMPClient ca@SMPClientAgent {agentCfg, agentParty} srv (sSub_, qSubs_) =
withSMP ca srv $ \smp -> liftIO $ case serviceParty agentParty of
Just Dict -> resubscribe smp
@@ -430,7 +431,7 @@ smpSubscribeQueues ca smp srv subs = do
let acc@(_, _, (qOks, sQs), notPending) = foldr (groupSub pending) (False, [], ([], []), []) (L.zip subs rs)
unless (null qOks) $ addActiveSubs ca srv qOks
unless (null sQs) $ forM_ smpServiceId $ \serviceId ->
updateActiveServiceSub ca srv ((serviceId, fromIntegral $ length sQs), sessId)
updateActiveServiceSub ca srv (ServiceSub serviceId (fromIntegral $ length sQs) (queueIdsHash sQs), sessId)
unless (null notPending) $ removePendingSubs ca srv notPending
pure acc
sessId = sessionId $ thParams smp
@@ -454,24 +455,24 @@ smpSubscribeQueues ca smp srv subs = do
notify_ :: (SMPServer -> NonEmpty a -> SMPClientAgentEvent) -> [a] -> IO ()
notify_ evt qs = mapM_ (notify ca . evt srv) $ L.nonEmpty qs
subscribeServiceNtfs :: SMPClientAgent 'NotifierService -> SMPServer -> (ServiceId, Int64) -> IO ()
subscribeServiceNtfs :: SMPClientAgent 'NotifierService -> SMPServer -> ServiceSub -> IO ()
subscribeServiceNtfs = subscribeService_
{-# INLINE subscribeServiceNtfs #-}
subscribeService_ :: (PartyI p, ServiceParty p) => SMPClientAgent p -> SMPServer -> (ServiceId, Int64) -> IO ()
subscribeService_ :: (PartyI p, ServiceParty p) => SMPClientAgent p -> SMPServer -> ServiceSub -> IO ()
subscribeService_ ca srv serviceSub = do
atomically $ setPendingServiceSub ca srv $ Just serviceSub
runExceptT (getSMPServerClient' ca srv) >>= \case
Right smp -> smpSubscribeService ca smp srv serviceSub
Left _ -> pure () -- no call to reconnectClient - failing getSMPServerClient' does that
smpSubscribeService :: (PartyI p, ServiceParty p) => SMPClientAgent p -> SMPClient -> SMPServer -> (ServiceId, Int64) -> IO ()
smpSubscribeService ca smp srv serviceSub@(serviceId, _) = case smpClientService smp of
smpSubscribeService :: (PartyI p, ServiceParty p) => SMPClientAgent p -> SMPClient -> SMPServer -> ServiceSub -> IO ()
smpSubscribeService ca smp srv serviceSub@(ServiceSub serviceId n idsHash) = case smpClientService smp of
Just service | serviceAvailable service -> subscribe
_ -> notifyUnavailable
where
subscribe = do
r <- runExceptT $ subscribeService smp $ agentParty ca
r <- runExceptT $ subscribeService smp (agentParty ca) n idsHash
ok <-
atomically $
ifM
@@ -479,15 +480,15 @@ smpSubscribeService ca smp srv serviceSub@(serviceId, _) = case smpClientService
(True <$ processSubscription r)
(pure False)
if ok
then case r of -- TODO [certs rcv] compare hash
Right (n, _idsHash) -> notify ca $ CAServiceSubscribed srv serviceSub n
then case r of
Right serviceSub' -> notify ca $ CAServiceSubscribed srv serviceSub serviceSub'
Left e
| smpClientServiceError e -> notifyUnavailable
| temporaryClientError e -> reconnectClient ca srv
| otherwise -> notify ca $ CAServiceSubError srv serviceSub e
else reconnectClient ca srv
processSubscription = mapM_ $ \(n, _idsHash) -> do -- TODO [certs rcv] validate hash here?
setActiveServiceSub ca srv $ Just ((serviceId, n), sessId)
processSubscription = mapM_ $ \serviceSub' -> do -- TODO [certs rcv] validate hash here?
setActiveServiceSub ca srv $ Just (serviceSub', sessId)
setPendingServiceSub ca srv Nothing
serviceAvailable THClientService {serviceRole, serviceId = serviceId'} =
serviceId == serviceId' && partyServiceRole (agentParty ca) == serviceRole
@@ -529,11 +530,11 @@ addSubs_ subs srv ss =
Just m -> TM.union ss m
_ -> TM.insertM srv (newTVar ss) subs
setActiveServiceSub :: SMPClientAgent p -> SMPServer -> Maybe ((ServiceId, Int64), SessionId) -> STM ()
setActiveServiceSub :: SMPClientAgent p -> SMPServer -> Maybe (ServiceSub, SessionId) -> STM ()
setActiveServiceSub = setServiceSub_ activeServiceSubs
{-# INLINE setActiveServiceSub #-}
setPendingServiceSub :: SMPClientAgent p -> SMPServer -> Maybe (ServiceId, Int64) -> STM ()
setPendingServiceSub :: SMPClientAgent p -> SMPServer -> Maybe ServiceSub -> STM ()
setPendingServiceSub = setServiceSub_ pendingServiceSubs
{-# INLINE setPendingServiceSub #-}
@@ -548,12 +549,12 @@ setServiceSub_ subsSel ca srv sub =
Just v -> writeTVar v sub
Nothing -> TM.insertM srv (newTVar sub) (subsSel ca)
updateActiveServiceSub :: SMPClientAgent p -> SMPServer -> ((ServiceId, Int64), SessionId) -> STM ()
updateActiveServiceSub ca srv sub@((serviceId', n'), sessId') =
updateActiveServiceSub :: SMPClientAgent p -> SMPServer -> (ServiceSub, SessionId) -> STM ()
updateActiveServiceSub ca srv sub@(ServiceSub serviceId' n' idsHash', sessId') =
TM.lookup srv (activeServiceSubs ca) >>= \case
Just v -> modifyTVar' v $ \case
Just ((serviceId, n), sessId) | serviceId == serviceId' && sessId == sessId' ->
Just ((serviceId, n + n'), sessId)
Just (ServiceSub serviceId n idsHash, sessId) | serviceId == serviceId' && sessId == sessId' ->
Just (ServiceSub serviceId (n + n') (idsHash <> idsHash'), sessId)
_ -> Just sub
Nothing -> TM.insertM srv (newTVar $ Just sub) (activeServiceSubs ca)
+5 -1
View File
@@ -178,6 +178,7 @@ module Simplex.Messaging.Crypto
sha512Hash,
sha3_256,
sha3_384,
md5Hash,
-- * Message padding / un-padding
canPad,
@@ -216,7 +217,7 @@ import Crypto.Cipher.AES (AES256)
import qualified Crypto.Cipher.Types as AES
import qualified Crypto.Cipher.XSalsa as XSalsa
import qualified Crypto.Error as CE
import Crypto.Hash (Digest, SHA3_256, SHA3_384, SHA256 (..), SHA512 (..), hash, hashDigestSize)
import Crypto.Hash (Digest, MD5, SHA3_256, SHA3_384, SHA256 (..), SHA512 (..), hash, hashDigestSize)
import qualified Crypto.KDF.HKDF as H
import qualified Crypto.MAC.Poly1305 as Poly1305
import qualified Crypto.PubKey.Curve25519 as X25519
@@ -1024,6 +1025,9 @@ sha3_384 :: ByteString -> ByteString
sha3_384 = BA.convert . (hash :: ByteString -> Digest SHA3_384)
{-# INLINE sha3_384 #-}
md5Hash :: ByteString -> ByteString
md5Hash = BA.convert . (hash :: ByteString -> Digest MD5)
-- | AEAD-GCM encryption with associated data.
--
-- Used as part of double ratchet encryption.
@@ -489,17 +489,9 @@ data NtfSubStatus
NSErr ByteString
deriving (Eq, Ord, Show)
ntfShouldSubscribe :: NtfSubStatus -> Bool
ntfShouldSubscribe = \case
NSNew -> True
NSPending -> True
NSActive -> True
NSInactive -> True
NSEnd -> False
NSDeleted -> False
NSAuth -> False
NSService -> True
NSErr _ -> False
-- if these statuses change, the queue ID hashes for services need to be updated in a new migration (see m20250830_queue_ids_hash)
subscribeNtfStatuses :: [NtfSubStatus]
subscribeNtfStatuses = [NSNew, NSPending, NSActive, NSInactive]
instance Encoding NtfSubStatus where
smpEncode = \case
+10 -9
View File
@@ -62,7 +62,7 @@ import Simplex.Messaging.Notifications.Server.Store (NtfSTMStore, TokenNtfMessag
import Simplex.Messaging.Notifications.Server.Store.Postgres
import Simplex.Messaging.Notifications.Server.Store.Types
import Simplex.Messaging.Notifications.Transport
import Simplex.Messaging.Protocol (EntityId (..), ErrorType (..), NotifierId, Party (..), ProtocolServer (host), SMPServer, ServiceId, SignedTransmission, Transmission, pattern NoEntity, pattern SMPServer, encodeTransmission, tGetServer, tPut)
import Simplex.Messaging.Protocol (EntityId (..), ErrorType (..), NotifierId, Party (..), ProtocolServer (host), SMPServer, ServiceSub (..), SignedTransmission, Transmission, pattern NoEntity, pattern SMPServer, encodeTransmission, tGetServer, tPut)
import qualified Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.Server
import Simplex.Messaging.Server.Control (CPClientRole (..))
@@ -257,9 +257,9 @@ ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg, startOptions}
srvSubscribers <- getSMPWorkerMetrics a smpSubscribers
srvClients <- getSMPWorkerMetrics a smpClients
srvSubWorkers <- getSMPWorkerMetrics a smpSubWorkers
ntfActiveServiceSubs <- getSMPServiceSubMetrics a activeServiceSubs $ snd . fst
ntfActiveServiceSubs <- getSMPServiceSubMetrics a activeServiceSubs $ smpQueueCount . fst
ntfActiveQueueSubs <- getSMPSubMetrics a activeQueueSubs
ntfPendingServiceSubs <- getSMPServiceSubMetrics a pendingServiceSubs snd
ntfPendingServiceSubs <- getSMPServiceSubMetrics a pendingServiceSubs smpQueueCount
ntfPendingQueueSubs <- getSMPSubMetrics a pendingQueueSubs
smpSessionCount <- M.size <$> readTVarIO smpSessions
apnsPushQLength <- atomically $ lengthTBQueue pushQ
@@ -452,13 +452,13 @@ resubscribe NtfSubscriber {smpAgent = ca} = do
counts <- mapConcurrently (subscribeSrvSubs ca st batchSize) srvs
logNote $ "Completed all SMP resubscriptions for " <> tshow (length srvs) <> " servers (" <> tshow (sum counts) <> " subscriptions)"
subscribeSrvSubs :: SMPClientAgent 'NotifierService -> NtfPostgresStore -> Int -> (SMPServer, Int64, Maybe (ServiceId, Int64)) -> IO Int
subscribeSrvSubs :: SMPClientAgent 'NotifierService -> NtfPostgresStore -> Int -> (SMPServer, Int64, Maybe ServiceSub) -> IO Int
subscribeSrvSubs ca st batchSize (srv, srvId, service_) = do
let srvStr = safeDecodeUtf8 (strEncode $ L.head $ host srv)
logNote $ "Starting SMP resubscriptions for " <> srvStr
forM_ service_ $ \(serviceId, n) -> do
logNote $ "Subscribing service to " <> srvStr <> " with " <> tshow n <> " associated queues"
subscribeServiceNtfs ca srv (serviceId, n)
forM_ service_ $ \serviceSub -> do
logNote $ "Subscribing service to " <> srvStr <> " with " <> tshow (smpQueueCount serviceSub) <> " associated queues"
subscribeServiceNtfs ca srv serviceSub
n <- subscribeLoop 0 Nothing
logNote $ "Completed SMP resubscriptions for " <> srvStr <> " (" <> tshow n <> " subscriptions)"
pure n
@@ -576,7 +576,7 @@ ntfSubscriber NtfSubscriber {smpAgent = ca@SMPClientAgent {msgQ, agentQ}} =
-- TODO [certs] resubscribe queues with statuses NSErr and NSService
CAServiceDisconnected srv serviceSub ->
logNote $ "SMP server service disconnected " <> showService srv serviceSub
CAServiceSubscribed srv serviceSub@(_, expected) n
CAServiceSubscribed srv serviceSub@(ServiceSub _ expected _) (ServiceSub _ n _) -- TODO [certs rcv] compare hash
| expected == n -> logNote msg
| otherwise -> logWarn $ msg <> ", confirmed subs: " <> tshow n
where
@@ -593,7 +593,8 @@ ntfSubscriber NtfSubscriber {smpAgent = ca@SMPClientAgent {msgQ, agentQ}} =
void $ subscribeSrvSubs ca st batchSize (srv, srvId, Nothing)
Left e -> logError $ "SMP server update and resubscription error " <> tshow e
where
showService srv (serviceId, n) = showServer' srv <> ", service ID " <> decodeLatin1 (strEncode serviceId) <> ", " <> tshow n <> " subs"
-- TODO [certs rcv] compare hash
showService srv (ServiceSub serviceId n _idsHash) = showServer' srv <> ", service ID " <> decodeLatin1 (strEncode serviceId) <> ", " <> tshow n <> " subs"
logSubErrors :: SMPServer -> NonEmpty (SMP.NotifierId, NtfSubStatus) -> Int -> IO ()
logSubErrors srv subs updated = forM_ (L.group $ L.sort $ L.map snd subs) $ \ss -> do
@@ -17,6 +17,7 @@ import Simplex.Messaging.Server.Stats
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
-- TODO [certs rcv] track service subscriptions and count/hash diffs for own and other servers + prometheus
data NtfServerStats = NtfServerStats
{ fromTime :: IORef UTCTime,
tknCreated :: IORef Int,
@@ -6,13 +6,15 @@ module Simplex.Messaging.Notifications.Server.Store.Migrations where
import Data.List (sortOn)
import Data.Text (Text)
import Simplex.Messaging.Agent.Store.Postgres.Migrations.Util
import Simplex.Messaging.Agent.Store.Shared
import Text.RawString.QQ (r)
ntfServerSchemaMigrations :: [(String, Text, Maybe Text)]
ntfServerSchemaMigrations =
[ ("20250417_initial", m20250417_initial, Nothing),
("20250517_service_cert", m20250517_service_cert, Just down_m20250517_service_cert)
("20250517_service_cert", m20250517_service_cert, Just down_m20250517_service_cert),
("20250830_queue_ids_hash", m20250830_queue_ids_hash, Just down_m20250830_queue_ids_hash)
]
-- | The list of migrations in ascending order by date
@@ -101,3 +103,125 @@ ALTER TABLE smp_servers DROP COLUMN ntf_service_id;
ALTER TABLE subscriptions DROP COLUMN ntf_service_assoc;
|]
m20250830_queue_ids_hash :: Text
m20250830_queue_ids_hash =
createXorHashFuncs
<> [r|
ALTER TABLE smp_servers
ADD COLUMN smp_notifier_count BIGINT NOT NULL DEFAULT 0,
ADD COLUMN smp_notifier_ids_hash BYTEA NOT NULL DEFAULT '\x00000000000000000000000000000000';
CREATE FUNCTION should_subscribe_status(p_status TEXT) RETURNS BOOLEAN
LANGUAGE plpgsql IMMUTABLE STRICT
AS $$
BEGIN
RETURN p_status IN ('NEW', 'PENDING', 'ACTIVE', 'INACTIVE');
END;
$$;
CREATE FUNCTION update_all_aggregates() RETURNS VOID
LANGUAGE plpgsql
AS $$
BEGIN
WITH acc AS (
SELECT
s.smp_server_id,
count(smp_notifier_id) as notifier_count,
xor_aggregate(public.digest(s.smp_notifier_id, 'md5')) AS notifier_hash
FROM subscriptions s
WHERE s.ntf_service_assoc = true AND should_subscribe_status(s.status)
GROUP BY s.smp_server_id
)
UPDATE smp_servers srv
SET smp_notifier_count = COALESCE(acc.notifier_count, 0),
smp_notifier_ids_hash = COALESCE(acc.notifier_hash, '\x00000000000000000000000000000000')
FROM acc
WHERE srv.smp_server_id = acc.smp_server_id;
END;
$$;
SELECT update_all_aggregates();
CREATE FUNCTION update_aggregates(p_server_id BIGINT, p_change BIGINT, p_notifier_id BYTEA) RETURNS VOID
LANGUAGE plpgsql
AS $$
BEGIN
UPDATE smp_servers
SET smp_notifier_count = smp_notifier_count + p_change,
smp_notifier_ids_hash = xor_combine(smp_notifier_ids_hash, public.digest(p_notifier_id, 'md5'))
WHERE smp_server_id = p_server_id;
END;
$$;
CREATE FUNCTION on_subscription_insert() RETURNS TRIGGER
LANGUAGE plpgsql
AS $$
BEGIN
IF NEW.ntf_service_assoc = true AND should_subscribe_status(NEW.status) THEN
PERFORM update_aggregates(NEW.smp_server_id, 1, NEW.smp_notifier_id);
END IF;
RETURN NEW;
END;
$$;
CREATE FUNCTION on_subscription_delete() RETURNS TRIGGER
LANGUAGE plpgsql
AS $$
BEGIN
IF OLD.ntf_service_assoc = true AND should_subscribe_status(OLD.status) THEN
PERFORM update_aggregates(OLD.smp_server_id, -1, OLD.smp_notifier_id);
END IF;
RETURN OLD;
END;
$$;
CREATE FUNCTION on_subscription_update() RETURNS TRIGGER
LANGUAGE plpgsql
AS $$
BEGIN
IF OLD.ntf_service_assoc = true AND should_subscribe_status(OLD.status) THEN
IF NOT (NEW.ntf_service_assoc = true AND should_subscribe_status(NEW.status)) THEN
PERFORM update_aggregates(OLD.smp_server_id, -1, OLD.smp_notifier_id);
END IF;
ELSIF NEW.ntf_service_assoc = true AND should_subscribe_status(NEW.status) THEN
PERFORM update_aggregates(NEW.smp_server_id, 1, NEW.smp_notifier_id);
END IF;
RETURN NEW;
END;
$$;
CREATE TRIGGER tr_subscriptions_insert
AFTER INSERT ON subscriptions
FOR EACH ROW EXECUTE PROCEDURE on_subscription_insert();
CREATE TRIGGER tr_subscriptions_delete
AFTER DELETE ON subscriptions
FOR EACH ROW EXECUTE PROCEDURE on_subscription_delete();
CREATE TRIGGER tr_subscriptions_update
AFTER UPDATE ON subscriptions
FOR EACH ROW EXECUTE PROCEDURE on_subscription_update();
|]
down_m20250830_queue_ids_hash :: Text
down_m20250830_queue_ids_hash =
[r|
DROP TRIGGER tr_subscriptions_insert ON subscriptions;
DROP TRIGGER tr_subscriptions_delete ON subscriptions;
DROP TRIGGER tr_subscriptions_update ON subscriptions;
DROP FUNCTION on_subscription_insert;
DROP FUNCTION on_subscription_delete;
DROP FUNCTION on_subscription_update;
DROP FUNCTION update_aggregates;
DROP FUNCTION update_all_aggregates;
DROP FUNCTION should_subscribe_status;
ALTER TABLE smp_servers
DROP COLUMN smp_notifier_count,
DROP COLUMN smp_notifier_ids_hash;
|]
<> dropXorHashFuncs
@@ -64,7 +64,7 @@ import Simplex.Messaging.Notifications.Server.Store (NtfSTMStore (..), NtfSubDat
import Simplex.Messaging.Notifications.Server.Store.Migrations
import Simplex.Messaging.Notifications.Server.Store.Types
import Simplex.Messaging.Notifications.Server.StoreLog
import Simplex.Messaging.Protocol (EntityId (..), EncNMsgMeta, ErrorType (..), NotifierId, NtfPrivateAuthKey, NtfPublicAuthKey, SMPServer, ServiceId, pattern SMPServer)
import Simplex.Messaging.Protocol (EntityId (..), EncNMsgMeta, ErrorType (..), IdsHash (..), NotifierId, NtfPrivateAuthKey, NtfPublicAuthKey, SMPServer, ServiceId, ServiceSub (..), pattern SMPServer)
import Simplex.Messaging.Server.QueueStore.Postgres (handleDuplicate, withLog_)
import Simplex.Messaging.Server.QueueStore.Postgres.Config (PostgresStoreCfg (..))
import Simplex.Messaging.Server.StoreLog (openWriteStoreLog)
@@ -239,7 +239,7 @@ updateTknCronInterval st tknId cronInt =
-- Reads servers that have subscriptions that need subscribing.
-- It is executed on server start, and it is supposed to crash on database error
getUsedSMPServers :: NtfPostgresStore -> IO [(SMPServer, Int64, Maybe (ServiceId, Int64))]
getUsedSMPServers :: NtfPostgresStore -> IO [(SMPServer, Int64, Maybe ServiceSub)]
getUsedSMPServers st =
withTransaction (dbStore st) $ \db ->
map rowToSrvSubs <$>
@@ -247,25 +247,17 @@ getUsedSMPServers st =
db
[sql|
SELECT
p.smp_host, p.smp_port, p.smp_keyhash, p.smp_server_id, p.ntf_service_id,
SUM(CASE WHEN s.ntf_service_assoc THEN s.subs_count ELSE 0 END) :: BIGINT as service_subs_count
FROM smp_servers p
JOIN (
SELECT
smp_server_id,
ntf_service_assoc,
COUNT(1) as subs_count
FROM subscriptions
WHERE status IN ?
GROUP BY smp_server_id, ntf_service_assoc
) s ON s.smp_server_id = p.smp_server_id
GROUP BY p.smp_host, p.smp_port, p.smp_keyhash, p.smp_server_id, p.ntf_service_id
smp_host, smp_port, smp_keyhash, smp_server_id,
ntf_service_id, smp_notifier_count, smp_notifier_ids_hash
FROM smp_servers
WHERE EXISTS (SELECT 1 FROM subscriptions WHERE status IN ?)
|]
(Only (In [NSNew, NSPending, NSActive, NSInactive]))
(Only (In subscribeNtfStatuses))
where
rowToSrvSubs :: SMPServerRow :. (Int64, Maybe ServiceId, Int64) -> (SMPServer, Int64, Maybe (ServiceId, Int64))
rowToSrvSubs ((host, port, kh) :. (srvId, serviceId_, subsCount)) =
(SMPServer host port kh, srvId, (,subsCount) <$> serviceId_)
rowToSrvSubs :: SMPServerRow :. (Int64, Maybe ServiceId, Int64, IdsHash) -> (SMPServer, Int64, Maybe ServiceSub)
rowToSrvSubs ((host, port, kh) :. (srvId, serviceId_, n, idsHash)) =
let service_ = (\serviceId -> ServiceSub serviceId n idsHash) <$> serviceId_
in (SMPServer host port kh, srvId, service_)
getServerNtfSubscriptions :: NtfPostgresStore -> Int64 -> Maybe NtfSubscriptionId -> Int -> IO (Either ErrorType [ServerNtfSub])
getServerNtfSubscriptions st srvId afterSubId_ count =
@@ -273,9 +265,9 @@ getServerNtfSubscriptions st srvId afterSubId_ count =
subs <-
map toServerNtfSub <$> case afterSubId_ of
Nothing ->
DB.query db (query <> orderLimit) (srvId, statusIn, count)
DB.query db (query <> orderLimit) (srvId, In subscribeNtfStatuses, count)
Just afterSubId ->
DB.query db (query <> " AND subscription_id > ?" <> orderLimit) (srvId, statusIn, afterSubId, count)
DB.query db (query <> " AND subscription_id > ?" <> orderLimit) (srvId, In subscribeNtfStatuses, afterSubId, count)
void $
DB.executeMany
db
@@ -296,7 +288,6 @@ getServerNtfSubscriptions st srvId afterSubId_ count =
WHERE smp_server_id = ? AND NOT ntf_service_assoc AND status IN ?
|]
orderLimit = " ORDER BY subscription_id LIMIT ?"
statusIn = In [NSNew, NSPending, NSActive, NSInactive]
toServerNtfSub (ntfSubId, notifierId, notifierKey) = (ntfSubId, (notifierId, notifierKey))
-- Returns token and subscription.
@@ -15,6 +15,123 @@ SET row_security = off;
CREATE SCHEMA ntf_server;
CREATE FUNCTION ntf_server.on_subscription_delete() RETURNS trigger
LANGUAGE plpgsql
AS $$
BEGIN
IF OLD.ntf_service_assoc = true AND should_subscribe_status(OLD.status) THEN
PERFORM update_aggregates(OLD.smp_server_id, -1, OLD.smp_notifier_id);
END IF;
RETURN OLD;
END;
$$;
CREATE FUNCTION ntf_server.on_subscription_insert() RETURNS trigger
LANGUAGE plpgsql
AS $$
BEGIN
IF NEW.ntf_service_assoc = true AND should_subscribe_status(NEW.status) THEN
PERFORM update_aggregates(NEW.smp_server_id, 1, NEW.smp_notifier_id);
END IF;
RETURN NEW;
END;
$$;
CREATE FUNCTION ntf_server.on_subscription_update() RETURNS trigger
LANGUAGE plpgsql
AS $$
BEGIN
IF OLD.ntf_service_assoc = true AND should_subscribe_status(OLD.status) THEN
IF NOT (NEW.ntf_service_assoc = true AND should_subscribe_status(NEW.status)) THEN
PERFORM update_aggregates(OLD.smp_server_id, -1, OLD.smp_notifier_id);
END IF;
ELSIF NEW.ntf_service_assoc = true AND should_subscribe_status(NEW.status) THEN
PERFORM update_aggregates(NEW.smp_server_id, 1, NEW.smp_notifier_id);
END IF;
RETURN NEW;
END;
$$;
CREATE FUNCTION ntf_server.should_subscribe_status(p_status text) RETURNS boolean
LANGUAGE plpgsql IMMUTABLE STRICT
AS $$
BEGIN
RETURN p_status IN ('NEW', 'PENDING', 'ACTIVE', 'INACTIVE');
END;
$$;
CREATE FUNCTION ntf_server.update_aggregates(p_server_id bigint, p_change bigint, p_notifier_id bytea) RETURNS void
LANGUAGE plpgsql
AS $$
BEGIN
UPDATE smp_servers
SET smp_notifier_count = smp_notifier_count + p_change,
smp_notifier_ids_hash = xor_combine(smp_notifier_ids_hash, public.digest(p_notifier_id, 'md5'))
WHERE smp_server_id = p_server_id;
END;
$$;
CREATE FUNCTION ntf_server.update_all_aggregates() RETURNS void
LANGUAGE plpgsql
AS $$
BEGIN
WITH acc AS (
SELECT
s.smp_server_id,
count(smp_notifier_id) as notifier_count,
xor_aggregate(public.digest(s.smp_notifier_id, 'md5')) AS notifier_hash
FROM subscriptions s
WHERE s.ntf_service_assoc = true AND should_subscribe_status(s.status)
GROUP BY s.smp_server_id
)
UPDATE smp_servers srv
SET smp_notifier_count = COALESCE(acc.notifier_count, 0),
smp_notifier_ids_hash = COALESCE(acc.notifier_hash, '\x00000000000000000000000000000000')
FROM acc
WHERE srv.smp_server_id = acc.smp_server_id;
END;
$$;
CREATE FUNCTION ntf_server.xor_combine(state bytea, value bytea) RETURNS bytea
LANGUAGE plpgsql IMMUTABLE STRICT
AS $$
DECLARE
result BYTEA := state;
i INTEGER;
len INTEGER := octet_length(value);
BEGIN
IF octet_length(state) != len THEN
RAISE EXCEPTION 'Inputs must be equal length (% != %)', octet_length(state), len;
END IF;
FOR i IN 0..len-1 LOOP
result := set_byte(result, i, get_byte(state, i) # get_byte(value, i));
END LOOP;
RETURN result;
END;
$$;
CREATE AGGREGATE ntf_server.xor_aggregate(bytea) (
SFUNC = ntf_server.xor_combine,
STYPE = bytea,
INITCOND = '\x00000000000000000000000000000000'
);
SET default_table_access_method = heap;
@@ -53,7 +170,9 @@ CREATE TABLE ntf_server.smp_servers (
smp_host text NOT NULL,
smp_port text NOT NULL,
smp_keyhash bytea NOT NULL,
ntf_service_id bytea
ntf_service_id bytea,
smp_notifier_count bigint DEFAULT 0 NOT NULL,
smp_notifier_ids_hash bytea DEFAULT '\x00000000000000000000000000000000'::bytea NOT NULL
);
@@ -158,6 +277,18 @@ CREATE INDEX idx_tokens_status_cron_interval_sent_at ON ntf_server.tokens USING
CREATE TRIGGER tr_subscriptions_delete AFTER DELETE ON ntf_server.subscriptions FOR EACH ROW EXECUTE FUNCTION ntf_server.on_subscription_delete();
CREATE TRIGGER tr_subscriptions_insert AFTER INSERT ON ntf_server.subscriptions FOR EACH ROW EXECUTE FUNCTION ntf_server.on_subscription_insert();
CREATE TRIGGER tr_subscriptions_update AFTER UPDATE ON ntf_server.subscriptions FOR EACH ROW EXECUTE FUNCTION ntf_server.on_subscription_update();
ALTER TABLE ONLY ntf_server.last_notifications
ADD CONSTRAINT last_notifications_subscription_id_fkey FOREIGN KEY (subscription_id) REFERENCES ntf_server.subscriptions(subscription_id) ON UPDATE RESTRICT ON DELETE CASCADE;
+61 -11
View File
@@ -140,7 +140,10 @@ module Simplex.Messaging.Protocol
RcvMessage (..),
MsgId,
MsgBody,
IdsHash,
IdsHash (..),
ServiceSub (..),
queueIdsHash,
queueIdHash,
MaxMessageLen,
MaxRcvMessageLen,
EncRcvMsgBody (..),
@@ -223,6 +226,8 @@ import qualified Data.Aeson.TH as J
import Data.Attoparsec.ByteString.Char8 (Parser, (<?>))
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.Bifunctor (bimap, first)
import Data.Bits (xor)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Base64 as B64
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
@@ -232,6 +237,7 @@ import Data.Constraint (Dict (..))
import Data.Functor (($>))
import Data.Int (Int64)
import Data.Kind
import Data.List (foldl')
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as L
import Data.Maybe (isJust, isNothing)
@@ -241,7 +247,7 @@ import qualified Data.Text as T
import Data.Text.Encoding (decodeLatin1, encodeUtf8)
import Data.Time.Clock.System (SystemTime (..), systemToUTCTime)
import Data.Type.Equality
import Data.Word (Word16)
import Data.Word (Word8, Word16)
import GHC.TypeLits (ErrorMessage (..), TypeError, type (+))
import qualified GHC.TypeLits as TE
import qualified GHC.TypeLits as Type
@@ -548,7 +554,8 @@ data Command (p :: Party) where
NEW :: NewQueueReq -> Command Creator
SUB :: Command Recipient
-- | subscribe all associated queues. Service ID must be used as entity ID, and service session key must sign the command.
SUBS :: Command RecipientService
-- Parameters are expected queue count and hash of all subscribed queues, it allows to monitor "state drift" on the server
SUBS :: Int64 -> IdsHash -> Command RecipientService
KEY :: SndPublicAuthKey -> Command Recipient
RKEY :: NonEmpty RcvPublicAuthKey -> Command Recipient
LSET :: LinkId -> QueueLinkData -> Command Recipient
@@ -572,7 +579,7 @@ data Command (p :: Party) where
-- SMP notification subscriber commands
NSUB :: Command Notifier
-- | subscribe all associated queues. Service ID must be used as entity ID, and service session key must sign the command.
NSUBS :: Command NotifierService
NSUBS :: Int64 -> IdsHash -> Command NotifierService
PRXY :: SMPServer -> Maybe BasicAuth -> Command ProxiedClient -- request a relay server connection by URI
-- Transmission to proxy:
-- - entity ID: ID of the session with relay returned in PKEY (response to PRXY)
@@ -698,7 +705,7 @@ data BrokerMsg where
LNK :: SenderId -> QueueLinkData -> BrokerMsg
-- | Service subscription success - confirms when queue was associated with the service
SOK :: Maybe ServiceId -> BrokerMsg
-- | The number of queues subscribed with SUBS command
-- | The number of queues and XOR-hash of their IDs subscribed with SUBS command
SOKS :: Int64 -> IdsHash -> BrokerMsg
-- MSG v1/2 has to be supported for encoding/decoding
-- v1: MSG :: MsgId -> SystemTime -> MsgBody -> BrokerMsg
@@ -1460,7 +1467,42 @@ type MsgId = ByteString
-- | SMP message body.
type MsgBody = ByteString
type IdsHash = ByteString
data ServiceSub = ServiceSub
{ serviceId :: ServiceId,
smpQueueCount :: Int64,
smpQueueIdsHash :: IdsHash
}
newtype IdsHash = IdsHash {unIdsHash :: BS.ByteString}
deriving (Eq, Show)
deriving newtype (Encoding, FromField)
instance ToField IdsHash where
toField (IdsHash s) = toField (Binary s)
{-# INLINE toField #-}
instance Semigroup IdsHash where
(IdsHash s1) <> (IdsHash s2) = IdsHash $! BS.pack $ BS.zipWith xor s1 s2
instance Monoid IdsHash where
mempty = IdsHash $ BS.replicate 16 0
mconcat ss =
let !s' = BS.pack $ foldl' (\ !r (IdsHash s) -> zipWith xor' r (BS.unpack s)) (replicate 16 0) ss -- to prevent packing/unpacking in <> on each step with default mappend
in IdsHash s'
xor' :: Word8 -> Word8 -> Word8
xor' x y = let !r = xor x y in r
noIdsHash ::IdsHash
noIdsHash = IdsHash B.empty
{-# INLINE noIdsHash #-}
queueIdsHash :: [QueueId] -> IdsHash
queueIdsHash = mconcat . map queueIdHash
queueIdHash :: QueueId -> IdsHash
queueIdHash = IdsHash . C.md5Hash . unEntityId
{-# INLINE queueIdHash #-}
data ProtocolErrorType = PECmdSyntax | PECmdUnknown | PESession | PEBlock
@@ -1695,7 +1737,9 @@ instance PartyI p => ProtocolEncoding SMPVersion ErrorType (Command p) where
new = e (NEW_, ' ', rKey, dhKey)
auth = maybe "" (e . ('A',)) auth_
SUB -> e SUB_
SUBS -> e SUBS_
SUBS n idsHash
| v >= rcvServiceSMPVersion -> e (SUBS_, ' ', n, idsHash)
| otherwise -> e SUBS_
KEY k -> e (KEY_, ' ', k)
RKEY ks -> e (RKEY_, ' ', ks)
LSET lnkId d -> e (LSET_, ' ', lnkId, d)
@@ -1711,7 +1755,9 @@ instance PartyI p => ProtocolEncoding SMPVersion ErrorType (Command p) where
SEND flags msg -> e (SEND_, ' ', flags, ' ', Tail msg)
PING -> e PING_
NSUB -> e NSUB_
NSUBS -> e NSUBS_
NSUBS n idsHash
| v >= rcvServiceSMPVersion -> e (NSUBS_, ' ', n, idsHash)
| otherwise -> e NSUBS_
LKEY k -> e (LKEY_, ' ', k)
LGET -> e LGET_
PRXY host auth_ -> e (PRXY_, ' ', host, auth_)
@@ -1802,7 +1848,9 @@ instance ProtocolEncoding SMPVersion ErrorType Cmd where
OFF_ -> pure OFF
DEL_ -> pure DEL
QUE_ -> pure QUE
CT SRecipientService SUBS_ -> pure $ Cmd SRecipientService SUBS
CT SRecipientService SUBS_
| v >= rcvServiceSMPVersion -> Cmd SRecipientService <$> (SUBS <$> _smpP <*> smpP)
| otherwise -> pure $ Cmd SRecipientService $ SUBS (-1) noIdsHash
CT SSender tag ->
Cmd SSender <$> case tag of
SKEY_ -> SKEY <$> _smpP
@@ -1819,7 +1867,9 @@ instance ProtocolEncoding SMPVersion ErrorType Cmd where
PFWD_ -> PFWD <$> _smpP <*> smpP <*> (EncTransmission . unTail <$> smpP)
PRXY_ -> PRXY <$> _smpP <*> smpP
CT SNotifier NSUB_ -> pure $ Cmd SNotifier NSUB
CT SNotifierService NSUBS_ -> pure $ Cmd SNotifierService NSUBS
CT SNotifierService NSUBS_
| v >= rcvServiceSMPVersion -> Cmd SNotifierService <$> (NSUBS <$> _smpP <*> smpP)
| otherwise -> pure $ Cmd SNotifierService $ NSUBS (-1) noIdsHash
fromProtocolError = fromProtocolError @SMPVersion @ErrorType @BrokerMsg
{-# INLINE fromProtocolError #-}
@@ -1901,7 +1951,7 @@ instance ProtocolEncoding SMPVersion ErrorType BrokerMsg where
SOK_ -> SOK <$> _smpP
SOKS_
| v >= rcvServiceSMPVersion -> SOKS <$> _smpP <*> smpP
| otherwise -> SOKS <$> _smpP <*> pure B.empty
| otherwise -> SOKS <$> _smpP <*> pure noIdsHash
NID_ -> NID <$> _smpP <*> smpP
NMSG_ -> NMSG <$> _smpP <*> smpP
PKEY_ -> PKEY <$> _smpP <*> smpP <*> smpP
+29 -20
View File
@@ -6,6 +6,7 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE OverloadedLists #-}
@@ -1247,7 +1248,7 @@ verifyQueueTransmission service thAuth (tAuth, authorized, (corrId, entId, comma
vc SCreator (NEW NewQueueReq {rcvAuthKey = k}) = verifiedWith k
vc SRecipient SUB = verifyQueue $ \q -> verifiedWithKeys $ recipientKeys (snd q)
vc SRecipient _ = verifyQueue $ \q -> verifiedWithKeys $ recipientKeys (snd q)
vc SRecipientService SUBS = verifyServiceCmd
vc SRecipientService SUBS {} = verifyServiceCmd
vc SSender (SKEY k) = verifySecure k
-- SEND will be accepted without authorization before the queue is secured with KEY, SKEY or LSKEY command
vc SSender SEND {} = verifyQueue $ \q -> if maybe (isNothing tAuth) verify (senderKey $ snd q) then VRVerified q_ else VRFailed AUTH
@@ -1255,7 +1256,7 @@ verifyQueueTransmission service thAuth (tAuth, authorized, (corrId, entId, comma
vc SSenderLink (LKEY k) = verifySecure k
vc SSenderLink LGET = verifyQueue $ \q -> if isContactQueue (snd q) then VRVerified q_ else VRFailed AUTH
vc SNotifier NSUB = verifyQueue $ \q -> maybe dummyVerify (\n -> verifiedWith $ notifierKey n) (notifier $ snd q)
vc SNotifierService NSUBS = verifyServiceCmd
vc SNotifierService NSUBS {} = verifyServiceCmd
vc SProxiedClient _ = VRVerified Nothing
vc SProxyService (RFWD _) = VRVerified Nothing
checkRole = case (service, partyClientRole p) of
@@ -1465,8 +1466,8 @@ client
Cmd SNotifier NSUB -> response . (corrId,entId,) <$> case q_ of
Just (q, QueueRec {notifier = Just ntfCreds}) -> subscribeNotifications q ntfCreds
_ -> pure $ ERR INTERNAL
Cmd SNotifierService NSUBS -> response . (corrId,entId,) <$> case clntServiceId of
Just serviceId -> subscribeServiceNotifications serviceId
Cmd SNotifierService (NSUBS n idsHash) -> response . (corrId,entId,) <$> case clntServiceId of
Just serviceId -> subscribeServiceNotifications serviceId (n, idsHash)
Nothing -> pure $ ERR INTERNAL
Cmd SCreator (NEW nqr@NewQueueReq {auth_}) ->
response <$> ifM allowNew (createQueue nqr) (pure (corrId, entId, ERR AUTH))
@@ -1495,8 +1496,8 @@ client
OFF -> response <$> maybe (pure $ err INTERNAL) suspendQueue_ q_
DEL -> response <$> maybe (pure $ err INTERNAL) delQueueAndMsgs q_
QUE -> withQueue $ \q qr -> (corrId,entId,) <$> getQueueInfo q qr
Cmd SRecipientService SUBS -> response . (corrId,entId,) <$> case clntServiceId of
Just serviceId -> subscribeServiceMessages serviceId
Cmd SRecipientService (SUBS n idsHash)-> response . (corrId,entId,) <$> case clntServiceId of
Just serviceId -> subscribeServiceMessages serviceId (n, idsHash)
Nothing -> pure $ ERR INTERNAL -- it's "internal" because it should never get to this branch
where
createQueue :: NewQueueReq -> M s (Transmission BrokerMsg)
@@ -1795,9 +1796,9 @@ client
TM.insert entId sub $ clientSubs clnt
pure (False, Just sub)
subscribeServiceMessages :: ServiceId -> M s BrokerMsg
subscribeServiceMessages serviceId =
sharedSubscribeService SRecipientService serviceId subscribers serviceSubscribed serviceSubsCount >>= \case
subscribeServiceMessages :: ServiceId -> (Int64, IdsHash) -> M s BrokerMsg
subscribeServiceMessages serviceId expected =
sharedSubscribeService SRecipientService serviceId expected subscribers serviceSubscribed serviceSubsCount rcvServices >>= \case
Left e -> pure $ ERR e
Right (hasSub, (count, idsHash)) -> do
unless hasSub $ forkClient clnt "deliverServiceMessages" $ liftIO $ deliverServiceMessages count
@@ -1806,7 +1807,7 @@ client
deliverServiceMessages expectedCnt = do
(qCnt, _msgCnt, _dupCnt, _errCnt) <- foldRcvServiceMessages ms serviceId deliverQueueMsg (0, 0, 0, 0)
atomically $ writeTBQueue msgQ [(NoCorrId, NoEntity, SALL)]
-- TODO [cert rcv] compare with expected
-- TODO [certs rcv] compare with expected
logNote $ "Service subscriptions for " <> tshow serviceId <> " (" <> tshow qCnt <> " queues)"
deliverQueueMsg :: (Int, Int, Int, Int) -> RecipientId -> Either ErrorType (Maybe (QueueRec, Message)) -> IO (Int, Int, Int, Int)
deliverQueueMsg (!qCnt, !msgCnt, !dupCnt, !errCnt) rId = \case
@@ -1831,25 +1832,33 @@ client
TM.insert rId sub $ subscriptions clnt
pure $ Just sub
subscribeServiceNotifications :: ServiceId -> M s BrokerMsg
subscribeServiceNotifications serviceId =
either ERR (uncurry SOKS . snd) <$> sharedSubscribeService SNotifierService serviceId ntfSubscribers ntfServiceSubscribed ntfServiceSubsCount
subscribeServiceNotifications :: ServiceId -> (Int64, IdsHash) -> M s BrokerMsg
subscribeServiceNotifications serviceId expected =
either ERR (uncurry SOKS . snd) <$> sharedSubscribeService SNotifierService serviceId expected ntfSubscribers ntfServiceSubscribed ntfServiceSubsCount ntfServices
sharedSubscribeService :: (PartyI p, ServiceParty p) => SParty p -> ServiceId -> ServerSubscribers s -> (Client s -> TVar Bool) -> (Client s -> TVar Int64) -> M s (Either ErrorType (Bool, (Int64, IdsHash)))
sharedSubscribeService party serviceId srvSubscribers clientServiceSubscribed clientServiceSubs = do
sharedSubscribeService :: (PartyI p, ServiceParty p) => SParty p -> ServiceId -> (Int64, IdsHash) -> ServerSubscribers s -> (Client s -> TVar Bool) -> (Client s -> TVar Int64) -> (ServerStats -> ServiceStats) -> M s (Either ErrorType (Bool, (Int64, IdsHash)))
sharedSubscribeService party serviceId (count, idsHash) srvSubscribers clientServiceSubscribed clientServiceSubs servicesSel = do
subscribed <- readTVarIO $ clientServiceSubscribed clnt
stats <- asks serverStats
liftIO $ runExceptT $
(subscribed,)
<$> if subscribed
then (,B.empty) <$> readTVarIO (clientServiceSubs clnt) -- TODO [certs rcv] get IDs hash
then (,mempty) <$> readTVarIO (clientServiceSubs clnt) -- TODO [certs rcv] get IDs hash
else do
count' <- ExceptT $ getServiceQueueCount @(StoreQueue s) (queueStore ms) party serviceId
(count', idsHash') <- ExceptT $ getServiceQueueCountHash @(StoreQueue s) (queueStore ms) party serviceId
incCount <- atomically $ do
writeTVar (clientServiceSubscribed clnt) True
count <- swapTVar (clientServiceSubs clnt) count'
pure $ count' - count
currCount <- swapTVar (clientServiceSubs clnt) count' -- TODO [certs rcv] maintain IDs hash here?
pure $ count' - currCount
let incSrvStat sel n = liftIO $ atomicModifyIORef'_ (sel $ servicesSel stats) (+ n)
diff = fromIntegral $ count' - count
if -- TODO [certs rcv] account for not provided counts/hashes (expected n = -1)
| diff == 0 && idsHash == idsHash' -> incSrvStat srvSubOk 1
| diff > 0 -> incSrvStat srvSubMore 1 >> incSrvStat srvSubMoreTotal diff
| diff < 0 -> incSrvStat srvSubFewer 1 >> incSrvStat srvSubFewerTotal (- diff)
| otherwise -> incSrvStat srvSubDiff 1
atomically $ writeTQueue (subQ srvSubscribers) (CSService serviceId incCount, clientId)
pure (count', B.empty) -- TODO [certs rcv] get IDs hash
pure (count', idsHash')
acknowledgeMsg :: MsgId -> StoreQueue s -> QueueRec -> M s (Transmission BrokerMsg)
acknowledgeMsg msgId q qr =
@@ -355,8 +355,8 @@ instance QueueStoreClass (JournalQueue s) (QStore s) where
{-# INLINE setQueueService #-}
getQueueNtfServices = withQS (getQueueNtfServices @(JournalQueue s))
{-# INLINE getQueueNtfServices #-}
getServiceQueueCount = withQS (getServiceQueueCount @(JournalQueue s))
{-# INLINE getServiceQueueCount #-}
getServiceQueueCountHash = withQS (getServiceQueueCountHash @(JournalQueue s))
{-# INLINE getServiceQueueCountHash #-}
makeQueue_ :: JournalMsgStore s -> RecipientId -> QueueRec -> Lock -> IO (JournalQueue s)
makeQueue_ JournalMsgStore {sharedLock} rId qr queueLock = do
@@ -21,6 +21,7 @@ import Simplex.Messaging.Transport (simplexMQVersion)
import Simplex.Messaging.Transport.Server (SocketStats (..))
import Simplex.Messaging.Util (tshow)
-- TODO [certs rcv] add service subscriptions and count/hash diffs
data ServerMetrics = ServerMetrics
{ statsData :: ServerStatsData,
activeQueueCounts :: PeriodStatCounts,
@@ -65,6 +65,7 @@ data ServiceRec = ServiceRec
serviceCert :: X.CertificateChain,
serviceCertHash :: XV.Fingerprint, -- SHA512 hash of long-term service client certificate. See comment for ClientHandshake.
serviceCreatedAt :: SystemDate
-- entitiesHash :: IdsHash -- a xor-hash of all associated entities
}
deriving (Show)
@@ -524,15 +524,11 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where
let (sNtfs, restNtfs) = partition (\(nId, _) -> S.member nId snIds) ntfs'
in ((serviceId, sNtfs) : ssNtfs, restNtfs)
getServiceQueueCount :: (PartyI p, ServiceParty p) => PostgresQueueStore q -> SParty p -> ServiceId -> IO (Either ErrorType Int64)
getServiceQueueCount st party serviceId =
E.uninterruptibleMask_ $ runExceptT $ withDB' "getServiceQueueCount" st $ \db ->
maybeFirstRow' 0 fromOnly $
DB.query db query (Only serviceId)
where
query = case party of
SRecipientService -> "SELECT count(1) FROM msg_queues WHERE rcv_service_id = ? AND deleted_at IS NULL"
SNotifierService -> "SELECT count(1) FROM msg_queues WHERE ntf_service_id = ? AND deleted_at IS NULL"
getServiceQueueCountHash :: (PartyI p, ServiceParty p) => PostgresQueueStore q -> SParty p -> ServiceId -> IO (Either ErrorType (Int64, IdsHash))
getServiceQueueCountHash st party serviceId =
E.uninterruptibleMask_ $ runExceptT $ withDB' "getServiceQueueCountHash" st $ \db ->
maybeFirstRow' (0, mempty) id $
DB.query db ("SELECT queue_count, queue_ids_hash FROM services WHERE service_id = ? AND service_role = ?") (serviceId, partyServiceRole party)
batchInsertServices :: [STMService] -> PostgresQueueStore q -> IO Int64
batchInsertServices services' toStore =
@@ -793,6 +789,10 @@ instance ToField C.APublicAuthKey where toField = toField . Binary . C.encodePub
instance FromField C.APublicAuthKey where fromField = blobFieldDecoder C.decodePubKey
instance ToField IdsHash where toField (IdsHash s) = toField (Binary s)
deriving newtype instance FromField IdsHash
instance ToField EncDataBytes where toField (EncDataBytes s) = toField (Binary s)
deriving newtype instance FromField EncDataBytes
@@ -7,6 +7,7 @@ module Simplex.Messaging.Server.QueueStore.Postgres.Migrations where
import Data.List (sortOn)
import Data.Text (Text)
import Simplex.Messaging.Agent.Store.Shared
import Simplex.Messaging.Agent.Store.Postgres.Migrations.Util
import Text.RawString.QQ (r)
serverSchemaMigrations :: [(String, Text, Maybe Text)]
@@ -15,7 +16,8 @@ serverSchemaMigrations =
("20250319_updated_index", m20250319_updated_index, Just down_m20250319_updated_index),
("20250320_short_links", m20250320_short_links, Just down_m20250320_short_links),
("20250514_service_certs", m20250514_service_certs, Just down_m20250514_service_certs),
("20250903_store_messages", m20250903_store_messages, Just down_m20250903_store_messages)
("20250903_store_messages", m20250903_store_messages, Just down_m20250903_store_messages),
("20250915_queue_ids_hash", m20250915_queue_ids_hash, Just down_m20250915_queue_ids_hash)
]
-- | The list of migrations in ascending order by date
@@ -447,3 +449,139 @@ ALTER TABLE msg_queues
DROP TABLE messages;
|]
m20250915_queue_ids_hash :: Text
m20250915_queue_ids_hash =
createXorHashFuncs
<> [r|
ALTER TABLE services
ADD COLUMN queue_count BIGINT NOT NULL DEFAULT 0,
ADD COLUMN queue_ids_hash BYTEA NOT NULL DEFAULT '\x00000000000000000000000000000000';
CREATE FUNCTION update_all_aggregates() RETURNS VOID
LANGUAGE plpgsql
AS $$
BEGIN
WITH acc AS (
SELECT
s.service_id,
count(1) as q_count,
xor_aggregate(public.digest(CASE WHEN s.service_role = 'M' THEN q.recipient_id ELSE COALESCE(q.notifier_id, '\x00000000000000000000000000000000') END, 'md5')) AS q_ids_hash
FROM services s
JOIN msg_queues q ON (s.service_id = q.rcv_service_id AND s.service_role = 'M') OR (s.service_id = q.ntf_service_id AND s.service_role = 'N')
WHERE q.deleted_at IS NULL
GROUP BY s.service_id
)
UPDATE services s
SET queue_count = COALESCE(acc.q_count, 0),
queue_ids_hash = COALESCE(acc.q_ids_hash, '\x00000000000000000000000000000000')
FROM acc
WHERE s.service_id = acc.service_id;
END;
$$;
SELECT update_all_aggregates();
CREATE FUNCTION update_aggregates(p_service_id BYTEA, p_role TEXT, p_queue_id BYTEA, p_change BIGINT) RETURNS VOID
LANGUAGE plpgsql
AS $$
BEGIN
UPDATE services
SET queue_count = queue_count + p_change,
queue_ids_hash = xor_combine(queue_ids_hash, public.digest(p_queue_id, 'md5'))
WHERE service_id = p_service_id AND service_role = p_role;
END;
$$;
CREATE FUNCTION on_queue_insert() RETURNS TRIGGER
LANGUAGE plpgsql
AS $$
BEGIN
IF NEW.rcv_service_id IS NOT NULL THEN
PERFORM update_aggregates(NEW.rcv_service_id, 'M', NEW.recipient_id, 1);
END IF;
IF NEW.ntf_service_id IS NOT NULL AND NEW.notifier_id IS NOT NULL THEN
PERFORM update_aggregates(NEW.ntf_service_id, 'N', NEW.notifier_id, 1);
END IF;
RETURN NEW;
END;
$$;
CREATE FUNCTION on_queue_delete() RETURNS TRIGGER
LANGUAGE plpgsql
AS $$
BEGIN
IF OLD.deleted_at IS NULL THEN
IF OLD.rcv_service_id IS NOT NULL THEN
PERFORM update_aggregates(OLD.rcv_service_id, 'M', OLD.recipient_id, -1);
END IF;
IF OLD.ntf_service_id IS NOT NULL AND OLD.notifier_id IS NOT NULL THEN
PERFORM update_aggregates(OLD.ntf_service_id, 'N', OLD.notifier_id, -1);
END IF;
END IF;
RETURN OLD;
END;
$$;
CREATE FUNCTION on_queue_update() RETURNS TRIGGER
LANGUAGE plpgsql
AS $$
BEGIN
IF OLD.deleted_at IS NULL AND OLD.rcv_service_id IS NOT NULL THEN
IF NOT (NEW.deleted_at IS NULL AND NEW.rcv_service_id IS NOT NULL) THEN
PERFORM update_aggregates(OLD.rcv_service_id, 'M', OLD.recipient_id, -1);
ELSIF OLD.rcv_service_id IS DISTINCT FROM NEW.rcv_service_id THEN
PERFORM update_aggregates(OLD.rcv_service_id, 'M', OLD.recipient_id, -1);
PERFORM update_aggregates(NEW.rcv_service_id, 'M', NEW.recipient_id, 1);
END IF;
ELSIF NEW.deleted_at IS NULL AND NEW.rcv_service_id IS NOT NULL THEN
PERFORM update_aggregates(NEW.rcv_service_id, 'M', NEW.recipient_id, 1);
END IF;
IF OLD.deleted_at IS NULL AND OLD.ntf_service_id IS NOT NULL AND OLD.notifier_id IS NOT NULL THEN
IF NOT (NEW.deleted_at IS NULL AND NEW.ntf_service_id IS NOT NULL AND NEW.notifier_id IS NOT NULL) THEN
PERFORM update_aggregates(OLD.ntf_service_id, 'N', OLD.notifier_id, -1);
ELSIF OLD.ntf_service_id IS DISTINCT FROM NEW.ntf_service_id OR OLD.notifier_id IS DISTINCT FROM NEW.notifier_id THEN
PERFORM update_aggregates(OLD.ntf_service_id, 'N', OLD.notifier_id, -1);
PERFORM update_aggregates(NEW.ntf_service_id, 'N', NEW.notifier_id, 1);
END IF;
ELSIF NEW.deleted_at IS NULL AND NEW.ntf_service_id IS NOT NULL AND NEW.notifier_id IS NOT NULL THEN
PERFORM update_aggregates(NEW.ntf_service_id, 'N', NEW.notifier_id, 1);
END IF;
RETURN NEW;
END;
$$;
CREATE TRIGGER tr_queue_insert
AFTER INSERT ON msg_queues
FOR EACH ROW EXECUTE PROCEDURE on_queue_insert();
CREATE TRIGGER tr_queue_delete
AFTER DELETE ON msg_queues
FOR EACH ROW EXECUTE PROCEDURE on_queue_delete();
CREATE TRIGGER tr_queue_update
AFTER UPDATE ON msg_queues
FOR EACH ROW EXECUTE PROCEDURE on_queue_update();
|]
down_m20250915_queue_ids_hash :: Text
down_m20250915_queue_ids_hash =
[r|
DROP TRIGGER tr_queue_insert ON msg_queues;
DROP TRIGGER tr_queue_delete ON msg_queues;
DROP TRIGGER tr_queue_update ON msg_queues;
DROP FUNCTION on_queue_insert;
DROP FUNCTION on_queue_delete;
DROP FUNCTION on_queue_update;
DROP FUNCTION update_aggregates;
DROP FUNCTION update_all_aggregates;
ALTER TABLE services
DROP COLUMN queue_count,
DROP COLUMN queue_ids_hash;
|]
<> dropXorHashFuncs
@@ -104,6 +104,71 @@ $$;
CREATE FUNCTION smp_server.on_queue_delete() RETURNS trigger
LANGUAGE plpgsql
AS $$
BEGIN
IF OLD.deleted_at IS NULL THEN
IF OLD.rcv_service_id IS NOT NULL THEN
PERFORM update_aggregates(OLD.rcv_service_id, 'M', OLD.recipient_id, -1);
END IF;
IF OLD.ntf_service_id IS NOT NULL AND OLD.notifier_id IS NOT NULL THEN
PERFORM update_aggregates(OLD.ntf_service_id, 'N', OLD.notifier_id, -1);
END IF;
END IF;
RETURN OLD;
END;
$$;
CREATE FUNCTION smp_server.on_queue_insert() RETURNS trigger
LANGUAGE plpgsql
AS $$
BEGIN
IF NEW.rcv_service_id IS NOT NULL THEN
PERFORM update_aggregates(NEW.rcv_service_id, 'M', NEW.recipient_id, 1);
END IF;
IF NEW.ntf_service_id IS NOT NULL AND NEW.notifier_id IS NOT NULL THEN
PERFORM update_aggregates(NEW.ntf_service_id, 'N', NEW.notifier_id, 1);
END IF;
RETURN NEW;
END;
$$;
CREATE FUNCTION smp_server.on_queue_update() RETURNS trigger
LANGUAGE plpgsql
AS $$
BEGIN
IF OLD.deleted_at IS NULL AND OLD.rcv_service_id IS NOT NULL THEN
IF NOT (NEW.deleted_at IS NULL AND NEW.rcv_service_id IS NOT NULL) THEN
PERFORM update_aggregates(OLD.rcv_service_id, 'M', OLD.recipient_id, -1);
ELSIF OLD.rcv_service_id IS DISTINCT FROM NEW.rcv_service_id THEN
PERFORM update_aggregates(OLD.rcv_service_id, 'M', OLD.recipient_id, -1);
PERFORM update_aggregates(NEW.rcv_service_id, 'M', NEW.recipient_id, 1);
END IF;
ELSIF NEW.deleted_at IS NULL AND NEW.rcv_service_id IS NOT NULL THEN
PERFORM update_aggregates(NEW.rcv_service_id, 'M', NEW.recipient_id, 1);
END IF;
IF OLD.deleted_at IS NULL AND OLD.ntf_service_id IS NOT NULL AND OLD.notifier_id IS NOT NULL THEN
IF NOT (NEW.deleted_at IS NULL AND NEW.ntf_service_id IS NOT NULL AND NEW.notifier_id IS NOT NULL) THEN
PERFORM update_aggregates(OLD.ntf_service_id, 'N', OLD.notifier_id, -1);
ELSIF OLD.ntf_service_id IS DISTINCT FROM NEW.ntf_service_id OR OLD.notifier_id IS DISTINCT FROM NEW.notifier_id THEN
PERFORM update_aggregates(OLD.ntf_service_id, 'N', OLD.notifier_id, -1);
PERFORM update_aggregates(NEW.ntf_service_id, 'N', NEW.notifier_id, 1);
END IF;
ELSIF NEW.deleted_at IS NULL AND NEW.ntf_service_id IS NOT NULL AND NEW.notifier_id IS NOT NULL THEN
PERFORM update_aggregates(NEW.ntf_service_id, 'N', NEW.notifier_id, 1);
END IF;
RETURN NEW;
END;
$$;
CREATE FUNCTION smp_server.try_del_msg(p_recipient_id bytea, p_msg_id bytea) RETURNS TABLE(r_msg_id bytea, r_msg_ts bigint, r_msg_quota boolean, r_msg_ntf_flag boolean, r_msg_body bytea)
LANGUAGE plpgsql
AS $$
@@ -225,6 +290,43 @@ $$;
CREATE FUNCTION smp_server.update_aggregates(p_service_id bytea, p_role text, p_queue_id bytea, p_change bigint) RETURNS void
LANGUAGE plpgsql
AS $$
BEGIN
UPDATE services
SET queue_count = queue_count + p_change,
queue_ids_hash = xor_combine(queue_ids_hash, public.digest(p_queue_id, 'md5'))
WHERE service_id = p_service_id AND service_role = p_role;
END;
$$;
CREATE FUNCTION smp_server.update_all_aggregates() RETURNS void
LANGUAGE plpgsql
AS $$
BEGIN
WITH acc AS (
SELECT
s.service_id,
count(1) as q_count,
xor_aggregate(public.digest(CASE WHEN s.service_role = 'M' THEN q.recipient_id ELSE COALESCE(q.notifier_id, '\x00000000000000000000000000000000') END, 'md5')) AS q_ids_hash
FROM services s
JOIN msg_queues q ON (s.service_id = q.rcv_service_id AND s.service_role = 'M') OR (s.service_id = q.ntf_service_id AND s.service_role = 'N')
WHERE q.deleted_at IS NULL
GROUP BY s.service_id
)
UPDATE services s
SET queue_count = COALESCE(acc.q_count, 0),
queue_ids_hash = COALESCE(acc.q_ids_hash, '\x00000000000000000000000000000000')
FROM acc
WHERE s.service_id = acc.service_id;
END;
$$;
CREATE FUNCTION smp_server.write_message(p_recipient_id bytea, p_msg_id bytea, p_msg_ts bigint, p_msg_quota boolean, p_msg_ntf_flag boolean, p_msg_body bytea, p_quota integer) RETURNS TABLE(quota_written boolean, was_empty boolean)
LANGUAGE plpgsql
AS $$
@@ -256,6 +358,34 @@ END;
$$;
CREATE FUNCTION smp_server.xor_combine(state bytea, value bytea) RETURNS bytea
LANGUAGE plpgsql IMMUTABLE STRICT
AS $$
DECLARE
result BYTEA := state;
i INTEGER;
len INTEGER := octet_length(value);
BEGIN
IF octet_length(state) != len THEN
RAISE EXCEPTION 'Inputs must be equal length (% != %)', octet_length(state), len;
END IF;
FOR i IN 0..len-1 LOOP
result := set_byte(result, i, get_byte(state, i) # get_byte(value, i));
END LOOP;
RETURN result;
END;
$$;
CREATE AGGREGATE smp_server.xor_aggregate(bytea) (
SFUNC = smp_server.xor_combine,
STYPE = bytea,
INITCOND = '\x00000000000000000000000000000000'
);
SET default_table_access_method = heap;
@@ -320,7 +450,9 @@ CREATE TABLE smp_server.services (
service_role text NOT NULL,
service_cert bytea NOT NULL,
service_cert_hash bytea NOT NULL,
created_at bigint NOT NULL
created_at bigint NOT NULL,
queue_count bigint DEFAULT 0 NOT NULL,
queue_ids_hash bytea DEFAULT '\x00000000000000000000000000000000'::bytea NOT NULL
);
@@ -390,6 +522,18 @@ CREATE INDEX idx_services_service_role ON smp_server.services USING btree (servi
CREATE TRIGGER tr_queue_delete AFTER DELETE ON smp_server.msg_queues FOR EACH ROW EXECUTE FUNCTION smp_server.on_queue_delete();
CREATE TRIGGER tr_queue_insert AFTER INSERT ON smp_server.msg_queues FOR EACH ROW EXECUTE FUNCTION smp_server.on_queue_insert();
CREATE TRIGGER tr_queue_update AFTER UPDATE ON smp_server.msg_queues FOR EACH ROW EXECUTE FUNCTION smp_server.on_queue_update();
ALTER TABLE ONLY smp_server.messages
ADD CONSTRAINT messages_recipient_id_fkey FOREIGN KEY (recipient_id) REFERENCES smp_server.msg_queues(recipient_id) ON UPDATE RESTRICT ON DELETE CASCADE;
+26 -18
View File
@@ -28,6 +28,7 @@ where
import qualified Control.Exception as E
import Control.Logger.Simple
import Control.Monad
import Data.Bifunctor (first)
import Data.Bitraversable (bimapM)
import Data.Functor (($>))
import Data.Int (Int64)
@@ -62,8 +63,8 @@ data STMQueueStore q = STMQueueStore
data STMService = STMService
{ serviceRec :: ServiceRec,
serviceRcvQueues :: TVar (Set RecipientId),
serviceNtfQueues :: TVar (Set NotifierId)
serviceRcvQueues :: TVar (Set RecipientId, IdsHash), -- TODO [certs rcv] get/maintain hash
serviceNtfQueues :: TVar (Set NotifierId, IdsHash) -- TODO [certs rcv] get/maintain hash
}
setStoreLog :: STMQueueStore q -> StoreLog 'WriteMode -> IO ()
@@ -113,7 +114,7 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where
}
where
serviceCount role = M.foldl' (\ !n s -> if serviceRole (serviceRec s) == role then n + 1 else n) 0
serviceQueuesCount serviceSel = foldM (\n s -> (n +) . S.size <$> readTVarIO (serviceSel s)) 0
serviceQueuesCount serviceSel = foldM (\n s -> (n +) . S.size . fst <$> readTVarIO (serviceSel s)) 0
addQueue_ :: STMQueueStore q -> (RecipientId -> QueueRec -> IO q) -> RecipientId -> QueueRec -> IO (Either ErrorType q)
addQueue_ st mkQ rId qr@QueueRec {senderId = sId, notifier, queueData, rcvServiceId} = do
@@ -304,8 +305,8 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where
TM.insert fp newSrvId serviceCerts
pure $ Right (newSrvId, True)
newSTMService = do
serviceRcvQueues <- newTVar S.empty
serviceNtfQueues <- newTVar S.empty
serviceRcvQueues <- newTVar (S.empty, mempty)
serviceNtfQueues <- newTVar (S.empty, mempty)
pure STMService {serviceRec = sr, serviceRcvQueues, serviceNtfQueues}
setQueueService :: (PartyI p, ServiceParty p) => STMQueueStore q -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ())
@@ -331,7 +332,7 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where
let !q' = Just q {notifier = Just nc {ntfServiceId = serviceId}}
updateServiceQueues serviceNtfQueues nId prevNtfSrvId
writeTVar qr q' $> Right ()
updateServiceQueues :: (STMService -> TVar (Set QueueId)) -> QueueId -> Maybe ServiceId -> STM ()
updateServiceQueues :: (STMService -> TVar (Set QueueId, IdsHash)) -> QueueId -> Maybe ServiceId -> STM ()
updateServiceQueues serviceSel qId prevSrvId = do
mapM_ (removeServiceQueue st serviceSel qId) prevSrvId
mapM_ (addServiceQueue st serviceSel qId) serviceId
@@ -346,16 +347,16 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where
pure $ Right (ssNtfs', deleteNtfs)
where
addService (ssNtfs, ntfs') (serviceId, s) = do
snIds <- readTVarIO $ serviceNtfQueues s
(snIds, _) <- readTVarIO $ serviceNtfQueues s
let (sNtfs, restNtfs) = partition (\(nId, _) -> S.member nId snIds) ntfs'
pure ((Just serviceId, sNtfs) : ssNtfs, restNtfs)
getServiceQueueCount :: (PartyI p, ServiceParty p) => STMQueueStore q -> SParty p -> ServiceId -> IO (Either ErrorType Int64)
getServiceQueueCount st party serviceId =
getServiceQueueCountHash :: (PartyI p, ServiceParty p) => STMQueueStore q -> SParty p -> ServiceId -> IO (Either ErrorType (Int64, IdsHash))
getServiceQueueCountHash st party serviceId =
TM.lookupIO serviceId (services st) >>=
maybe (pure $ Left AUTH) (fmap (Right . fromIntegral . S.size) . readTVarIO . serviceSel)
maybe (pure $ Left AUTH) (fmap (Right . first (fromIntegral . S.size)) . readTVarIO . serviceSel)
where
serviceSel :: STMService -> TVar (Set QueueId)
serviceSel :: STMService -> TVar (Set QueueId, IdsHash)
serviceSel = case party of
SRecipientService -> serviceRcvQueues
SNotifierService -> serviceNtfQueues
@@ -366,7 +367,7 @@ foldRcvServiceQueues st serviceId f acc =
Nothing -> pure acc
Just s ->
readTVarIO (serviceRcvQueues s)
>>= foldM (\a -> get >=> maybe (pure a) (f a)) acc
>>= foldM (\a -> get >=> maybe (pure a) (f a)) acc . fst
where
get rId = TM.lookupIO rId (queues st) $>>= \q -> (q,) <$$> readTVarIO (queueRec q)
@@ -379,16 +380,23 @@ setStatus qr status =
Just q -> (Right (), Just q {status})
Nothing -> (Left AUTH, Nothing)
addServiceQueue :: STMQueueStore q -> (STMService -> TVar (Set QueueId)) -> QueueId -> ServiceId -> STM ()
addServiceQueue st serviceSel qId serviceId =
TM.lookup serviceId (services st) >>= mapM_ (\s -> modifyTVar' (serviceSel s) (S.insert qId))
addServiceQueue :: STMQueueStore q -> (STMService -> TVar (Set QueueId, IdsHash)) -> QueueId -> ServiceId -> STM ()
addServiceQueue = setServiceQueues_ S.insert
{-# INLINE addServiceQueue #-}
removeServiceQueue :: STMQueueStore q -> (STMService -> TVar (Set QueueId)) -> QueueId -> ServiceId -> STM ()
removeServiceQueue st serviceSel qId serviceId =
TM.lookup serviceId (services st) >>= mapM_ (\s -> modifyTVar' (serviceSel s) (S.delete qId))
removeServiceQueue :: STMQueueStore q -> (STMService -> TVar (Set QueueId, IdsHash)) -> QueueId -> ServiceId -> STM ()
removeServiceQueue = setServiceQueues_ S.delete
{-# INLINE removeServiceQueue #-}
setServiceQueues_ :: (QueueId -> Set QueueId -> Set QueueId) -> STMQueueStore q -> (STMService -> TVar (Set QueueId, IdsHash)) -> QueueId -> ServiceId -> STM ()
setServiceQueues_ updateSet st serviceSel qId serviceId =
TM.lookup serviceId (services st) >>= mapM_ (\v -> modifyTVar' (serviceSel v) update)
where
update (s, idsHash) =
let !s' = updateSet qId s
!idsHash' = queueIdHash qId <> idsHash
in (s', idsHash')
removeNotifier :: STMQueueStore q -> NtfCreds -> STM ()
removeNotifier st NtfCreds {notifierId = nId, ntfServiceId} = do
TM.delete nId $ notifiers st
@@ -47,7 +47,7 @@ class StoreQueueClass q => QueueStoreClass q s where
getCreateService :: s -> ServiceRec -> IO (Either ErrorType ServiceId)
setQueueService :: (PartyI p, ServiceParty p) => s -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ())
getQueueNtfServices :: s -> [(NotifierId, a)] -> IO (Either ErrorType ([(Maybe ServiceId, [(NotifierId, a)])], [(NotifierId, a)]))
getServiceQueueCount :: (PartyI p, ServiceParty p) => s -> SParty p -> ServiceId -> IO (Either ErrorType Int64)
getServiceQueueCountHash :: (PartyI p, ServiceParty p) => s -> SParty p -> ServiceId -> IO (Either ErrorType (Int64, IdsHash))
data EntityCounts = EntityCounts
{ queueCount :: Int,
+75 -7
View File
@@ -821,7 +821,15 @@ data ServiceStats = ServiceStats
srvSubCount :: IORef Int,
srvSubDuplicate :: IORef Int,
srvSubQueues :: IORef Int,
srvSubEnd :: IORef Int
srvSubEnd :: IORef Int,
-- counts of subscriptions
srvSubOk :: IORef Int, -- server has the same queues as expected
srvSubMore :: IORef Int, -- server has more queues than expected
srvSubFewer :: IORef Int, -- server has fewer queues than expected
srvSubDiff :: IORef Int, -- server has the same count, but different queues than expected (based on xor-hash)
-- adds actual deviations
srvSubMoreTotal :: IORef Int, -- server has more queues than expected, adds diff
srvSubFewerTotal :: IORef Int
}
data ServiceStatsData = ServiceStatsData
@@ -832,7 +840,13 @@ data ServiceStatsData = ServiceStatsData
_srvSubCount :: Int,
_srvSubDuplicate :: Int,
_srvSubQueues :: Int,
_srvSubEnd :: Int
_srvSubEnd :: Int,
_srvSubOk :: Int,
_srvSubMore :: Int,
_srvSubFewer :: Int,
_srvSubDiff :: Int,
_srvSubMoreTotal :: Int,
_srvSubFewerTotal :: Int
}
deriving (Show)
@@ -846,7 +860,13 @@ newServiceStatsData =
_srvSubCount = 0,
_srvSubDuplicate = 0,
_srvSubQueues = 0,
_srvSubEnd = 0
_srvSubEnd = 0,
_srvSubOk = 0,
_srvSubMore = 0,
_srvSubFewer = 0,
_srvSubDiff = 0,
_srvSubMoreTotal = 0,
_srvSubFewerTotal = 0
}
newServiceStats :: IO ServiceStats
@@ -859,6 +879,12 @@ newServiceStats = do
srvSubDuplicate <- newIORef 0
srvSubQueues <- newIORef 0
srvSubEnd <- newIORef 0
srvSubOk <- newIORef 0
srvSubMore <- newIORef 0
srvSubFewer <- newIORef 0
srvSubDiff <- newIORef 0
srvSubMoreTotal <- newIORef 0
srvSubFewerTotal <- newIORef 0
pure
ServiceStats
{ srvAssocNew,
@@ -868,7 +894,13 @@ newServiceStats = do
srvSubCount,
srvSubDuplicate,
srvSubQueues,
srvSubEnd
srvSubEnd,
srvSubOk,
srvSubMore,
srvSubFewer,
srvSubDiff,
srvSubMoreTotal,
srvSubFewerTotal
}
getServiceStatsData :: ServiceStats -> IO ServiceStatsData
@@ -881,6 +913,12 @@ getServiceStatsData s = do
_srvSubDuplicate <- readIORef $ srvSubDuplicate s
_srvSubQueues <- readIORef $ srvSubQueues s
_srvSubEnd <- readIORef $ srvSubEnd s
_srvSubOk <- readIORef $ srvSubOk s
_srvSubMore <- readIORef $ srvSubMore s
_srvSubFewer <- readIORef $ srvSubFewer s
_srvSubDiff <- readIORef $ srvSubDiff s
_srvSubMoreTotal <- readIORef $ srvSubMoreTotal s
_srvSubFewerTotal <- readIORef $ srvSubFewerTotal s
pure
ServiceStatsData
{ _srvAssocNew,
@@ -890,7 +928,13 @@ getServiceStatsData s = do
_srvSubCount,
_srvSubDuplicate,
_srvSubQueues,
_srvSubEnd
_srvSubEnd,
_srvSubOk,
_srvSubMore,
_srvSubFewer,
_srvSubDiff,
_srvSubMoreTotal,
_srvSubFewerTotal
}
getResetServiceStatsData :: ServiceStats -> IO ServiceStatsData
@@ -903,6 +947,12 @@ getResetServiceStatsData s = do
_srvSubDuplicate <- atomicSwapIORef (srvSubDuplicate s) 0
_srvSubQueues <- atomicSwapIORef (srvSubQueues s) 0
_srvSubEnd <- atomicSwapIORef (srvSubEnd s) 0
_srvSubOk <- atomicSwapIORef (srvSubOk s) 0
_srvSubMore <- atomicSwapIORef (srvSubMore s) 0
_srvSubFewer <- atomicSwapIORef (srvSubFewer s) 0
_srvSubDiff <- atomicSwapIORef (srvSubDiff s) 0
_srvSubMoreTotal <- atomicSwapIORef (srvSubMoreTotal s) 0
_srvSubFewerTotal <- atomicSwapIORef (srvSubFewerTotal s) 0
pure
ServiceStatsData
{ _srvAssocNew,
@@ -912,7 +962,13 @@ getResetServiceStatsData s = do
_srvSubCount,
_srvSubDuplicate,
_srvSubQueues,
_srvSubEnd
_srvSubEnd,
_srvSubOk,
_srvSubMore,
_srvSubFewer,
_srvSubDiff,
_srvSubMoreTotal,
_srvSubFewerTotal
}
-- this function is not thread safe, it is used on server start only
@@ -926,6 +982,12 @@ setServiceStats s d = do
writeIORef (srvSubDuplicate s) $! _srvSubDuplicate d
writeIORef (srvSubQueues s) $! _srvSubQueues d
writeIORef (srvSubEnd s) $! _srvSubEnd d
writeIORef (srvSubOk s) $! _srvSubOk d
writeIORef (srvSubMore s) $! _srvSubMore d
writeIORef (srvSubFewer s) $! _srvSubFewer d
writeIORef (srvSubDiff s) $! _srvSubDiff d
writeIORef (srvSubMoreTotal s) $! _srvSubMoreTotal d
writeIORef (srvSubFewerTotal s) $! _srvSubFewerTotal d
instance StrEncoding ServiceStatsData where
strEncode ServiceStatsData {_srvAssocNew, _srvAssocDuplicate, _srvAssocUpdated, _srvAssocRemoved, _srvSubCount, _srvSubDuplicate, _srvSubQueues, _srvSubEnd} =
@@ -963,7 +1025,13 @@ instance StrEncoding ServiceStatsData where
_srvSubCount,
_srvSubDuplicate,
_srvSubQueues,
_srvSubEnd
_srvSubEnd,
_srvSubOk = 0,
_srvSubMore = 0,
_srvSubFewer = 0,
_srvSubDiff = 0,
_srvSubMoreTotal = 0,
_srvSubFewerTotal = 0
}
data TimeBuckets = TimeBuckets
@@ -61,7 +61,7 @@ readQueueStore tty mkQ f st = readLogLines tty f $ \_ -> processLine
Left e -> logError $ errPfx <> tshow e
where
errPfx = "STORE: getCreateService, stored service " <> decodeLatin1 (strEncode serviceId) <> ", "
QueueService rId (ASP party) serviceId -> withQueue rId "QueueService" $ \q -> setQueueService st q party serviceId
QueueService qId (ASP party) serviceId -> withQueue qId "QueueService" $ \q -> setQueueService st q party serviceId
printError :: String -> IO ()
printError e = B.putStrLn $ "Error parsing log: " <> B.pack e <> " - " <> s
withQueue :: forall a. RecipientId -> T.Text -> (q -> IO (Either ErrorType a)) -> IO ()