mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-05-24 16:55:24 +00:00
servers: maintain xor-hash of all associated queue IDs in PostgreSQL (#1668)
* servers: maintain xor-hash of all associated queue IDs in PostgreSQL (#1615)
* ntf server: maintain xor-hash of all associated queue IDs via PostgreSQL triggers
* smp server: xor hash with triggers
* fix sql and using pgcrypto extension in tests
* track counts and hashes in smp/ntf servers via triggers, smp server stats for service subscription, update SMP protocol to pass expected count and hash in SSUB/NSSUB commands
* agent migrations with functions/triggers
* remove agent triggers
* try tracking service subs in the agent (WIP, does not compile)
* Revert "try tracking service subs in the agent (WIP, does not compile)"
This reverts commit 59e908100d.
* comment
* agent database triggers
* service subscriptions in the client
* test / fix client services
* update schema
* fix postgres migration
* update schema
* move schema test to the end
* use static function with SQLite to avoid dynamic wrapper
This commit is contained in:
@@ -211,7 +211,6 @@ import Simplex.Messaging.Protocol
|
||||
ErrorType (AUTH),
|
||||
MsgBody,
|
||||
MsgFlags (..),
|
||||
IdsHash,
|
||||
NtfServer,
|
||||
ProtoServerWithAuth (..),
|
||||
ProtocolServer (..),
|
||||
@@ -222,6 +221,7 @@ import Simplex.Messaging.Protocol
|
||||
SMPMsgMeta,
|
||||
SParty (..),
|
||||
SProtocolType (..),
|
||||
ServiceSub (..),
|
||||
SndPublicAuthKey,
|
||||
SubscriptionMode (..),
|
||||
UserProtocol,
|
||||
@@ -500,7 +500,7 @@ resubscribeConnections :: AgentClient -> [ConnId] -> AE (Map ConnId (Either Agen
|
||||
resubscribeConnections c = withAgentEnv c . resubscribeConnections' c
|
||||
{-# INLINE resubscribeConnections #-}
|
||||
|
||||
subscribeClientServices :: AgentClient -> UserId -> AE (Map SMPServer (Either AgentErrorType (Int64, IdsHash)))
|
||||
subscribeClientServices :: AgentClient -> UserId -> AE (Map SMPServer (Either AgentErrorType ServiceSub))
|
||||
subscribeClientServices c = withAgentEnv c . subscribeClientServices' c
|
||||
{-# INLINE subscribeClientServices #-}
|
||||
|
||||
@@ -594,6 +594,7 @@ testProtocolServer c nm userId srv = withAgentEnv' c $ case protocolTypeI @p of
|
||||
SPNTF -> runNTFServerTest c nm userId srv
|
||||
|
||||
-- | set SOCKS5 proxy on/off and optionally set TCP timeouts for fast network
|
||||
-- TODO [certs rcv] should fail if any user is enabled to use services and per-connection isolation is chosen
|
||||
setNetworkConfig :: AgentClient -> NetworkConfig -> IO ()
|
||||
setNetworkConfig c@AgentClient {useNetworkConfig, proxySessTs} cfg' = do
|
||||
ts <- getCurrentTime
|
||||
@@ -771,6 +772,7 @@ deleteUser' c@AgentClient {smpServersStats, xftpServersStats} userId delSMPQueue
|
||||
whenM (withStore' c (`deleteUserWithoutConns` userId)) . atomically $
|
||||
writeTBQueue (subQ c) ("", "", AEvt SAENone $ DEL_USER userId)
|
||||
|
||||
-- TODO [certs rcv] should fail enabling if per-connection isolation is set
|
||||
setUserService' :: AgentClient -> UserId -> Bool -> AM ()
|
||||
setUserService' c userId enable = do
|
||||
wasEnabled <- liftIO $ fromMaybe False <$> TM.lookupIO userId (useClientServices c)
|
||||
@@ -1507,15 +1509,15 @@ resubscribeConnections' c connIds = do
|
||||
[] -> pure True
|
||||
rqs' -> anyM $ map (atomically . hasActiveSubscription c) rqs'
|
||||
|
||||
-- TODO [certs rcv] compare hash with lock
|
||||
subscribeClientServices' :: AgentClient -> UserId -> AM (Map SMPServer (Either AgentErrorType (Int64, IdsHash)))
|
||||
-- TODO [certs rcv] compare hash. possibly, it should return both expected and returned counts
|
||||
subscribeClientServices' :: AgentClient -> UserId -> AM (Map SMPServer (Either AgentErrorType ServiceSub))
|
||||
subscribeClientServices' c userId =
|
||||
ifM useService subscribe $ throwError $ CMD PROHIBITED "no user service allowed"
|
||||
where
|
||||
useService = liftIO $ (Just True ==) <$> TM.lookupIO userId (useClientServices c)
|
||||
subscribe = do
|
||||
srvs <- withStore' c (`getClientServiceServers` userId)
|
||||
lift $ M.fromList . zip srvs <$> mapConcurrently (tryAllErrors' . subscribeClientService c userId) srvs
|
||||
lift $ M.fromList <$> mapConcurrently (\(srv, ServiceSub _ n idsHash) -> fmap (srv,) $ tryAllErrors' $ subscribeClientService c userId srv n idsHash) srvs
|
||||
|
||||
-- requesting messages sequentially, to reduce memory usage
|
||||
getConnectionMessages' :: AgentClient -> NonEmpty ConnMsgReq -> AM' (NonEmpty (Either AgentErrorType (Maybe SMPMsgMeta)))
|
||||
@@ -2829,12 +2831,13 @@ processSMPTransmissions :: AgentClient -> ServerTransmissionBatch SMPVersion Err
|
||||
processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId, ts) = do
|
||||
upConnIds <- newTVarIO []
|
||||
forM_ ts $ \(entId, t) -> case t of
|
||||
STEvent msgOrErr ->
|
||||
withRcvConn entId $ \rq@RcvQueue {connId} conn -> case msgOrErr of
|
||||
Right msg -> runProcessSMP rq conn (toConnData conn) msg
|
||||
Left e -> lift $ do
|
||||
processClientNotice rq e
|
||||
notifyErr connId e
|
||||
STEvent msgOrErr
|
||||
| entId == SMP.NoEntity -> pure () -- TODO [certs rcv] process SALL
|
||||
| otherwise -> withRcvConn entId $ \rq@RcvQueue {connId} conn -> case msgOrErr of
|
||||
Right msg -> runProcessSMP rq conn (toConnData conn) msg
|
||||
Left e -> lift $ do
|
||||
processClientNotice rq e
|
||||
notifyErr connId e
|
||||
STResponse (Cmd SRecipient cmd) respOrErr ->
|
||||
withRcvConn entId $ \rq conn -> case cmd of
|
||||
SMP.SUB -> case respOrErr of
|
||||
@@ -2870,7 +2873,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId
|
||||
processSubOk :: RcvQueue -> TVar [ConnId] -> IO ()
|
||||
processSubOk rq@RcvQueue {connId} upConnIds =
|
||||
atomically . whenM (isPendingSub rq) $ do
|
||||
SS.addActiveSub tSess sessId (rcvQueueSub rq) $ currentSubs c
|
||||
SS.addActiveSub tSess sessId rq $ currentSubs c
|
||||
modifyTVar' upConnIds (connId :)
|
||||
processSubErr :: RcvQueue -> SMPClientError -> AM' ()
|
||||
processSubErr rq@RcvQueue {connId} e = do
|
||||
|
||||
@@ -241,7 +241,7 @@ import Simplex.Messaging.Agent.RetryInterval
|
||||
import Simplex.Messaging.Agent.Stats
|
||||
import Simplex.Messaging.Agent.Store
|
||||
import Simplex.Messaging.Agent.Store.AgentStore
|
||||
import Simplex.Messaging.Agent.Store.Common (DBStore, withTransaction)
|
||||
import Simplex.Messaging.Agent.Store.Common (DBStore)
|
||||
import qualified Simplex.Messaging.Agent.Store.DB as DB
|
||||
import Simplex.Messaging.Agent.Store.Entity
|
||||
import Simplex.Messaging.Agent.TSessionSubs (TSessionSubs)
|
||||
@@ -279,6 +279,7 @@ import Simplex.Messaging.Protocol
|
||||
RcvNtfPublicDhKey,
|
||||
SMPMsgMeta (..),
|
||||
SProtocolType (..),
|
||||
ServiceSub (..),
|
||||
SndPublicAuthKey,
|
||||
SubscriptionMode (..),
|
||||
NewNtfCreds (..),
|
||||
@@ -499,6 +500,7 @@ data UserNetworkType = UNNone | UNCellular | UNWifi | UNEthernet | UNOther
|
||||
deriving (Eq, Show)
|
||||
|
||||
-- | Creates an SMP agent client instance that receives commands and sends responses via 'TBQueue's.
|
||||
-- TODO [certs rcv] should fail if both per-connection isolation is set and any users use services
|
||||
newAgentClient :: Int -> InitialAgentServers -> UTCTime -> Map (Maybe SMPServer) (Maybe SystemSeconds) -> Env -> IO AgentClient
|
||||
newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg, useServices, presetDomains, presetServers} currentTs notices agentEnv = do
|
||||
let cfg = config agentEnv
|
||||
@@ -622,9 +624,8 @@ getServiceCredentials c userId srv =
|
||||
let tlsCreds = tlsCredentials [cred]
|
||||
createClientService db userId srv tlsCreds
|
||||
pure (tlsCreds, Nothing)
|
||||
(_, pk) <- atomically $ C.generateKeyPair g
|
||||
let serviceSignKey = C.APrivateSignKey C.SEd25519 pk
|
||||
creds = ServiceCredentials {serviceRole = SRMessaging, serviceCreds, serviceCertHash = XV.Fingerprint kh, serviceSignKey}
|
||||
serviceSignKey <- liftEitherWith INTERNAL $ C.x509ToPrivate' $ snd serviceCreds
|
||||
let creds = ServiceCredentials {serviceRole = SRMessaging, serviceCreds, serviceCertHash = XV.Fingerprint kh, serviceSignKey}
|
||||
pure (creds, serviceId_)
|
||||
|
||||
class (Encoding err, Show err) => ProtocolServerClient v err msg | msg -> v, msg -> err where
|
||||
@@ -744,9 +745,11 @@ smpConnectClient c@AgentClient {smpClients, msgQ, proxySessTs, presetDomains} nm
|
||||
smp <- liftError (protocolClientError SMP $ B.unpack $ strEncode srv) $ do
|
||||
ts <- readTVarIO proxySessTs
|
||||
ExceptT $ getProtocolClient g nm tSess cfg' presetDomains (Just msgQ) ts $ smpClientDisconnected c tSess env v' prs
|
||||
-- TODO [certs rcv] add service to SS, possibly combine with SS.setSessionId
|
||||
atomically $ SS.setSessionId tSess (sessionId $ thParams smp) $ currentSubs c
|
||||
updateClientService service smp
|
||||
pure SMPConnectedClient {connectedClient = smp, proxiedRelays = prs}
|
||||
-- TODO [certs rcv] this should differentiate between service ID just set and service ID changed, and in the latter case disassociate the queue
|
||||
updateClientService service smp = case (service, smpClientService smp) of
|
||||
(Just (_, serviceId_), Just THClientService {serviceId})
|
||||
| serviceId_ /= Just serviceId -> withStore' c $ \db -> setClientServiceId db userId srv serviceId
|
||||
@@ -763,32 +766,34 @@ smpClientDisconnected c@AgentClient {active, smpClients, smpProxiedRelays} tSess
|
||||
-- we make active subscriptions pending only if the client for tSess was current (in the map) and active,
|
||||
-- because we can have a race condition when a new current client could have already
|
||||
-- made subscriptions active, and the old client would be processing diconnection later.
|
||||
removeClientAndSubs :: IO ([RcvQueueSub], [ConnId])
|
||||
removeClientAndSubs :: IO ([RcvQueueSub], [ConnId], Maybe ServiceSub)
|
||||
removeClientAndSubs = atomically $ do
|
||||
removeSessVar v tSess smpClients
|
||||
ifM (readTVar active) removeSubs (pure ([], []))
|
||||
ifM (readTVar active) removeSubs (pure ([], [], Nothing))
|
||||
where
|
||||
sessId = sessionId $ thParams client
|
||||
removeSubs = do
|
||||
mode <- getSessionMode c
|
||||
subs <- SS.setSubsPending mode tSess sessId $ currentSubs c
|
||||
(subs, serviceSub_) <- SS.setSubsPending mode tSess sessId $ currentSubs c
|
||||
let qs = M.elems subs
|
||||
cs = nubOrd $ map qConnId qs
|
||||
-- this removes proxied relays that this client created sessions to
|
||||
destSrvs <- M.keys <$> readTVar prs
|
||||
forM_ destSrvs $ \destSrv -> TM.delete (userId, destSrv, cId) smpProxiedRelays
|
||||
pure (qs, cs)
|
||||
pure (qs, cs, serviceSub_)
|
||||
|
||||
serverDown :: ([RcvQueueSub], [ConnId]) -> IO ()
|
||||
serverDown (qs, conns) = whenM (readTVarIO active) $ do
|
||||
serverDown :: ([RcvQueueSub], [ConnId], Maybe ServiceSub) -> IO ()
|
||||
serverDown (qs, conns, serviceSub_) = whenM (readTVarIO active) $ do
|
||||
notifySub c $ hostEvent' DISCONNECT client
|
||||
unless (null conns) $ notifySub c $ DOWN srv conns
|
||||
unless (null qs) $ do
|
||||
unless (null qs && isNothing serviceSub_) $ do
|
||||
releaseGetLocksIO c qs
|
||||
mode <- getSessionModeIO c
|
||||
let resubscribe
|
||||
| (mode == TSMEntity) == isJust cId = resubscribeSMPSession c tSess
|
||||
| otherwise = void $ subscribeQueues c True qs
|
||||
| otherwise = do
|
||||
mapM_ (runExceptT . resubscribeClientService c tSess) serviceSub_
|
||||
unless (null qs) $ void $ subscribeQueues c True qs
|
||||
runReaderT resubscribe env
|
||||
|
||||
resubscribeSMPSession :: AgentClient -> SMPTransportSession -> AM' ()
|
||||
@@ -807,11 +812,12 @@ resubscribeSMPSession c@AgentClient {smpSubWorkers, workerSeq} tSess = do
|
||||
runSubWorker = do
|
||||
ri <- asks $ reconnectInterval . config
|
||||
withRetryForeground ri isForeground (isNetworkOnline c) $ \_ loop -> do
|
||||
pending <- atomically $ SS.getPendingSubs tSess $ currentSubs c
|
||||
unless (M.null pending) $ do
|
||||
(pendingSubs, pendingSS) <- atomically $ SS.getPendingSubs tSess $ currentSubs c
|
||||
unless (M.null pendingSubs && isNothing pendingSS) $ do
|
||||
liftIO $ waitUntilForeground c
|
||||
liftIO $ waitForUserNetwork c
|
||||
handleNotify $ resubscribeSessQueues c tSess $ M.elems pending
|
||||
mapM_ (handleNotify . void . runExceptT . resubscribeClientService c tSess) pendingSS
|
||||
unless (M.null pendingSubs) $ handleNotify $ resubscribeSessQueues c tSess $ M.elems pendingSubs
|
||||
loop
|
||||
isForeground = (ASForeground ==) <$> readTVar (agentState c)
|
||||
cleanup :: SessionVar (Async ()) -> STM ()
|
||||
@@ -1508,25 +1514,25 @@ newRcvQueue_ c nm userId connId (ProtoServerWithAuth srv auth) vRange cqrd enabl
|
||||
newErr :: String -> AM (Maybe ShortLinkCreds)
|
||||
newErr = throwE . BROKER (B.unpack $ strEncode srv) . UNEXPECTED . ("Create queue: " <>)
|
||||
|
||||
processSubResults :: AgentClient -> SMPTransportSession -> SessionId -> NonEmpty (RcvQueueSub, Either SMPClientError (Maybe ServiceId)) -> STM [(RcvQueueSub, Maybe ClientNotice)]
|
||||
processSubResults c tSess@(userId, srv, _) sessId rs = do
|
||||
pendingSubs <- SS.getPendingSubs tSess $ currentSubs c
|
||||
let (failed, subscribed, notices, ignored) = foldr (partitionResults pendingSubs) (M.empty, [], [], 0) rs
|
||||
processSubResults :: AgentClient -> SMPTransportSession -> SessionId -> Maybe ServiceId -> NonEmpty (RcvQueueSub, Either SMPClientError (Maybe ServiceId)) -> STM ([RcvQueueSub], [(RcvQueueSub, Maybe ClientNotice)])
|
||||
processSubResults c tSess@(userId, srv, _) sessId smpServiceId rs = do
|
||||
pending <- SS.getPendingSubs tSess $ currentSubs c
|
||||
let (failed, subscribed@(qs, sQs), notices, ignored) = foldr (partitionResults pending) (M.empty, ([], []), [], 0) rs
|
||||
unless (M.null failed) $ do
|
||||
incSMPServerStat' c userId srv connSubErrs $ M.size failed
|
||||
failSubscriptions c tSess failed
|
||||
unless (null subscribed) $ do
|
||||
incSMPServerStat' c userId srv connSubscribed $ length subscribed
|
||||
unless (null qs && null sQs) $ do
|
||||
incSMPServerStat' c userId srv connSubscribed $ length qs + length sQs
|
||||
SS.batchAddActiveSubs tSess sessId subscribed $ currentSubs c
|
||||
unless (ignored == 0) $ incSMPServerStat' c userId srv connSubIgnored ignored
|
||||
pure notices
|
||||
pure (sQs, notices)
|
||||
where
|
||||
partitionResults ::
|
||||
Map SMP.RecipientId RcvQueueSub ->
|
||||
(Map SMP.RecipientId RcvQueueSub, Maybe ServiceSub) ->
|
||||
(RcvQueueSub, Either SMPClientError (Maybe ServiceId)) ->
|
||||
(Map SMP.RecipientId SMPClientError, [RcvQueueSub], [(RcvQueueSub, Maybe ClientNotice)], Int) ->
|
||||
(Map SMP.RecipientId SMPClientError, [RcvQueueSub], [(RcvQueueSub, Maybe ClientNotice)], Int)
|
||||
partitionResults pendingSubs (rq@RcvQueueSub {rcvId, clientNoticeId}, r) acc@(failed, subscribed, notices, ignored) = case r of
|
||||
(Map SMP.RecipientId SMPClientError, ([RcvQueueSub], [RcvQueueSub]), [(RcvQueueSub, Maybe ClientNotice)], Int) ->
|
||||
(Map SMP.RecipientId SMPClientError, ([RcvQueueSub], [RcvQueueSub]), [(RcvQueueSub, Maybe ClientNotice)], Int)
|
||||
partitionResults (pendingSubs, pendingSS) (rq@RcvQueueSub {rcvId, clientNoticeId}, r) acc@(failed, subscribed@(qs, sQs), notices, ignored) = case r of
|
||||
Left e -> case smpErrorClientNotice e of
|
||||
Just notice_ -> (failed', subscribed, (rq, notice_) : notices, ignored)
|
||||
where
|
||||
@@ -1536,8 +1542,12 @@ processSubResults c tSess@(userId, srv, _) sessId rs = do
|
||||
| otherwise -> (failed', subscribed, notices, ignored)
|
||||
where
|
||||
failed' = M.insert rcvId e failed
|
||||
Right _serviceId -- TODO [certs rcv] store association with the service
|
||||
| rcvId `M.member` pendingSubs -> (failed, rq : subscribed, notices', ignored)
|
||||
Right serviceId_
|
||||
| rcvId `M.member` pendingSubs ->
|
||||
let subscribed' = case (smpServiceId, serviceId_, pendingSS) of
|
||||
(Just sId, Just sId', Just ServiceSub {serviceId}) | sId == sId' && sId == serviceId -> (qs, rq : sQs)
|
||||
_ -> (rq : qs, sQs)
|
||||
in (failed, subscribed', notices', ignored)
|
||||
| otherwise -> (failed, subscribed, notices', ignored + 1)
|
||||
where
|
||||
notices' = if isJust clientNoticeId then (rq, Nothing) : notices else notices
|
||||
@@ -1576,6 +1586,7 @@ serverHostError = \case
|
||||
|
||||
-- | Batch by transport session and subscribe queues. The list of results can have a different order.
|
||||
subscribeQueues :: AgentClient -> Bool -> [RcvQueueSub] -> AM' [(RcvQueueSub, Either AgentErrorType (Maybe ServiceId))]
|
||||
subscribeQueues _ _ [] = pure []
|
||||
subscribeQueues c withEvents qs = do
|
||||
(errs, qs') <- checkQueues c qs
|
||||
atomically $ modifyTVar' (subscrConns c) (`S.union` S.fromList (map qConnId qs'))
|
||||
@@ -1632,6 +1643,7 @@ checkQueues c = fmap partitionEithers . mapM checkQueue
|
||||
-- This function expects that all queues belong to one transport session,
|
||||
-- and that they are already added to pending subscriptions.
|
||||
resubscribeSessQueues :: AgentClient -> SMPTransportSession -> [RcvQueueSub] -> AM' ()
|
||||
resubscribeSessQueues _ _ [] = pure ()
|
||||
resubscribeSessQueues c tSess qs = do
|
||||
(errs, qs_) <- checkQueues c qs
|
||||
forM_ (L.nonEmpty qs_) $ \qs' -> void $ subscribeSessQueues_ c True (tSess, qs')
|
||||
@@ -1650,13 +1662,15 @@ subscribeSessQueues_ c withEvents qs = sendClientBatch_ "SUB" False subscribe_ c
|
||||
then Just . S.fromList . map qConnId . M.elems <$> atomically (SS.getActiveSubs tSess $ currentSubs c)
|
||||
else pure Nothing
|
||||
active <- E.uninterruptibleMask_ $ do
|
||||
(active, notices) <- atomically $ do
|
||||
r@(_, notices) <- ifM
|
||||
(active, (serviceQs, notices)) <- atomically $ do
|
||||
r@(_, (_, notices)) <- ifM
|
||||
(activeClientSession c tSess sessId)
|
||||
((True,) <$> processSubResults c tSess sessId rs)
|
||||
((False, []) <$ incSMPServerStat' c userId srv connSubIgnored (length rs))
|
||||
((True,) <$> processSubResults c tSess sessId smpServiceId rs)
|
||||
((False, ([], [])) <$ incSMPServerStat' c userId srv connSubIgnored (length rs))
|
||||
unless (null notices) $ takeTMVar $ clientNoticesLock c
|
||||
pure r
|
||||
unless (null serviceQs) $ void $
|
||||
processRcvServiceAssocs c serviceQs `runReaderT` agentEnv c
|
||||
unless (null notices) $ void $
|
||||
(processClientNotices c tSess notices `runReaderT` agentEnv c)
|
||||
`E.finally` atomically (putTMVar (clientNoticesLock c) ())
|
||||
@@ -1677,6 +1691,13 @@ subscribeSessQueues_ c withEvents qs = sendClientBatch_ "SUB" False subscribe_ c
|
||||
where
|
||||
tSess = transportSession' smp
|
||||
sessId = sessionId $ thParams smp
|
||||
smpServiceId = (\THClientService {serviceId} -> serviceId) <$> smpClientService smp
|
||||
|
||||
processRcvServiceAssocs :: AgentClient -> [RcvQueueSub] -> AM' ()
|
||||
processRcvServiceAssocs c serviceQs =
|
||||
withStore' c (`setRcvServiceAssocs` serviceQs) `catchAllErrors'` \e -> do
|
||||
logError $ "processClientNotices error: " <> tshow e
|
||||
notifySub' c "" $ ERR e
|
||||
|
||||
processClientNotices :: AgentClient -> SMPTransportSession -> [(RcvQueueSub, Maybe ClientNotice)] -> AM' ()
|
||||
processClientNotices c@AgentClient {presetServers} tSess notices = do
|
||||
@@ -1689,10 +1710,35 @@ processClientNotices c@AgentClient {presetServers} tSess notices = do
|
||||
logError $ "processClientNotices error: " <> tshow e
|
||||
notifySub' c "" $ ERR e
|
||||
|
||||
subscribeClientService :: AgentClient -> UserId -> SMPServer -> AM (Int64, IdsHash)
|
||||
subscribeClientService c userId srv =
|
||||
withLogClient c NRMBackground (userId, srv, Nothing) B.empty "SUBS" $
|
||||
(`subscribeService` SMP.SRecipientService) . connectedClient
|
||||
resubscribeClientService :: AgentClient -> SMPTransportSession -> ServiceSub -> AM ServiceSub
|
||||
resubscribeClientService c tSess (ServiceSub _ n idsHash) =
|
||||
withServiceClient c tSess $ \smp _ -> do
|
||||
subscribeClientService_ c tSess smp n idsHash
|
||||
|
||||
subscribeClientService :: AgentClient -> UserId -> SMPServer -> Int64 -> IdsHash -> AM ServiceSub
|
||||
subscribeClientService c userId srv n idsHash =
|
||||
withServiceClient c tSess $ \smp smpServiceId -> do
|
||||
let serviceSub = ServiceSub smpServiceId n idsHash
|
||||
atomically $ SS.setPendingServiceSub tSess serviceSub $ currentSubs c
|
||||
subscribeClientService_ c tSess smp n idsHash
|
||||
where
|
||||
tSess = (userId, srv, Nothing)
|
||||
|
||||
withServiceClient :: AgentClient -> SMPTransportSession -> (SMPClient -> ServiceId -> ExceptT SMPClientError IO a) -> AM a
|
||||
withServiceClient c tSess action =
|
||||
withLogClient c NRMBackground tSess B.empty "SUBS" $ \(SMPConnectedClient smp _) ->
|
||||
case (\THClientService {serviceId} -> serviceId) <$> smpClientService smp of
|
||||
Just smpServiceId -> action smp smpServiceId
|
||||
Nothing -> throwE PCEServiceUnavailable
|
||||
|
||||
subscribeClientService_ :: AgentClient -> SMPTransportSession -> SMPClient -> Int64 -> IdsHash -> ExceptT SMPClientError IO ServiceSub
|
||||
subscribeClientService_ c tSess smp n idsHash = do
|
||||
-- TODO [certs rcv] handle error
|
||||
serviceSub' <- subscribeService smp SMP.SRecipientService n idsHash
|
||||
let sessId = sessionId $ thParams smp
|
||||
atomically $ whenM (activeClientSession c tSess sessId) $
|
||||
SS.setActiveServiceSub tSess sessId serviceSub' $ currentSubs c
|
||||
pure serviceSub'
|
||||
|
||||
activeClientSession :: AgentClient -> SMPTransportSession -> SessionId -> STM Bool
|
||||
activeClientSession c tSess sessId = sameSess <$> tryReadSessVar tSess (smpClients c)
|
||||
@@ -1762,7 +1808,7 @@ addNewQueueSubscription c rq' tSess sessId = do
|
||||
modifyTVar' (subscrConns c) $ S.insert $ qConnId rq
|
||||
active <- activeClientSession c tSess sessId
|
||||
if active
|
||||
then SS.addActiveSub tSess sessId rq $ currentSubs c
|
||||
then SS.addActiveSub tSess sessId rq' $ currentSubs c
|
||||
else SS.addPendingSub tSess rq $ currentSubs c
|
||||
pure active
|
||||
unless same $ resubscribeSMPSession c tSess
|
||||
@@ -1951,6 +1997,7 @@ releaseGetLock c rq =
|
||||
{-# INLINE releaseGetLock #-}
|
||||
|
||||
releaseGetLocksIO :: SomeRcvQueue q => AgentClient -> [q] -> IO ()
|
||||
releaseGetLocksIO _ [] = pure ()
|
||||
releaseGetLocksIO c rqs = do
|
||||
locks <- readTVarIO $ getMsgLocks c
|
||||
forM_ rqs $ \rq ->
|
||||
@@ -2301,7 +2348,8 @@ withStore c action = do
|
||||
[ E.Handler $ \(e :: SQL.SQLError) ->
|
||||
let se = SQL.sqlError e
|
||||
busy = se == SQL.ErrorBusy || se == SQL.ErrorLocked
|
||||
in pure . Left . (if busy then SEDatabaseBusy else SEInternal) $ bshow se,
|
||||
err = tshow se <> ": " <> SQL.sqlErrorDetails e <> ", " <> SQL.sqlErrorContext e
|
||||
in pure . Left . (if busy then SEDatabaseBusy else SEInternal) $ encodeUtf8 err,
|
||||
E.Handler $ \(E.SomeException e) -> pure . Left $ SEInternal $ bshow e
|
||||
]
|
||||
#endif
|
||||
|
||||
@@ -314,7 +314,7 @@ runNtfWorker c srv Worker {doWork} =
|
||||
_ -> ((ntfSubConnId sub, INTERNAL "NSACheck - no subscription ID") : errs, subs, subIds)
|
||||
updateSub :: DB.Connection -> NtfServer -> UTCTime -> UTCTime -> (NtfSubscription, NtfSubStatus) -> IO (Maybe SMPServer)
|
||||
updateSub db ntfServer ts nextCheckTs (sub, status)
|
||||
| ntfShouldSubscribe status =
|
||||
| status `elem` subscribeNtfStatuses =
|
||||
let sub' = sub {ntfSubStatus = NASCreated status}
|
||||
in Nothing <$ updateNtfSubscription db sub' (NSANtf NSACheck) nextCheckTs
|
||||
-- ntf server stopped subscribing to this queue
|
||||
|
||||
@@ -53,6 +53,7 @@ module Simplex.Messaging.Agent.Store.AgentStore
|
||||
getSubscriptionServers,
|
||||
getUserServerRcvQueueSubs,
|
||||
unsetQueuesToSubscribe,
|
||||
setRcvServiceAssocs,
|
||||
getConnIds,
|
||||
getConn,
|
||||
getDeletedConn,
|
||||
@@ -401,29 +402,31 @@ deleteUsersWithoutConns db = do
|
||||
pure userIds
|
||||
|
||||
createClientService :: DB.Connection -> UserId -> SMPServer -> (C.KeyHash, TLS.Credential) -> IO ()
|
||||
createClientService db userId srv (kh, (cert, pk)) =
|
||||
createClientService db userId srv (kh, (cert, pk)) = do
|
||||
serverKeyHash_ <- createServer_ db srv
|
||||
DB.execute
|
||||
db
|
||||
[sql|
|
||||
INSERT INTO client_services
|
||||
(user_id, host, port, service_cert_hash, service_cert, service_priv_key)
|
||||
VALUES (?,?,?,?,?,?)
|
||||
ON CONFLICT (user_id, host, port)
|
||||
(user_id, host, port, server_key_hash, service_cert_hash, service_cert, service_priv_key)
|
||||
VALUES (?,?,?,?,?,?,?)
|
||||
ON CONFLICT (user_id, host, port, server_key_hash)
|
||||
DO UPDATE SET
|
||||
service_cert_hash = EXCLUDED.service_cert_hash,
|
||||
service_cert = EXCLUDED.service_cert,
|
||||
service_priv_key = EXCLUDED.service_priv_key,
|
||||
rcv_service_id = NULL
|
||||
service_id = NULL
|
||||
|]
|
||||
(userId, host srv, port srv, kh, cert, pk)
|
||||
(userId, host srv, port srv, serverKeyHash_, kh, cert, pk)
|
||||
|
||||
-- TODO [certs rcv] get correct service based on key hash of the server
|
||||
getClientService :: DB.Connection -> UserId -> SMPServer -> IO (Maybe ((C.KeyHash, TLS.Credential), Maybe ServiceId))
|
||||
getClientService db userId srv =
|
||||
maybeFirstRow toService $
|
||||
DB.query
|
||||
db
|
||||
[sql|
|
||||
SELECT service_cert_hash, service_cert, service_priv_key, rcv_service_id
|
||||
SELECT service_cert_hash, service_cert, service_priv_key, service_id
|
||||
FROM client_services
|
||||
WHERE user_id = ? AND host = ? AND port = ?
|
||||
|]
|
||||
@@ -431,19 +434,21 @@ getClientService db userId srv =
|
||||
where
|
||||
toService (kh, cert, pk, serviceId_) = ((kh, (cert, pk)), serviceId_)
|
||||
|
||||
getClientServiceServers :: DB.Connection -> UserId -> IO [SMPServer]
|
||||
getClientServiceServers :: DB.Connection -> UserId -> IO [(SMPServer, ServiceSub)]
|
||||
getClientServiceServers db userId =
|
||||
map toServer
|
||||
<$> DB.query
|
||||
db
|
||||
[sql|
|
||||
SELECT c.host, c.port, s.key_hash
|
||||
SELECT c.host, c.port, s.key_hash, c.service_id, c.service_queue_count, c.service_queue_ids_hash
|
||||
FROM client_services c
|
||||
JOIN servers s ON s.host = c.host AND s.port = c.port
|
||||
WHERE c.user_id = ?
|
||||
|]
|
||||
(Only userId)
|
||||
where
|
||||
toServer (host, port, kh) = SMPServer host port kh
|
||||
toServer (host, port, kh, serviceId, n, Binary idsHash) =
|
||||
(SMPServer host port kh, ServiceSub serviceId n (IdsHash idsHash))
|
||||
|
||||
setClientServiceId :: DB.Connection -> UserId -> SMPServer -> ServiceId -> IO ()
|
||||
setClientServiceId db userId srv serviceId =
|
||||
@@ -451,7 +456,7 @@ setClientServiceId db userId srv serviceId =
|
||||
db
|
||||
[sql|
|
||||
UPDATE client_services
|
||||
SET rcv_service_id = ?
|
||||
SET service_id = ?
|
||||
WHERE user_id = ? AND host = ? AND port = ?
|
||||
|]
|
||||
(serviceId, userId, host srv, port srv)
|
||||
@@ -2099,7 +2104,7 @@ insertRcvQueue_ db connId' rq@RcvQueue {..} subMode serverKeyHash_ = do
|
||||
ntf_public_key, ntf_private_key, ntf_id, rcv_ntf_dh_secret
|
||||
) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?);
|
||||
|]
|
||||
( (host server, port server, rcvId, rcvServiceAssoc, connId', rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret)
|
||||
( (host server, port server, rcvId, BI rcvServiceAssoc, connId', rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret)
|
||||
:. (sndId, queueMode, status, BI toSubscribe, qId, BI primary, dbReplaceQueueId, smpClientVersion, serverKeyHash_)
|
||||
:. (shortLinkId <$> shortLink, shortLinkKey <$> shortLink, linkPrivSigKey <$> shortLink, linkEncFixedData <$> shortLink)
|
||||
:. ntfCredsFields
|
||||
@@ -2248,6 +2253,14 @@ getUserServerRcvQueueSubs db userId srv onlyNeeded =
|
||||
unsetQueuesToSubscribe :: DB.Connection -> IO ()
|
||||
unsetQueuesToSubscribe db = DB.execute_ db "UPDATE rcv_queues SET to_subscribe = 0 WHERE to_subscribe = 1"
|
||||
|
||||
setRcvServiceAssocs :: DB.Connection -> [RcvQueueSub] -> IO ()
|
||||
setRcvServiceAssocs db rqs =
|
||||
#if defined(dbPostgres)
|
||||
DB.execute db "UPDATE rcv_queues SET rcv_service_assoc = 1 WHERE rcv_id IN " $ Only $ In (map queueId rqs)
|
||||
#else
|
||||
DB.executeMany db "UPDATE rcv_queues SET rcv_service_assoc = 1 WHERE rcv_id = " $ map (Only . queueId) rqs
|
||||
#endif
|
||||
|
||||
-- * getConn helpers
|
||||
|
||||
getConnIds :: DB.Connection -> IO [ConnId]
|
||||
@@ -2468,13 +2481,13 @@ rcvQueueQuery =
|
||||
|
||||
toRcvQueue ::
|
||||
(UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SMP.RecipientId, SMP.RcvPrivateAuthKey, SMP.RcvDhSecret, C.PrivateKeyX25519, Maybe C.DhSecretX25519, SMP.SenderId, Maybe QueueMode)
|
||||
:. (QueueStatus, Maybe BoolInt, Maybe NoticeId, DBEntityId, BoolInt, Maybe Int64, Maybe RcvSwitchStatus, Maybe VersionSMPC, Int, ServiceAssoc)
|
||||
:. (QueueStatus, Maybe BoolInt, Maybe NoticeId, DBEntityId, BoolInt, Maybe Int64, Maybe RcvSwitchStatus, Maybe VersionSMPC, Int, BoolInt)
|
||||
:. (Maybe SMP.NtfPublicAuthKey, Maybe SMP.NtfPrivateAuthKey, Maybe SMP.NotifierId, Maybe RcvNtfDhSecret)
|
||||
:. (Maybe SMP.LinkId, Maybe LinkKey, Maybe C.PrivateKeyEd25519, Maybe EncDataBytes) ->
|
||||
RcvQueue
|
||||
toRcvQueue
|
||||
( (userId, keyHash, connId, host, port, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, queueMode)
|
||||
:. (status, enableNtfs_, clientNoticeId, dbQueueId, BI primary, dbReplaceQueueId, rcvSwchStatus, smpClientVersion_, deleteErrors, rcvServiceAssoc)
|
||||
:. (status, enableNtfs_, clientNoticeId, dbQueueId, BI primary, dbReplaceQueueId, rcvSwchStatus, smpClientVersion_, deleteErrors, BI rcvServiceAssoc)
|
||||
:. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_)
|
||||
:. (shortLinkId_, shortLinkKey_, linkPrivSigKey_, linkEncFixedData_)
|
||||
) =
|
||||
|
||||
@@ -10,6 +10,7 @@ import Simplex.Messaging.Agent.Store.Postgres.Migrations.M20250322_short_links
|
||||
import Simplex.Messaging.Agent.Store.Postgres.Migrations.M20250702_conn_invitations_remove_cascade_delete
|
||||
import Simplex.Messaging.Agent.Store.Postgres.Migrations.M20251009_queue_to_subscribe
|
||||
import Simplex.Messaging.Agent.Store.Postgres.Migrations.M20251010_client_notices
|
||||
import Simplex.Messaging.Agent.Store.Postgres.Migrations.M20251020_service_certs
|
||||
import Simplex.Messaging.Agent.Store.Shared (Migration (..))
|
||||
|
||||
schemaMigrations :: [(String, Text, Maybe Text)]
|
||||
@@ -19,7 +20,8 @@ schemaMigrations =
|
||||
("20250322_short_links", m20250322_short_links, Just down_m20250322_short_links),
|
||||
("20250702_conn_invitations_remove_cascade_delete", m20250702_conn_invitations_remove_cascade_delete, Just down_m20250702_conn_invitations_remove_cascade_delete),
|
||||
("20251009_queue_to_subscribe", m20251009_queue_to_subscribe, Just down_m20251009_queue_to_subscribe),
|
||||
("20251010_client_notices", m20251010_client_notices, Just down_m20251010_client_notices)
|
||||
("20251010_client_notices", m20251010_client_notices, Just down_m20251010_client_notices),
|
||||
("20251020_service_certs", m20251020_service_certs, Just down_m20251020_service_certs)
|
||||
]
|
||||
|
||||
-- | The list of migrations in ascending order by date
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE QuasiQuotes #-}
|
||||
|
||||
module Simplex.Messaging.Agent.Store.Postgres.Migrations.M20251020_service_certs where
|
||||
|
||||
import Data.Text (Text)
|
||||
import Simplex.Messaging.Agent.Store.Postgres.Migrations.Util
|
||||
import Text.RawString.QQ (r)
|
||||
|
||||
m20251020_service_certs :: Text
|
||||
m20251020_service_certs =
|
||||
createXorHashFuncs <> [r|
|
||||
CREATE TABLE client_services(
|
||||
user_id BIGINT NOT NULL REFERENCES users ON UPDATE RESTRICT ON DELETE CASCADE,
|
||||
host TEXT NOT NULL,
|
||||
port TEXT NOT NULL,
|
||||
server_key_hash BYTEA,
|
||||
service_cert BYTEA NOT NULL,
|
||||
service_cert_hash BYTEA NOT NULL,
|
||||
service_priv_key BYTEA NOT NULL,
|
||||
service_id BYTEA,
|
||||
service_queue_count BIGINT NOT NULL DEFAULT 0,
|
||||
service_queue_ids_hash BYTEA NOT NULL DEFAULT '\x00000000000000000000000000000000',
|
||||
FOREIGN KEY(host, port) REFERENCES servers ON DELETE RESTRICT
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX idx_server_certs_user_id_host_port ON client_services(user_id, host, port, server_key_hash);
|
||||
CREATE INDEX idx_server_certs_host_port ON client_services(host, port);
|
||||
|
||||
ALTER TABLE rcv_queues ADD COLUMN rcv_service_assoc SMALLINT NOT NULL DEFAULT 0;
|
||||
|
||||
CREATE FUNCTION update_aggregates(p_conn_id BYTEA, p_host TEXT, p_port TEXT, p_change BIGINT, p_rcv_id BYTEA) RETURNS VOID
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
DECLARE q_user_id BIGINT;
|
||||
BEGIN
|
||||
SELECT user_id INTO q_user_id FROM connections WHERE conn_id = p_conn_id;
|
||||
UPDATE client_services
|
||||
SET service_queue_count = service_queue_count + p_change,
|
||||
service_queue_ids_hash = xor_combine(service_queue_ids_hash, public.digest(p_rcv_id, 'md5'))
|
||||
WHERE user_id = q_user_id AND host = p_host AND port = p_port;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE FUNCTION on_rcv_queue_insert() RETURNS TRIGGER
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
IF NEW.rcv_service_assoc != 0 AND NEW.deleted = 0 THEN
|
||||
PERFORM update_aggregates(NEW.conn_id, NEW.host, NEW.port, 1, NEW.rcv_id);
|
||||
END IF;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE FUNCTION on_rcv_queue_delete() RETURNS TRIGGER
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
IF OLD.rcv_service_assoc != 0 AND OLD.deleted = 0 THEN
|
||||
PERFORM update_aggregates(OLD.conn_id, OLD.host, OLD.port, -1, OLD.rcv_id);
|
||||
END IF;
|
||||
RETURN OLD;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE FUNCTION on_rcv_queue_update() RETURNS TRIGGER
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
IF OLD.rcv_service_assoc != 0 AND OLD.deleted = 0 THEN
|
||||
IF NOT (NEW.rcv_service_assoc != 0 AND NEW.deleted = 0) THEN
|
||||
PERFORM update_aggregates(OLD.conn_id, OLD.host, OLD.port, -1, OLD.rcv_id);
|
||||
END IF;
|
||||
ELSIF NEW.rcv_service_assoc != 0 AND NEW.deleted = 0 THEN
|
||||
PERFORM update_aggregates(NEW.conn_id, NEW.host, NEW.port, 1, NEW.rcv_id);
|
||||
END IF;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE TRIGGER tr_rcv_queue_insert
|
||||
AFTER INSERT ON rcv_queues
|
||||
FOR EACH ROW EXECUTE PROCEDURE on_rcv_queue_insert();
|
||||
|
||||
CREATE TRIGGER tr_rcv_queue_delete
|
||||
AFTER DELETE ON rcv_queues
|
||||
FOR EACH ROW EXECUTE PROCEDURE on_rcv_queue_delete();
|
||||
|
||||
CREATE TRIGGER tr_rcv_queue_update
|
||||
AFTER UPDATE ON rcv_queues
|
||||
FOR EACH ROW EXECUTE PROCEDURE on_rcv_queue_update();
|
||||
|]
|
||||
|
||||
down_m20251020_service_certs :: Text
|
||||
down_m20251020_service_certs =
|
||||
[r|
|
||||
DROP TRIGGER tr_rcv_queue_insert ON rcv_queues;
|
||||
DROP TRIGGER tr_rcv_queue_delete ON rcv_queues;
|
||||
DROP TRIGGER tr_rcv_queue_update ON rcv_queues;
|
||||
|
||||
DROP FUNCTION on_rcv_queue_insert;
|
||||
DROP FUNCTION on_rcv_queue_delete;
|
||||
DROP FUNCTION on_rcv_queue_update;
|
||||
|
||||
DROP FUNCTION update_aggregates;
|
||||
|
||||
ALTER TABLE rcv_queues DROP COLUMN rcv_service_assoc;
|
||||
|
||||
DROP INDEX idx_server_certs_host_port;
|
||||
DROP INDEX idx_server_certs_user_id_host_port;
|
||||
DROP TABLE client_services;
|
||||
|]
|
||||
<> dropXorHashFuncs
|
||||
@@ -0,0 +1,46 @@
|
||||
{-# LANGUAGE QuasiQuotes #-}
|
||||
|
||||
module Simplex.Messaging.Agent.Store.Postgres.Migrations.Util where
|
||||
|
||||
import Data.Text (Text)
|
||||
import qualified Data.Text as T
|
||||
import Text.RawString.QQ (r)
|
||||
|
||||
-- xor_combine is only applied to locally computed md5 hashes (128 bits/16 bytes),
|
||||
-- so it is safe to require that all values are of the same length.
|
||||
createXorHashFuncs :: Text
|
||||
createXorHashFuncs =
|
||||
T.pack
|
||||
[r|
|
||||
CREATE OR REPLACE FUNCTION xor_combine(state BYTEA, value BYTEA) RETURNS BYTEA
|
||||
LANGUAGE plpgsql IMMUTABLE STRICT
|
||||
AS $$
|
||||
DECLARE
|
||||
result BYTEA := state;
|
||||
i INTEGER;
|
||||
len INTEGER := octet_length(value);
|
||||
BEGIN
|
||||
IF octet_length(state) != len THEN
|
||||
RAISE EXCEPTION 'Inputs must be equal length (% != %)', octet_length(state), len;
|
||||
END IF;
|
||||
FOR i IN 0..len-1 LOOP
|
||||
result := set_byte(result, i, get_byte(state, i) # get_byte(value, i));
|
||||
END LOOP;
|
||||
RETURN result;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE OR REPLACE AGGREGATE xor_aggregate(BYTEA) (
|
||||
SFUNC = xor_combine,
|
||||
STYPE = BYTEA,
|
||||
INITCOND = '\x00000000000000000000000000000000' -- 16 bytes
|
||||
);
|
||||
|]
|
||||
|
||||
dropXorHashFuncs :: Text
|
||||
dropXorHashFuncs =
|
||||
T.pack
|
||||
[r|
|
||||
DROP AGGREGATE xor_aggregate(BYTEA);
|
||||
DROP FUNCTION xor_combine;
|
||||
|]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -21,30 +21,32 @@ import Database.PostgreSQL.Simple.SqlQQ (sql)
|
||||
createDBAndUserIfNotExists :: ConnectInfo -> IO ()
|
||||
createDBAndUserIfNotExists ConnectInfo {connectUser = user, connectDatabase = dbName} = do
|
||||
-- connect to the default "postgres" maintenance database
|
||||
bracket (PSQL.connect defaultConnectInfo {connectUser = "postgres", connectDatabase = "postgres"}) PSQL.close $
|
||||
\postgresDB -> do
|
||||
void $ PSQL.execute_ postgresDB "SET client_min_messages TO WARNING"
|
||||
-- check if the user exists, create if not
|
||||
[Only userExists] <-
|
||||
PSQL.query
|
||||
postgresDB
|
||||
[sql|
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM pg_catalog.pg_roles
|
||||
WHERE rolname = ?
|
||||
)
|
||||
|]
|
||||
(Only user)
|
||||
unless userExists $ void $ PSQL.execute_ postgresDB (fromString $ "CREATE USER " <> user)
|
||||
-- check if the database exists, create if not
|
||||
dbExists <- checkDBExists postgresDB dbName
|
||||
unless dbExists $ void $ PSQL.execute_ postgresDB (fromString $ "CREATE DATABASE " <> dbName <> " OWNER " <> user)
|
||||
bracket (PSQL.connect defaultConnectInfo {connectUser = "postgres", connectDatabase = "postgres"}) PSQL.close $ \db -> do
|
||||
execSQL db "SET client_min_messages TO WARNING"
|
||||
-- check if the user exists, create if not
|
||||
[Only userExists] <-
|
||||
PSQL.query
|
||||
db
|
||||
[sql|
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM pg_catalog.pg_roles
|
||||
WHERE rolname = ?
|
||||
)
|
||||
|]
|
||||
(Only user)
|
||||
unless userExists $ execSQL db $ "CREATE USER " <> user
|
||||
-- check if the database exists, create if not
|
||||
dbExists <- checkDBExists db dbName
|
||||
unless dbExists $ do
|
||||
execSQL db $ "CREATE DATABASE " <> dbName <> " OWNER " <> user
|
||||
bracket (PSQL.connect defaultConnectInfo {connectUser = "postgres", connectDatabase = dbName}) PSQL.close $
|
||||
(`execSQL` "CREATE EXTENSION IF NOT EXISTS pgcrypto")
|
||||
|
||||
checkDBExists :: PSQL.Connection -> String -> IO Bool
|
||||
checkDBExists postgresDB dbName = do
|
||||
checkDBExists db dbName = do
|
||||
[Only dbExists] <-
|
||||
PSQL.query
|
||||
postgresDB
|
||||
db
|
||||
[sql|
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM pg_catalog.pg_database
|
||||
@@ -56,45 +58,45 @@ checkDBExists postgresDB dbName = do
|
||||
|
||||
dropSchema :: ConnectInfo -> String -> IO ()
|
||||
dropSchema connectInfo schema =
|
||||
bracket (PSQL.connect connectInfo) PSQL.close $
|
||||
\db -> do
|
||||
void $ PSQL.execute_ db "SET client_min_messages TO WARNING"
|
||||
void $ PSQL.execute_ db (fromString $ "DROP SCHEMA IF EXISTS " <> schema <> " CASCADE")
|
||||
bracket (PSQL.connect connectInfo) PSQL.close $ \db -> do
|
||||
execSQL db "SET client_min_messages TO WARNING"
|
||||
execSQL db $ "DROP SCHEMA IF EXISTS " <> schema <> " CASCADE"
|
||||
|
||||
dropAllSchemasExceptSystem :: ConnectInfo -> IO ()
|
||||
dropAllSchemasExceptSystem connectInfo =
|
||||
bracket (PSQL.connect connectInfo) PSQL.close $
|
||||
\db -> do
|
||||
void $ PSQL.execute_ db "SET client_min_messages TO WARNING"
|
||||
schemaNames :: [Only String] <-
|
||||
PSQL.query_
|
||||
db
|
||||
[sql|
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('public', 'pg_catalog', 'information_schema')
|
||||
|]
|
||||
forM_ schemaNames $ \(Only schema) ->
|
||||
PSQL.execute_ db (fromString $ "DROP SCHEMA " <> schema <> " CASCADE")
|
||||
bracket (PSQL.connect connectInfo) PSQL.close $ \db -> do
|
||||
execSQL db "SET client_min_messages TO WARNING"
|
||||
schemaNames :: [Only String] <-
|
||||
PSQL.query_
|
||||
db
|
||||
[sql|
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('public', 'pg_catalog', 'information_schema')
|
||||
|]
|
||||
forM_ schemaNames $ \(Only schema) ->
|
||||
execSQL db $ "DROP SCHEMA " <> schema <> " CASCADE"
|
||||
|
||||
dropDatabaseAndUser :: ConnectInfo -> IO ()
|
||||
dropDatabaseAndUser ConnectInfo {connectUser = user, connectDatabase = dbName} =
|
||||
bracket (PSQL.connect defaultConnectInfo {connectUser = "postgres", connectDatabase = "postgres"}) PSQL.close $
|
||||
\postgresDB -> do
|
||||
void $ PSQL.execute_ postgresDB "SET client_min_messages TO WARNING"
|
||||
dbExists <- checkDBExists postgresDB dbName
|
||||
when dbExists $ do
|
||||
void $ PSQL.execute_ postgresDB (fromString $ "ALTER DATABASE " <> dbName <> " WITH ALLOW_CONNECTIONS false")
|
||||
-- terminate all connections to the database
|
||||
_r :: [Only Bool] <-
|
||||
PSQL.query
|
||||
postgresDB
|
||||
[sql|
|
||||
SELECT pg_terminate_backend(pg_stat_activity.pid)
|
||||
FROM pg_stat_activity
|
||||
WHERE datname = ?
|
||||
AND pid <> pg_backend_pid()
|
||||
|]
|
||||
(Only dbName)
|
||||
void $ PSQL.execute_ postgresDB (fromString $ "DROP DATABASE " <> dbName)
|
||||
void $ PSQL.execute_ postgresDB (fromString $ "DROP USER IF EXISTS " <> user)
|
||||
bracket (PSQL.connect defaultConnectInfo {connectUser = "postgres", connectDatabase = "postgres"}) PSQL.close $ \db -> do
|
||||
execSQL db "SET client_min_messages TO WARNING"
|
||||
dbExists <- checkDBExists db dbName
|
||||
when dbExists $ do
|
||||
execSQL db $ "ALTER DATABASE " <> dbName <> " WITH ALLOW_CONNECTIONS false"
|
||||
-- terminate all connections to the database
|
||||
_r :: [Only Bool] <-
|
||||
PSQL.query
|
||||
db
|
||||
[sql|
|
||||
SELECT pg_terminate_backend(pg_stat_activity.pid)
|
||||
FROM pg_stat_activity
|
||||
WHERE datname = ?
|
||||
AND pid <> pg_backend_pid()
|
||||
|]
|
||||
(Only dbName)
|
||||
execSQL db $ "DROP DATABASE " <> dbName
|
||||
execSQL db $ "DROP USER IF EXISTS " <> user
|
||||
|
||||
execSQL :: PSQL.Connection -> String -> IO ()
|
||||
execSQL db = void . PSQL.execute_ db . fromString
|
||||
|
||||
@@ -42,9 +42,15 @@ module Simplex.Messaging.Agent.Store.SQLite
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Concurrent.MVar
|
||||
import Control.Concurrent.STM
|
||||
import Control.Exception (bracketOnError, onException, throwIO)
|
||||
import Control.Monad
|
||||
import Data.Bits (xor)
|
||||
import Data.ByteArray (ScrubbedBytes)
|
||||
import qualified Data.ByteArray as BA
|
||||
import Data.ByteString (ByteString)
|
||||
import qualified Data.ByteString as B
|
||||
import Data.Functor (($>))
|
||||
import Data.IORef
|
||||
import Data.Maybe (fromMaybe)
|
||||
@@ -54,17 +60,19 @@ import Database.SQLite.Simple (Query (..))
|
||||
import qualified Database.SQLite.Simple as SQL
|
||||
import Database.SQLite.Simple.QQ (sql)
|
||||
import qualified Database.SQLite3 as SQLite3
|
||||
import Database.SQLite3.Bindings
|
||||
import Foreign.C.Types
|
||||
import Foreign.Ptr
|
||||
import Simplex.Messaging.Agent.Store.Migrations (DBMigrate (..), sharedMigrateSchema)
|
||||
import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Common
|
||||
import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB
|
||||
import Simplex.Messaging.Agent.Store.Shared (Migration (..), MigrationConfig (..), MigrationError (..))
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Util (SQLiteFunc, createStaticFunction, mkSQLiteFunc)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Util (ifM, safeDecodeUtf8)
|
||||
import System.Directory (copyFile, createDirectoryIfMissing, doesFileExist)
|
||||
import System.FilePath (takeDirectory, takeFileName, (</>))
|
||||
import UnliftIO.Exception (bracketOnError, onException)
|
||||
import UnliftIO.MVar
|
||||
import UnliftIO.STM
|
||||
|
||||
-- * SQLite Store implementation
|
||||
|
||||
@@ -109,9 +117,9 @@ connectDB path key track = do
|
||||
pure db
|
||||
where
|
||||
prepare db = do
|
||||
let exec = SQLite3.exec $ SQL.connectionHandle $ DB.conn db
|
||||
unless (BA.null key) . exec $ "PRAGMA key = " <> keyString key <> ";"
|
||||
exec . fromQuery $
|
||||
let db' = SQL.connectionHandle $ DB.conn db
|
||||
unless (BA.null key) . SQLite3.exec db' $ "PRAGMA key = " <> keyString key <> ";"
|
||||
SQLite3.exec db' . fromQuery $
|
||||
[sql|
|
||||
PRAGMA busy_timeout = 100;
|
||||
PRAGMA foreign_keys = ON;
|
||||
@@ -119,6 +127,21 @@ connectDB path key track = do
|
||||
PRAGMA secure_delete = ON;
|
||||
PRAGMA auto_vacuum = FULL;
|
||||
|]
|
||||
createStaticFunction db' "simplex_xor_md5_combine" 2 True sqliteXorMd5CombinePtr
|
||||
>>= either (throwIO . userError . show) pure
|
||||
|
||||
foreign export ccall "simplex_xor_md5_combine" sqliteXorMd5Combine :: SQLiteFunc
|
||||
|
||||
foreign import ccall "&simplex_xor_md5_combine" sqliteXorMd5CombinePtr :: FunPtr SQLiteFunc
|
||||
|
||||
sqliteXorMd5Combine :: SQLiteFunc
|
||||
sqliteXorMd5Combine = mkSQLiteFunc $ \cxt args -> do
|
||||
idsHash <- SQLite3.funcArgBlob args 0
|
||||
rId <- SQLite3.funcArgBlob args 1
|
||||
SQLite3.funcResultBlob cxt $ xorMd5Combine idsHash rId
|
||||
|
||||
xorMd5Combine :: ByteString -> ByteString -> ByteString
|
||||
xorMd5Combine idsHash rId = B.packZipWith xor idsHash $ C.md5Hash rId
|
||||
|
||||
closeDBStore :: DBStore -> IO ()
|
||||
closeDBStore st@DBStore {dbClosed} =
|
||||
|
||||
@@ -53,6 +53,12 @@ withConnectionPriority DBStore {dbSem, dbConnection} priority action
|
||||
| priority = E.bracket_ signal release $ withMVar dbConnection action
|
||||
| otherwise = lowPriority
|
||||
where
|
||||
-- To debug FK errors, set foreign_keys = OFF in Simplex.Messaging.Agent.Store.SQLite and use action' instead of action
|
||||
-- action' conn = do
|
||||
-- r <- action conn
|
||||
-- violations <- DB.query_ conn "PRAGMA foreign_key_check" :: IO [ (String, Int, String, Int)]
|
||||
-- unless (null violations) $ print violations
|
||||
-- pure r
|
||||
lowPriority = wait >> withMVar dbConnection (\db -> ifM free (Just <$> action db) (pure Nothing)) >>= maybe lowPriority pure
|
||||
signal = atomically $ modifyTVar' dbSem (+ 1)
|
||||
release = atomically $ modifyTVar' dbSem $ \sem -> if sem > 0 then sem - 1 else 0
|
||||
|
||||
@@ -5,7 +5,6 @@ module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20251020_service_certs w
|
||||
import Database.SQLite.Simple (Query)
|
||||
import Database.SQLite.Simple.QQ (sql)
|
||||
|
||||
-- TODO move date forward, create migration for postgres
|
||||
m20251020_service_certs :: Query
|
||||
m20251020_service_certs =
|
||||
[sql|
|
||||
@@ -13,27 +12,81 @@ CREATE TABLE client_services(
|
||||
user_id INTEGER NOT NULL REFERENCES users ON DELETE CASCADE,
|
||||
host TEXT NOT NULL,
|
||||
port TEXT NOT NULL,
|
||||
server_key_hash BLOB,
|
||||
service_cert BLOB NOT NULL,
|
||||
service_cert_hash BLOB NOT NULL,
|
||||
service_priv_key BLOB NOT NULL,
|
||||
rcv_service_id BLOB,
|
||||
service_id BLOB,
|
||||
service_queue_count INTEGER NOT NULL DEFAULT 0,
|
||||
service_queue_ids_hash BLOB NOT NULL DEFAULT x'00000000000000000000000000000000',
|
||||
FOREIGN KEY(host, port) REFERENCES servers ON UPDATE CASCADE ON DELETE RESTRICT
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX idx_server_certs_user_id_host_port ON client_services(user_id, host, port);
|
||||
|
||||
CREATE UNIQUE INDEX idx_server_certs_user_id_host_port ON client_services(user_id, host, port, server_key_hash);
|
||||
CREATE INDEX idx_server_certs_host_port ON client_services(host, port);
|
||||
|
||||
ALTER TABLE rcv_queues ADD COLUMN rcv_service_assoc INTEGER NOT NULL DEFAULT 0;
|
||||
|
||||
CREATE TRIGGER tr_rcv_queue_insert
|
||||
AFTER INSERT ON rcv_queues
|
||||
FOR EACH ROW
|
||||
WHEN NEW.rcv_service_assoc != 0 AND NEW.deleted = 0
|
||||
BEGIN
|
||||
UPDATE client_services
|
||||
SET service_queue_count = service_queue_count + 1,
|
||||
service_queue_ids_hash = simplex_xor_md5_combine(service_queue_ids_hash, NEW.rcv_id)
|
||||
WHERE user_id = (SELECT user_id FROM connections WHERE conn_id = NEW.conn_id)
|
||||
AND host = NEW.host AND port = NEW.port;
|
||||
END;
|
||||
|
||||
CREATE TRIGGER tr_rcv_queue_delete
|
||||
AFTER DELETE ON rcv_queues
|
||||
FOR EACH ROW
|
||||
WHEN OLD.rcv_service_assoc != 0 AND OLD.deleted = 0
|
||||
BEGIN
|
||||
UPDATE client_services
|
||||
SET service_queue_count = service_queue_count - 1,
|
||||
service_queue_ids_hash = simplex_xor_md5_combine(service_queue_ids_hash, OLD.rcv_id)
|
||||
WHERE user_id = (SELECT user_id FROM connections WHERE conn_id = OLD.conn_id)
|
||||
AND host = OLD.host AND port = OLD.port;
|
||||
END;
|
||||
|
||||
CREATE TRIGGER tr_rcv_queue_update_remove
|
||||
AFTER UPDATE ON rcv_queues
|
||||
FOR EACH ROW
|
||||
WHEN OLD.rcv_service_assoc != 0 AND OLD.deleted = 0 AND NOT (NEW.rcv_service_assoc != 0 AND NEW.deleted = 0)
|
||||
BEGIN
|
||||
UPDATE client_services
|
||||
SET service_queue_count = service_queue_count - 1,
|
||||
service_queue_ids_hash = simplex_xor_md5_combine(service_queue_ids_hash, OLD.rcv_id)
|
||||
WHERE user_id = (SELECT user_id FROM connections WHERE conn_id = OLD.conn_id)
|
||||
AND host = OLD.host AND port = OLD.port;
|
||||
END;
|
||||
|
||||
CREATE TRIGGER tr_rcv_queue_update_add
|
||||
AFTER UPDATE ON rcv_queues
|
||||
FOR EACH ROW
|
||||
WHEN NEW.rcv_service_assoc != 0 AND NEW.deleted = 0 AND NOT (OLD.rcv_service_assoc != 0 AND OLD.deleted = 0)
|
||||
BEGIN
|
||||
UPDATE client_services
|
||||
SET service_queue_count = service_queue_count + 1,
|
||||
service_queue_ids_hash = simplex_xor_md5_combine(service_queue_ids_hash, NEW.rcv_id)
|
||||
WHERE user_id = (SELECT user_id FROM connections WHERE conn_id = NEW.conn_id)
|
||||
AND host = NEW.host AND port = NEW.port;
|
||||
END;
|
||||
|]
|
||||
|
||||
down_m20251020_service_certs :: Query
|
||||
down_m20251020_service_certs =
|
||||
[sql|
|
||||
DROP TRIGGER tr_rcv_queue_insert;
|
||||
DROP TRIGGER tr_rcv_queue_delete;
|
||||
DROP TRIGGER tr_rcv_queue_update_remove;
|
||||
DROP TRIGGER tr_rcv_queue_update_add;
|
||||
|
||||
ALTER TABLE rcv_queues DROP COLUMN rcv_service_assoc;
|
||||
|
||||
DROP INDEX idx_server_certs_host_port;
|
||||
|
||||
DROP INDEX idx_server_certs_user_id_host_port;
|
||||
|
||||
DROP TABLE client_services;
|
||||
|
||||
@@ -455,10 +455,13 @@ CREATE TABLE client_services(
|
||||
user_id INTEGER NOT NULL REFERENCES users ON DELETE CASCADE,
|
||||
host TEXT NOT NULL,
|
||||
port TEXT NOT NULL,
|
||||
server_key_hash BLOB,
|
||||
service_cert BLOB NOT NULL,
|
||||
service_cert_hash BLOB NOT NULL,
|
||||
service_priv_key BLOB NOT NULL,
|
||||
rcv_service_id BLOB,
|
||||
service_id BLOB,
|
||||
service_queue_count INTEGER NOT NULL DEFAULT 0,
|
||||
service_queue_ids_hash BLOB NOT NULL DEFAULT x'00000000000000000000000000000000',
|
||||
FOREIGN KEY(host, port) REFERENCES servers ON UPDATE CASCADE ON DELETE RESTRICT
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_rcv_queues_ntf ON rcv_queues(host, port, ntf_id);
|
||||
@@ -607,6 +610,51 @@ CREATE INDEX idx_rcv_queues_client_notice_id ON rcv_queues(client_notice_id);
|
||||
CREATE UNIQUE INDEX idx_server_certs_user_id_host_port ON client_services(
|
||||
user_id,
|
||||
host,
|
||||
port
|
||||
port,
|
||||
server_key_hash
|
||||
);
|
||||
CREATE INDEX idx_server_certs_host_port ON client_services(host, port);
|
||||
CREATE TRIGGER tr_rcv_queue_insert
|
||||
AFTER INSERT ON rcv_queues
|
||||
FOR EACH ROW
|
||||
WHEN NEW.rcv_service_assoc != 0 AND NEW.deleted = 0
|
||||
BEGIN
|
||||
UPDATE client_services
|
||||
SET service_queue_count = service_queue_count + 1,
|
||||
service_queue_ids_hash = simplex_xor_md5_combine(service_queue_ids_hash, NEW.rcv_id)
|
||||
WHERE user_id = (SELECT user_id FROM connections WHERE conn_id = NEW.conn_id)
|
||||
AND host = NEW.host AND port = NEW.port;
|
||||
END;
|
||||
CREATE TRIGGER tr_rcv_queue_delete
|
||||
AFTER DELETE ON rcv_queues
|
||||
FOR EACH ROW
|
||||
WHEN OLD.rcv_service_assoc != 0 AND OLD.deleted = 0
|
||||
BEGIN
|
||||
UPDATE client_services
|
||||
SET service_queue_count = service_queue_count - 1,
|
||||
service_queue_ids_hash = simplex_xor_md5_combine(service_queue_ids_hash, OLD.rcv_id)
|
||||
WHERE user_id = (SELECT user_id FROM connections WHERE conn_id = OLD.conn_id)
|
||||
AND host = OLD.host AND port = OLD.port;
|
||||
END;
|
||||
CREATE TRIGGER tr_rcv_queue_update_remove
|
||||
AFTER UPDATE ON rcv_queues
|
||||
FOR EACH ROW
|
||||
WHEN OLD.rcv_service_assoc != 0 AND OLD.deleted = 0 AND NOT (NEW.rcv_service_assoc != 0 AND NEW.deleted = 0)
|
||||
BEGIN
|
||||
UPDATE client_services
|
||||
SET service_queue_count = service_queue_count - 1,
|
||||
service_queue_ids_hash = simplex_xor_md5_combine(service_queue_ids_hash, OLD.rcv_id)
|
||||
WHERE user_id = (SELECT user_id FROM connections WHERE conn_id = OLD.conn_id)
|
||||
AND host = OLD.host AND port = OLD.port;
|
||||
END;
|
||||
CREATE TRIGGER tr_rcv_queue_update_add
|
||||
AFTER UPDATE ON rcv_queues
|
||||
FOR EACH ROW
|
||||
WHEN NEW.rcv_service_assoc != 0 AND NEW.deleted = 0 AND NOT (OLD.rcv_service_assoc != 0 AND OLD.deleted = 0)
|
||||
BEGIN
|
||||
UPDATE client_services
|
||||
SET service_queue_count = service_queue_count + 1,
|
||||
service_queue_ids_hash = simplex_xor_md5_combine(service_queue_ids_hash, NEW.rcv_id)
|
||||
WHERE user_id = (SELECT user_id FROM connections WHERE conn_id = NEW.conn_id)
|
||||
AND host = NEW.host AND port = NEW.port;
|
||||
END;
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
module Simplex.Messaging.Agent.Store.SQLite.Util where
|
||||
|
||||
import Control.Exception (SomeException, catch, mask_)
|
||||
import Data.ByteString (ByteString)
|
||||
import qualified Data.ByteString as B
|
||||
import Database.SQLite3.Direct (Database (..), FuncArgs (..), FuncContext (..))
|
||||
import Database.SQLite3.Bindings
|
||||
import Foreign.C.String
|
||||
import Foreign.Ptr
|
||||
import Foreign.StablePtr
|
||||
|
||||
data CFuncPtrs = CFuncPtrs (FunPtr CFunc) (FunPtr CFunc) (FunPtr CFuncFinal)
|
||||
|
||||
type SQLiteFunc = Ptr CContext -> CArgCount -> Ptr (Ptr CValue) -> IO ()
|
||||
|
||||
mkSQLiteFunc :: (FuncContext -> FuncArgs -> IO ()) -> SQLiteFunc
|
||||
mkSQLiteFunc f cxt nArgs cvals = catchAsResultError cxt $ f (FuncContext cxt) (FuncArgs nArgs cvals)
|
||||
{-# INLINE mkSQLiteFunc #-}
|
||||
|
||||
-- Based on createFunction from Database.SQLite3.Direct, but uses static function pointer to avoid dynamic wrapper that triggers DCL.
|
||||
createStaticFunction :: Database -> ByteString -> CArgCount -> Bool -> FunPtr SQLiteFunc -> IO (Either Error ())
|
||||
createStaticFunction (Database db) name nArgs isDet funPtr = mask_ $ do
|
||||
u <- newStablePtr $ CFuncPtrs funPtr nullFunPtr nullFunPtr
|
||||
let flags = if isDet then c_SQLITE_DETERMINISTIC else 0
|
||||
B.useAsCString name $ \namePtr ->
|
||||
toResult () <$> c_sqlite3_create_function_v2 db namePtr nArgs flags (castStablePtrToPtr u) funPtr nullFunPtr nullFunPtr nullFunPtr
|
||||
|
||||
-- Convert a 'CError' to a 'Either Error', in the common case where
|
||||
-- SQLITE_OK signals success and anything else signals an error.
|
||||
--
|
||||
-- Note that SQLITE_OK == 0.
|
||||
toResult :: a -> CError -> Either Error a
|
||||
toResult a (CError 0) = Right a
|
||||
toResult _ code = Left $ decodeError code
|
||||
|
||||
-- call c_sqlite3_result_error in the event of an error
|
||||
catchAsResultError :: Ptr CContext -> IO () -> IO ()
|
||||
catchAsResultError ctx action = catch action $ \exn -> do
|
||||
let msg = show (exn :: SomeException)
|
||||
withCAStringLen msg $ \(ptr, len) ->
|
||||
c_sqlite3_result_error ctx ptr (fromIntegral len)
|
||||
@@ -2,6 +2,7 @@
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
|
||||
module Simplex.Messaging.Agent.TSessionSubs
|
||||
( TSessionSubs (sessionSubs),
|
||||
@@ -12,7 +13,10 @@ module Simplex.Messaging.Agent.TSessionSubs
|
||||
hasPendingSub,
|
||||
addPendingSub,
|
||||
setSessionId,
|
||||
setPendingServiceSub,
|
||||
setActiveServiceSub,
|
||||
addActiveSub,
|
||||
addActiveSub',
|
||||
batchAddActiveSubs,
|
||||
batchAddPendingSubs,
|
||||
deletePendingSub,
|
||||
@@ -38,13 +42,13 @@ import qualified Data.Map.Strict as M
|
||||
import Data.Maybe (isJust)
|
||||
import qualified Data.Set as S
|
||||
import Simplex.Messaging.Agent.Protocol (SMPQueue (..))
|
||||
import Simplex.Messaging.Agent.Store (RcvQueueSub (..), SomeRcvQueue)
|
||||
import Simplex.Messaging.Agent.Store (RcvQueue, RcvQueueSub (..), SomeRcvQueue, StoredRcvQueue (rcvServiceAssoc), rcvQueueSub)
|
||||
import Simplex.Messaging.Client (SMPTransportSession, TransportSessionMode (..))
|
||||
import Simplex.Messaging.Protocol (RecipientId)
|
||||
import Simplex.Messaging.Protocol (RecipientId, ServiceSub (..), queueIdHash)
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Util (($>>=))
|
||||
import Simplex.Messaging.Util (anyM, ($>>=))
|
||||
|
||||
data TSessionSubs = TSessionSubs
|
||||
{ sessionSubs :: TMap SMPTransportSession SessSubs
|
||||
@@ -53,7 +57,9 @@ data TSessionSubs = TSessionSubs
|
||||
data SessSubs = SessSubs
|
||||
{ subsSessId :: TVar (Maybe SessionId),
|
||||
activeSubs :: TMap RecipientId RcvQueueSub,
|
||||
pendingSubs :: TMap RecipientId RcvQueueSub
|
||||
pendingSubs :: TMap RecipientId RcvQueueSub,
|
||||
activeServiceSub :: TVar (Maybe ServiceSub),
|
||||
pendingServiceSub :: TVar (Maybe ServiceSub)
|
||||
}
|
||||
|
||||
emptyIO :: IO TSessionSubs
|
||||
@@ -72,7 +78,7 @@ getSessSubs :: SMPTransportSession -> TSessionSubs -> STM SessSubs
|
||||
getSessSubs tSess ss = lookupSubs tSess ss >>= maybe new pure
|
||||
where
|
||||
new = do
|
||||
s <- SessSubs <$> newTVar Nothing <*> newTVar M.empty <*> newTVar M.empty
|
||||
s <- SessSubs <$> newTVar Nothing <*> newTVar M.empty <*> newTVar M.empty <*> newTVar Nothing <*> newTVar Nothing
|
||||
TM.insert tSess s $ sessionSubs ss
|
||||
pure s
|
||||
|
||||
@@ -98,8 +104,27 @@ setSessionId tSess sessId ss = do
|
||||
Nothing -> writeTVar (subsSessId s) (Just sessId)
|
||||
Just sessId' -> unless (sessId == sessId') $ void $ setSubsPending_ s $ Just sessId
|
||||
|
||||
addActiveSub :: SMPTransportSession -> SessionId -> RcvQueueSub -> TSessionSubs -> STM ()
|
||||
addActiveSub tSess sessId rq ss = do
|
||||
setPendingServiceSub :: SMPTransportSession -> ServiceSub -> TSessionSubs -> STM ()
|
||||
setPendingServiceSub tSess serviceSub ss = do
|
||||
s <- getSessSubs tSess ss
|
||||
writeTVar (pendingServiceSub s) $ Just serviceSub
|
||||
|
||||
setActiveServiceSub :: SMPTransportSession -> SessionId -> ServiceSub -> TSessionSubs -> STM ()
|
||||
setActiveServiceSub tSess sessId serviceSub ss = do
|
||||
s <- getSessSubs tSess ss
|
||||
sessId' <- readTVar $ subsSessId s
|
||||
if Just sessId == sessId'
|
||||
then do
|
||||
writeTVar (activeServiceSub s) $ Just serviceSub
|
||||
writeTVar (pendingServiceSub s) Nothing
|
||||
else writeTVar (pendingServiceSub s) $ Just serviceSub
|
||||
|
||||
addActiveSub :: SMPTransportSession -> SessionId -> RcvQueue -> TSessionSubs -> STM ()
|
||||
addActiveSub tSess sessId rq = addActiveSub' tSess sessId (rcvQueueSub rq) (rcvServiceAssoc rq)
|
||||
{-# INLINE addActiveSub #-}
|
||||
|
||||
addActiveSub' :: SMPTransportSession -> SessionId -> RcvQueueSub -> Bool -> TSessionSubs -> STM ()
|
||||
addActiveSub' tSess sessId rq serviceAssoc ss = do
|
||||
s <- getSessSubs tSess ss
|
||||
sessId' <- readTVar $ subsSessId s
|
||||
let rId = rcvId rq
|
||||
@@ -107,10 +132,13 @@ addActiveSub tSess sessId rq ss = do
|
||||
then do
|
||||
TM.insert rId rq $ activeSubs s
|
||||
TM.delete rId $ pendingSubs s
|
||||
when serviceAssoc $
|
||||
let updateServiceSub (ServiceSub serviceId n idsHash) = ServiceSub serviceId (n + 1) (idsHash <> queueIdHash rId)
|
||||
in modifyTVar' (activeServiceSub s) (updateServiceSub <$>)
|
||||
else TM.insert rId rq $ pendingSubs s
|
||||
|
||||
batchAddActiveSubs :: SMPTransportSession -> SessionId -> [RcvQueueSub] -> TSessionSubs -> STM ()
|
||||
batchAddActiveSubs tSess sessId rqs ss = do
|
||||
batchAddActiveSubs :: SMPTransportSession -> SessionId -> ([RcvQueueSub], [RcvQueueSub]) -> TSessionSubs -> STM ()
|
||||
batchAddActiveSubs tSess sessId (rqs, serviceRQs) ss = do
|
||||
s <- getSessSubs tSess ss
|
||||
sessId' <- readTVar $ subsSessId s
|
||||
let qs = M.fromList $ map (\rq -> (rcvId rq, rq)) rqs
|
||||
@@ -118,6 +146,12 @@ batchAddActiveSubs tSess sessId rqs ss = do
|
||||
then do
|
||||
TM.union qs $ activeSubs s
|
||||
modifyTVar' (pendingSubs s) (`M.difference` qs)
|
||||
serviceSub_ <- readTVar $ activeServiceSub s
|
||||
forM_ serviceSub_ $ \(ServiceSub serviceId n idsHash) -> do
|
||||
unless (null serviceRQs) $ do
|
||||
let idsHash' = idsHash <> mconcat (map (queueIdHash . rcvId) serviceRQs)
|
||||
n' = n + fromIntegral (length serviceRQs)
|
||||
writeTVar (activeServiceSub s) $ Just $ ServiceSub serviceId n' idsHash'
|
||||
else TM.union qs $ pendingSubs s
|
||||
|
||||
batchAddPendingSubs :: SMPTransportSession -> [RcvQueueSub] -> TSessionSubs -> STM ()
|
||||
@@ -143,11 +177,15 @@ batchDeleteSubs tSess rqs = lookupSubs tSess >=> mapM_ (\s -> delete (activeSubs
|
||||
delete = (`modifyTVar'` (`M.withoutKeys` rIds))
|
||||
|
||||
hasPendingSubs :: SMPTransportSession -> TSessionSubs -> STM Bool
|
||||
hasPendingSubs tSess = lookupSubs tSess >=> maybe (pure False) (fmap (not . null) . readTVar . pendingSubs)
|
||||
hasPendingSubs tSess = lookupSubs tSess >=> maybe (pure False) (\s -> anyM [hasSubs s, hasServiceSub s])
|
||||
where
|
||||
hasSubs = fmap (not . null) . readTVar . pendingSubs
|
||||
hasServiceSub = fmap isJust . readTVar . pendingServiceSub
|
||||
|
||||
getPendingSubs :: SMPTransportSession -> TSessionSubs -> STM (Map RecipientId RcvQueueSub)
|
||||
getPendingSubs = getSubs_ pendingSubs
|
||||
{-# INLINE getPendingSubs #-}
|
||||
getPendingSubs :: SMPTransportSession -> TSessionSubs -> STM (Map RecipientId RcvQueueSub, Maybe ServiceSub)
|
||||
getPendingSubs tSess = lookupSubs tSess >=> maybe (pure (M.empty, Nothing)) get
|
||||
where
|
||||
get s = liftM2 (,) (readTVar $ pendingSubs s) (readTVar $ pendingServiceSub s)
|
||||
|
||||
getActiveSubs :: SMPTransportSession -> TSessionSubs -> STM (Map RecipientId RcvQueueSub)
|
||||
getActiveSubs = getSubs_ activeSubs
|
||||
@@ -156,7 +194,7 @@ getActiveSubs = getSubs_ activeSubs
|
||||
getSubs_ :: (SessSubs -> TMap RecipientId RcvQueueSub) -> SMPTransportSession -> TSessionSubs -> STM (Map RecipientId RcvQueueSub)
|
||||
getSubs_ subs tSess = lookupSubs tSess >=> maybe (pure M.empty) (readTVar . subs)
|
||||
|
||||
setSubsPending :: TransportSessionMode -> SMPTransportSession -> SessionId -> TSessionSubs -> STM (Map RecipientId RcvQueueSub)
|
||||
setSubsPending :: TransportSessionMode -> SMPTransportSession -> SessionId -> TSessionSubs -> STM (Map RecipientId RcvQueueSub, Maybe ServiceSub)
|
||||
setSubsPending mode tSess@(uId, srv, connId_) sessId tss@(TSessionSubs ss)
|
||||
| entitySession == isJust connId_ =
|
||||
TM.lookup tSess ss >>= withSessSubs (`setSubsPending_` Nothing)
|
||||
@@ -166,17 +204,17 @@ setSubsPending mode tSess@(uId, srv, connId_) sessId tss@(TSessionSubs ss)
|
||||
entitySession = mode == TSMEntity
|
||||
sessEntId = if entitySession then Just else const Nothing
|
||||
withSessSubs run = \case
|
||||
Nothing -> pure M.empty
|
||||
Nothing -> pure (M.empty, Nothing)
|
||||
Just s -> do
|
||||
sessId' <- readTVar $ subsSessId s
|
||||
if Just sessId == sessId' then run s else pure M.empty
|
||||
if Just sessId == sessId' then run s else pure (M.empty, Nothing)
|
||||
setPendingChangeMode s = do
|
||||
subs <- M.union <$> readTVar (activeSubs s) <*> readTVar (pendingSubs s)
|
||||
unless (null subs) $
|
||||
forM_ subs $ \rq -> addPendingSub (uId, srv, sessEntId (connId rq)) rq tss
|
||||
pure subs
|
||||
(subs,) <$> setServiceSubPending_ s
|
||||
|
||||
setSubsPending_ :: SessSubs -> Maybe SessionId -> STM (Map RecipientId RcvQueueSub)
|
||||
setSubsPending_ :: SessSubs -> Maybe SessionId -> STM (Map RecipientId RcvQueueSub, Maybe ServiceSub)
|
||||
setSubsPending_ s sessId_ = do
|
||||
writeTVar (subsSessId s) sessId_
|
||||
let as = activeSubs s
|
||||
@@ -184,7 +222,15 @@ setSubsPending_ s sessId_ = do
|
||||
unless (null subs) $ do
|
||||
writeTVar as M.empty
|
||||
modifyTVar' (pendingSubs s) $ M.union subs
|
||||
pure subs
|
||||
(subs,) <$> setServiceSubPending_ s
|
||||
|
||||
setServiceSubPending_ :: SessSubs -> STM (Maybe ServiceSub)
|
||||
setServiceSubPending_ s = do
|
||||
serviceSub_ <- readTVar $ activeServiceSub s
|
||||
forM_ serviceSub_ $ \serviceSub -> do
|
||||
writeTVar (activeServiceSub s) Nothing
|
||||
writeTVar (pendingServiceSub s) $ Just serviceSub
|
||||
pure serviceSub_
|
||||
|
||||
updateClientNotices :: SMPTransportSession -> [(RecipientId, Maybe Int64)] -> TSessionSubs -> STM ()
|
||||
updateClientNotices tSess noticeIds ss = do
|
||||
|
||||
@@ -909,18 +909,18 @@ nsubResponse_ = \case
|
||||
{-# INLINE nsubResponse_ #-}
|
||||
|
||||
-- This command is always sent in background request mode
|
||||
subscribeService :: forall p. (PartyI p, ServiceParty p) => SMPClient -> SParty p -> ExceptT SMPClientError IO (Int64, IdsHash)
|
||||
subscribeService c party = case smpClientService c of
|
||||
subscribeService :: forall p. (PartyI p, ServiceParty p) => SMPClient -> SParty p -> Int64 -> IdsHash -> ExceptT SMPClientError IO ServiceSub
|
||||
subscribeService c party n idsHash = case smpClientService c of
|
||||
Just THClientService {serviceId, serviceKey} -> do
|
||||
liftIO $ enablePings c
|
||||
sendSMPCommand c NRMBackground (Just (C.APrivateAuthKey C.SEd25519 serviceKey)) serviceId subCmd >>= \case
|
||||
SOKS n idsHash -> pure (n, idsHash)
|
||||
SOKS n' idsHash' -> pure $ ServiceSub serviceId n' idsHash'
|
||||
r -> throwE $ unexpectedResponse r
|
||||
where
|
||||
subCmd :: Command p
|
||||
subCmd = case party of
|
||||
SRecipientService -> SUBS
|
||||
SNotifierService -> NSUBS
|
||||
SRecipientService -> SUBS n idsHash
|
||||
SNotifierService -> NSUBS n idsHash
|
||||
Nothing -> throwE PCEServiceUnavailable
|
||||
|
||||
smpClientService :: SMPClient -> Maybe THClientService
|
||||
|
||||
@@ -45,7 +45,6 @@ import Crypto.Random (ChaChaDRG)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Constraint (Dict (..))
|
||||
import Data.Int (Int64)
|
||||
import Data.List.NonEmpty (NonEmpty)
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import Data.Map.Strict (Map)
|
||||
@@ -69,10 +68,12 @@ import Simplex.Messaging.Protocol
|
||||
ProtocolServer (..),
|
||||
QueueId,
|
||||
SMPServer,
|
||||
ServiceSub (..),
|
||||
SParty (..),
|
||||
ServiceParty,
|
||||
serviceParty,
|
||||
partyServiceRole
|
||||
partyServiceRole,
|
||||
queueIdsHash,
|
||||
)
|
||||
import Simplex.Messaging.Session
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
@@ -91,14 +92,14 @@ data SMPClientAgentEvent
|
||||
| CADisconnected SMPServer (NonEmpty QueueId)
|
||||
| CASubscribed SMPServer (Maybe ServiceId) (NonEmpty QueueId)
|
||||
| CASubError SMPServer (NonEmpty (QueueId, SMPClientError))
|
||||
| CAServiceDisconnected SMPServer (ServiceId, Int64)
|
||||
| CAServiceSubscribed SMPServer (ServiceId, Int64) Int64
|
||||
| CAServiceSubError SMPServer (ServiceId, Int64) SMPClientError
|
||||
| CAServiceDisconnected SMPServer ServiceSub
|
||||
| CAServiceSubscribed {subServer :: SMPServer, expected :: ServiceSub, subscribed :: ServiceSub}
|
||||
| CAServiceSubError SMPServer ServiceSub SMPClientError
|
||||
-- CAServiceUnavailable is used when service ID in pending subscription is different from the current service in connection.
|
||||
-- This will require resubscribing to all queues associated with this service ID individually, creating new associations.
|
||||
-- It may happen if, for example, SMP server deletes service information (e.g. via downgrade and upgrade)
|
||||
-- and assigns different service ID to the service certificate.
|
||||
| CAServiceUnavailable SMPServer (ServiceId, Int64)
|
||||
| CAServiceUnavailable SMPServer ServiceSub
|
||||
|
||||
data SMPClientAgentConfig = SMPClientAgentConfig
|
||||
{ smpCfg :: ProtocolClientConfig SMPVersion,
|
||||
@@ -142,11 +143,11 @@ data SMPClientAgent p = SMPClientAgent
|
||||
-- Only one service subscription can exist per server with this agent.
|
||||
-- With correctly functioning SMP server, queue and service subscriptions can't be
|
||||
-- active at the same time.
|
||||
activeServiceSubs :: TMap SMPServer (TVar (Maybe ((ServiceId, Int64), SessionId))),
|
||||
activeServiceSubs :: TMap SMPServer (TVar (Maybe (ServiceSub, SessionId))),
|
||||
activeQueueSubs :: TMap SMPServer (TMap QueueId (SessionId, C.APrivateAuthKey)),
|
||||
-- Pending service subscriptions can co-exist with pending queue subscriptions
|
||||
-- on the same SMP server during subscriptions being transitioned from per-queue to service.
|
||||
pendingServiceSubs :: TMap SMPServer (TVar (Maybe (ServiceId, Int64))),
|
||||
pendingServiceSubs :: TMap SMPServer (TVar (Maybe ServiceSub)),
|
||||
pendingQueueSubs :: TMap SMPServer (TMap QueueId C.APrivateAuthKey),
|
||||
smpSubWorkers :: TMap SMPServer (SessionVar (Async ())),
|
||||
workerSeq :: TVar Int
|
||||
@@ -256,7 +257,7 @@ connectClient ca@SMPClientAgent {agentCfg, smpClients, smpSessions, msgQ, random
|
||||
removeClientAndSubs smp >>= serverDown
|
||||
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
|
||||
|
||||
removeClientAndSubs :: SMPClient -> IO (Maybe (ServiceId, Int64), Maybe (Map QueueId C.APrivateAuthKey))
|
||||
removeClientAndSubs :: SMPClient -> IO (Maybe ServiceSub, Maybe (Map QueueId C.APrivateAuthKey))
|
||||
removeClientAndSubs smp = do
|
||||
-- Looking up subscription vars outside of STM transaction to reduce re-evaluation.
|
||||
-- It is possible because these vars are never removed, they are only added.
|
||||
@@ -287,7 +288,7 @@ connectClient ca@SMPClientAgent {agentCfg, smpClients, smpSessions, msgQ, random
|
||||
then pure Nothing
|
||||
else Just subs <$ addSubs_ (pendingQueueSubs ca) srv subs
|
||||
|
||||
serverDown :: (Maybe (ServiceId, Int64), Maybe (Map QueueId C.APrivateAuthKey)) -> IO ()
|
||||
serverDown :: (Maybe ServiceSub, Maybe (Map QueueId C.APrivateAuthKey)) -> IO ()
|
||||
serverDown (sSub, qSubs) = do
|
||||
mapM_ (notify ca . CAServiceDisconnected srv) sSub
|
||||
let qIds = L.nonEmpty . M.keys =<< qSubs
|
||||
@@ -317,7 +318,7 @@ reconnectClient ca@SMPClientAgent {active, agentCfg, smpSubWorkers, workerSeq} s
|
||||
loop
|
||||
ProtocolClientConfig {networkConfig = NetworkConfig {tcpConnectTimeout}} = smpCfg agentCfg
|
||||
noPending (sSub, qSubs) = isNothing sSub && maybe True M.null qSubs
|
||||
getPending :: Monad m => (forall a. SMPServer -> TMap SMPServer a -> m (Maybe a)) -> (forall a. TVar a -> m a) -> m (Maybe (ServiceId, Int64), Maybe (Map QueueId C.APrivateAuthKey))
|
||||
getPending :: Monad m => (forall a. SMPServer -> TMap SMPServer a -> m (Maybe a)) -> (forall a. TVar a -> m a) -> m (Maybe ServiceSub, Maybe (Map QueueId C.APrivateAuthKey))
|
||||
getPending lkup rd = do
|
||||
sSub <- lkup srv (pendingServiceSubs ca) $>>= rd
|
||||
qSubs <- lkup srv (pendingQueueSubs ca) >>= mapM rd
|
||||
@@ -329,7 +330,7 @@ reconnectClient ca@SMPClientAgent {active, agentCfg, smpSubWorkers, workerSeq} s
|
||||
whenM (isEmptyTMVar $ sessionVar v) retry
|
||||
removeSessVar v srv smpSubWorkers
|
||||
|
||||
reconnectSMPClient :: forall p. SMPClientAgent p -> SMPServer -> (Maybe (ServiceId, Int64), Maybe (Map QueueId C.APrivateAuthKey)) -> ExceptT SMPClientError IO ()
|
||||
reconnectSMPClient :: forall p. SMPClientAgent p -> SMPServer -> (Maybe ServiceSub, Maybe (Map QueueId C.APrivateAuthKey)) -> ExceptT SMPClientError IO ()
|
||||
reconnectSMPClient ca@SMPClientAgent {agentCfg, agentParty} srv (sSub_, qSubs_) =
|
||||
withSMP ca srv $ \smp -> liftIO $ case serviceParty agentParty of
|
||||
Just Dict -> resubscribe smp
|
||||
@@ -430,7 +431,7 @@ smpSubscribeQueues ca smp srv subs = do
|
||||
let acc@(_, _, (qOks, sQs), notPending) = foldr (groupSub pending) (False, [], ([], []), []) (L.zip subs rs)
|
||||
unless (null qOks) $ addActiveSubs ca srv qOks
|
||||
unless (null sQs) $ forM_ smpServiceId $ \serviceId ->
|
||||
updateActiveServiceSub ca srv ((serviceId, fromIntegral $ length sQs), sessId)
|
||||
updateActiveServiceSub ca srv (ServiceSub serviceId (fromIntegral $ length sQs) (queueIdsHash sQs), sessId)
|
||||
unless (null notPending) $ removePendingSubs ca srv notPending
|
||||
pure acc
|
||||
sessId = sessionId $ thParams smp
|
||||
@@ -454,24 +455,24 @@ smpSubscribeQueues ca smp srv subs = do
|
||||
notify_ :: (SMPServer -> NonEmpty a -> SMPClientAgentEvent) -> [a] -> IO ()
|
||||
notify_ evt qs = mapM_ (notify ca . evt srv) $ L.nonEmpty qs
|
||||
|
||||
subscribeServiceNtfs :: SMPClientAgent 'NotifierService -> SMPServer -> (ServiceId, Int64) -> IO ()
|
||||
subscribeServiceNtfs :: SMPClientAgent 'NotifierService -> SMPServer -> ServiceSub -> IO ()
|
||||
subscribeServiceNtfs = subscribeService_
|
||||
{-# INLINE subscribeServiceNtfs #-}
|
||||
|
||||
subscribeService_ :: (PartyI p, ServiceParty p) => SMPClientAgent p -> SMPServer -> (ServiceId, Int64) -> IO ()
|
||||
subscribeService_ :: (PartyI p, ServiceParty p) => SMPClientAgent p -> SMPServer -> ServiceSub -> IO ()
|
||||
subscribeService_ ca srv serviceSub = do
|
||||
atomically $ setPendingServiceSub ca srv $ Just serviceSub
|
||||
runExceptT (getSMPServerClient' ca srv) >>= \case
|
||||
Right smp -> smpSubscribeService ca smp srv serviceSub
|
||||
Left _ -> pure () -- no call to reconnectClient - failing getSMPServerClient' does that
|
||||
|
||||
smpSubscribeService :: (PartyI p, ServiceParty p) => SMPClientAgent p -> SMPClient -> SMPServer -> (ServiceId, Int64) -> IO ()
|
||||
smpSubscribeService ca smp srv serviceSub@(serviceId, _) = case smpClientService smp of
|
||||
smpSubscribeService :: (PartyI p, ServiceParty p) => SMPClientAgent p -> SMPClient -> SMPServer -> ServiceSub -> IO ()
|
||||
smpSubscribeService ca smp srv serviceSub@(ServiceSub serviceId n idsHash) = case smpClientService smp of
|
||||
Just service | serviceAvailable service -> subscribe
|
||||
_ -> notifyUnavailable
|
||||
where
|
||||
subscribe = do
|
||||
r <- runExceptT $ subscribeService smp $ agentParty ca
|
||||
r <- runExceptT $ subscribeService smp (agentParty ca) n idsHash
|
||||
ok <-
|
||||
atomically $
|
||||
ifM
|
||||
@@ -479,15 +480,15 @@ smpSubscribeService ca smp srv serviceSub@(serviceId, _) = case smpClientService
|
||||
(True <$ processSubscription r)
|
||||
(pure False)
|
||||
if ok
|
||||
then case r of -- TODO [certs rcv] compare hash
|
||||
Right (n, _idsHash) -> notify ca $ CAServiceSubscribed srv serviceSub n
|
||||
then case r of
|
||||
Right serviceSub' -> notify ca $ CAServiceSubscribed srv serviceSub serviceSub'
|
||||
Left e
|
||||
| smpClientServiceError e -> notifyUnavailable
|
||||
| temporaryClientError e -> reconnectClient ca srv
|
||||
| otherwise -> notify ca $ CAServiceSubError srv serviceSub e
|
||||
else reconnectClient ca srv
|
||||
processSubscription = mapM_ $ \(n, _idsHash) -> do -- TODO [certs rcv] validate hash here?
|
||||
setActiveServiceSub ca srv $ Just ((serviceId, n), sessId)
|
||||
processSubscription = mapM_ $ \serviceSub' -> do -- TODO [certs rcv] validate hash here?
|
||||
setActiveServiceSub ca srv $ Just (serviceSub', sessId)
|
||||
setPendingServiceSub ca srv Nothing
|
||||
serviceAvailable THClientService {serviceRole, serviceId = serviceId'} =
|
||||
serviceId == serviceId' && partyServiceRole (agentParty ca) == serviceRole
|
||||
@@ -529,11 +530,11 @@ addSubs_ subs srv ss =
|
||||
Just m -> TM.union ss m
|
||||
_ -> TM.insertM srv (newTVar ss) subs
|
||||
|
||||
setActiveServiceSub :: SMPClientAgent p -> SMPServer -> Maybe ((ServiceId, Int64), SessionId) -> STM ()
|
||||
setActiveServiceSub :: SMPClientAgent p -> SMPServer -> Maybe (ServiceSub, SessionId) -> STM ()
|
||||
setActiveServiceSub = setServiceSub_ activeServiceSubs
|
||||
{-# INLINE setActiveServiceSub #-}
|
||||
|
||||
setPendingServiceSub :: SMPClientAgent p -> SMPServer -> Maybe (ServiceId, Int64) -> STM ()
|
||||
setPendingServiceSub :: SMPClientAgent p -> SMPServer -> Maybe ServiceSub -> STM ()
|
||||
setPendingServiceSub = setServiceSub_ pendingServiceSubs
|
||||
{-# INLINE setPendingServiceSub #-}
|
||||
|
||||
@@ -548,12 +549,12 @@ setServiceSub_ subsSel ca srv sub =
|
||||
Just v -> writeTVar v sub
|
||||
Nothing -> TM.insertM srv (newTVar sub) (subsSel ca)
|
||||
|
||||
updateActiveServiceSub :: SMPClientAgent p -> SMPServer -> ((ServiceId, Int64), SessionId) -> STM ()
|
||||
updateActiveServiceSub ca srv sub@((serviceId', n'), sessId') =
|
||||
updateActiveServiceSub :: SMPClientAgent p -> SMPServer -> (ServiceSub, SessionId) -> STM ()
|
||||
updateActiveServiceSub ca srv sub@(ServiceSub serviceId' n' idsHash', sessId') =
|
||||
TM.lookup srv (activeServiceSubs ca) >>= \case
|
||||
Just v -> modifyTVar' v $ \case
|
||||
Just ((serviceId, n), sessId) | serviceId == serviceId' && sessId == sessId' ->
|
||||
Just ((serviceId, n + n'), sessId)
|
||||
Just (ServiceSub serviceId n idsHash, sessId) | serviceId == serviceId' && sessId == sessId' ->
|
||||
Just (ServiceSub serviceId (n + n') (idsHash <> idsHash'), sessId)
|
||||
_ -> Just sub
|
||||
Nothing -> TM.insertM srv (newTVar $ Just sub) (activeServiceSubs ca)
|
||||
|
||||
|
||||
@@ -178,6 +178,7 @@ module Simplex.Messaging.Crypto
|
||||
sha512Hash,
|
||||
sha3_256,
|
||||
sha3_384,
|
||||
md5Hash,
|
||||
|
||||
-- * Message padding / un-padding
|
||||
canPad,
|
||||
@@ -216,7 +217,7 @@ import Crypto.Cipher.AES (AES256)
|
||||
import qualified Crypto.Cipher.Types as AES
|
||||
import qualified Crypto.Cipher.XSalsa as XSalsa
|
||||
import qualified Crypto.Error as CE
|
||||
import Crypto.Hash (Digest, SHA3_256, SHA3_384, SHA256 (..), SHA512 (..), hash, hashDigestSize)
|
||||
import Crypto.Hash (Digest, MD5, SHA3_256, SHA3_384, SHA256 (..), SHA512 (..), hash, hashDigestSize)
|
||||
import qualified Crypto.KDF.HKDF as H
|
||||
import qualified Crypto.MAC.Poly1305 as Poly1305
|
||||
import qualified Crypto.PubKey.Curve25519 as X25519
|
||||
@@ -1024,6 +1025,9 @@ sha3_384 :: ByteString -> ByteString
|
||||
sha3_384 = BA.convert . (hash :: ByteString -> Digest SHA3_384)
|
||||
{-# INLINE sha3_384 #-}
|
||||
|
||||
md5Hash :: ByteString -> ByteString
|
||||
md5Hash = BA.convert . (hash :: ByteString -> Digest MD5)
|
||||
|
||||
-- | AEAD-GCM encryption with associated data.
|
||||
--
|
||||
-- Used as part of double ratchet encryption.
|
||||
|
||||
@@ -489,17 +489,9 @@ data NtfSubStatus
|
||||
NSErr ByteString
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
ntfShouldSubscribe :: NtfSubStatus -> Bool
|
||||
ntfShouldSubscribe = \case
|
||||
NSNew -> True
|
||||
NSPending -> True
|
||||
NSActive -> True
|
||||
NSInactive -> True
|
||||
NSEnd -> False
|
||||
NSDeleted -> False
|
||||
NSAuth -> False
|
||||
NSService -> True
|
||||
NSErr _ -> False
|
||||
-- if these statuses change, the queue ID hashes for services need to be updated in a new migration (see m20250830_queue_ids_hash)
|
||||
subscribeNtfStatuses :: [NtfSubStatus]
|
||||
subscribeNtfStatuses = [NSNew, NSPending, NSActive, NSInactive]
|
||||
|
||||
instance Encoding NtfSubStatus where
|
||||
smpEncode = \case
|
||||
|
||||
@@ -62,7 +62,7 @@ import Simplex.Messaging.Notifications.Server.Store (NtfSTMStore, TokenNtfMessag
|
||||
import Simplex.Messaging.Notifications.Server.Store.Postgres
|
||||
import Simplex.Messaging.Notifications.Server.Store.Types
|
||||
import Simplex.Messaging.Notifications.Transport
|
||||
import Simplex.Messaging.Protocol (EntityId (..), ErrorType (..), NotifierId, Party (..), ProtocolServer (host), SMPServer, ServiceId, SignedTransmission, Transmission, pattern NoEntity, pattern SMPServer, encodeTransmission, tGetServer, tPut)
|
||||
import Simplex.Messaging.Protocol (EntityId (..), ErrorType (..), NotifierId, Party (..), ProtocolServer (host), SMPServer, ServiceSub (..), SignedTransmission, Transmission, pattern NoEntity, pattern SMPServer, encodeTransmission, tGetServer, tPut)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Server
|
||||
import Simplex.Messaging.Server.Control (CPClientRole (..))
|
||||
@@ -257,9 +257,9 @@ ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg, startOptions}
|
||||
srvSubscribers <- getSMPWorkerMetrics a smpSubscribers
|
||||
srvClients <- getSMPWorkerMetrics a smpClients
|
||||
srvSubWorkers <- getSMPWorkerMetrics a smpSubWorkers
|
||||
ntfActiveServiceSubs <- getSMPServiceSubMetrics a activeServiceSubs $ snd . fst
|
||||
ntfActiveServiceSubs <- getSMPServiceSubMetrics a activeServiceSubs $ smpQueueCount . fst
|
||||
ntfActiveQueueSubs <- getSMPSubMetrics a activeQueueSubs
|
||||
ntfPendingServiceSubs <- getSMPServiceSubMetrics a pendingServiceSubs snd
|
||||
ntfPendingServiceSubs <- getSMPServiceSubMetrics a pendingServiceSubs smpQueueCount
|
||||
ntfPendingQueueSubs <- getSMPSubMetrics a pendingQueueSubs
|
||||
smpSessionCount <- M.size <$> readTVarIO smpSessions
|
||||
apnsPushQLength <- atomically $ lengthTBQueue pushQ
|
||||
@@ -452,13 +452,13 @@ resubscribe NtfSubscriber {smpAgent = ca} = do
|
||||
counts <- mapConcurrently (subscribeSrvSubs ca st batchSize) srvs
|
||||
logNote $ "Completed all SMP resubscriptions for " <> tshow (length srvs) <> " servers (" <> tshow (sum counts) <> " subscriptions)"
|
||||
|
||||
subscribeSrvSubs :: SMPClientAgent 'NotifierService -> NtfPostgresStore -> Int -> (SMPServer, Int64, Maybe (ServiceId, Int64)) -> IO Int
|
||||
subscribeSrvSubs :: SMPClientAgent 'NotifierService -> NtfPostgresStore -> Int -> (SMPServer, Int64, Maybe ServiceSub) -> IO Int
|
||||
subscribeSrvSubs ca st batchSize (srv, srvId, service_) = do
|
||||
let srvStr = safeDecodeUtf8 (strEncode $ L.head $ host srv)
|
||||
logNote $ "Starting SMP resubscriptions for " <> srvStr
|
||||
forM_ service_ $ \(serviceId, n) -> do
|
||||
logNote $ "Subscribing service to " <> srvStr <> " with " <> tshow n <> " associated queues"
|
||||
subscribeServiceNtfs ca srv (serviceId, n)
|
||||
forM_ service_ $ \serviceSub -> do
|
||||
logNote $ "Subscribing service to " <> srvStr <> " with " <> tshow (smpQueueCount serviceSub) <> " associated queues"
|
||||
subscribeServiceNtfs ca srv serviceSub
|
||||
n <- subscribeLoop 0 Nothing
|
||||
logNote $ "Completed SMP resubscriptions for " <> srvStr <> " (" <> tshow n <> " subscriptions)"
|
||||
pure n
|
||||
@@ -576,7 +576,7 @@ ntfSubscriber NtfSubscriber {smpAgent = ca@SMPClientAgent {msgQ, agentQ}} =
|
||||
-- TODO [certs] resubscribe queues with statuses NSErr and NSService
|
||||
CAServiceDisconnected srv serviceSub ->
|
||||
logNote $ "SMP server service disconnected " <> showService srv serviceSub
|
||||
CAServiceSubscribed srv serviceSub@(_, expected) n
|
||||
CAServiceSubscribed srv serviceSub@(ServiceSub _ expected _) (ServiceSub _ n _) -- TODO [certs rcv] compare hash
|
||||
| expected == n -> logNote msg
|
||||
| otherwise -> logWarn $ msg <> ", confirmed subs: " <> tshow n
|
||||
where
|
||||
@@ -593,7 +593,8 @@ ntfSubscriber NtfSubscriber {smpAgent = ca@SMPClientAgent {msgQ, agentQ}} =
|
||||
void $ subscribeSrvSubs ca st batchSize (srv, srvId, Nothing)
|
||||
Left e -> logError $ "SMP server update and resubscription error " <> tshow e
|
||||
where
|
||||
showService srv (serviceId, n) = showServer' srv <> ", service ID " <> decodeLatin1 (strEncode serviceId) <> ", " <> tshow n <> " subs"
|
||||
-- TODO [certs rcv] compare hash
|
||||
showService srv (ServiceSub serviceId n _idsHash) = showServer' srv <> ", service ID " <> decodeLatin1 (strEncode serviceId) <> ", " <> tshow n <> " subs"
|
||||
|
||||
logSubErrors :: SMPServer -> NonEmpty (SMP.NotifierId, NtfSubStatus) -> Int -> IO ()
|
||||
logSubErrors srv subs updated = forM_ (L.group $ L.sort $ L.map snd subs) $ \ss -> do
|
||||
|
||||
@@ -17,6 +17,7 @@ import Simplex.Messaging.Server.Stats
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
|
||||
-- TODO [certs rcv] track service subscriptions and count/hash diffs for own and other servers + prometheus
|
||||
data NtfServerStats = NtfServerStats
|
||||
{ fromTime :: IORef UTCTime,
|
||||
tknCreated :: IORef Int,
|
||||
|
||||
@@ -6,13 +6,15 @@ module Simplex.Messaging.Notifications.Server.Store.Migrations where
|
||||
|
||||
import Data.List (sortOn)
|
||||
import Data.Text (Text)
|
||||
import Simplex.Messaging.Agent.Store.Postgres.Migrations.Util
|
||||
import Simplex.Messaging.Agent.Store.Shared
|
||||
import Text.RawString.QQ (r)
|
||||
|
||||
ntfServerSchemaMigrations :: [(String, Text, Maybe Text)]
|
||||
ntfServerSchemaMigrations =
|
||||
[ ("20250417_initial", m20250417_initial, Nothing),
|
||||
("20250517_service_cert", m20250517_service_cert, Just down_m20250517_service_cert)
|
||||
("20250517_service_cert", m20250517_service_cert, Just down_m20250517_service_cert),
|
||||
("20250830_queue_ids_hash", m20250830_queue_ids_hash, Just down_m20250830_queue_ids_hash)
|
||||
]
|
||||
|
||||
-- | The list of migrations in ascending order by date
|
||||
@@ -101,3 +103,125 @@ ALTER TABLE smp_servers DROP COLUMN ntf_service_id;
|
||||
|
||||
ALTER TABLE subscriptions DROP COLUMN ntf_service_assoc;
|
||||
|]
|
||||
|
||||
m20250830_queue_ids_hash :: Text
|
||||
m20250830_queue_ids_hash =
|
||||
createXorHashFuncs
|
||||
<> [r|
|
||||
ALTER TABLE smp_servers
|
||||
ADD COLUMN smp_notifier_count BIGINT NOT NULL DEFAULT 0,
|
||||
ADD COLUMN smp_notifier_ids_hash BYTEA NOT NULL DEFAULT '\x00000000000000000000000000000000';
|
||||
|
||||
CREATE FUNCTION should_subscribe_status(p_status TEXT) RETURNS BOOLEAN
|
||||
LANGUAGE plpgsql IMMUTABLE STRICT
|
||||
AS $$
|
||||
BEGIN
|
||||
RETURN p_status IN ('NEW', 'PENDING', 'ACTIVE', 'INACTIVE');
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE FUNCTION update_all_aggregates() RETURNS VOID
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
WITH acc AS (
|
||||
SELECT
|
||||
s.smp_server_id,
|
||||
count(smp_notifier_id) as notifier_count,
|
||||
xor_aggregate(public.digest(s.smp_notifier_id, 'md5')) AS notifier_hash
|
||||
FROM subscriptions s
|
||||
WHERE s.ntf_service_assoc = true AND should_subscribe_status(s.status)
|
||||
GROUP BY s.smp_server_id
|
||||
)
|
||||
UPDATE smp_servers srv
|
||||
SET smp_notifier_count = COALESCE(acc.notifier_count, 0),
|
||||
smp_notifier_ids_hash = COALESCE(acc.notifier_hash, '\x00000000000000000000000000000000')
|
||||
FROM acc
|
||||
WHERE srv.smp_server_id = acc.smp_server_id;
|
||||
END;
|
||||
$$;
|
||||
|
||||
SELECT update_all_aggregates();
|
||||
|
||||
CREATE FUNCTION update_aggregates(p_server_id BIGINT, p_change BIGINT, p_notifier_id BYTEA) RETURNS VOID
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
UPDATE smp_servers
|
||||
SET smp_notifier_count = smp_notifier_count + p_change,
|
||||
smp_notifier_ids_hash = xor_combine(smp_notifier_ids_hash, public.digest(p_notifier_id, 'md5'))
|
||||
WHERE smp_server_id = p_server_id;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE FUNCTION on_subscription_insert() RETURNS TRIGGER
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
IF NEW.ntf_service_assoc = true AND should_subscribe_status(NEW.status) THEN
|
||||
PERFORM update_aggregates(NEW.smp_server_id, 1, NEW.smp_notifier_id);
|
||||
END IF;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE FUNCTION on_subscription_delete() RETURNS TRIGGER
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
IF OLD.ntf_service_assoc = true AND should_subscribe_status(OLD.status) THEN
|
||||
PERFORM update_aggregates(OLD.smp_server_id, -1, OLD.smp_notifier_id);
|
||||
END IF;
|
||||
RETURN OLD;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE FUNCTION on_subscription_update() RETURNS TRIGGER
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
IF OLD.ntf_service_assoc = true AND should_subscribe_status(OLD.status) THEN
|
||||
IF NOT (NEW.ntf_service_assoc = true AND should_subscribe_status(NEW.status)) THEN
|
||||
PERFORM update_aggregates(OLD.smp_server_id, -1, OLD.smp_notifier_id);
|
||||
END IF;
|
||||
ELSIF NEW.ntf_service_assoc = true AND should_subscribe_status(NEW.status) THEN
|
||||
PERFORM update_aggregates(NEW.smp_server_id, 1, NEW.smp_notifier_id);
|
||||
END IF;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE TRIGGER tr_subscriptions_insert
|
||||
AFTER INSERT ON subscriptions
|
||||
FOR EACH ROW EXECUTE PROCEDURE on_subscription_insert();
|
||||
|
||||
CREATE TRIGGER tr_subscriptions_delete
|
||||
AFTER DELETE ON subscriptions
|
||||
FOR EACH ROW EXECUTE PROCEDURE on_subscription_delete();
|
||||
|
||||
CREATE TRIGGER tr_subscriptions_update
|
||||
AFTER UPDATE ON subscriptions
|
||||
FOR EACH ROW EXECUTE PROCEDURE on_subscription_update();
|
||||
|]
|
||||
|
||||
down_m20250830_queue_ids_hash :: Text
|
||||
down_m20250830_queue_ids_hash =
|
||||
[r|
|
||||
DROP TRIGGER tr_subscriptions_insert ON subscriptions;
|
||||
DROP TRIGGER tr_subscriptions_delete ON subscriptions;
|
||||
DROP TRIGGER tr_subscriptions_update ON subscriptions;
|
||||
|
||||
DROP FUNCTION on_subscription_insert;
|
||||
DROP FUNCTION on_subscription_delete;
|
||||
DROP FUNCTION on_subscription_update;
|
||||
|
||||
DROP FUNCTION update_aggregates;
|
||||
DROP FUNCTION update_all_aggregates;
|
||||
|
||||
DROP FUNCTION should_subscribe_status;
|
||||
|
||||
ALTER TABLE smp_servers
|
||||
DROP COLUMN smp_notifier_count,
|
||||
DROP COLUMN smp_notifier_ids_hash;
|
||||
|]
|
||||
<> dropXorHashFuncs
|
||||
|
||||
@@ -64,7 +64,7 @@ import Simplex.Messaging.Notifications.Server.Store (NtfSTMStore (..), NtfSubDat
|
||||
import Simplex.Messaging.Notifications.Server.Store.Migrations
|
||||
import Simplex.Messaging.Notifications.Server.Store.Types
|
||||
import Simplex.Messaging.Notifications.Server.StoreLog
|
||||
import Simplex.Messaging.Protocol (EntityId (..), EncNMsgMeta, ErrorType (..), NotifierId, NtfPrivateAuthKey, NtfPublicAuthKey, SMPServer, ServiceId, pattern SMPServer)
|
||||
import Simplex.Messaging.Protocol (EntityId (..), EncNMsgMeta, ErrorType (..), IdsHash (..), NotifierId, NtfPrivateAuthKey, NtfPublicAuthKey, SMPServer, ServiceId, ServiceSub (..), pattern SMPServer)
|
||||
import Simplex.Messaging.Server.QueueStore.Postgres (handleDuplicate, withLog_)
|
||||
import Simplex.Messaging.Server.QueueStore.Postgres.Config (PostgresStoreCfg (..))
|
||||
import Simplex.Messaging.Server.StoreLog (openWriteStoreLog)
|
||||
@@ -239,7 +239,7 @@ updateTknCronInterval st tknId cronInt =
|
||||
|
||||
-- Reads servers that have subscriptions that need subscribing.
|
||||
-- It is executed on server start, and it is supposed to crash on database error
|
||||
getUsedSMPServers :: NtfPostgresStore -> IO [(SMPServer, Int64, Maybe (ServiceId, Int64))]
|
||||
getUsedSMPServers :: NtfPostgresStore -> IO [(SMPServer, Int64, Maybe ServiceSub)]
|
||||
getUsedSMPServers st =
|
||||
withTransaction (dbStore st) $ \db ->
|
||||
map rowToSrvSubs <$>
|
||||
@@ -247,25 +247,17 @@ getUsedSMPServers st =
|
||||
db
|
||||
[sql|
|
||||
SELECT
|
||||
p.smp_host, p.smp_port, p.smp_keyhash, p.smp_server_id, p.ntf_service_id,
|
||||
SUM(CASE WHEN s.ntf_service_assoc THEN s.subs_count ELSE 0 END) :: BIGINT as service_subs_count
|
||||
FROM smp_servers p
|
||||
JOIN (
|
||||
SELECT
|
||||
smp_server_id,
|
||||
ntf_service_assoc,
|
||||
COUNT(1) as subs_count
|
||||
FROM subscriptions
|
||||
WHERE status IN ?
|
||||
GROUP BY smp_server_id, ntf_service_assoc
|
||||
) s ON s.smp_server_id = p.smp_server_id
|
||||
GROUP BY p.smp_host, p.smp_port, p.smp_keyhash, p.smp_server_id, p.ntf_service_id
|
||||
smp_host, smp_port, smp_keyhash, smp_server_id,
|
||||
ntf_service_id, smp_notifier_count, smp_notifier_ids_hash
|
||||
FROM smp_servers
|
||||
WHERE EXISTS (SELECT 1 FROM subscriptions WHERE status IN ?)
|
||||
|]
|
||||
(Only (In [NSNew, NSPending, NSActive, NSInactive]))
|
||||
(Only (In subscribeNtfStatuses))
|
||||
where
|
||||
rowToSrvSubs :: SMPServerRow :. (Int64, Maybe ServiceId, Int64) -> (SMPServer, Int64, Maybe (ServiceId, Int64))
|
||||
rowToSrvSubs ((host, port, kh) :. (srvId, serviceId_, subsCount)) =
|
||||
(SMPServer host port kh, srvId, (,subsCount) <$> serviceId_)
|
||||
rowToSrvSubs :: SMPServerRow :. (Int64, Maybe ServiceId, Int64, IdsHash) -> (SMPServer, Int64, Maybe ServiceSub)
|
||||
rowToSrvSubs ((host, port, kh) :. (srvId, serviceId_, n, idsHash)) =
|
||||
let service_ = (\serviceId -> ServiceSub serviceId n idsHash) <$> serviceId_
|
||||
in (SMPServer host port kh, srvId, service_)
|
||||
|
||||
getServerNtfSubscriptions :: NtfPostgresStore -> Int64 -> Maybe NtfSubscriptionId -> Int -> IO (Either ErrorType [ServerNtfSub])
|
||||
getServerNtfSubscriptions st srvId afterSubId_ count =
|
||||
@@ -273,9 +265,9 @@ getServerNtfSubscriptions st srvId afterSubId_ count =
|
||||
subs <-
|
||||
map toServerNtfSub <$> case afterSubId_ of
|
||||
Nothing ->
|
||||
DB.query db (query <> orderLimit) (srvId, statusIn, count)
|
||||
DB.query db (query <> orderLimit) (srvId, In subscribeNtfStatuses, count)
|
||||
Just afterSubId ->
|
||||
DB.query db (query <> " AND subscription_id > ?" <> orderLimit) (srvId, statusIn, afterSubId, count)
|
||||
DB.query db (query <> " AND subscription_id > ?" <> orderLimit) (srvId, In subscribeNtfStatuses, afterSubId, count)
|
||||
void $
|
||||
DB.executeMany
|
||||
db
|
||||
@@ -296,7 +288,6 @@ getServerNtfSubscriptions st srvId afterSubId_ count =
|
||||
WHERE smp_server_id = ? AND NOT ntf_service_assoc AND status IN ?
|
||||
|]
|
||||
orderLimit = " ORDER BY subscription_id LIMIT ?"
|
||||
statusIn = In [NSNew, NSPending, NSActive, NSInactive]
|
||||
toServerNtfSub (ntfSubId, notifierId, notifierKey) = (ntfSubId, (notifierId, notifierKey))
|
||||
|
||||
-- Returns token and subscription.
|
||||
|
||||
@@ -15,6 +15,123 @@ SET row_security = off;
|
||||
CREATE SCHEMA ntf_server;
|
||||
|
||||
|
||||
|
||||
CREATE FUNCTION ntf_server.on_subscription_delete() RETURNS trigger
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
IF OLD.ntf_service_assoc = true AND should_subscribe_status(OLD.status) THEN
|
||||
PERFORM update_aggregates(OLD.smp_server_id, -1, OLD.smp_notifier_id);
|
||||
END IF;
|
||||
RETURN OLD;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
|
||||
CREATE FUNCTION ntf_server.on_subscription_insert() RETURNS trigger
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
IF NEW.ntf_service_assoc = true AND should_subscribe_status(NEW.status) THEN
|
||||
PERFORM update_aggregates(NEW.smp_server_id, 1, NEW.smp_notifier_id);
|
||||
END IF;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
|
||||
CREATE FUNCTION ntf_server.on_subscription_update() RETURNS trigger
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
IF OLD.ntf_service_assoc = true AND should_subscribe_status(OLD.status) THEN
|
||||
IF NOT (NEW.ntf_service_assoc = true AND should_subscribe_status(NEW.status)) THEN
|
||||
PERFORM update_aggregates(OLD.smp_server_id, -1, OLD.smp_notifier_id);
|
||||
END IF;
|
||||
ELSIF NEW.ntf_service_assoc = true AND should_subscribe_status(NEW.status) THEN
|
||||
PERFORM update_aggregates(NEW.smp_server_id, 1, NEW.smp_notifier_id);
|
||||
END IF;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
|
||||
CREATE FUNCTION ntf_server.should_subscribe_status(p_status text) RETURNS boolean
|
||||
LANGUAGE plpgsql IMMUTABLE STRICT
|
||||
AS $$
|
||||
BEGIN
|
||||
RETURN p_status IN ('NEW', 'PENDING', 'ACTIVE', 'INACTIVE');
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
|
||||
CREATE FUNCTION ntf_server.update_aggregates(p_server_id bigint, p_change bigint, p_notifier_id bytea) RETURNS void
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
UPDATE smp_servers
|
||||
SET smp_notifier_count = smp_notifier_count + p_change,
|
||||
smp_notifier_ids_hash = xor_combine(smp_notifier_ids_hash, public.digest(p_notifier_id, 'md5'))
|
||||
WHERE smp_server_id = p_server_id;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
|
||||
CREATE FUNCTION ntf_server.update_all_aggregates() RETURNS void
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
WITH acc AS (
|
||||
SELECT
|
||||
s.smp_server_id,
|
||||
count(smp_notifier_id) as notifier_count,
|
||||
xor_aggregate(public.digest(s.smp_notifier_id, 'md5')) AS notifier_hash
|
||||
FROM subscriptions s
|
||||
WHERE s.ntf_service_assoc = true AND should_subscribe_status(s.status)
|
||||
GROUP BY s.smp_server_id
|
||||
)
|
||||
UPDATE smp_servers srv
|
||||
SET smp_notifier_count = COALESCE(acc.notifier_count, 0),
|
||||
smp_notifier_ids_hash = COALESCE(acc.notifier_hash, '\x00000000000000000000000000000000')
|
||||
FROM acc
|
||||
WHERE srv.smp_server_id = acc.smp_server_id;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
|
||||
CREATE FUNCTION ntf_server.xor_combine(state bytea, value bytea) RETURNS bytea
|
||||
LANGUAGE plpgsql IMMUTABLE STRICT
|
||||
AS $$
|
||||
DECLARE
|
||||
result BYTEA := state;
|
||||
i INTEGER;
|
||||
len INTEGER := octet_length(value);
|
||||
BEGIN
|
||||
IF octet_length(state) != len THEN
|
||||
RAISE EXCEPTION 'Inputs must be equal length (% != %)', octet_length(state), len;
|
||||
END IF;
|
||||
FOR i IN 0..len-1 LOOP
|
||||
result := set_byte(result, i, get_byte(state, i) # get_byte(value, i));
|
||||
END LOOP;
|
||||
RETURN result;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
|
||||
CREATE AGGREGATE ntf_server.xor_aggregate(bytea) (
|
||||
SFUNC = ntf_server.xor_combine,
|
||||
STYPE = bytea,
|
||||
INITCOND = '\x00000000000000000000000000000000'
|
||||
);
|
||||
|
||||
|
||||
SET default_table_access_method = heap;
|
||||
|
||||
|
||||
@@ -53,7 +170,9 @@ CREATE TABLE ntf_server.smp_servers (
|
||||
smp_host text NOT NULL,
|
||||
smp_port text NOT NULL,
|
||||
smp_keyhash bytea NOT NULL,
|
||||
ntf_service_id bytea
|
||||
ntf_service_id bytea,
|
||||
smp_notifier_count bigint DEFAULT 0 NOT NULL,
|
||||
smp_notifier_ids_hash bytea DEFAULT '\x00000000000000000000000000000000'::bytea NOT NULL
|
||||
);
|
||||
|
||||
|
||||
@@ -158,6 +277,18 @@ CREATE INDEX idx_tokens_status_cron_interval_sent_at ON ntf_server.tokens USING
|
||||
|
||||
|
||||
|
||||
CREATE TRIGGER tr_subscriptions_delete AFTER DELETE ON ntf_server.subscriptions FOR EACH ROW EXECUTE FUNCTION ntf_server.on_subscription_delete();
|
||||
|
||||
|
||||
|
||||
CREATE TRIGGER tr_subscriptions_insert AFTER INSERT ON ntf_server.subscriptions FOR EACH ROW EXECUTE FUNCTION ntf_server.on_subscription_insert();
|
||||
|
||||
|
||||
|
||||
CREATE TRIGGER tr_subscriptions_update AFTER UPDATE ON ntf_server.subscriptions FOR EACH ROW EXECUTE FUNCTION ntf_server.on_subscription_update();
|
||||
|
||||
|
||||
|
||||
ALTER TABLE ONLY ntf_server.last_notifications
|
||||
ADD CONSTRAINT last_notifications_subscription_id_fkey FOREIGN KEY (subscription_id) REFERENCES ntf_server.subscriptions(subscription_id) ON UPDATE RESTRICT ON DELETE CASCADE;
|
||||
|
||||
|
||||
@@ -140,7 +140,10 @@ module Simplex.Messaging.Protocol
|
||||
RcvMessage (..),
|
||||
MsgId,
|
||||
MsgBody,
|
||||
IdsHash,
|
||||
IdsHash (..),
|
||||
ServiceSub (..),
|
||||
queueIdsHash,
|
||||
queueIdHash,
|
||||
MaxMessageLen,
|
||||
MaxRcvMessageLen,
|
||||
EncRcvMsgBody (..),
|
||||
@@ -223,6 +226,8 @@ import qualified Data.Aeson.TH as J
|
||||
import Data.Attoparsec.ByteString.Char8 (Parser, (<?>))
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Bifunctor (bimap, first)
|
||||
import Data.Bits (xor)
|
||||
import qualified Data.ByteString as BS
|
||||
import qualified Data.ByteString.Base64 as B64
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
@@ -232,6 +237,7 @@ import Data.Constraint (Dict (..))
|
||||
import Data.Functor (($>))
|
||||
import Data.Int (Int64)
|
||||
import Data.Kind
|
||||
import Data.List (foldl')
|
||||
import Data.List.NonEmpty (NonEmpty (..))
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import Data.Maybe (isJust, isNothing)
|
||||
@@ -241,7 +247,7 @@ import qualified Data.Text as T
|
||||
import Data.Text.Encoding (decodeLatin1, encodeUtf8)
|
||||
import Data.Time.Clock.System (SystemTime (..), systemToUTCTime)
|
||||
import Data.Type.Equality
|
||||
import Data.Word (Word16)
|
||||
import Data.Word (Word8, Word16)
|
||||
import GHC.TypeLits (ErrorMessage (..), TypeError, type (+))
|
||||
import qualified GHC.TypeLits as TE
|
||||
import qualified GHC.TypeLits as Type
|
||||
@@ -548,7 +554,8 @@ data Command (p :: Party) where
|
||||
NEW :: NewQueueReq -> Command Creator
|
||||
SUB :: Command Recipient
|
||||
-- | subscribe all associated queues. Service ID must be used as entity ID, and service session key must sign the command.
|
||||
SUBS :: Command RecipientService
|
||||
-- Parameters are expected queue count and hash of all subscribed queues, it allows to monitor "state drift" on the server
|
||||
SUBS :: Int64 -> IdsHash -> Command RecipientService
|
||||
KEY :: SndPublicAuthKey -> Command Recipient
|
||||
RKEY :: NonEmpty RcvPublicAuthKey -> Command Recipient
|
||||
LSET :: LinkId -> QueueLinkData -> Command Recipient
|
||||
@@ -572,7 +579,7 @@ data Command (p :: Party) where
|
||||
-- SMP notification subscriber commands
|
||||
NSUB :: Command Notifier
|
||||
-- | subscribe all associated queues. Service ID must be used as entity ID, and service session key must sign the command.
|
||||
NSUBS :: Command NotifierService
|
||||
NSUBS :: Int64 -> IdsHash -> Command NotifierService
|
||||
PRXY :: SMPServer -> Maybe BasicAuth -> Command ProxiedClient -- request a relay server connection by URI
|
||||
-- Transmission to proxy:
|
||||
-- - entity ID: ID of the session with relay returned in PKEY (response to PRXY)
|
||||
@@ -698,7 +705,7 @@ data BrokerMsg where
|
||||
LNK :: SenderId -> QueueLinkData -> BrokerMsg
|
||||
-- | Service subscription success - confirms when queue was associated with the service
|
||||
SOK :: Maybe ServiceId -> BrokerMsg
|
||||
-- | The number of queues subscribed with SUBS command
|
||||
-- | The number of queues and XOR-hash of their IDs subscribed with SUBS command
|
||||
SOKS :: Int64 -> IdsHash -> BrokerMsg
|
||||
-- MSG v1/2 has to be supported for encoding/decoding
|
||||
-- v1: MSG :: MsgId -> SystemTime -> MsgBody -> BrokerMsg
|
||||
@@ -1460,7 +1467,42 @@ type MsgId = ByteString
|
||||
-- | SMP message body.
|
||||
type MsgBody = ByteString
|
||||
|
||||
type IdsHash = ByteString
|
||||
data ServiceSub = ServiceSub
|
||||
{ serviceId :: ServiceId,
|
||||
smpQueueCount :: Int64,
|
||||
smpQueueIdsHash :: IdsHash
|
||||
}
|
||||
|
||||
newtype IdsHash = IdsHash {unIdsHash :: BS.ByteString}
|
||||
deriving (Eq, Show)
|
||||
deriving newtype (Encoding, FromField)
|
||||
|
||||
instance ToField IdsHash where
|
||||
toField (IdsHash s) = toField (Binary s)
|
||||
{-# INLINE toField #-}
|
||||
|
||||
instance Semigroup IdsHash where
|
||||
(IdsHash s1) <> (IdsHash s2) = IdsHash $! BS.pack $ BS.zipWith xor s1 s2
|
||||
|
||||
instance Monoid IdsHash where
|
||||
mempty = IdsHash $ BS.replicate 16 0
|
||||
mconcat ss =
|
||||
let !s' = BS.pack $ foldl' (\ !r (IdsHash s) -> zipWith xor' r (BS.unpack s)) (replicate 16 0) ss -- to prevent packing/unpacking in <> on each step with default mappend
|
||||
in IdsHash s'
|
||||
|
||||
xor' :: Word8 -> Word8 -> Word8
|
||||
xor' x y = let !r = xor x y in r
|
||||
|
||||
noIdsHash ::IdsHash
|
||||
noIdsHash = IdsHash B.empty
|
||||
{-# INLINE noIdsHash #-}
|
||||
|
||||
queueIdsHash :: [QueueId] -> IdsHash
|
||||
queueIdsHash = mconcat . map queueIdHash
|
||||
|
||||
queueIdHash :: QueueId -> IdsHash
|
||||
queueIdHash = IdsHash . C.md5Hash . unEntityId
|
||||
{-# INLINE queueIdHash #-}
|
||||
|
||||
data ProtocolErrorType = PECmdSyntax | PECmdUnknown | PESession | PEBlock
|
||||
|
||||
@@ -1695,7 +1737,9 @@ instance PartyI p => ProtocolEncoding SMPVersion ErrorType (Command p) where
|
||||
new = e (NEW_, ' ', rKey, dhKey)
|
||||
auth = maybe "" (e . ('A',)) auth_
|
||||
SUB -> e SUB_
|
||||
SUBS -> e SUBS_
|
||||
SUBS n idsHash
|
||||
| v >= rcvServiceSMPVersion -> e (SUBS_, ' ', n, idsHash)
|
||||
| otherwise -> e SUBS_
|
||||
KEY k -> e (KEY_, ' ', k)
|
||||
RKEY ks -> e (RKEY_, ' ', ks)
|
||||
LSET lnkId d -> e (LSET_, ' ', lnkId, d)
|
||||
@@ -1711,7 +1755,9 @@ instance PartyI p => ProtocolEncoding SMPVersion ErrorType (Command p) where
|
||||
SEND flags msg -> e (SEND_, ' ', flags, ' ', Tail msg)
|
||||
PING -> e PING_
|
||||
NSUB -> e NSUB_
|
||||
NSUBS -> e NSUBS_
|
||||
NSUBS n idsHash
|
||||
| v >= rcvServiceSMPVersion -> e (NSUBS_, ' ', n, idsHash)
|
||||
| otherwise -> e NSUBS_
|
||||
LKEY k -> e (LKEY_, ' ', k)
|
||||
LGET -> e LGET_
|
||||
PRXY host auth_ -> e (PRXY_, ' ', host, auth_)
|
||||
@@ -1802,7 +1848,9 @@ instance ProtocolEncoding SMPVersion ErrorType Cmd where
|
||||
OFF_ -> pure OFF
|
||||
DEL_ -> pure DEL
|
||||
QUE_ -> pure QUE
|
||||
CT SRecipientService SUBS_ -> pure $ Cmd SRecipientService SUBS
|
||||
CT SRecipientService SUBS_
|
||||
| v >= rcvServiceSMPVersion -> Cmd SRecipientService <$> (SUBS <$> _smpP <*> smpP)
|
||||
| otherwise -> pure $ Cmd SRecipientService $ SUBS (-1) noIdsHash
|
||||
CT SSender tag ->
|
||||
Cmd SSender <$> case tag of
|
||||
SKEY_ -> SKEY <$> _smpP
|
||||
@@ -1819,7 +1867,9 @@ instance ProtocolEncoding SMPVersion ErrorType Cmd where
|
||||
PFWD_ -> PFWD <$> _smpP <*> smpP <*> (EncTransmission . unTail <$> smpP)
|
||||
PRXY_ -> PRXY <$> _smpP <*> smpP
|
||||
CT SNotifier NSUB_ -> pure $ Cmd SNotifier NSUB
|
||||
CT SNotifierService NSUBS_ -> pure $ Cmd SNotifierService NSUBS
|
||||
CT SNotifierService NSUBS_
|
||||
| v >= rcvServiceSMPVersion -> Cmd SNotifierService <$> (NSUBS <$> _smpP <*> smpP)
|
||||
| otherwise -> pure $ Cmd SNotifierService $ NSUBS (-1) noIdsHash
|
||||
|
||||
fromProtocolError = fromProtocolError @SMPVersion @ErrorType @BrokerMsg
|
||||
{-# INLINE fromProtocolError #-}
|
||||
@@ -1901,7 +1951,7 @@ instance ProtocolEncoding SMPVersion ErrorType BrokerMsg where
|
||||
SOK_ -> SOK <$> _smpP
|
||||
SOKS_
|
||||
| v >= rcvServiceSMPVersion -> SOKS <$> _smpP <*> smpP
|
||||
| otherwise -> SOKS <$> _smpP <*> pure B.empty
|
||||
| otherwise -> SOKS <$> _smpP <*> pure noIdsHash
|
||||
NID_ -> NID <$> _smpP <*> smpP
|
||||
NMSG_ -> NMSG <$> _smpP <*> smpP
|
||||
PKEY_ -> PKEY <$> _smpP <*> smpP <*> smpP
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE MultiWayIf #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE NumericUnderscores #-}
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
@@ -1247,7 +1248,7 @@ verifyQueueTransmission service thAuth (tAuth, authorized, (corrId, entId, comma
|
||||
vc SCreator (NEW NewQueueReq {rcvAuthKey = k}) = verifiedWith k
|
||||
vc SRecipient SUB = verifyQueue $ \q -> verifiedWithKeys $ recipientKeys (snd q)
|
||||
vc SRecipient _ = verifyQueue $ \q -> verifiedWithKeys $ recipientKeys (snd q)
|
||||
vc SRecipientService SUBS = verifyServiceCmd
|
||||
vc SRecipientService SUBS {} = verifyServiceCmd
|
||||
vc SSender (SKEY k) = verifySecure k
|
||||
-- SEND will be accepted without authorization before the queue is secured with KEY, SKEY or LSKEY command
|
||||
vc SSender SEND {} = verifyQueue $ \q -> if maybe (isNothing tAuth) verify (senderKey $ snd q) then VRVerified q_ else VRFailed AUTH
|
||||
@@ -1255,7 +1256,7 @@ verifyQueueTransmission service thAuth (tAuth, authorized, (corrId, entId, comma
|
||||
vc SSenderLink (LKEY k) = verifySecure k
|
||||
vc SSenderLink LGET = verifyQueue $ \q -> if isContactQueue (snd q) then VRVerified q_ else VRFailed AUTH
|
||||
vc SNotifier NSUB = verifyQueue $ \q -> maybe dummyVerify (\n -> verifiedWith $ notifierKey n) (notifier $ snd q)
|
||||
vc SNotifierService NSUBS = verifyServiceCmd
|
||||
vc SNotifierService NSUBS {} = verifyServiceCmd
|
||||
vc SProxiedClient _ = VRVerified Nothing
|
||||
vc SProxyService (RFWD _) = VRVerified Nothing
|
||||
checkRole = case (service, partyClientRole p) of
|
||||
@@ -1465,8 +1466,8 @@ client
|
||||
Cmd SNotifier NSUB -> response . (corrId,entId,) <$> case q_ of
|
||||
Just (q, QueueRec {notifier = Just ntfCreds}) -> subscribeNotifications q ntfCreds
|
||||
_ -> pure $ ERR INTERNAL
|
||||
Cmd SNotifierService NSUBS -> response . (corrId,entId,) <$> case clntServiceId of
|
||||
Just serviceId -> subscribeServiceNotifications serviceId
|
||||
Cmd SNotifierService (NSUBS n idsHash) -> response . (corrId,entId,) <$> case clntServiceId of
|
||||
Just serviceId -> subscribeServiceNotifications serviceId (n, idsHash)
|
||||
Nothing -> pure $ ERR INTERNAL
|
||||
Cmd SCreator (NEW nqr@NewQueueReq {auth_}) ->
|
||||
response <$> ifM allowNew (createQueue nqr) (pure (corrId, entId, ERR AUTH))
|
||||
@@ -1495,8 +1496,8 @@ client
|
||||
OFF -> response <$> maybe (pure $ err INTERNAL) suspendQueue_ q_
|
||||
DEL -> response <$> maybe (pure $ err INTERNAL) delQueueAndMsgs q_
|
||||
QUE -> withQueue $ \q qr -> (corrId,entId,) <$> getQueueInfo q qr
|
||||
Cmd SRecipientService SUBS -> response . (corrId,entId,) <$> case clntServiceId of
|
||||
Just serviceId -> subscribeServiceMessages serviceId
|
||||
Cmd SRecipientService (SUBS n idsHash)-> response . (corrId,entId,) <$> case clntServiceId of
|
||||
Just serviceId -> subscribeServiceMessages serviceId (n, idsHash)
|
||||
Nothing -> pure $ ERR INTERNAL -- it's "internal" because it should never get to this branch
|
||||
where
|
||||
createQueue :: NewQueueReq -> M s (Transmission BrokerMsg)
|
||||
@@ -1795,9 +1796,9 @@ client
|
||||
TM.insert entId sub $ clientSubs clnt
|
||||
pure (False, Just sub)
|
||||
|
||||
subscribeServiceMessages :: ServiceId -> M s BrokerMsg
|
||||
subscribeServiceMessages serviceId =
|
||||
sharedSubscribeService SRecipientService serviceId subscribers serviceSubscribed serviceSubsCount >>= \case
|
||||
subscribeServiceMessages :: ServiceId -> (Int64, IdsHash) -> M s BrokerMsg
|
||||
subscribeServiceMessages serviceId expected =
|
||||
sharedSubscribeService SRecipientService serviceId expected subscribers serviceSubscribed serviceSubsCount rcvServices >>= \case
|
||||
Left e -> pure $ ERR e
|
||||
Right (hasSub, (count, idsHash)) -> do
|
||||
unless hasSub $ forkClient clnt "deliverServiceMessages" $ liftIO $ deliverServiceMessages count
|
||||
@@ -1806,7 +1807,7 @@ client
|
||||
deliverServiceMessages expectedCnt = do
|
||||
(qCnt, _msgCnt, _dupCnt, _errCnt) <- foldRcvServiceMessages ms serviceId deliverQueueMsg (0, 0, 0, 0)
|
||||
atomically $ writeTBQueue msgQ [(NoCorrId, NoEntity, SALL)]
|
||||
-- TODO [cert rcv] compare with expected
|
||||
-- TODO [certs rcv] compare with expected
|
||||
logNote $ "Service subscriptions for " <> tshow serviceId <> " (" <> tshow qCnt <> " queues)"
|
||||
deliverQueueMsg :: (Int, Int, Int, Int) -> RecipientId -> Either ErrorType (Maybe (QueueRec, Message)) -> IO (Int, Int, Int, Int)
|
||||
deliverQueueMsg (!qCnt, !msgCnt, !dupCnt, !errCnt) rId = \case
|
||||
@@ -1831,25 +1832,33 @@ client
|
||||
TM.insert rId sub $ subscriptions clnt
|
||||
pure $ Just sub
|
||||
|
||||
subscribeServiceNotifications :: ServiceId -> M s BrokerMsg
|
||||
subscribeServiceNotifications serviceId =
|
||||
either ERR (uncurry SOKS . snd) <$> sharedSubscribeService SNotifierService serviceId ntfSubscribers ntfServiceSubscribed ntfServiceSubsCount
|
||||
subscribeServiceNotifications :: ServiceId -> (Int64, IdsHash) -> M s BrokerMsg
|
||||
subscribeServiceNotifications serviceId expected =
|
||||
either ERR (uncurry SOKS . snd) <$> sharedSubscribeService SNotifierService serviceId expected ntfSubscribers ntfServiceSubscribed ntfServiceSubsCount ntfServices
|
||||
|
||||
sharedSubscribeService :: (PartyI p, ServiceParty p) => SParty p -> ServiceId -> ServerSubscribers s -> (Client s -> TVar Bool) -> (Client s -> TVar Int64) -> M s (Either ErrorType (Bool, (Int64, IdsHash)))
|
||||
sharedSubscribeService party serviceId srvSubscribers clientServiceSubscribed clientServiceSubs = do
|
||||
sharedSubscribeService :: (PartyI p, ServiceParty p) => SParty p -> ServiceId -> (Int64, IdsHash) -> ServerSubscribers s -> (Client s -> TVar Bool) -> (Client s -> TVar Int64) -> (ServerStats -> ServiceStats) -> M s (Either ErrorType (Bool, (Int64, IdsHash)))
|
||||
sharedSubscribeService party serviceId (count, idsHash) srvSubscribers clientServiceSubscribed clientServiceSubs servicesSel = do
|
||||
subscribed <- readTVarIO $ clientServiceSubscribed clnt
|
||||
stats <- asks serverStats
|
||||
liftIO $ runExceptT $
|
||||
(subscribed,)
|
||||
<$> if subscribed
|
||||
then (,B.empty) <$> readTVarIO (clientServiceSubs clnt) -- TODO [certs rcv] get IDs hash
|
||||
then (,mempty) <$> readTVarIO (clientServiceSubs clnt) -- TODO [certs rcv] get IDs hash
|
||||
else do
|
||||
count' <- ExceptT $ getServiceQueueCount @(StoreQueue s) (queueStore ms) party serviceId
|
||||
(count', idsHash') <- ExceptT $ getServiceQueueCountHash @(StoreQueue s) (queueStore ms) party serviceId
|
||||
incCount <- atomically $ do
|
||||
writeTVar (clientServiceSubscribed clnt) True
|
||||
count <- swapTVar (clientServiceSubs clnt) count'
|
||||
pure $ count' - count
|
||||
currCount <- swapTVar (clientServiceSubs clnt) count' -- TODO [certs rcv] maintain IDs hash here?
|
||||
pure $ count' - currCount
|
||||
let incSrvStat sel n = liftIO $ atomicModifyIORef'_ (sel $ servicesSel stats) (+ n)
|
||||
diff = fromIntegral $ count' - count
|
||||
if -- TODO [certs rcv] account for not provided counts/hashes (expected n = -1)
|
||||
| diff == 0 && idsHash == idsHash' -> incSrvStat srvSubOk 1
|
||||
| diff > 0 -> incSrvStat srvSubMore 1 >> incSrvStat srvSubMoreTotal diff
|
||||
| diff < 0 -> incSrvStat srvSubFewer 1 >> incSrvStat srvSubFewerTotal (- diff)
|
||||
| otherwise -> incSrvStat srvSubDiff 1
|
||||
atomically $ writeTQueue (subQ srvSubscribers) (CSService serviceId incCount, clientId)
|
||||
pure (count', B.empty) -- TODO [certs rcv] get IDs hash
|
||||
pure (count', idsHash')
|
||||
|
||||
acknowledgeMsg :: MsgId -> StoreQueue s -> QueueRec -> M s (Transmission BrokerMsg)
|
||||
acknowledgeMsg msgId q qr =
|
||||
|
||||
@@ -355,8 +355,8 @@ instance QueueStoreClass (JournalQueue s) (QStore s) where
|
||||
{-# INLINE setQueueService #-}
|
||||
getQueueNtfServices = withQS (getQueueNtfServices @(JournalQueue s))
|
||||
{-# INLINE getQueueNtfServices #-}
|
||||
getServiceQueueCount = withQS (getServiceQueueCount @(JournalQueue s))
|
||||
{-# INLINE getServiceQueueCount #-}
|
||||
getServiceQueueCountHash = withQS (getServiceQueueCountHash @(JournalQueue s))
|
||||
{-# INLINE getServiceQueueCountHash #-}
|
||||
|
||||
makeQueue_ :: JournalMsgStore s -> RecipientId -> QueueRec -> Lock -> IO (JournalQueue s)
|
||||
makeQueue_ JournalMsgStore {sharedLock} rId qr queueLock = do
|
||||
|
||||
@@ -21,6 +21,7 @@ import Simplex.Messaging.Transport (simplexMQVersion)
|
||||
import Simplex.Messaging.Transport.Server (SocketStats (..))
|
||||
import Simplex.Messaging.Util (tshow)
|
||||
|
||||
-- TODO [certs rcv] add service subscriptions and count/hash diffs
|
||||
data ServerMetrics = ServerMetrics
|
||||
{ statsData :: ServerStatsData,
|
||||
activeQueueCounts :: PeriodStatCounts,
|
||||
|
||||
@@ -65,6 +65,7 @@ data ServiceRec = ServiceRec
|
||||
serviceCert :: X.CertificateChain,
|
||||
serviceCertHash :: XV.Fingerprint, -- SHA512 hash of long-term service client certificate. See comment for ClientHandshake.
|
||||
serviceCreatedAt :: SystemDate
|
||||
-- entitiesHash :: IdsHash -- a xor-hash of all associated entities
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
|
||||
@@ -524,15 +524,11 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where
|
||||
let (sNtfs, restNtfs) = partition (\(nId, _) -> S.member nId snIds) ntfs'
|
||||
in ((serviceId, sNtfs) : ssNtfs, restNtfs)
|
||||
|
||||
getServiceQueueCount :: (PartyI p, ServiceParty p) => PostgresQueueStore q -> SParty p -> ServiceId -> IO (Either ErrorType Int64)
|
||||
getServiceQueueCount st party serviceId =
|
||||
E.uninterruptibleMask_ $ runExceptT $ withDB' "getServiceQueueCount" st $ \db ->
|
||||
maybeFirstRow' 0 fromOnly $
|
||||
DB.query db query (Only serviceId)
|
||||
where
|
||||
query = case party of
|
||||
SRecipientService -> "SELECT count(1) FROM msg_queues WHERE rcv_service_id = ? AND deleted_at IS NULL"
|
||||
SNotifierService -> "SELECT count(1) FROM msg_queues WHERE ntf_service_id = ? AND deleted_at IS NULL"
|
||||
getServiceQueueCountHash :: (PartyI p, ServiceParty p) => PostgresQueueStore q -> SParty p -> ServiceId -> IO (Either ErrorType (Int64, IdsHash))
|
||||
getServiceQueueCountHash st party serviceId =
|
||||
E.uninterruptibleMask_ $ runExceptT $ withDB' "getServiceQueueCountHash" st $ \db ->
|
||||
maybeFirstRow' (0, mempty) id $
|
||||
DB.query db ("SELECT queue_count, queue_ids_hash FROM services WHERE service_id = ? AND service_role = ?") (serviceId, partyServiceRole party)
|
||||
|
||||
batchInsertServices :: [STMService] -> PostgresQueueStore q -> IO Int64
|
||||
batchInsertServices services' toStore =
|
||||
@@ -793,6 +789,10 @@ instance ToField C.APublicAuthKey where toField = toField . Binary . C.encodePub
|
||||
|
||||
instance FromField C.APublicAuthKey where fromField = blobFieldDecoder C.decodePubKey
|
||||
|
||||
instance ToField IdsHash where toField (IdsHash s) = toField (Binary s)
|
||||
|
||||
deriving newtype instance FromField IdsHash
|
||||
|
||||
instance ToField EncDataBytes where toField (EncDataBytes s) = toField (Binary s)
|
||||
|
||||
deriving newtype instance FromField EncDataBytes
|
||||
|
||||
@@ -7,6 +7,7 @@ module Simplex.Messaging.Server.QueueStore.Postgres.Migrations where
|
||||
import Data.List (sortOn)
|
||||
import Data.Text (Text)
|
||||
import Simplex.Messaging.Agent.Store.Shared
|
||||
import Simplex.Messaging.Agent.Store.Postgres.Migrations.Util
|
||||
import Text.RawString.QQ (r)
|
||||
|
||||
serverSchemaMigrations :: [(String, Text, Maybe Text)]
|
||||
@@ -15,7 +16,8 @@ serverSchemaMigrations =
|
||||
("20250319_updated_index", m20250319_updated_index, Just down_m20250319_updated_index),
|
||||
("20250320_short_links", m20250320_short_links, Just down_m20250320_short_links),
|
||||
("20250514_service_certs", m20250514_service_certs, Just down_m20250514_service_certs),
|
||||
("20250903_store_messages", m20250903_store_messages, Just down_m20250903_store_messages)
|
||||
("20250903_store_messages", m20250903_store_messages, Just down_m20250903_store_messages),
|
||||
("20250915_queue_ids_hash", m20250915_queue_ids_hash, Just down_m20250915_queue_ids_hash)
|
||||
]
|
||||
|
||||
-- | The list of migrations in ascending order by date
|
||||
@@ -447,3 +449,139 @@ ALTER TABLE msg_queues
|
||||
|
||||
DROP TABLE messages;
|
||||
|]
|
||||
|
||||
m20250915_queue_ids_hash :: Text
|
||||
m20250915_queue_ids_hash =
|
||||
createXorHashFuncs
|
||||
<> [r|
|
||||
ALTER TABLE services
|
||||
ADD COLUMN queue_count BIGINT NOT NULL DEFAULT 0,
|
||||
ADD COLUMN queue_ids_hash BYTEA NOT NULL DEFAULT '\x00000000000000000000000000000000';
|
||||
|
||||
CREATE FUNCTION update_all_aggregates() RETURNS VOID
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
WITH acc AS (
|
||||
SELECT
|
||||
s.service_id,
|
||||
count(1) as q_count,
|
||||
xor_aggregate(public.digest(CASE WHEN s.service_role = 'M' THEN q.recipient_id ELSE COALESCE(q.notifier_id, '\x00000000000000000000000000000000') END, 'md5')) AS q_ids_hash
|
||||
FROM services s
|
||||
JOIN msg_queues q ON (s.service_id = q.rcv_service_id AND s.service_role = 'M') OR (s.service_id = q.ntf_service_id AND s.service_role = 'N')
|
||||
WHERE q.deleted_at IS NULL
|
||||
GROUP BY s.service_id
|
||||
)
|
||||
UPDATE services s
|
||||
SET queue_count = COALESCE(acc.q_count, 0),
|
||||
queue_ids_hash = COALESCE(acc.q_ids_hash, '\x00000000000000000000000000000000')
|
||||
FROM acc
|
||||
WHERE s.service_id = acc.service_id;
|
||||
END;
|
||||
$$;
|
||||
|
||||
SELECT update_all_aggregates();
|
||||
|
||||
CREATE FUNCTION update_aggregates(p_service_id BYTEA, p_role TEXT, p_queue_id BYTEA, p_change BIGINT) RETURNS VOID
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
UPDATE services
|
||||
SET queue_count = queue_count + p_change,
|
||||
queue_ids_hash = xor_combine(queue_ids_hash, public.digest(p_queue_id, 'md5'))
|
||||
WHERE service_id = p_service_id AND service_role = p_role;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE FUNCTION on_queue_insert() RETURNS TRIGGER
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
IF NEW.rcv_service_id IS NOT NULL THEN
|
||||
PERFORM update_aggregates(NEW.rcv_service_id, 'M', NEW.recipient_id, 1);
|
||||
END IF;
|
||||
IF NEW.ntf_service_id IS NOT NULL AND NEW.notifier_id IS NOT NULL THEN
|
||||
PERFORM update_aggregates(NEW.ntf_service_id, 'N', NEW.notifier_id, 1);
|
||||
END IF;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE FUNCTION on_queue_delete() RETURNS TRIGGER
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
IF OLD.deleted_at IS NULL THEN
|
||||
IF OLD.rcv_service_id IS NOT NULL THEN
|
||||
PERFORM update_aggregates(OLD.rcv_service_id, 'M', OLD.recipient_id, -1);
|
||||
END IF;
|
||||
IF OLD.ntf_service_id IS NOT NULL AND OLD.notifier_id IS NOT NULL THEN
|
||||
PERFORM update_aggregates(OLD.ntf_service_id, 'N', OLD.notifier_id, -1);
|
||||
END IF;
|
||||
END IF;
|
||||
RETURN OLD;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE FUNCTION on_queue_update() RETURNS TRIGGER
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
IF OLD.deleted_at IS NULL AND OLD.rcv_service_id IS NOT NULL THEN
|
||||
IF NOT (NEW.deleted_at IS NULL AND NEW.rcv_service_id IS NOT NULL) THEN
|
||||
PERFORM update_aggregates(OLD.rcv_service_id, 'M', OLD.recipient_id, -1);
|
||||
ELSIF OLD.rcv_service_id IS DISTINCT FROM NEW.rcv_service_id THEN
|
||||
PERFORM update_aggregates(OLD.rcv_service_id, 'M', OLD.recipient_id, -1);
|
||||
PERFORM update_aggregates(NEW.rcv_service_id, 'M', NEW.recipient_id, 1);
|
||||
END IF;
|
||||
ELSIF NEW.deleted_at IS NULL AND NEW.rcv_service_id IS NOT NULL THEN
|
||||
PERFORM update_aggregates(NEW.rcv_service_id, 'M', NEW.recipient_id, 1);
|
||||
END IF;
|
||||
|
||||
IF OLD.deleted_at IS NULL AND OLD.ntf_service_id IS NOT NULL AND OLD.notifier_id IS NOT NULL THEN
|
||||
IF NOT (NEW.deleted_at IS NULL AND NEW.ntf_service_id IS NOT NULL AND NEW.notifier_id IS NOT NULL) THEN
|
||||
PERFORM update_aggregates(OLD.ntf_service_id, 'N', OLD.notifier_id, -1);
|
||||
ELSIF OLD.ntf_service_id IS DISTINCT FROM NEW.ntf_service_id OR OLD.notifier_id IS DISTINCT FROM NEW.notifier_id THEN
|
||||
PERFORM update_aggregates(OLD.ntf_service_id, 'N', OLD.notifier_id, -1);
|
||||
PERFORM update_aggregates(NEW.ntf_service_id, 'N', NEW.notifier_id, 1);
|
||||
END IF;
|
||||
ELSIF NEW.deleted_at IS NULL AND NEW.ntf_service_id IS NOT NULL AND NEW.notifier_id IS NOT NULL THEN
|
||||
PERFORM update_aggregates(NEW.ntf_service_id, 'N', NEW.notifier_id, 1);
|
||||
END IF;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE TRIGGER tr_queue_insert
|
||||
AFTER INSERT ON msg_queues
|
||||
FOR EACH ROW EXECUTE PROCEDURE on_queue_insert();
|
||||
|
||||
CREATE TRIGGER tr_queue_delete
|
||||
AFTER DELETE ON msg_queues
|
||||
FOR EACH ROW EXECUTE PROCEDURE on_queue_delete();
|
||||
|
||||
CREATE TRIGGER tr_queue_update
|
||||
AFTER UPDATE ON msg_queues
|
||||
FOR EACH ROW EXECUTE PROCEDURE on_queue_update();
|
||||
|]
|
||||
|
||||
down_m20250915_queue_ids_hash :: Text
|
||||
down_m20250915_queue_ids_hash =
|
||||
[r|
|
||||
DROP TRIGGER tr_queue_insert ON msg_queues;
|
||||
DROP TRIGGER tr_queue_delete ON msg_queues;
|
||||
DROP TRIGGER tr_queue_update ON msg_queues;
|
||||
|
||||
DROP FUNCTION on_queue_insert;
|
||||
DROP FUNCTION on_queue_delete;
|
||||
DROP FUNCTION on_queue_update;
|
||||
|
||||
DROP FUNCTION update_aggregates;
|
||||
|
||||
DROP FUNCTION update_all_aggregates;
|
||||
|
||||
ALTER TABLE services
|
||||
DROP COLUMN queue_count,
|
||||
DROP COLUMN queue_ids_hash;
|
||||
|]
|
||||
<> dropXorHashFuncs
|
||||
|
||||
@@ -104,6 +104,71 @@ $$;
|
||||
|
||||
|
||||
|
||||
CREATE FUNCTION smp_server.on_queue_delete() RETURNS trigger
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
IF OLD.deleted_at IS NULL THEN
|
||||
IF OLD.rcv_service_id IS NOT NULL THEN
|
||||
PERFORM update_aggregates(OLD.rcv_service_id, 'M', OLD.recipient_id, -1);
|
||||
END IF;
|
||||
IF OLD.ntf_service_id IS NOT NULL AND OLD.notifier_id IS NOT NULL THEN
|
||||
PERFORM update_aggregates(OLD.ntf_service_id, 'N', OLD.notifier_id, -1);
|
||||
END IF;
|
||||
END IF;
|
||||
RETURN OLD;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
|
||||
CREATE FUNCTION smp_server.on_queue_insert() RETURNS trigger
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
IF NEW.rcv_service_id IS NOT NULL THEN
|
||||
PERFORM update_aggregates(NEW.rcv_service_id, 'M', NEW.recipient_id, 1);
|
||||
END IF;
|
||||
IF NEW.ntf_service_id IS NOT NULL AND NEW.notifier_id IS NOT NULL THEN
|
||||
PERFORM update_aggregates(NEW.ntf_service_id, 'N', NEW.notifier_id, 1);
|
||||
END IF;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
|
||||
CREATE FUNCTION smp_server.on_queue_update() RETURNS trigger
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
IF OLD.deleted_at IS NULL AND OLD.rcv_service_id IS NOT NULL THEN
|
||||
IF NOT (NEW.deleted_at IS NULL AND NEW.rcv_service_id IS NOT NULL) THEN
|
||||
PERFORM update_aggregates(OLD.rcv_service_id, 'M', OLD.recipient_id, -1);
|
||||
ELSIF OLD.rcv_service_id IS DISTINCT FROM NEW.rcv_service_id THEN
|
||||
PERFORM update_aggregates(OLD.rcv_service_id, 'M', OLD.recipient_id, -1);
|
||||
PERFORM update_aggregates(NEW.rcv_service_id, 'M', NEW.recipient_id, 1);
|
||||
END IF;
|
||||
ELSIF NEW.deleted_at IS NULL AND NEW.rcv_service_id IS NOT NULL THEN
|
||||
PERFORM update_aggregates(NEW.rcv_service_id, 'M', NEW.recipient_id, 1);
|
||||
END IF;
|
||||
|
||||
IF OLD.deleted_at IS NULL AND OLD.ntf_service_id IS NOT NULL AND OLD.notifier_id IS NOT NULL THEN
|
||||
IF NOT (NEW.deleted_at IS NULL AND NEW.ntf_service_id IS NOT NULL AND NEW.notifier_id IS NOT NULL) THEN
|
||||
PERFORM update_aggregates(OLD.ntf_service_id, 'N', OLD.notifier_id, -1);
|
||||
ELSIF OLD.ntf_service_id IS DISTINCT FROM NEW.ntf_service_id OR OLD.notifier_id IS DISTINCT FROM NEW.notifier_id THEN
|
||||
PERFORM update_aggregates(OLD.ntf_service_id, 'N', OLD.notifier_id, -1);
|
||||
PERFORM update_aggregates(NEW.ntf_service_id, 'N', NEW.notifier_id, 1);
|
||||
END IF;
|
||||
ELSIF NEW.deleted_at IS NULL AND NEW.ntf_service_id IS NOT NULL AND NEW.notifier_id IS NOT NULL THEN
|
||||
PERFORM update_aggregates(NEW.ntf_service_id, 'N', NEW.notifier_id, 1);
|
||||
END IF;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
|
||||
CREATE FUNCTION smp_server.try_del_msg(p_recipient_id bytea, p_msg_id bytea) RETURNS TABLE(r_msg_id bytea, r_msg_ts bigint, r_msg_quota boolean, r_msg_ntf_flag boolean, r_msg_body bytea)
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
@@ -225,6 +290,43 @@ $$;
|
||||
|
||||
|
||||
|
||||
CREATE FUNCTION smp_server.update_aggregates(p_service_id bytea, p_role text, p_queue_id bytea, p_change bigint) RETURNS void
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
UPDATE services
|
||||
SET queue_count = queue_count + p_change,
|
||||
queue_ids_hash = xor_combine(queue_ids_hash, public.digest(p_queue_id, 'md5'))
|
||||
WHERE service_id = p_service_id AND service_role = p_role;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
|
||||
CREATE FUNCTION smp_server.update_all_aggregates() RETURNS void
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
WITH acc AS (
|
||||
SELECT
|
||||
s.service_id,
|
||||
count(1) as q_count,
|
||||
xor_aggregate(public.digest(CASE WHEN s.service_role = 'M' THEN q.recipient_id ELSE COALESCE(q.notifier_id, '\x00000000000000000000000000000000') END, 'md5')) AS q_ids_hash
|
||||
FROM services s
|
||||
JOIN msg_queues q ON (s.service_id = q.rcv_service_id AND s.service_role = 'M') OR (s.service_id = q.ntf_service_id AND s.service_role = 'N')
|
||||
WHERE q.deleted_at IS NULL
|
||||
GROUP BY s.service_id
|
||||
)
|
||||
UPDATE services s
|
||||
SET queue_count = COALESCE(acc.q_count, 0),
|
||||
queue_ids_hash = COALESCE(acc.q_ids_hash, '\x00000000000000000000000000000000')
|
||||
FROM acc
|
||||
WHERE s.service_id = acc.service_id;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
|
||||
CREATE FUNCTION smp_server.write_message(p_recipient_id bytea, p_msg_id bytea, p_msg_ts bigint, p_msg_quota boolean, p_msg_ntf_flag boolean, p_msg_body bytea, p_quota integer) RETURNS TABLE(quota_written boolean, was_empty boolean)
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
@@ -256,6 +358,34 @@ END;
|
||||
$$;
|
||||
|
||||
|
||||
|
||||
CREATE FUNCTION smp_server.xor_combine(state bytea, value bytea) RETURNS bytea
|
||||
LANGUAGE plpgsql IMMUTABLE STRICT
|
||||
AS $$
|
||||
DECLARE
|
||||
result BYTEA := state;
|
||||
i INTEGER;
|
||||
len INTEGER := octet_length(value);
|
||||
BEGIN
|
||||
IF octet_length(state) != len THEN
|
||||
RAISE EXCEPTION 'Inputs must be equal length (% != %)', octet_length(state), len;
|
||||
END IF;
|
||||
FOR i IN 0..len-1 LOOP
|
||||
result := set_byte(result, i, get_byte(state, i) # get_byte(value, i));
|
||||
END LOOP;
|
||||
RETURN result;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
|
||||
CREATE AGGREGATE smp_server.xor_aggregate(bytea) (
|
||||
SFUNC = smp_server.xor_combine,
|
||||
STYPE = bytea,
|
||||
INITCOND = '\x00000000000000000000000000000000'
|
||||
);
|
||||
|
||||
|
||||
SET default_table_access_method = heap;
|
||||
|
||||
|
||||
@@ -320,7 +450,9 @@ CREATE TABLE smp_server.services (
|
||||
service_role text NOT NULL,
|
||||
service_cert bytea NOT NULL,
|
||||
service_cert_hash bytea NOT NULL,
|
||||
created_at bigint NOT NULL
|
||||
created_at bigint NOT NULL,
|
||||
queue_count bigint DEFAULT 0 NOT NULL,
|
||||
queue_ids_hash bytea DEFAULT '\x00000000000000000000000000000000'::bytea NOT NULL
|
||||
);
|
||||
|
||||
|
||||
@@ -390,6 +522,18 @@ CREATE INDEX idx_services_service_role ON smp_server.services USING btree (servi
|
||||
|
||||
|
||||
|
||||
CREATE TRIGGER tr_queue_delete AFTER DELETE ON smp_server.msg_queues FOR EACH ROW EXECUTE FUNCTION smp_server.on_queue_delete();
|
||||
|
||||
|
||||
|
||||
CREATE TRIGGER tr_queue_insert AFTER INSERT ON smp_server.msg_queues FOR EACH ROW EXECUTE FUNCTION smp_server.on_queue_insert();
|
||||
|
||||
|
||||
|
||||
CREATE TRIGGER tr_queue_update AFTER UPDATE ON smp_server.msg_queues FOR EACH ROW EXECUTE FUNCTION smp_server.on_queue_update();
|
||||
|
||||
|
||||
|
||||
ALTER TABLE ONLY smp_server.messages
|
||||
ADD CONSTRAINT messages_recipient_id_fkey FOREIGN KEY (recipient_id) REFERENCES smp_server.msg_queues(recipient_id) ON UPDATE RESTRICT ON DELETE CASCADE;
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ where
|
||||
import qualified Control.Exception as E
|
||||
import Control.Logger.Simple
|
||||
import Control.Monad
|
||||
import Data.Bifunctor (first)
|
||||
import Data.Bitraversable (bimapM)
|
||||
import Data.Functor (($>))
|
||||
import Data.Int (Int64)
|
||||
@@ -62,8 +63,8 @@ data STMQueueStore q = STMQueueStore
|
||||
|
||||
data STMService = STMService
|
||||
{ serviceRec :: ServiceRec,
|
||||
serviceRcvQueues :: TVar (Set RecipientId),
|
||||
serviceNtfQueues :: TVar (Set NotifierId)
|
||||
serviceRcvQueues :: TVar (Set RecipientId, IdsHash), -- TODO [certs rcv] get/maintain hash
|
||||
serviceNtfQueues :: TVar (Set NotifierId, IdsHash) -- TODO [certs rcv] get/maintain hash
|
||||
}
|
||||
|
||||
setStoreLog :: STMQueueStore q -> StoreLog 'WriteMode -> IO ()
|
||||
@@ -113,7 +114,7 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where
|
||||
}
|
||||
where
|
||||
serviceCount role = M.foldl' (\ !n s -> if serviceRole (serviceRec s) == role then n + 1 else n) 0
|
||||
serviceQueuesCount serviceSel = foldM (\n s -> (n +) . S.size <$> readTVarIO (serviceSel s)) 0
|
||||
serviceQueuesCount serviceSel = foldM (\n s -> (n +) . S.size . fst <$> readTVarIO (serviceSel s)) 0
|
||||
|
||||
addQueue_ :: STMQueueStore q -> (RecipientId -> QueueRec -> IO q) -> RecipientId -> QueueRec -> IO (Either ErrorType q)
|
||||
addQueue_ st mkQ rId qr@QueueRec {senderId = sId, notifier, queueData, rcvServiceId} = do
|
||||
@@ -304,8 +305,8 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where
|
||||
TM.insert fp newSrvId serviceCerts
|
||||
pure $ Right (newSrvId, True)
|
||||
newSTMService = do
|
||||
serviceRcvQueues <- newTVar S.empty
|
||||
serviceNtfQueues <- newTVar S.empty
|
||||
serviceRcvQueues <- newTVar (S.empty, mempty)
|
||||
serviceNtfQueues <- newTVar (S.empty, mempty)
|
||||
pure STMService {serviceRec = sr, serviceRcvQueues, serviceNtfQueues}
|
||||
|
||||
setQueueService :: (PartyI p, ServiceParty p) => STMQueueStore q -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ())
|
||||
@@ -331,7 +332,7 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where
|
||||
let !q' = Just q {notifier = Just nc {ntfServiceId = serviceId}}
|
||||
updateServiceQueues serviceNtfQueues nId prevNtfSrvId
|
||||
writeTVar qr q' $> Right ()
|
||||
updateServiceQueues :: (STMService -> TVar (Set QueueId)) -> QueueId -> Maybe ServiceId -> STM ()
|
||||
updateServiceQueues :: (STMService -> TVar (Set QueueId, IdsHash)) -> QueueId -> Maybe ServiceId -> STM ()
|
||||
updateServiceQueues serviceSel qId prevSrvId = do
|
||||
mapM_ (removeServiceQueue st serviceSel qId) prevSrvId
|
||||
mapM_ (addServiceQueue st serviceSel qId) serviceId
|
||||
@@ -346,16 +347,16 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where
|
||||
pure $ Right (ssNtfs', deleteNtfs)
|
||||
where
|
||||
addService (ssNtfs, ntfs') (serviceId, s) = do
|
||||
snIds <- readTVarIO $ serviceNtfQueues s
|
||||
(snIds, _) <- readTVarIO $ serviceNtfQueues s
|
||||
let (sNtfs, restNtfs) = partition (\(nId, _) -> S.member nId snIds) ntfs'
|
||||
pure ((Just serviceId, sNtfs) : ssNtfs, restNtfs)
|
||||
|
||||
getServiceQueueCount :: (PartyI p, ServiceParty p) => STMQueueStore q -> SParty p -> ServiceId -> IO (Either ErrorType Int64)
|
||||
getServiceQueueCount st party serviceId =
|
||||
getServiceQueueCountHash :: (PartyI p, ServiceParty p) => STMQueueStore q -> SParty p -> ServiceId -> IO (Either ErrorType (Int64, IdsHash))
|
||||
getServiceQueueCountHash st party serviceId =
|
||||
TM.lookupIO serviceId (services st) >>=
|
||||
maybe (pure $ Left AUTH) (fmap (Right . fromIntegral . S.size) . readTVarIO . serviceSel)
|
||||
maybe (pure $ Left AUTH) (fmap (Right . first (fromIntegral . S.size)) . readTVarIO . serviceSel)
|
||||
where
|
||||
serviceSel :: STMService -> TVar (Set QueueId)
|
||||
serviceSel :: STMService -> TVar (Set QueueId, IdsHash)
|
||||
serviceSel = case party of
|
||||
SRecipientService -> serviceRcvQueues
|
||||
SNotifierService -> serviceNtfQueues
|
||||
@@ -366,7 +367,7 @@ foldRcvServiceQueues st serviceId f acc =
|
||||
Nothing -> pure acc
|
||||
Just s ->
|
||||
readTVarIO (serviceRcvQueues s)
|
||||
>>= foldM (\a -> get >=> maybe (pure a) (f a)) acc
|
||||
>>= foldM (\a -> get >=> maybe (pure a) (f a)) acc . fst
|
||||
where
|
||||
get rId = TM.lookupIO rId (queues st) $>>= \q -> (q,) <$$> readTVarIO (queueRec q)
|
||||
|
||||
@@ -379,16 +380,23 @@ setStatus qr status =
|
||||
Just q -> (Right (), Just q {status})
|
||||
Nothing -> (Left AUTH, Nothing)
|
||||
|
||||
addServiceQueue :: STMQueueStore q -> (STMService -> TVar (Set QueueId)) -> QueueId -> ServiceId -> STM ()
|
||||
addServiceQueue st serviceSel qId serviceId =
|
||||
TM.lookup serviceId (services st) >>= mapM_ (\s -> modifyTVar' (serviceSel s) (S.insert qId))
|
||||
addServiceQueue :: STMQueueStore q -> (STMService -> TVar (Set QueueId, IdsHash)) -> QueueId -> ServiceId -> STM ()
|
||||
addServiceQueue = setServiceQueues_ S.insert
|
||||
{-# INLINE addServiceQueue #-}
|
||||
|
||||
removeServiceQueue :: STMQueueStore q -> (STMService -> TVar (Set QueueId)) -> QueueId -> ServiceId -> STM ()
|
||||
removeServiceQueue st serviceSel qId serviceId =
|
||||
TM.lookup serviceId (services st) >>= mapM_ (\s -> modifyTVar' (serviceSel s) (S.delete qId))
|
||||
removeServiceQueue :: STMQueueStore q -> (STMService -> TVar (Set QueueId, IdsHash)) -> QueueId -> ServiceId -> STM ()
|
||||
removeServiceQueue = setServiceQueues_ S.delete
|
||||
{-# INLINE removeServiceQueue #-}
|
||||
|
||||
setServiceQueues_ :: (QueueId -> Set QueueId -> Set QueueId) -> STMQueueStore q -> (STMService -> TVar (Set QueueId, IdsHash)) -> QueueId -> ServiceId -> STM ()
|
||||
setServiceQueues_ updateSet st serviceSel qId serviceId =
|
||||
TM.lookup serviceId (services st) >>= mapM_ (\v -> modifyTVar' (serviceSel v) update)
|
||||
where
|
||||
update (s, idsHash) =
|
||||
let !s' = updateSet qId s
|
||||
!idsHash' = queueIdHash qId <> idsHash
|
||||
in (s', idsHash')
|
||||
|
||||
removeNotifier :: STMQueueStore q -> NtfCreds -> STM ()
|
||||
removeNotifier st NtfCreds {notifierId = nId, ntfServiceId} = do
|
||||
TM.delete nId $ notifiers st
|
||||
|
||||
@@ -47,7 +47,7 @@ class StoreQueueClass q => QueueStoreClass q s where
|
||||
getCreateService :: s -> ServiceRec -> IO (Either ErrorType ServiceId)
|
||||
setQueueService :: (PartyI p, ServiceParty p) => s -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ())
|
||||
getQueueNtfServices :: s -> [(NotifierId, a)] -> IO (Either ErrorType ([(Maybe ServiceId, [(NotifierId, a)])], [(NotifierId, a)]))
|
||||
getServiceQueueCount :: (PartyI p, ServiceParty p) => s -> SParty p -> ServiceId -> IO (Either ErrorType Int64)
|
||||
getServiceQueueCountHash :: (PartyI p, ServiceParty p) => s -> SParty p -> ServiceId -> IO (Either ErrorType (Int64, IdsHash))
|
||||
|
||||
data EntityCounts = EntityCounts
|
||||
{ queueCount :: Int,
|
||||
|
||||
@@ -821,7 +821,15 @@ data ServiceStats = ServiceStats
|
||||
srvSubCount :: IORef Int,
|
||||
srvSubDuplicate :: IORef Int,
|
||||
srvSubQueues :: IORef Int,
|
||||
srvSubEnd :: IORef Int
|
||||
srvSubEnd :: IORef Int,
|
||||
-- counts of subscriptions
|
||||
srvSubOk :: IORef Int, -- server has the same queues as expected
|
||||
srvSubMore :: IORef Int, -- server has more queues than expected
|
||||
srvSubFewer :: IORef Int, -- server has fewer queues than expected
|
||||
srvSubDiff :: IORef Int, -- server has the same count, but different queues than expected (based on xor-hash)
|
||||
-- adds actual deviations
|
||||
srvSubMoreTotal :: IORef Int, -- server has more queues than expected, adds diff
|
||||
srvSubFewerTotal :: IORef Int
|
||||
}
|
||||
|
||||
data ServiceStatsData = ServiceStatsData
|
||||
@@ -832,7 +840,13 @@ data ServiceStatsData = ServiceStatsData
|
||||
_srvSubCount :: Int,
|
||||
_srvSubDuplicate :: Int,
|
||||
_srvSubQueues :: Int,
|
||||
_srvSubEnd :: Int
|
||||
_srvSubEnd :: Int,
|
||||
_srvSubOk :: Int,
|
||||
_srvSubMore :: Int,
|
||||
_srvSubFewer :: Int,
|
||||
_srvSubDiff :: Int,
|
||||
_srvSubMoreTotal :: Int,
|
||||
_srvSubFewerTotal :: Int
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
@@ -846,7 +860,13 @@ newServiceStatsData =
|
||||
_srvSubCount = 0,
|
||||
_srvSubDuplicate = 0,
|
||||
_srvSubQueues = 0,
|
||||
_srvSubEnd = 0
|
||||
_srvSubEnd = 0,
|
||||
_srvSubOk = 0,
|
||||
_srvSubMore = 0,
|
||||
_srvSubFewer = 0,
|
||||
_srvSubDiff = 0,
|
||||
_srvSubMoreTotal = 0,
|
||||
_srvSubFewerTotal = 0
|
||||
}
|
||||
|
||||
newServiceStats :: IO ServiceStats
|
||||
@@ -859,6 +879,12 @@ newServiceStats = do
|
||||
srvSubDuplicate <- newIORef 0
|
||||
srvSubQueues <- newIORef 0
|
||||
srvSubEnd <- newIORef 0
|
||||
srvSubOk <- newIORef 0
|
||||
srvSubMore <- newIORef 0
|
||||
srvSubFewer <- newIORef 0
|
||||
srvSubDiff <- newIORef 0
|
||||
srvSubMoreTotal <- newIORef 0
|
||||
srvSubFewerTotal <- newIORef 0
|
||||
pure
|
||||
ServiceStats
|
||||
{ srvAssocNew,
|
||||
@@ -868,7 +894,13 @@ newServiceStats = do
|
||||
srvSubCount,
|
||||
srvSubDuplicate,
|
||||
srvSubQueues,
|
||||
srvSubEnd
|
||||
srvSubEnd,
|
||||
srvSubOk,
|
||||
srvSubMore,
|
||||
srvSubFewer,
|
||||
srvSubDiff,
|
||||
srvSubMoreTotal,
|
||||
srvSubFewerTotal
|
||||
}
|
||||
|
||||
getServiceStatsData :: ServiceStats -> IO ServiceStatsData
|
||||
@@ -881,6 +913,12 @@ getServiceStatsData s = do
|
||||
_srvSubDuplicate <- readIORef $ srvSubDuplicate s
|
||||
_srvSubQueues <- readIORef $ srvSubQueues s
|
||||
_srvSubEnd <- readIORef $ srvSubEnd s
|
||||
_srvSubOk <- readIORef $ srvSubOk s
|
||||
_srvSubMore <- readIORef $ srvSubMore s
|
||||
_srvSubFewer <- readIORef $ srvSubFewer s
|
||||
_srvSubDiff <- readIORef $ srvSubDiff s
|
||||
_srvSubMoreTotal <- readIORef $ srvSubMoreTotal s
|
||||
_srvSubFewerTotal <- readIORef $ srvSubFewerTotal s
|
||||
pure
|
||||
ServiceStatsData
|
||||
{ _srvAssocNew,
|
||||
@@ -890,7 +928,13 @@ getServiceStatsData s = do
|
||||
_srvSubCount,
|
||||
_srvSubDuplicate,
|
||||
_srvSubQueues,
|
||||
_srvSubEnd
|
||||
_srvSubEnd,
|
||||
_srvSubOk,
|
||||
_srvSubMore,
|
||||
_srvSubFewer,
|
||||
_srvSubDiff,
|
||||
_srvSubMoreTotal,
|
||||
_srvSubFewerTotal
|
||||
}
|
||||
|
||||
getResetServiceStatsData :: ServiceStats -> IO ServiceStatsData
|
||||
@@ -903,6 +947,12 @@ getResetServiceStatsData s = do
|
||||
_srvSubDuplicate <- atomicSwapIORef (srvSubDuplicate s) 0
|
||||
_srvSubQueues <- atomicSwapIORef (srvSubQueues s) 0
|
||||
_srvSubEnd <- atomicSwapIORef (srvSubEnd s) 0
|
||||
_srvSubOk <- atomicSwapIORef (srvSubOk s) 0
|
||||
_srvSubMore <- atomicSwapIORef (srvSubMore s) 0
|
||||
_srvSubFewer <- atomicSwapIORef (srvSubFewer s) 0
|
||||
_srvSubDiff <- atomicSwapIORef (srvSubDiff s) 0
|
||||
_srvSubMoreTotal <- atomicSwapIORef (srvSubMoreTotal s) 0
|
||||
_srvSubFewerTotal <- atomicSwapIORef (srvSubFewerTotal s) 0
|
||||
pure
|
||||
ServiceStatsData
|
||||
{ _srvAssocNew,
|
||||
@@ -912,7 +962,13 @@ getResetServiceStatsData s = do
|
||||
_srvSubCount,
|
||||
_srvSubDuplicate,
|
||||
_srvSubQueues,
|
||||
_srvSubEnd
|
||||
_srvSubEnd,
|
||||
_srvSubOk,
|
||||
_srvSubMore,
|
||||
_srvSubFewer,
|
||||
_srvSubDiff,
|
||||
_srvSubMoreTotal,
|
||||
_srvSubFewerTotal
|
||||
}
|
||||
|
||||
-- this function is not thread safe, it is used on server start only
|
||||
@@ -926,6 +982,12 @@ setServiceStats s d = do
|
||||
writeIORef (srvSubDuplicate s) $! _srvSubDuplicate d
|
||||
writeIORef (srvSubQueues s) $! _srvSubQueues d
|
||||
writeIORef (srvSubEnd s) $! _srvSubEnd d
|
||||
writeIORef (srvSubOk s) $! _srvSubOk d
|
||||
writeIORef (srvSubMore s) $! _srvSubMore d
|
||||
writeIORef (srvSubFewer s) $! _srvSubFewer d
|
||||
writeIORef (srvSubDiff s) $! _srvSubDiff d
|
||||
writeIORef (srvSubMoreTotal s) $! _srvSubMoreTotal d
|
||||
writeIORef (srvSubFewerTotal s) $! _srvSubFewerTotal d
|
||||
|
||||
instance StrEncoding ServiceStatsData where
|
||||
strEncode ServiceStatsData {_srvAssocNew, _srvAssocDuplicate, _srvAssocUpdated, _srvAssocRemoved, _srvSubCount, _srvSubDuplicate, _srvSubQueues, _srvSubEnd} =
|
||||
@@ -963,7 +1025,13 @@ instance StrEncoding ServiceStatsData where
|
||||
_srvSubCount,
|
||||
_srvSubDuplicate,
|
||||
_srvSubQueues,
|
||||
_srvSubEnd
|
||||
_srvSubEnd,
|
||||
_srvSubOk = 0,
|
||||
_srvSubMore = 0,
|
||||
_srvSubFewer = 0,
|
||||
_srvSubDiff = 0,
|
||||
_srvSubMoreTotal = 0,
|
||||
_srvSubFewerTotal = 0
|
||||
}
|
||||
|
||||
data TimeBuckets = TimeBuckets
|
||||
|
||||
@@ -61,7 +61,7 @@ readQueueStore tty mkQ f st = readLogLines tty f $ \_ -> processLine
|
||||
Left e -> logError $ errPfx <> tshow e
|
||||
where
|
||||
errPfx = "STORE: getCreateService, stored service " <> decodeLatin1 (strEncode serviceId) <> ", "
|
||||
QueueService rId (ASP party) serviceId -> withQueue rId "QueueService" $ \q -> setQueueService st q party serviceId
|
||||
QueueService qId (ASP party) serviceId -> withQueue qId "QueueService" $ \q -> setQueueService st q party serviceId
|
||||
printError :: String -> IO ()
|
||||
printError e = B.putStrLn $ "Error parsing log: " <> B.pack e <> " - " <> s
|
||||
withQueue :: forall a. RecipientId -> T.Text -> (q -> IO (Either ErrorType a)) -> IO ()
|
||||
|
||||
Reference in New Issue
Block a user