Files
simplexmq/src/Simplex/Messaging/Server.hs
Evgeny 07eaf9157b smp server: allow getting and deleting short links for the old contact queues (#1549)
* smp server: allow getting and deleting short links for the old contact queues

* fix verifaction of legacy contact queues

* test
2025-05-25 17:03:02 +01:00

2026 lines
106 KiB
Haskell

{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
-- |
-- Module : Simplex.Messaging.Server
-- Copyright : (c) simplex.chat
-- License : AGPL-3
--
-- Maintainer : chat@simplex.chat
-- Stability : experimental
-- Portability : non-portable
--
-- This module defines SMP protocol server with in-memory persistence
-- and optional append only log of SMP queue records.
--
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md
module Simplex.Messaging.Server
( runSMPServer,
runSMPServerBlocking,
importMessages,
exportMessages,
printMessageStats,
disconnectTransport,
verifyCmdAuthorization,
dummyVerifyCmd,
randomId,
AttachHTTP,
MessageStats (..),
)
where
import Control.Concurrent.STM (throwSTM)
import Control.Logger.Simple
import Control.Monad
import Control.Monad.Except
import Control.Monad.IO.Unlift
import Control.Monad.Reader
import Control.Monad.Trans.Except
import Control.Monad.STM (retry)
import Data.Bifunctor (first)
import Data.ByteString.Base64 (encode)
import qualified Data.ByteString.Builder as BLD
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy.Char8 as LB
import Data.Dynamic (toDyn)
import Data.Either (fromRight, partitionEithers)
import Data.Functor (($>))
import Data.IORef
import Data.Int (Int64)
import qualified Data.IntMap.Strict as IM
import qualified Data.IntSet as IS
import Data.List (foldl', intercalate, mapAccumR)
import Data.List.NonEmpty (NonEmpty (..), (<|))
import qualified Data.List.NonEmpty as L
import qualified Data.Map.Strict as M
import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing)
import Data.Semigroup (Sum (..))
import Data.Text (Text)
import qualified Data.Text as T
import Data.Text.Encoding (decodeLatin1)
import qualified Data.Text.IO as T
import Data.Time.Clock (UTCTime (..), diffTimeToPicoseconds, getCurrentTime)
import Data.Time.Clock.System (SystemTime (..), getSystemTime)
import Data.Time.Format.ISO8601 (iso8601Show)
import Data.Type.Equality
import Data.Typeable (cast)
import qualified Data.X509 as X
import GHC.Conc.Signal
import GHC.IORef (atomicSwapIORef)
import GHC.Stats (getRTSStats)
import GHC.TypeLits (KnownNat)
import Network.Socket (ServiceName, Socket, socketToHandle)
import qualified Network.TLS as TLS
import Numeric.Natural (Natural)
import Simplex.Messaging.Agent.Lock
import Simplex.Messaging.Client (ProtocolClient (thParams), ProtocolClientError (..), SMPClient, SMPClientError, forwardSMPTransmission, nonBlockingWriteTBQueue, smpProxyError, temporaryClientError)
import Simplex.Messaging.Client.Agent (OwnServer, SMPClientAgent (..), SMPClientAgentEvent (..), closeSMPClientAgent, getSMPServerClient'', isOwnServer, lookupSMPServerClient, getConnectedSMPServerClient)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Protocol
import Simplex.Messaging.Server.Control
import Simplex.Messaging.Server.Env.STM as Env
import Simplex.Messaging.Server.Expiration
import Simplex.Messaging.Server.MsgStore
import Simplex.Messaging.Server.MsgStore.Journal (JournalMsgStore, JournalQueue)
import Simplex.Messaging.Server.MsgStore.STM
import Simplex.Messaging.Server.MsgStore.Types
import Simplex.Messaging.Server.NtfStore
import Simplex.Messaging.Server.Prometheus
import Simplex.Messaging.Server.QueueStore
import Simplex.Messaging.Server.QueueStore.QueueInfo
import Simplex.Messaging.Server.QueueStore.Types
import Simplex.Messaging.Server.Stats
import Simplex.Messaging.Server.StoreLog (foldLogLines)
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport
import Simplex.Messaging.Transport.Buffer (trimCR)
import Simplex.Messaging.Transport.Server
import Simplex.Messaging.Util
import Simplex.Messaging.Version
import System.Environment (lookupEnv)
import System.Exit (exitFailure, exitSuccess)
import System.IO (hPrint, hPutStrLn, hSetNewlineMode, universalNewlineMode)
import System.Mem.Weak (deRefWeak)
import UnliftIO (timeout)
import UnliftIO.Concurrent
import UnliftIO.Directory (doesFileExist, renameFile)
import UnliftIO.Exception
import UnliftIO.IO
import UnliftIO.STM
#if MIN_VERSION_base(4,18,0)
import Data.List (sort)
import GHC.Conc (listThreads, threadStatus)
import GHC.Conc.Sync (threadLabel)
#endif
-- | Runs an SMP server using passed configuration.
--
-- See a full server here: https://github.com/simplex-chat/simplexmq/blob/master/apps/smp-server/Main.hs
runSMPServer :: ServerConfig -> Maybe AttachHTTP -> IO ()
runSMPServer cfg attachHTTP_ = do
started <- newEmptyTMVarIO
runSMPServerBlocking started cfg attachHTTP_
-- | Runs an SMP server using passed configuration with signalling.
--
-- This function uses passed TMVar to signal when the server is ready to accept TCP requests (True)
-- and when it is disconnected from the TCP socket once the server thread is killed (False).
runSMPServerBlocking :: TMVar Bool -> ServerConfig -> Maybe AttachHTTP -> IO ()
runSMPServerBlocking started cfg attachHTTP_ = newEnv cfg >>= runReaderT (smpServer started cfg attachHTTP_)
type M a = ReaderT Env IO a
type AttachHTTP = Socket -> TLS.Context -> IO ()
smpServer :: TMVar Bool -> ServerConfig -> Maybe AttachHTTP -> M ()
smpServer started cfg@ServerConfig {transports, transportConfig = tCfg, startOptions} attachHTTP_ = do
s <- asks server
pa <- asks proxyAgent
msgStats_ <- processServerMessages startOptions
ntfStats <- restoreServerNtfs
liftIO $ mapM_ (printMessageStats "messages") msgStats_
liftIO $ printMessageStats "notifications" ntfStats
restoreServerStats msgStats_ ntfStats
when (maintenance startOptions) $ do
liftIO $ putStrLn "Server started in 'maintenance' mode, exiting"
stopServer s
liftIO $ exitSuccess
raceAny_
( serverThread "server subscribers" s subscribers subscriptions cancelSub
: serverThread "server ntfSubscribers" s ntfSubscribers ntfSubscriptions (\_ -> pure ())
: deliverNtfsThread s
: sendPendingEvtsThread s
: receiveFromProxyAgent pa
: expireNtfsThread cfg
: sigIntHandlerThread
: map runServer transports
<> expireMessagesThread_ cfg
<> serverStatsThread_ cfg
<> prometheusMetricsThread_ cfg
<> controlPortThread_ cfg
)
`finally` stopServer s
where
runServer :: (ServiceName, ASrvTransport, AddHTTP) -> M ()
runServer (tcpPort, ATransport t, addHTTP) = do
smpCreds@(srvCert, srvKey) <- asks tlsServerCreds
httpCreds_ <- asks httpServerCreds
ss <- liftIO newSocketState
asks sockets >>= atomically . (`modifyTVar'` ((tcpPort, ss) :))
srvSignKey <- either fail pure $ fromTLSPrivKey srvKey
env <- ask
liftIO $ case (httpCreds_, attachHTTP_) of
(Just httpCreds, Just attachHTTP) | addHTTP ->
runTransportServerState_ ss started tcpPort defaultSupportedParamsHTTPS chooseCreds (Just combinedALPNs) tCfg $ \s h ->
case cast h of
Just (TLS {tlsContext} :: TLS 'TServer) | maybe False (`elem` httpALPN) (getSessionALPN h) -> labelMyThread "https client" >> attachHTTP s tlsContext
_ -> runClient srvCert srvSignKey t h `runReaderT` env
where
chooseCreds = maybe smpCreds (\_host -> httpCreds)
combinedALPNs = supportedSMPHandshakes <> httpALPN
httpALPN :: [ALPN]
httpALPN = ["h2", "http/1.1"]
_ ->
runTransportServerState ss started tcpPort defaultSupportedParams smpCreds (Just supportedSMPHandshakes) tCfg $ \h -> runClient srvCert srvSignKey t h `runReaderT` env
fromTLSPrivKey pk = C.x509ToPrivate (pk, []) >>= C.privKey
sigIntHandlerThread :: M ()
sigIntHandlerThread = do
flagINT <- newEmptyTMVarIO
let sigINT = 2 -- CONST_SIGINT value
sigIntAction = \_ptr -> atomically $ void $ tryPutTMVar flagINT ()
sigIntHandler = Just (sigIntAction, toDyn ())
void $ liftIO $ setHandler sigINT sigIntHandler
atomically $ readTMVar flagINT
logNote "Received SIGINT, stopping server..."
stopServer :: Server -> M ()
stopServer s = do
asks serverActive >>= atomically . (`writeTVar` False)
logNote "Saving server state..."
withLock' (savingLock s) "final" $ saveServer True >> closeServer
logNote "Server stopped"
saveServer :: Bool -> M ()
saveServer drainMsgs = do
ams@(AMS _ _ ms) <- asks msgStore
liftIO $ saveServerMessages drainMsgs ams >> closeMsgStore ms
saveServerNtfs
saveServerStats
closeServer :: M ()
closeServer = asks (smpAgent . proxyAgent) >>= liftIO . closeSMPClientAgent
serverThread ::
forall s.
String ->
Server ->
(Server -> ServerSubscribers) ->
(forall st. Client st -> TMap QueueId s) ->
(s -> IO ()) ->
M ()
serverThread label srv srvSubscribers clientSubs unsub = do
labelMyThread label
liftIO . forever $ do
-- Reading clients outside of `updateSubscribers` transaction to avoid transaction re-evaluation on each new connected client.
-- In case client disconnects during the transaction (its `connected` property is read),
-- the transaction will still be re-evaluated, and the client won't be stored as subscribed.
sub@(_, clntId, _) <- atomically $ readTQueue subQ
c_ <- getServerClient clntId srv
atomically (updateSubscribers c_ sub)
$>>= endPreviousSubscriptions
>>= mapM_ unsub
where
ServerSubscribers {subQ, queueSubscribers, subClients, pendingEvents} = srvSubscribers srv
updateSubscribers :: Maybe AClient -> (QueueId, ClientId, Subscribed) -> STM (Maybe ((QueueId, BrokerMsg), AClient))
updateSubscribers c_ (qId, clntId, subscribed) = updateSub $>>= clientToBeNotified
where
updateSub = case c_ of
Just c@(AClient _ _ Client {connected}) -> ifM (readTVar connected) (updateSubConnected c) updateSubDisconnected
Nothing -> updateSubDisconnected
updateSubConnected c
| subscribed = do
modifyTVar' subClients $ IS.insert clntId -- add client to server's subscribed cients
upsertSubscribedClient qId c queueSubscribers
| otherwise = do
removeWhenNoSubs c
lookupDeleteSubscribedClient qId queueSubscribers
-- do not insert client if it is already disconnected, but send END to any other client
updateSubDisconnected = lookupDeleteSubscribedClient qId queueSubscribers
clientToBeNotified ac@(AClient _ _ Client {clientId, connected})
| clntId == clientId = pure Nothing
| otherwise = (\yes -> if yes then Just ((qId, subEvt), ac) else Nothing) <$> readTVar connected
where
subEvt = if subscribed then END else DELD
endPreviousSubscriptions :: ((QueueId, BrokerMsg), AClient) -> IO (Maybe s)
endPreviousSubscriptions (evt@(qId, _), ac@(AClient _ _ c)) = do
atomically $ modifyTVar' pendingEvents $ IM.alter (Just . maybe [evt] (evt <|)) (clientId c)
atomically $ do
sub <- TM.lookupDelete qId (clientSubs c)
removeWhenNoSubs ac $> sub
-- remove client from server's subscribed cients
removeWhenNoSubs (AClient _ _ c) = whenM (null <$> readTVar (clientSubs c)) $ modifyTVar' subClients $ IS.delete (clientId c)
deliverNtfsThread :: Server -> M ()
deliverNtfsThread srv@Server {ntfSubscribers} = do
ntfInt <- asks $ ntfDeliveryInterval . config
NtfStore ns <- asks ntfStore
stats <- asks serverStats
liftIO $ forever $ do
threadDelay ntfInt
cIds <- IS.toList <$> readTVarIO (subClients ntfSubscribers)
forM_ cIds $ \cId -> getServerClient cId srv >>= mapM_ (deliverNtfs ns stats)
where
deliverNtfs ns stats (AClient _ _ Client {clientId, ntfSubscriptions, sndQ, connected}) =
whenM (currentClient readTVarIO) $ do
subs <- readTVarIO ntfSubscriptions
ntfQs <- M.assocs . M.filterWithKey (\nId _ -> M.member nId subs) <$> readTVarIO ns
tryAny (atomically $ flushSubscribedNtfs ntfQs) >>= \case
Right len -> updateNtfStats len
Left e -> logDebug $ "NOTIFICATIONS: cancelled for client #" <> tshow clientId <> ", reason: " <> tshow e
where
flushSubscribedNtfs :: [(NotifierId, TVar [MsgNtf])] -> STM Int
flushSubscribedNtfs ntfQs = do
ts_ <- foldM addNtfs [] ntfQs
forM_ (L.nonEmpty ts_) $ \ts -> do
let cancelNtfs s = throwSTM $ userError $ s <> ", " <> show (length ts_) <> " ntfs kept"
unlessM (currentClient readTVar) $ cancelNtfs "not current client"
whenM (isFullTBQueue sndQ) $ cancelNtfs "sending queue full"
writeTBQueue sndQ ts
pure $ length ts_
currentClient :: Monad m => (forall a. TVar a -> m a) -> m Bool
currentClient rd = (&&) <$> rd connected <*> (IS.member clientId <$> rd (subClients ntfSubscribers))
addNtfs :: [Transmission BrokerMsg] -> (NotifierId, TVar [MsgNtf]) -> STM [Transmission BrokerMsg]
addNtfs acc (nId, v) =
readTVar v >>= \case
[] -> pure acc
ntfs -> do
writeTVar v []
pure $ foldl' (\acc' ntf -> nmsg nId ntf : acc') acc ntfs -- reverses, to order by time
nmsg nId MsgNtf {ntfNonce, ntfEncMeta} = (CorrId "", nId, NMSG ntfNonce ntfEncMeta)
updateNtfStats 0 = pure ()
updateNtfStats len = liftIO $ do
atomicModifyIORef'_ (ntfCount stats) (subtract len)
atomicModifyIORef'_ (msgNtfs stats) (+ len)
atomicModifyIORef'_ (msgNtfsB stats) (+ (len `div` 80 + 1)) -- up to 80 NMSG in the batch
sendPendingEvtsThread :: Server -> M ()
sendPendingEvtsThread srv@Server {subscribers, ntfSubscribers} = do
endInt <- asks $ pendingENDInterval . config
stats <- asks serverStats
liftIO $ forever $ do
threadDelay endInt
sendPending subscribers stats
sendPending ntfSubscribers stats
where
sendPending ServerSubscribers {pendingEvents} stats = do
pending <- atomically $ swapTVar pendingEvents IM.empty
unless (null pending) $ forM_ (IM.assocs pending) $ \(cId, evts) ->
getServerClient cId srv >>= mapM_ (enqueueEvts evts)
where
enqueueEvts evts (AClient _ _ Client {connected, sndQ}) =
whenM (readTVarIO connected) $
nonBlockingWriteTBQueue sndQ ts >> updateEndStats
where
ts = L.map (\(qId, evt) -> (CorrId "", qId, evt)) evts
-- this accounts for both END and DELD events
updateEndStats = do
let len = L.length evts
when (len > 0) $ do
atomicModifyIORef'_ (qSubEnd stats) (+ len)
atomicModifyIORef'_ (qSubEndB stats) (+ (len `div` 255 + 1)) -- up to 255 ENDs or DELDs in the batch
receiveFromProxyAgent :: ProxyAgent -> M ()
receiveFromProxyAgent ProxyAgent {smpAgent = SMPClientAgent {agentQ}} =
forever $
atomically (readTBQueue agentQ) >>= \case
CAConnected srv -> logInfo $ "SMP server connected " <> showServer' srv
CADisconnected srv [] -> logInfo $ "SMP server disconnected " <> showServer' srv
CADisconnected srv subs -> logError $ "SMP server disconnected " <> showServer' srv <> " / subscriptions: " <> tshow (length subs)
CASubscribed srv _ subs -> logError $ "SMP server subscribed " <> showServer' srv <> " / subscriptions: " <> tshow (length subs)
CASubError srv _ errs -> logError $ "SMP server subscription errors " <> showServer' srv <> " / errors: " <> tshow (length errs)
where
showServer' = decodeLatin1 . strEncode . host
expireMessagesThread_ :: ServerConfig -> [M ()]
expireMessagesThread_ ServerConfig {messageExpiration = Just msgExp} = [expireMessagesThread msgExp]
expireMessagesThread_ _ = []
expireMessagesThread :: ExpirationConfig -> M ()
expireMessagesThread ExpirationConfig {checkInterval, ttl} = do
AMS _ _ ms <- asks msgStore
let interval = checkInterval * 1000000
stats <- asks serverStats
labelMyThread "expireMessagesThread"
liftIO $ forever $ expire ms stats interval
where
expire :: forall s. MsgStoreClass s => s -> ServerStats -> Int64 -> IO ()
expire ms stats interval = do
threadDelay' interval
logNote "Started expiring messages..."
n <- compactQueues @(StoreQueue s) $ queueStore ms
when (n > 0) $ logNote $ "Removed " <> tshow n <> " old deleted queues from the database."
now <- systemSeconds <$> getSystemTime
tryAny (expireOldMessages False ms now ttl) >>= \case
Right msgStats@MessageStats {storedMsgsCount = stored, expiredMsgsCount = expired} -> do
atomicWriteIORef (msgCount stats) stored
atomicModifyIORef'_ (msgExpired stats) (+ expired)
printMessageStats "STORE: messages" msgStats
Left e -> logError $ "STORE: withAllMsgQueues, error expiring messages, " <> tshow e
expireNtfsThread :: ServerConfig -> M ()
expireNtfsThread ServerConfig {notificationExpiration = expCfg} = do
ns <- asks ntfStore
let interval = checkInterval expCfg * 1000000
stats <- asks serverStats
labelMyThread "expireNtfsThread"
liftIO $ forever $ do
threadDelay' interval
old <- expireBeforeEpoch expCfg
expired <- deleteExpiredNtfs ns old
when (expired > 0) $ do
atomicModifyIORef'_ (msgNtfExpired stats) (+ expired)
atomicModifyIORef'_ (ntfCount stats) (subtract expired)
serverStatsThread_ :: ServerConfig -> [M ()]
serverStatsThread_ ServerConfig {logStatsInterval = Just interval, logStatsStartTime, serverStatsLogFile} =
[logServerStats logStatsStartTime interval serverStatsLogFile]
serverStatsThread_ _ = []
logServerStats :: Int64 -> Int64 -> FilePath -> M ()
logServerStats startAt logInterval statsFilePath = do
labelMyThread "logServerStats"
initialDelay <- (startAt -) . fromIntegral . (`div` 1000000_000000) . diffTimeToPicoseconds . utctDayTime <$> liftIO getCurrentTime
liftIO $ putStrLn $ "server stats log enabled: " <> statsFilePath
liftIO $ threadDelay' $ 1000000 * (initialDelay + if initialDelay < 0 then 86400 else 0)
ss@ServerStats {fromTime, qCreated, qSecured, qDeletedAll, qDeletedAllB, qDeletedNew, qDeletedSecured, qSub, qSubAllB, qSubAuth, qSubDuplicate, qSubProhibited, qSubEnd, qSubEndB, ntfCreated, ntfDeleted, ntfDeletedB, ntfSub, ntfSubB, ntfSubAuth, ntfSubDuplicate, msgSent, msgSentAuth, msgSentQuota, msgSentLarge, msgRecv, msgRecvGet, msgGet, msgGetNoMsg, msgGetAuth, msgGetDuplicate, msgGetProhibited, msgExpired, activeQueues, msgSentNtf, msgRecvNtf, activeQueuesNtf, qCount, msgCount, ntfCount, pRelays, pRelaysOwn, pMsgFwds, pMsgFwdsOwn, pMsgFwdsRecv}
<- asks serverStats
AMS _ _ (st :: s) <- asks msgStore
QueueCounts {queueCount, notifierCount} <- liftIO $ queueCounts @(StoreQueue s) $ queueStore st
let interval = 1000000 * logInterval
forever $ do
withFile statsFilePath AppendMode $ \h -> liftIO $ do
hSetBuffering h LineBuffering
ts <- getCurrentTime
fromTime' <- atomicSwapIORef fromTime ts
qCreated' <- atomicSwapIORef qCreated 0
qSecured' <- atomicSwapIORef qSecured 0
qDeletedAll' <- atomicSwapIORef qDeletedAll 0
qDeletedAllB' <- atomicSwapIORef qDeletedAllB 0
qDeletedNew' <- atomicSwapIORef qDeletedNew 0
qDeletedSecured' <- atomicSwapIORef qDeletedSecured 0
qSub' <- atomicSwapIORef qSub 0
qSubAllB' <- atomicSwapIORef qSubAllB 0
qSubAuth' <- atomicSwapIORef qSubAuth 0
qSubDuplicate' <- atomicSwapIORef qSubDuplicate 0
qSubProhibited' <- atomicSwapIORef qSubProhibited 0
qSubEnd' <- atomicSwapIORef qSubEnd 0
qSubEndB' <- atomicSwapIORef qSubEndB 0
ntfCreated' <- atomicSwapIORef ntfCreated 0
ntfDeleted' <- atomicSwapIORef ntfDeleted 0
ntfDeletedB' <- atomicSwapIORef ntfDeletedB 0
ntfSub' <- atomicSwapIORef ntfSub 0
ntfSubB' <- atomicSwapIORef ntfSubB 0
ntfSubAuth' <- atomicSwapIORef ntfSubAuth 0
ntfSubDuplicate' <- atomicSwapIORef ntfSubDuplicate 0
msgSent' <- atomicSwapIORef msgSent 0
msgSentAuth' <- atomicSwapIORef msgSentAuth 0
msgSentQuota' <- atomicSwapIORef msgSentQuota 0
msgSentLarge' <- atomicSwapIORef msgSentLarge 0
msgRecv' <- atomicSwapIORef msgRecv 0
msgRecvGet' <- atomicSwapIORef msgRecvGet 0
msgGet' <- atomicSwapIORef msgGet 0
msgGetNoMsg' <- atomicSwapIORef msgGetNoMsg 0
msgGetAuth' <- atomicSwapIORef msgGetAuth 0
msgGetDuplicate' <- atomicSwapIORef msgGetDuplicate 0
msgGetProhibited' <- atomicSwapIORef msgGetProhibited 0
msgExpired' <- atomicSwapIORef msgExpired 0
ps <- liftIO $ periodStatCounts activeQueues ts
msgSentNtf' <- atomicSwapIORef msgSentNtf 0
msgRecvNtf' <- atomicSwapIORef msgRecvNtf 0
psNtf <- liftIO $ periodStatCounts activeQueuesNtf ts
msgNtfs' <- atomicSwapIORef (msgNtfs ss) 0
msgNtfsB' <- atomicSwapIORef (msgNtfsB ss) 0
msgNtfNoSub' <- atomicSwapIORef (msgNtfNoSub ss) 0
msgNtfLost' <- atomicSwapIORef (msgNtfLost ss) 0
msgNtfExpired' <- atomicSwapIORef (msgNtfExpired ss) 0
pRelays' <- getResetProxyStatsData pRelays
pRelaysOwn' <- getResetProxyStatsData pRelaysOwn
pMsgFwds' <- getResetProxyStatsData pMsgFwds
pMsgFwdsOwn' <- getResetProxyStatsData pMsgFwdsOwn
pMsgFwdsRecv' <- atomicSwapIORef pMsgFwdsRecv 0
qCount' <- readIORef qCount
msgCount' <- readIORef msgCount
ntfCount' <- readIORef ntfCount
hPutStrLn h $
intercalate
","
( [ iso8601Show $ utctDay fromTime',
show qCreated',
show qSecured',
show qDeletedAll',
show msgSent',
show msgRecv',
dayCount ps,
weekCount ps,
monthCount ps,
show msgSentNtf',
show msgRecvNtf',
dayCount psNtf,
weekCount psNtf,
monthCount psNtf,
show qCount',
show msgCount',
show msgExpired',
show qDeletedNew',
show qDeletedSecured'
]
<> showProxyStats pRelays'
<> showProxyStats pRelaysOwn'
<> showProxyStats pMsgFwds'
<> showProxyStats pMsgFwdsOwn'
<> [ show pMsgFwdsRecv',
show qSub',
show qSubAuth',
show qSubDuplicate',
show qSubProhibited',
show msgSentAuth',
show msgSentQuota',
show msgSentLarge',
show msgNtfs',
show msgNtfNoSub',
show msgNtfLost',
"0", -- qSubNoMsg' is removed for performance.
-- Use qSubAllB for the approximate number of all subscriptions.
-- Average observed batch size is 25-30 subscriptions.
show msgRecvGet',
show msgGet',
show msgGetNoMsg',
show msgGetAuth',
show msgGetDuplicate',
show msgGetProhibited',
"0", -- dayCount psSub; psSub is removed to reduce memory usage
"0", -- weekCount psSub
"0", -- monthCount psSub
show queueCount,
show ntfCreated',
show ntfDeleted',
show ntfSub',
show ntfSubAuth',
show ntfSubDuplicate',
show notifierCount,
show qDeletedAllB',
show qSubAllB',
show qSubEnd',
show qSubEndB',
show ntfDeletedB',
show ntfSubB',
show msgNtfsB',
show msgNtfExpired',
show ntfCount'
]
)
liftIO $ threadDelay' interval
where
showProxyStats ProxyStatsData {_pRequests, _pSuccesses, _pErrorsConnect, _pErrorsCompat, _pErrorsOther} =
[show _pRequests, show _pSuccesses, show _pErrorsConnect, show _pErrorsCompat, show _pErrorsOther]
prometheusMetricsThread_ :: ServerConfig -> [M ()]
prometheusMetricsThread_ ServerConfig {prometheusInterval = Just interval, prometheusMetricsFile} =
[savePrometheusMetrics interval prometheusMetricsFile]
prometheusMetricsThread_ _ = []
savePrometheusMetrics :: Int -> FilePath -> M ()
savePrometheusMetrics saveInterval metricsFile = do
labelMyThread "savePrometheusMetrics"
liftIO $ putStrLn $ "Prometheus metrics saved every " <> show saveInterval <> " seconds to " <> metricsFile
AMS _ _ st <- asks msgStore
ss <- asks serverStats
env <- ask
rtsOpts <- liftIO $ maybe ("set " <> rtsOptionsEnv) T.pack <$> lookupEnv (T.unpack rtsOptionsEnv)
let interval = 1000000 * saveInterval
liftIO $ forever $ do
threadDelay interval
ts <- getCurrentTime
sm <- getServerMetrics st ss rtsOpts
rtm <- getRealTimeMetrics env
T.writeFile metricsFile $ prometheusMetrics sm rtm ts
getServerMetrics :: forall s. MsgStoreClass s => s -> ServerStats -> Text -> IO ServerMetrics
getServerMetrics st ss rtsOptions = do
d <- getServerStatsData ss
let ps = periodStatDataCounts $ _activeQueues d
psNtf = periodStatDataCounts $ _activeQueuesNtf d
QueueCounts {queueCount, notifierCount} <- queueCounts @(StoreQueue s) $ queueStore st
pure ServerMetrics {statsData = d, activeQueueCounts = ps, activeNtfCounts = psNtf, queueCount, notifierCount, rtsOptions}
getRealTimeMetrics :: Env -> IO RealTimeMetrics
getRealTimeMetrics Env {sockets, msgStore = AMS _ _ ms, server = srv@Server {subscribers, ntfSubscribers}} = do
socketStats <- mapM (traverse getSocketStats) =<< readTVarIO sockets
#if MIN_VERSION_base(4,18,0)
threadsCount <- length <$> listThreads
#else
let threadsCount = 0
#endif
clientsCount <- IM.size <$> getServerClients srv
smpSubs <- getSubscribersMetrics subscribers
ntfSubs <- getSubscribersMetrics ntfSubscribers
loadedCounts <- loadedQueueCounts ms
pure RealTimeMetrics {socketStats, threadsCount, clientsCount, smpSubs, ntfSubs, loadedCounts}
where
getSubscribersMetrics ServerSubscribers {queueSubscribers, subClients} = do
subsCount <- M.size <$> getSubscribedClients queueSubscribers
subClientsCount <- IS.size <$> readTVarIO subClients
pure RTSubscriberMetrics {subsCount, subClientsCount}
runClient :: Transport c => X.CertificateChain -> C.APrivateSignKey -> TProxy c 'TServer -> c 'TServer -> M ()
runClient srvCert srvSignKey tp h = do
kh <- asks serverIdentity
ks <- atomically . C.generateKeyPair =<< asks random
ServerConfig {smpServerVRange, smpHandshakeTimeout} <- asks config
labelMyThread $ "smp handshake for " <> transportName tp
liftIO (timeout smpHandshakeTimeout . runExceptT $ smpServerHandshake srvCert srvSignKey h ks kh smpServerVRange) >>= \case
Just (Right th) -> runClientTransport th
_ -> pure ()
controlPortThread_ :: ServerConfig -> [M ()]
controlPortThread_ ServerConfig {controlPort = Just port} = [runCPServer port]
controlPortThread_ _ = []
runCPServer :: ServiceName -> M ()
runCPServer port = do
srv <- asks server
cpStarted <- newEmptyTMVarIO
u <- askUnliftIO
liftIO $ do
labelMyThread "control port server"
runLocalTCPServer cpStarted port $ runCPClient u srv
where
runCPClient :: UnliftIO (ReaderT Env IO) -> Server -> Socket -> IO ()
runCPClient u srv sock = do
labelMyThread "control port client"
h <- socketToHandle sock ReadWriteMode
hSetBuffering h LineBuffering
hSetNewlineMode h universalNewlineMode
hPutStrLn h "SMP server control port\n'help' for supported commands"
role <- newTVarIO CPRNone
cpLoop h role
where
cpLoop h role = do
s <- trimCR <$> B.hGetLine h
case strDecode s of
Right CPQuit -> hClose h
Right cmd -> logCmd s cmd >> processCP h role cmd >> cpLoop h role
Left err -> hPutStrLn h ("error: " <> err) >> cpLoop h role
logCmd s cmd = when shouldLog $ logWarn $ "ControlPort: " <> tshow s
where
shouldLog = case cmd of
CPAuth _ -> False
CPHelp -> False
CPQuit -> False
CPSkip -> False
_ -> True
processCP h role = \case
CPAuth auth -> atomically $ writeTVar role $! newRole cfg
where
newRole ServerConfig {controlPortUserAuth = user, controlPortAdminAuth = admin}
| Just auth == admin = CPRAdmin
| Just auth == user = CPRUser
| otherwise = CPRNone
CPSuspend -> withAdminRole $ hPutStrLn h "suspend not implemented"
CPResume -> withAdminRole $ hPutStrLn h "resume not implemented"
CPClients -> withAdminRole $ do
cls <- getServerClients srv
hPutStrLn h "clientId,sessionId,connected,createdAt,rcvActiveAt,sndActiveAt,age,subscriptions"
forM_ (IM.toList cls) $ \(cid, (AClient _ _ Client {sessionId, connected, createdAt, rcvActiveAt, sndActiveAt, subscriptions})) -> do
connected' <- bshow <$> readTVarIO connected
rcvActiveAt' <- strEncode <$> readTVarIO rcvActiveAt
sndActiveAt' <- strEncode <$> readTVarIO sndActiveAt
now <- liftIO getSystemTime
let age = systemSeconds now - systemSeconds createdAt
subscriptions' <- bshow . M.size <$> readTVarIO subscriptions
hPutStrLn h . B.unpack $ B.intercalate "," [bshow cid, encode sessionId, connected', strEncode createdAt, rcvActiveAt', sndActiveAt', bshow age, subscriptions']
CPStats -> withUserRole $ do
ss <- unliftIO u $ asks serverStats
AMS _ _ (st :: s) <- unliftIO u $ asks msgStore
QueueCounts {queueCount, notifierCount} <- queueCounts @(StoreQueue s) $ queueStore st
let getStat :: (ServerStats -> IORef a) -> IO a
getStat var = readIORef (var ss)
putStat :: Show a => String -> (ServerStats -> IORef a) -> IO ()
putStat label var = getStat var >>= \v -> hPutStrLn h $ label <> ": " <> show v
putProxyStat :: String -> (ServerStats -> ProxyStats) -> IO ()
putProxyStat label var = do
ProxyStatsData {_pRequests, _pSuccesses, _pErrorsConnect, _pErrorsCompat, _pErrorsOther} <- getProxyStatsData $ var ss
hPutStrLn h $ label <> ": requests=" <> show _pRequests <> ", successes=" <> show _pSuccesses <> ", errorsConnect=" <> show _pErrorsConnect <> ", errorsCompat=" <> show _pErrorsCompat <> ", errorsOther=" <> show _pErrorsOther
putStat "fromTime" fromTime
putStat "qCreated" qCreated
putStat "qSecured" qSecured
putStat "qDeletedAll" qDeletedAll
putStat "qDeletedAllB" qDeletedAllB
putStat "qDeletedNew" qDeletedNew
putStat "qDeletedSecured" qDeletedSecured
getStat (day . activeQueues) >>= \v -> hPutStrLn h $ "daily active queues: " <> show (IS.size v)
-- removed to reduce memory usage
-- getStat (day . subscribedQueues) >>= \v -> hPutStrLn h $ "daily subscribed queues: " <> show (S.size v)
putStat "qSub" qSub
putStat "qSubAllB" qSubAllB
putStat "qSubEnd" qSubEnd
putStat "qSubEndB" qSubEndB
subs <- (,,) <$> getStat qSubAuth <*> getStat qSubDuplicate <*> getStat qSubProhibited
hPutStrLn h $ "other SUB events (auth, duplicate, prohibited): " <> show subs
putStat "msgSent" msgSent
putStat "msgRecv" msgRecv
putStat "msgRecvGet" msgRecvGet
putStat "msgGet" msgGet
putStat "msgGetNoMsg" msgGetNoMsg
gets <- (,,) <$> getStat msgGetAuth <*> getStat msgGetDuplicate <*> getStat msgGetProhibited
hPutStrLn h $ "other GET events (auth, duplicate, prohibited): " <> show gets
putStat "msgSentNtf" msgSentNtf
putStat "msgRecvNtf" msgRecvNtf
putStat "msgNtfs" msgNtfs
putStat "msgNtfsB" msgNtfsB
putStat "msgNtfExpired" msgNtfExpired
putStat "qCount" qCount
hPutStrLn h $ "qCount 2: " <> show queueCount
hPutStrLn h $ "notifiers: " <> show notifierCount
putStat "msgCount" msgCount
putStat "ntfCount" ntfCount
readTVarIO role >>= \case
CPRAdmin -> do
NtfStore ns <- unliftIO u $ asks ntfStore
ntfCount2 <- liftIO . foldM (\(!n) q -> (n +) . length <$> readTVarIO q) 0 =<< readTVarIO ns
hPutStrLn h $ "ntfCount 2: " <> show ntfCount2
_ -> pure ()
putProxyStat "pRelays" pRelays
putProxyStat "pRelaysOwn" pRelaysOwn
putProxyStat "pMsgFwds" pMsgFwds
putProxyStat "pMsgFwdsOwn" pMsgFwdsOwn
putStat "pMsgFwdsRecv" pMsgFwdsRecv
CPStatsRTS -> getRTSStats >>= hPrint h
CPThreads -> withAdminRole $ do
#if MIN_VERSION_base(4,18,0)
threads <- liftIO listThreads
hPutStrLn h $ "Threads: " <> show (length threads)
forM_ (sort threads) $ \tid -> do
label <- threadLabel tid
status <- threadStatus tid
hPutStrLn h $ show tid <> " (" <> show status <> ") " <> fromMaybe "" label
#else
hPutStrLn h "Not available on GHC 8.10"
#endif
CPSockets -> withUserRole $ unliftIO u (asks sockets) >>= readTVarIO >>= mapM_ putSockets
where
putSockets (tcpPort, socketsState) = do
ss <- getSocketStats socketsState
hPutStrLn h $ "Sockets for port " <> tcpPort <> ":"
hPutStrLn h $ "accepted: " <> show (socketsAccepted ss)
hPutStrLn h $ "closed: " <> show (socketsClosed ss)
hPutStrLn h $ "active: " <> show (socketsActive ss)
hPutStrLn h $ "leaked: " <> show (socketsLeaked ss)
CPSocketThreads -> withAdminRole $ do
#if MIN_VERSION_base(4,18,0)
unliftIO u (asks sockets) >>= readTVarIO >>= mapM_ putSocketThreads
where
putSocketThreads (tcpPort, (_, _, active')) = do
active <- readTVarIO active'
forM_ (IM.toList active) $ \(sid, tid') ->
deRefWeak tid' >>= \case
Nothing -> hPutStrLn h $ intercalate "," [tcpPort, show sid, "", "gone", ""]
Just tid -> do
label <- threadLabel tid
status <- threadStatus tid
hPutStrLn h $ intercalate "," [tcpPort, show sid, show tid, show status, fromMaybe "" label]
#else
hPutStrLn h "Not available on GHC 8.10"
#endif
CPServerInfo -> readTVarIO role >>= \case
CPRNone -> do
logError "Unauthorized control port command"
hPutStrLn h "AUTH"
r -> do
#if MIN_VERSION_base(4,18,0)
threads <- liftIO listThreads
hPutStrLn h $ "Threads: " <> show (length threads)
#else
hPutStrLn h "Threads: not available on GHC 8.10"
#endif
let Server {subscribers, ntfSubscribers} = srv
activeClients <- getServerClients srv
hPutStrLn h $ "Clients: " <> show (IM.size activeClients)
when (r == CPRAdmin) $ do
clQs <- clientTBQueueLengths' activeClients
hPutStrLn h $ "Client queues (rcvQ, sndQ, msgQ): " <> show clQs
(smpSubCnt, smpSubCntByGroup, smpClCnt, smpClQs) <- countClientSubs subscriptions (Just countSMPSubs) activeClients
hPutStrLn h $ "SMP subscriptions (via clients): " <> show smpSubCnt
hPutStrLn h $ "SMP subscriptions (by group: NoSub, SubPending, SubThread, ProhibitSub): " <> show smpSubCntByGroup
hPutStrLn h $ "SMP subscribed clients (via clients): " <> show smpClCnt
hPutStrLn h $ "SMP subscribed clients queues (via clients, rcvQ, sndQ, msgQ): " <> show smpClQs
(ntfSubCnt, _, ntfClCnt, ntfClQs) <- countClientSubs ntfSubscriptions Nothing activeClients
hPutStrLn h $ "Ntf subscriptions (via clients): " <> show ntfSubCnt
hPutStrLn h $ "Ntf subscribed clients (via clients): " <> show ntfClCnt
hPutStrLn h $ "Ntf subscribed clients queues (via clients, rcvQ, sndQ, msgQ): " <> show ntfClQs
putSubscribersInfo "SMP" subscribers False
putSubscribersInfo "Ntf" ntfSubscribers True
where
putSubscribersInfo :: String -> ServerSubscribers -> Bool -> IO ()
putSubscribersInfo protoName ServerSubscribers {queueSubscribers, subClients} showIds = do
activeSubs <- getSubscribedClients queueSubscribers
hPutStrLn h $ protoName <> " subscriptions: " <> show (M.size activeSubs)
clnts <- countSubClients activeSubs
hPutStrLn h $ protoName <> " subscribed clients: " <> show (IS.size clnts) <> (if showIds then " " <> show (IS.toList clnts) else "")
clnts' <- readTVarIO subClients
hPutStrLn h $ protoName <> " subscribed clients count 2: " <> show (IS.size clnts') <> (if showIds then " " <> show clnts' else "")
where
countSubClients :: M.Map QueueId (TVar (Maybe AClient)) -> IO IS.IntSet
countSubClients = foldM (\ !s c -> maybe s ((`IS.insert` s) . clientId') <$> readTVarIO c) IS.empty
countClientSubs :: (forall s. Client s -> TMap QueueId a) -> Maybe (M.Map QueueId a -> IO (Int, Int, Int, Int)) -> IM.IntMap AClient -> IO (Int, (Int, Int, Int, Int), Int, (Natural, Natural, Natural))
countClientSubs subSel countSubs_ = foldM addSubs (0, (0, 0, 0, 0), 0, (0, 0, 0))
where
addSubs :: (Int, (Int, Int, Int, Int), Int, (Natural, Natural, Natural)) -> AClient -> IO (Int, (Int, Int, Int, Int), Int, (Natural, Natural, Natural))
addSubs (!subCnt, cnts@(!c1, !c2, !c3, !c4), !clCnt, !qs) acl@(AClient _ _ cl) = do
subs <- readTVarIO $ subSel cl
cnts' <- case countSubs_ of
Nothing -> pure cnts
Just countSubs -> do
(c1', c2', c3', c4') <- countSubs subs
pure (c1 + c1', c2 + c2', c3 + c3', c4 + c4')
let cnt = M.size subs
clCnt' = if cnt == 0 then clCnt else clCnt + 1
qs' <- if cnt == 0 then pure qs else addQueueLengths qs acl
pure (subCnt + cnt, cnts', clCnt', qs')
clientTBQueueLengths' :: Foldable t => t AClient -> IO (Natural, Natural, Natural)
clientTBQueueLengths' = foldM addQueueLengths (0, 0, 0)
addQueueLengths (!rl, !sl, !ml) (AClient _ _ cl) = do
(rl', sl', ml') <- queueLengths cl
pure (rl + rl', sl + sl', ml + ml')
queueLengths Client {rcvQ, sndQ, msgQ} = do
rl <- atomically $ lengthTBQueue rcvQ
sl <- atomically $ lengthTBQueue sndQ
ml <- atomically $ lengthTBQueue msgQ
pure (rl, sl, ml)
countSMPSubs :: M.Map QueueId Sub -> IO (Int, Int, Int, Int)
countSMPSubs = foldM countSubs (0, 0, 0, 0)
where
countSubs (c1, c2, c3, c4) Sub {subThread} = case subThread of
ServerSub t -> do
st <- readTVarIO t
pure $ case st of
NoSub -> (c1 + 1, c2, c3, c4)
SubPending -> (c1, c2 + 1, c3, c4)
SubThread _ -> (c1, c2, c3 + 1, c4)
ProhibitSub -> pure (c1, c2, c3, c4 + 1)
CPDelete sId -> withUserRole $ unliftIO u $ do
AMS _ _ st <- asks msgStore
r <- liftIO $ runExceptT $ do
q <- ExceptT $ getQueue st SSender sId
ExceptT $ deleteQueueSize st q
case r of
Left e -> liftIO $ hPutStrLn h $ "error: " <> show e
Right (qr, numDeleted) -> do
updateDeletedStats qr
liftIO $ hPutStrLn h $ "ok, " <> show numDeleted <> " messages deleted"
CPStatus sId -> withUserRole $ unliftIO u $ do
AMS _ _ st <- asks msgStore
q <- liftIO $ getQueueRec st SSender sId
liftIO $ hPutStrLn h $ case q of
Left e -> "error: " <> show e
Right (_, QueueRec {queueMode, status, updatedAt}) ->
"status: " <> show status <> ", updatedAt: " <> show updatedAt <> ", queueMode: " <> show queueMode
CPBlock sId info -> withUserRole $ unliftIO u $ do
AMS _ _ (st :: s) <- asks msgStore
r <- liftIO $ runExceptT $ do
q <- ExceptT $ getQueue st SSender sId
ExceptT $ blockQueue (queueStore st) q info
case r of
Left e -> liftIO $ hPutStrLn h $ "error: " <> show e
Right () -> do
incStat . qBlocked =<< asks serverStats
liftIO $ hPutStrLn h "ok"
CPUnblock sId -> withUserRole $ unliftIO u $ do
AMS _ _ (st :: s) <- asks msgStore
r <- liftIO $ runExceptT $ do
q <- ExceptT $ getQueue st SSender sId
ExceptT $ unblockQueue (queueStore st) q
liftIO $ hPutStrLn h $ case r of
Left e -> "error: " <> show e
Right () -> "ok"
CPSave -> withAdminRole $ withLock' (savingLock srv) "control" $ do
hPutStrLn h "saving server state..."
unliftIO u $ saveServer False
hPutStrLn h "server state saved!"
CPHelp -> hPutStrLn h "commands: stats, stats-rts, clients, sockets, socket-threads, threads, server-info, delete, save, help, quit"
CPQuit -> pure ()
CPSkip -> pure ()
where
withUserRole action = readTVarIO role >>= \case
CPRAdmin -> action
CPRUser -> action
_ -> do
logError "Unauthorized control port command"
hPutStrLn h "AUTH"
withAdminRole action = readTVarIO role >>= \case
CPRAdmin -> action
_ -> do
logError "Unauthorized control port command"
hPutStrLn h "AUTH"
runClientTransport :: Transport c => THandleSMP c 'TServer -> M ()
runClientTransport h@THandle {params = thParams@THandleParams {thVersion, sessionId}} = do
q <- asks $ tbqSize . config
ts <- liftIO getSystemTime
nextClientId <- asks clientSeq
clientId <- atomically $ stateTVar nextClientId $ \next -> (next, next + 1)
AMS qt mt ms <- asks msgStore
c <- liftIO $ newClient qt mt clientId q thVersion sessionId ts
runClientThreads qt mt ms c `finally` clientDisconnected c
where
runClientThreads :: MsgStoreClass (MsgStore qs ms) => SQSType qs -> SMSType ms -> MsgStore qs ms -> Client (MsgStore qs ms) -> M ()
runClientThreads qt mt ms c = do
s <- asks server
whenM (liftIO $ insertServerClient (AClient qt mt c) s) $ do
expCfg <- asks $ inactiveClientExpiration . config
th <- newMVar h -- put TH under a fair lock to interleave messages and command responses
labelMyThread . B.unpack $ "client $" <> encode sessionId
raceAny_ $ [liftIO $ send th c, liftIO $ sendMsg th c, client thParams s ms c, receive h ms c] <> disconnectThread_ c s expCfg
disconnectThread_ :: Client s -> Server -> Maybe ExpirationConfig -> [M ()]
disconnectThread_ c s (Just expCfg) = [liftIO $ disconnectTransport h (rcvActiveAt c) (sndActiveAt c) expCfg (noSubscriptions c s)]
disconnectThread_ _ _ _ = []
noSubscriptions Client {clientId} Server {subscribers, ntfSubscribers} = do
hasSubs <- IS.member clientId <$> readTVarIO (subClients subscribers)
if hasSubs
then pure False
else not . IS.member clientId <$> readTVarIO (subClients ntfSubscribers)
clientDisconnected :: Client s -> M ()
clientDisconnected c@Client {clientId, subscriptions, ntfSubscriptions, connected, sessionId, endThreads} = do
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " disc"
-- these can be in separate transactions,
-- because the client already disconnected and they won't change
atomically $ writeTVar connected False
subs <- atomically $ swapTVar subscriptions M.empty
ntfSubs <- atomically $ swapTVar ntfSubscriptions M.empty
liftIO $ mapM_ cancelSub subs
whenM (asks serverActive >>= readTVarIO) $ do
srv@Server {subscribers, ntfSubscribers} <- asks server
liftIO $ updateSubscribers subs subscribers
liftIO $ updateSubscribers ntfSubs ntfSubscribers
liftIO $ deleteServerClient clientId srv
tIds <- atomically $ swapTVar endThreads IM.empty
liftIO $ mapM_ (mapM_ killThread <=< deRefWeak) tIds
where
updateSubscribers :: M.Map QueueId a -> ServerSubscribers -> IO ()
updateSubscribers subs ServerSubscribers {queueSubscribers, subClients} = do
mapM_ (\qId -> deleteSubcribedClient qId c queueSubscribers) (M.keys subs)
atomically $ modifyTVar' subClients $ IS.delete clientId
cancelSub :: Sub -> IO ()
cancelSub s = case subThread s of
ServerSub st ->
readTVarIO st >>= \case
SubThread t -> liftIO $ deRefWeak t >>= mapM_ killThread
_ -> pure ()
ProhibitSub -> pure ()
receive :: forall c s. (Transport c, MsgStoreClass s) => THandleSMP c 'TServer -> s -> Client s -> M ()
receive h@THandle {params = THandleParams {thAuth}} ms Client {rcvQ, sndQ, rcvActiveAt, sessionId} = do
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " receive"
sa <- asks serverActive
forever $ do
ts <- L.toList <$> liftIO (tGet h)
unlessM (readTVarIO sa) $ throwIO $ userError "server stopped"
atomically . (writeTVar rcvActiveAt $!) =<< liftIO getSystemTime
stats <- asks serverStats
(errs, cmds) <- partitionEithers <$> mapM (cmdAction stats) ts
updateBatchStats stats cmds
write sndQ errs
write rcvQ cmds
where
updateBatchStats :: ServerStats -> [(Maybe (StoreQueue s, QueueRec), Transmission Cmd)] -> M ()
updateBatchStats stats = \case
(_, (_, _, (Cmd _ cmd))) : _ -> do
let sel_ = case cmd of
SUB -> Just qSubAllB
DEL -> Just qDeletedAllB
NSUB -> Just ntfSubB
NDEL -> Just ntfDeletedB
_ -> Nothing
mapM_ (\sel -> incStat $ sel stats) sel_
[] -> pure ()
cmdAction :: ServerStats -> SignedTransmission ErrorType Cmd -> M (Either (Transmission BrokerMsg) (Maybe (StoreQueue s, QueueRec), Transmission Cmd))
cmdAction stats (tAuth, authorized, (corrId, entId, cmdOrError)) =
case cmdOrError of
Left e -> pure $ Left (corrId, entId, ERR e)
Right cmd -> verified =<< verifyTransmission ms ((,C.cbNonce (bs corrId)) <$> thAuth) tAuth authorized entId cmd
where
verified = \case
VRVerified q -> pure $ Right (q, (corrId, entId, cmd))
VRFailed -> do
case cmd of
Cmd _ SEND {} -> incStat $ msgSentAuth stats
Cmd _ SUB -> incStat $ qSubAuth stats
Cmd _ NSUB -> incStat $ ntfSubAuth stats
Cmd _ GET -> incStat $ msgGetAuth stats
_ -> pure ()
pure $ Left (corrId, entId, ERR AUTH)
write q = mapM_ (atomically . writeTBQueue q) . L.nonEmpty
send :: Transport c => MVar (THandleSMP c 'TServer) -> Client s -> IO ()
send th c@Client {sndQ, msgQ, sessionId} = do
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " send"
forever $ atomically (readTBQueue sndQ) >>= sendTransmissions
where
sendTransmissions :: NonEmpty (Transmission BrokerMsg) -> IO ()
sendTransmissions ts
| L.length ts <= 2 = tSend th c ts
| otherwise = do
let (msgs_, ts') = mapAccumR splitMessages [] ts
-- If the request had batched subscriptions and L.length ts > 2
-- this will reply OK to all SUBs in the first batched transmission,
-- to reduce client timeouts.
tSend th c ts'
-- After that all messages will be sent in separate transmissions,
-- without any client response timeouts, and allowing them to interleave
-- with other requests responses.
mapM_ (atomically . writeTBQueue msgQ) $ L.nonEmpty msgs_
where
splitMessages :: [Transmission BrokerMsg] -> Transmission BrokerMsg -> ([Transmission BrokerMsg], Transmission BrokerMsg)
splitMessages msgs t@(corrId, entId, cmd) = case cmd of
-- replace MSG response with OK, accumulating MSG in a separate list.
MSG {} -> ((CorrId "", entId, cmd) : msgs, (corrId, entId, OK))
_ -> (msgs, t)
sendMsg :: Transport c => MVar (THandleSMP c 'TServer) -> Client s -> IO ()
sendMsg th c@Client {msgQ, sessionId} = do
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " sendMsg"
forever $ atomically (readTBQueue msgQ) >>= mapM_ (\t -> tSend th c [t])
tSend :: Transport c => MVar (THandleSMP c 'TServer) -> Client s -> NonEmpty (Transmission BrokerMsg) -> IO ()
tSend th Client {sndActiveAt} ts = do
withMVar th $ \h@THandle {params} ->
void . tPut h $ L.map (\t -> Right (Nothing, encodeTransmission params t)) ts
atomically . (writeTVar sndActiveAt $!) =<< liftIO getSystemTime
disconnectTransport :: Transport c => THandle v c 'TServer -> TVar SystemTime -> TVar SystemTime -> ExpirationConfig -> IO Bool -> IO ()
disconnectTransport THandle {connection, params = THandleParams {sessionId}} rcvActiveAt sndActiveAt expCfg noSubscriptions = do
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " disconnectTransport"
loop
where
loop = do
threadDelay' $ checkInterval expCfg * 1000000
ifM noSubscriptions checkExpired loop
checkExpired = do
old <- expireBeforeEpoch expCfg
ts <- max <$> readTVarIO rcvActiveAt <*> readTVarIO sndActiveAt
if systemSeconds ts < old then closeConnection connection else loop
data VerificationResult s = VRVerified (Maybe (StoreQueue s, QueueRec)) | VRFailed
-- This function verifies queue command authorization, with the objective to have constant time between the three AUTH error scenarios:
-- - the queue and party key exist, and the provided authorization has type matching queue key, but it is made with the different key.
-- - the queue and party key exist, but the provided authorization has incorrect type.
-- - the queue or party key do not exist.
-- In all cases, the time of the verification should depend only on the provided authorization type,
-- a dummy key is used to run verification in the last two cases, and failure is returned irrespective of the result.
verifyTransmission :: forall s. MsgStoreClass s => s -> Maybe (THandleAuth 'TServer, C.CbNonce) -> Maybe TransmissionAuth -> ByteString -> QueueId -> Cmd -> M (VerificationResult s)
verifyTransmission ms auth_ tAuth authorized queueId cmd =
case cmd of
Cmd SRecipient (NEW NewQueueReq {rcvAuthKey = k}) -> pure $ Nothing `verifiedWith` k
Cmd SRecipient _ -> verifyQueue (\q -> Just q `verifiedWithKeys` recipientKeys (snd q)) <$> get SRecipient
Cmd SSender (SKEY k) -> verifySecure SSender k
-- SEND will be accepted without authorization before the queue is secured with KEY, SKEY or LSKEY command
Cmd SSender SEND {} -> verifyQueue (\q -> if maybe (isNothing tAuth) verify (senderKey $ snd q) then VRVerified (Just q) else VRFailed) <$> get SSender
Cmd SSender PING -> pure $ VRVerified Nothing
Cmd SSender RFWD {} -> pure $ VRVerified Nothing
Cmd SSenderLink (LKEY k) -> verifySecure SSenderLink k
Cmd SSenderLink LGET -> verifyQueue (\q -> if isContactQueue (snd q) then VRVerified (Just q) else VRFailed) <$> get SSenderLink
-- NSUB will not be accepted without authorization
Cmd SNotifier NSUB -> verifyQueue (\q -> maybe dummyVerify (\n -> Just q `verifiedWith` notifierKey n) (notifier $ snd q)) <$> get SNotifier
Cmd SProxiedClient _ -> pure $ VRVerified Nothing
where
verify = verifyCmdAuthorization auth_ tAuth authorized
dummyVerify = verify (dummyAuthKey tAuth) `seq` VRFailed
verifyQueue :: ((StoreQueue s, QueueRec) -> VerificationResult s) -> Either ErrorType (StoreQueue s, QueueRec) -> VerificationResult s
verifyQueue = either (const dummyVerify)
verifySecure :: DirectParty p => SParty p -> SndPublicAuthKey -> M (VerificationResult s)
verifySecure p k = verifyQueue (\q -> if k `allowedKey` snd q then Just q `verifiedWith` k else dummyVerify) <$> get p
verifiedWith :: Maybe (StoreQueue s, QueueRec) -> C.APublicAuthKey -> VerificationResult s
verifiedWith q_ k = if verify k then VRVerified q_ else VRFailed
verifiedWithKeys :: Maybe (StoreQueue s, QueueRec) -> NonEmpty C.APublicAuthKey -> VerificationResult s
verifiedWithKeys q_ ks = if any verify ks then VRVerified q_ else VRFailed
allowedKey k = \case
QueueRec {queueMode = Just QMMessaging, senderKey} -> maybe True (k ==) senderKey
_ -> False
get :: DirectParty p => SParty p -> M (Either ErrorType (StoreQueue s, QueueRec))
get party = liftIO $ getQueueRec ms party queueId
isContactQueue :: QueueRec -> Bool
isContactQueue QueueRec {queueMode, senderKey} = case queueMode of
Just QMMessaging -> False
Just QMContact -> True
Nothing -> isNothing senderKey -- for backward compatibility with pre-SKEY contact addresses
verifyCmdAuthorization :: Maybe (THandleAuth 'TServer, C.CbNonce) -> Maybe TransmissionAuth -> ByteString -> C.APublicAuthKey -> Bool
verifyCmdAuthorization auth_ tAuth authorized key = maybe False (verify key) tAuth
where
verify :: C.APublicAuthKey -> TransmissionAuth -> Bool
verify (C.APublicAuthKey a k) = \case
TASignature (C.ASignature a' s) -> case testEquality a a' of
Just Refl -> C.verify' k s authorized
_ -> C.verify' (dummySignKey a') s authorized `seq` False
TAAuthenticator s -> case a of
C.SX25519 -> verifyCmdAuth auth_ k s authorized
_ -> verifyCmdAuth auth_ dummyKeyX25519 s authorized `seq` False
verifyCmdAuth :: Maybe (THandleAuth 'TServer, C.CbNonce) -> C.PublicKeyX25519 -> C.CbAuthenticator -> ByteString -> Bool
verifyCmdAuth auth_ k authenticator authorized = case auth_ of
Just (THAuthServer {serverPrivKey = pk}, nonce) -> C.cbVerify k pk nonce authenticator authorized
Nothing -> False
dummyVerifyCmd :: Maybe (THandleAuth 'TServer, C.CbNonce) -> ByteString -> TransmissionAuth -> Bool
dummyVerifyCmd auth_ authorized = \case
TASignature (C.ASignature a s) -> C.verify' (dummySignKey a) s authorized
TAAuthenticator s -> verifyCmdAuth auth_ dummyKeyX25519 s authorized
-- These dummy keys are used with `dummyVerify` function to mitigate timing attacks
-- by having the same time of the response whether a queue exists or nor, for all valid key/signature sizes
dummySignKey :: C.SignatureAlgorithm a => C.SAlgorithm a -> C.PublicKey a
dummySignKey = \case
C.SEd25519 -> dummyKeyEd25519
C.SEd448 -> dummyKeyEd448
dummyAuthKey :: Maybe TransmissionAuth -> C.APublicAuthKey
dummyAuthKey = \case
Just (TASignature (C.ASignature a _)) -> case a of
C.SEd25519 -> C.APublicAuthKey C.SEd25519 dummyKeyEd25519
C.SEd448 -> C.APublicAuthKey C.SEd448 dummyKeyEd448
_ -> C.APublicAuthKey C.SX25519 dummyKeyX25519
dummyKeyEd25519 :: C.PublicKey 'C.Ed25519
dummyKeyEd25519 = "MCowBQYDK2VwAyEA139Oqs4QgpqbAmB0o7rZf6T19ryl7E65k4AYe0kE3Qs="
dummyKeyEd448 :: C.PublicKey 'C.Ed448
dummyKeyEd448 = "MEMwBQYDK2VxAzoA6ibQc9XpkSLtwrf7PLvp81qW/etiumckVFImCMRdftcG/XopbOSaq9qyLhrgJWKOLyNrQPNVvpMA"
dummyKeyX25519 :: C.PublicKey 'C.X25519
dummyKeyX25519 = "MCowBQYDK2VuAyEA4JGSMYht18H4mas/jHeBwfcM7jLwNYJNOAhi2/g4RXg="
forkClient :: Client s -> String -> M () -> M ()
forkClient Client {endThreads, endThreadSeq} label action = do
tId <- atomically $ stateTVar endThreadSeq $ \next -> (next, next + 1)
t <- forkIO $ do
labelMyThread label
action `finally` atomically (modifyTVar' endThreads $ IM.delete tId)
mkWeakThreadId t >>= atomically . modifyTVar' endThreads . IM.insert tId
client :: forall s. MsgStoreClass s => THandleParams SMPVersion 'TServer -> Server -> s -> Client s -> M ()
client
thParams'
Server {subscribers, ntfSubscribers}
ms
clnt@Client {clientId, subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId, procThreads} = do
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " commands"
let THandleParams {thVersion} = thParams'
forever $
atomically (readTBQueue rcvQ)
>>= mapM (processCommand thVersion)
>>= mapM_ reply . L.nonEmpty . catMaybes . L.toList
where
reply :: MonadIO m => NonEmpty (Transmission BrokerMsg) -> m ()
reply = atomically . writeTBQueue sndQ
processProxiedCmd :: Transmission (Command 'ProxiedClient) -> M (Maybe (Transmission BrokerMsg))
processProxiedCmd (corrId, EntityId sessId, command) = (corrId,EntityId sessId,) <$$> case command of
PRXY srv auth -> ifM allowProxy getRelay (pure $ Just $ ERR $ PROXY BASIC_AUTH)
where
allowProxy = do
ServerConfig {allowSMPProxy, newQueueBasicAuth} <- asks config
pure $ allowSMPProxy && maybe True ((== auth) . Just) newQueueBasicAuth
getRelay = do
ProxyAgent {smpAgent = a} <- asks proxyAgent
liftIO (getConnectedSMPServerClient a srv) >>= \case
Just r -> Just <$> proxyServerResponse a r
Nothing ->
forkProxiedCmd $
liftIO (runExceptT (getSMPServerClient'' a srv) `catch` (pure . Left . PCEIOError))
>>= proxyServerResponse a
proxyServerResponse :: SMPClientAgent -> Either SMPClientError (OwnServer, SMPClient) -> M BrokerMsg
proxyServerResponse a smp_ = do
ServerStats {pRelays, pRelaysOwn} <- asks serverStats
let inc = mkIncProxyStats pRelays pRelaysOwn
case smp_ of
Right (own, smp) -> do
inc own pRequests
case proxyResp smp of
r@PKEY {} -> r <$ inc own pSuccesses
r -> r <$ inc own pErrorsCompat
Left e -> do
let own = isOwnServer a srv
inc own pRequests
inc own $ if temporaryClientError e then pErrorsConnect else pErrorsOther
logWarn $ "Error connecting: " <> decodeLatin1 (strEncode $ host srv) <> " " <> tshow e
pure . ERR $ smpProxyError e
where
proxyResp smp =
let THandleParams {sessionId = srvSessId, thVersion, thServerVRange, thAuth} = thParams smp
in case compatibleVRange thServerVRange proxiedSMPRelayVRange of
-- Cap the destination relay version range to prevent client version fingerprinting.
-- See comment for proxiedSMPRelayVersion.
Just (Compatible vr) | thVersion >= sendingProxySMPVersion -> case thAuth of
Just THAuthClient {serverCertKey} -> PKEY srvSessId vr serverCertKey
Nothing -> ERR $ transportErr TENoServerAuth
_ -> ERR $ transportErr TEVersion
PFWD fwdV pubKey encBlock -> do
ProxyAgent {smpAgent = a} <- asks proxyAgent
ServerStats {pMsgFwds, pMsgFwdsOwn} <- asks serverStats
let inc = mkIncProxyStats pMsgFwds pMsgFwdsOwn
liftIO (lookupSMPServerClient a sessId) >>= \case
Just (own, smp) -> do
inc own pRequests
if v >= sendingProxySMPVersion
then forkProxiedCmd $ do
liftIO (runExceptT (forwardSMPTransmission smp corrId fwdV pubKey encBlock) `catch` (pure . Left . PCEIOError)) >>= \case
Right r -> PRES r <$ inc own pSuccesses
Left e -> ERR (smpProxyError e) <$ case e of
PCEProtocolError {} -> inc own pSuccesses
_ -> inc own pErrorsOther
else Just (ERR $ transportErr TEVersion) <$ inc own pErrorsCompat
where
THandleParams {thVersion = v} = thParams smp
Nothing -> inc False pRequests >> inc False pErrorsConnect $> Just (ERR $ PROXY NO_SESSION)
where
forkProxiedCmd :: M BrokerMsg -> M (Maybe BrokerMsg)
forkProxiedCmd cmdAction = do
bracket_ wait signal . forkClient clnt (B.unpack $ "client $" <> encode sessionId <> " proxy") $ do
-- commands MUST be processed under a reasonable timeout or the client would halt
cmdAction >>= \t -> reply [(corrId, EntityId sessId, t)]
pure Nothing
where
wait = do
ServerConfig {serverClientConcurrency} <- asks config
atomically $ do
used <- readTVar procThreads
when (used >= serverClientConcurrency) retry
writeTVar procThreads $! used + 1
signal = atomically $ modifyTVar' procThreads (\t -> t - 1)
transportErr :: TransportError -> ErrorType
transportErr = PROXY . BROKER . TRANSPORT
mkIncProxyStats :: MonadIO m => ProxyStats -> ProxyStats -> OwnServer -> (ProxyStats -> IORef Int) -> m ()
mkIncProxyStats ps psOwn own sel = do
incStat $ sel ps
when own $ incStat $ sel psOwn
processCommand :: VersionSMP -> (Maybe (StoreQueue s, QueueRec), Transmission Cmd) -> M (Maybe (Transmission BrokerMsg))
processCommand clntVersion (q_, (corrId, entId, cmd)) = case cmd of
Cmd SProxiedClient command -> processProxiedCmd (corrId, entId, command)
Cmd SSender command -> Just <$> case command of
SKEY k -> withQueue $ \q qr -> checkMode QMMessaging qr $ secureQueue_ q k
SEND flags msgBody -> withQueue_ False $ sendMessage flags msgBody
PING -> pure (corrId, NoEntity, PONG)
RFWD encBlock -> (corrId, NoEntity,) <$> processForwardedCommand encBlock
Cmd SSenderLink command -> Just <$> case command of
LKEY k -> withQueue $ \q qr -> checkMode QMMessaging qr $ secureQueue_ q k $>> getQueueLink_ q qr
LGET -> withQueue $ \q qr -> checkContact qr $ getQueueLink_ q qr
Cmd SNotifier NSUB -> Just <$> subscribeNotifications
Cmd SRecipient command ->
Just <$> case command of
NEW nqr@NewQueueReq {auth_} ->
ifM allowNew (createQueue nqr) (pure (corrId, entId, ERR AUTH))
where
allowNew = do
ServerConfig {allowNewQueues, newQueueBasicAuth} <- asks config
pure $ allowNewQueues && maybe True ((== auth_) . Just) newQueueBasicAuth
SUB -> withQueue subscribeQueue
GET -> withQueue getMessage
ACK msgId -> withQueue $ acknowledgeMsg msgId
KEY sKey -> withQueue $ \q _ -> either err (corrId,entId,) <$> secureQueue_ q sKey
RKEY rKeys -> withQueue $ \q qr -> checkMode QMContact qr $ OK <$$ liftIO (updateKeys (queueStore ms) q rKeys)
LSET lnkId d ->
withQueue $ \q qr -> checkContact qr $ liftIO $ case queueData qr of
Just (lnkId', _) | lnkId' /= lnkId -> pure $ Left AUTH
_ -> OK <$$ addQueueLinkData (queueStore ms) q lnkId d
LDEL ->
withQueue $ \q qr -> checkContact qr $ liftIO $ case queueData qr of
Just _ -> OK <$$ deleteQueueLinkData (queueStore ms) q
Nothing -> pure $ Right OK
NKEY nKey dhKey -> withQueue $ \q _ -> addQueueNotifier_ q nKey dhKey
NDEL -> withQueue $ \q _ -> deleteQueueNotifier_ q
OFF -> maybe (pure $ err INTERNAL) suspendQueue_ q_
DEL -> maybe (pure $ err INTERNAL) delQueueAndMsgs q_
QUE -> withQueue $ \q qr -> (corrId,entId,) <$> getQueueInfo q qr
where
createQueue :: NewQueueReq -> M (Transmission BrokerMsg)
createQueue NewQueueReq {rcvAuthKey, rcvDhKey, subMode, queueReqData} = time "NEW" $ do
g <- asks random
idSize <- asks $ queueIdBytes . config
updatedAt <- Just <$> liftIO getSystemDate
(rcvPublicDhKey, privDhKey) <- atomically $ C.generateKeyPair g
-- TODO [notifications]
-- ntfKeys_ <- forM ntfCreds $ \(NewNtfCreds notifierKey dhKey) -> do
-- (ntfPubDhKey, ntfPrivDhKey) <- atomically $ C.generateKeyPair g
-- pure (notifierKey, C.dh' dhKey ntfPrivDhKey, ntfPubDhKey)
let randId = EntityId <$> atomically (C.randomBytes idSize g)
-- TODO [notifications] the remaining 24 bytes are reserver for notifier ID
sndId' = B.take 24 $ C.sha3_384 (bs corrId)
tryCreate 0 = pure $ ERR INTERNAL
tryCreate n = do
(sndId, clntIds, queueData) <- case queueReqData of
Just (QRMessaging (Just (sId, d))) -> (\linkId -> (sId, True, Just (linkId, d))) <$> randId
Just (QRContact (Just (linkId, (sId, d)))) -> pure (sId, True, Just (linkId, d))
_ -> (,False,Nothing) <$> randId
-- The condition that client-provided sender ID must match hash of correlation ID
-- prevents "ID oracle" attack, when creating queue with supplied ID can be used to check
-- if queue with this ID still exists.
if clntIds && unEntityId sndId /= sndId'
then pure $ ERR $ CMD PROHIBITED
else do
rcvId <- randId
-- TODO [notifications]
-- ntf <- forM ntfKeys_ $ \(notifierKey, rcvNtfDhSecret, rcvPubDhKey) -> do
-- notifierId <- randId
-- pure (NtfCreds {notifierId, notifierKey, rcvNtfDhSecret}, ServerNtfCreds notifierId rcvPubDhKey)
let queueMode = queueReqMode <$> queueReqData
qr =
QueueRec
{ senderId = sndId,
recipientKeys = [rcvAuthKey],
rcvDhSecret = C.dh' rcvDhKey privDhKey,
senderKey = Nothing,
queueMode,
queueData,
-- TODO [notifications]
notifier = Nothing, -- fst <$> ntf,
status = EntityActive,
updatedAt
}
liftIO (addQueue ms rcvId qr) >>= \case
Left DUPLICATE_ -- TODO [short links] possibly, we somehow need to understand which IDs caused collision to retry if it's not client-supplied?
| clntIds -> pure $ ERR AUTH -- no retry on collision if sender ID is client-supplied
| otherwise -> tryCreate (n - 1)
Left e -> pure $ ERR e
Right q -> do
stats <- asks serverStats
incStat $ qCreated stats
incStat $ qCount stats
-- TODO [notifications]
-- when (isJust ntf) $ incStat $ ntfCreated stats
case subMode of
SMOnlyCreate -> pure ()
SMSubscribe -> void $ subscribeQueue q qr
pure $ IDS QIK {rcvId, sndId, rcvPublicDhKey, queueMode, linkId = fst <$> queueData} -- , serverNtfCreds = snd <$> ntf
(corrId,entId,) <$> tryCreate (3 :: Int)
-- this check allows to support contact queues created prior to SKEY,
-- using `queueMode == Just QMContact` would prevent it, as they have queueMode `Nothing`.
checkContact :: QueueRec -> M (Either ErrorType BrokerMsg) -> M (Transmission BrokerMsg)
checkContact qr a =
either err (corrId,entId,)
<$> if isContactQueue qr then a else pure $ Left AUTH
checkMode :: QueueMode -> QueueRec -> M (Either ErrorType BrokerMsg) -> M (Transmission BrokerMsg)
checkMode qm QueueRec {queueMode} a =
either err (corrId,entId,)
<$> if queueMode == Just qm then a else pure $ Left AUTH
secureQueue_ :: StoreQueue s -> SndPublicAuthKey -> M (Either ErrorType BrokerMsg)
secureQueue_ q sKey = do
liftIO (secureQueue (queueStore ms) q sKey)
$>> (asks serverStats >>= incStat . qSecured) $> Right OK
getQueueLink_ :: StoreQueue s -> QueueRec -> M (Either ErrorType BrokerMsg)
getQueueLink_ q qr = liftIO $ LNK (senderId qr) <$$> getQueueLinkData (queueStore ms) q entId
addQueueNotifier_ :: StoreQueue s -> NtfPublicAuthKey -> RcvNtfPublicDhKey -> M (Transmission BrokerMsg)
addQueueNotifier_ q notifierKey dhKey = time "NKEY" $ do
(rcvPublicDhKey, privDhKey) <- atomically . C.generateKeyPair =<< asks random
let rcvNtfDhSecret = C.dh' dhKey privDhKey
(corrId,entId,) <$> addNotifierRetry 3 rcvPublicDhKey rcvNtfDhSecret
where
addNotifierRetry :: Int -> RcvNtfPublicDhKey -> RcvNtfDhSecret -> M BrokerMsg
addNotifierRetry 0 _ _ = pure $ ERR INTERNAL
addNotifierRetry n rcvPublicDhKey rcvNtfDhSecret = do
notifierId <- randomId =<< asks (queueIdBytes . config)
let ntfCreds = NtfCreds {notifierId, notifierKey, rcvNtfDhSecret}
liftIO (addQueueNotifier (queueStore ms) q ntfCreds) >>= \case
Left DUPLICATE_ -> addNotifierRetry (n - 1) rcvPublicDhKey rcvNtfDhSecret
Left e -> pure $ ERR e
Right nId_ -> do
incStat . ntfCreated =<< asks serverStats
forM_ nId_ $ \nId -> atomically $ writeTQueue (subQ ntfSubscribers) (nId, clientId, False)
pure $ NID notifierId rcvPublicDhKey
deleteQueueNotifier_ :: StoreQueue s -> M (Transmission BrokerMsg)
deleteQueueNotifier_ q =
liftIO (deleteQueueNotifier (queueStore ms) q) >>= \case
Right (Just nId) -> do
-- Possibly, the same should be done if the queue is suspended, but currently we do not use it
stats <- asks serverStats
deleted <- asks ntfStore >>= liftIO . (`deleteNtfs` nId)
when (deleted > 0) $ liftIO $ atomicModifyIORef'_ (ntfCount stats) (subtract deleted)
atomically $ writeTQueue (subQ ntfSubscribers) (nId, clientId, False)
incStat $ ntfDeleted stats
pure ok
Right Nothing -> pure ok
Left e -> pure $ err e
suspendQueue_ :: (StoreQueue s, QueueRec) -> M (Transmission BrokerMsg)
suspendQueue_ (q, _) = liftIO $ either err (const ok) <$> suspendQueue (queueStore ms) q
subscribeQueue :: StoreQueue s -> QueueRec -> M (Transmission BrokerMsg)
subscribeQueue q qr =
liftIO (TM.lookupIO rId subscriptions) >>= \case
Nothing -> newSub >>= deliver True
Just s@Sub {subThread} -> do
stats <- asks serverStats
case subThread of
ProhibitSub -> do
-- cannot use SUB in the same connection where GET was used
incStat $ qSubProhibited stats
pure (corrId, rId, ERR $ CMD PROHIBITED)
_ -> do
incStat $ qSubDuplicate stats
atomically (tryTakeTMVar $ delivered s) >> deliver False s
where
rId = recipientId q
newSub :: M Sub
newSub = time "SUB newSub" . atomically $ do
writeTQueue (subQ subscribers) (rId, clientId, True)
sub <- newSubscription NoSub
TM.insert rId sub subscriptions
pure sub
deliver :: Bool -> Sub -> M (Transmission BrokerMsg)
deliver inc sub = do
stats <- asks serverStats
fmap (either (\e -> (corrId, rId, ERR e)) id) $ liftIO $ runExceptT $ do
msg_ <- tryPeekMsg ms q
liftIO $ when (inc && isJust msg_) $ incStat (qSub stats)
liftIO $ deliverMessage "SUB" qr rId sub msg_
-- clients that use GET are not added to server subscribers
getMessage :: StoreQueue s -> QueueRec -> M (Transmission BrokerMsg)
getMessage q qr = time "GET" $ do
atomically (TM.lookup entId subscriptions) >>= \case
Nothing ->
atomically newSub >>= (`getMessage_` Nothing)
Just s@Sub {subThread} ->
case subThread of
ProhibitSub ->
atomically (tryTakeTMVar $ delivered s)
>>= getMessage_ s
-- cannot use GET in the same connection where there is an active subscription
_ -> do
stats <- asks serverStats
incStat $ msgGetProhibited stats
pure $ err $ CMD PROHIBITED
where
newSub :: STM Sub
newSub = do
s <- newProhibitedSub
TM.insert entId s subscriptions
-- Here we don't account for this client as subscribed in the server
-- and don't notify other subscribed clients.
-- This is tracked as "subscription" in the client to prevent these
-- clients from being able to subscribe.
pure s
getMessage_ :: Sub -> Maybe MsgId -> M (Transmission BrokerMsg)
getMessage_ s delivered_ = do
stats <- asks serverStats
fmap (either err id) $ liftIO $ runExceptT $
tryPeekMsg ms q >>= \case
Just msg -> do
let encMsg = encryptMsg qr msg
incStat $ (if isJust delivered_ then msgGetDuplicate else msgGet) stats
atomically $ setDelivered s msg $> (corrId, entId, MSG encMsg)
Nothing -> incStat (msgGetNoMsg stats) $> ok
withQueue :: (StoreQueue s -> QueueRec -> M (Transmission BrokerMsg)) -> M (Transmission BrokerMsg)
withQueue = withQueue_ True
-- SEND passes queueNotBlocked False here to update time, but it fails anyway on blocked queues (see code for SEND).
withQueue_ :: Bool -> (StoreQueue s -> QueueRec -> M (Transmission BrokerMsg)) -> M (Transmission BrokerMsg)
withQueue_ queueNotBlocked action = case q_ of
Nothing -> pure $ err INTERNAL
Just (q, qr@QueueRec {status, updatedAt}) -> case status of
EntityBlocked info | queueNotBlocked -> pure $ err $ BLOCKED info
_ -> do
t <- liftIO getSystemDate
if updatedAt == Just t
then action q qr
else liftIO (updateQueueTime (queueStore ms) q t) >>= either (pure . err) (action q)
subscribeNotifications :: M (Transmission BrokerMsg)
subscribeNotifications = do
statCount <-
time "NSUB" . atomically $ do
ifM
(TM.member entId ntfSubscriptions)
(pure ntfSubDuplicate)
(newSub $> ntfSub)
incStat . statCount =<< asks serverStats
pure ok
where
newSub = do
writeTQueue (subQ ntfSubscribers) (entId, clientId, True)
TM.insert entId () ntfSubscriptions
acknowledgeMsg :: MsgId -> StoreQueue s -> QueueRec -> M (Transmission BrokerMsg)
acknowledgeMsg msgId q qr = time "ACK" $ do
liftIO (TM.lookupIO entId subscriptions) >>= \case
Nothing -> pure $ err NO_MSG
Just sub ->
atomically (getDelivered sub) >>= \case
Just st -> do
stats <- asks serverStats
fmap (either err id) $ liftIO $ runExceptT $ do
case st of
ProhibitSub -> do
deletedMsg_ <- tryDelMsg ms q msgId
liftIO $ mapM_ (updateStats stats True) deletedMsg_
pure ok
_ -> do
(deletedMsg_, msg_) <- tryDelPeekMsg ms q msgId
liftIO $ mapM_ (updateStats stats False) deletedMsg_
liftIO $ deliverMessage "ACK" qr entId sub msg_
_ -> pure $ err NO_MSG
where
getDelivered :: Sub -> STM (Maybe ServerSub)
getDelivered Sub {delivered, subThread} = do
tryTakeTMVar delivered $>>= \msgId' ->
if msgId == msgId' || B.null msgId
then pure $ Just subThread
else putTMVar delivered msgId' $> Nothing
updateStats :: ServerStats -> Bool -> Message -> IO ()
updateStats stats isGet = \case
MessageQuota {} -> pure ()
Message {msgFlags} -> do
incStat $ msgRecv stats
if isGet
then incStat $ msgRecvGet stats
else pure () -- TODO skip notification delivery for delivered message
-- skipping delivery fails tests, it should be counted in msgNtfSkipped
-- forM_ (notifierId <$> notifier qr) $ \nId -> do
-- ns <- asks ntfStore
-- atomically $ TM.lookup nId ns >>=
-- mapM_ (\MsgNtf {ntfMsgId} -> when (msgId == msgId') $ TM.delete nId ns)
atomicModifyIORef'_ (msgCount stats) (subtract 1)
updatePeriodStats (activeQueues stats) entId
when (notification msgFlags) $ do
incStat $ msgRecvNtf stats
updatePeriodStats (activeQueuesNtf stats) entId
sendMessage :: MsgFlags -> MsgBody -> StoreQueue s -> QueueRec -> M (Transmission BrokerMsg)
sendMessage msgFlags msgBody q qr
| B.length msgBody > maxMessageLength clntVersion = do
stats <- asks serverStats
incStat $ msgSentLarge stats
pure $ err LARGE_MSG
| otherwise = do
stats <- asks serverStats
case status qr of
EntityOff -> do
incStat $ msgSentAuth stats
pure $ err AUTH
EntityBlocked info -> do
incStat $ msgSentBlock stats
pure $ err $ BLOCKED info
EntityActive ->
case C.maxLenBS msgBody of
Left _ -> pure $ err LARGE_MSG
Right body -> do
ServerConfig {messageExpiration, msgIdBytes} <- asks config
msgId <- randomId' msgIdBytes
msg_ <- liftIO $ time "SEND" $ runExceptT $ do
expireMessages messageExpiration stats
msg <- liftIO $ mkMessage msgId body
writeMsg ms q True msg
case msg_ of
Left e -> pure $ err e
Right Nothing -> do
incStat $ msgSentQuota stats
pure $ err QUOTA
Right (Just (msg, wasEmpty)) -> time "SEND ok" $ do
when wasEmpty $ liftIO $ tryDeliverMessage msg
when (notification msgFlags) $ do
mapM_ (`enqueueNotification` msg) (notifier qr)
incStat $ msgSentNtf stats
liftIO $ updatePeriodStats (activeQueuesNtf stats) (recipientId q)
incStat $ msgSent stats
incStat $ msgCount stats
liftIO $ updatePeriodStats (activeQueues stats) (recipientId q)
pure ok
where
mkMessage :: MsgId -> C.MaxLenBS MaxMessageLen -> IO Message
mkMessage msgId body = do
msgTs <- getSystemTime
pure $! Message msgId msgTs msgFlags body
expireMessages :: Maybe ExpirationConfig -> ServerStats -> ExceptT ErrorType IO ()
expireMessages msgExp stats = do
deleted <- maybe (pure 0) (deleteExpiredMsgs ms q <=< liftIO . expireBeforeEpoch) msgExp
liftIO $ when (deleted > 0) $ atomicModifyIORef'_ (msgExpired stats) (+ deleted)
-- The condition for delivery of the message is:
-- - the queue was empty when the message was sent,
-- - there is subscribed recipient,
-- - no message was "delivered" that was not acknowledged.
-- If the send queue of the subscribed client is not full the message is put there in the same transaction.
-- If the queue is not full, then the thread is created where these checks are made:
-- - it is the same subscribed client (in case it was reconnected it would receive message via SUB command)
-- - nothing was delivered to this subscription (to avoid race conditions with the recipient).
tryDeliverMessage :: Message -> IO ()
tryDeliverMessage msg =
-- the subscribed client var is read outside of STM to avoid transaction cost
-- in case no client is subscribed.
getSubscribedClient rId (queueSubscribers subscribers)
$>>= atomically . deliverToSub
>>= mapM_ forkDeliver
where
rId = recipientId q
deliverToSub rcv =
-- reading client TVar in the same transaction,
-- so that if subscription ends, it re-evalutates
-- and delivery is cancelled -
-- the new client will receive message in response to SUB.
readTVar rcv
$>>= \rc@(AClient _ _ Client {subscriptions = subs, sndQ = sndQ'}) -> TM.lookup rId subs
$>>= \s@Sub {subThread, delivered} -> case subThread of
ProhibitSub -> pure Nothing
ServerSub st -> readTVar st >>= \case
NoSub ->
tryTakeTMVar delivered >>= \case
Just _ -> pure Nothing -- if a message was already delivered, should not deliver more
Nothing ->
ifM
(isFullTBQueue sndQ')
(writeTVar st SubPending $> Just (rc, s, st))
(deliver sndQ' s $> Nothing)
_ -> pure Nothing
deliver sndQ' s = do
let encMsg = encryptMsg qr msg
writeTBQueue sndQ' [(CorrId "", rId, MSG encMsg)]
void $ setDelivered s msg
forkDeliver ((AClient _ _ rc@Client {sndQ = sndQ'}), s@Sub {delivered}, st) = do
t <- mkWeakThreadId =<< forkIO deliverThread
atomically $ modifyTVar' st $ \case
-- this case is needed because deliverThread can exit before it
SubPending -> SubThread t
st' -> st'
where
deliverThread = do
labelMyThread $ B.unpack ("client $" <> encode sessionId) <> " deliver/SEND"
-- lookup can be outside of STM transaction,
-- as long as the check that it is the same client is inside.
getSubscribedClient rId (queueSubscribers subscribers) >>= mapM_ deliverIfSame
deliverIfSame rcv = time "deliver" . atomically $
whenM (sameClient rc rcv) $
tryTakeTMVar delivered >>= \case
Just _ -> pure () -- if a message was already delivered, should not deliver more
Nothing -> do
-- a separate thread is needed because it blocks when client sndQ is full.
deliver sndQ' s
writeTVar st NoSub
enqueueNotification :: NtfCreds -> Message -> M ()
enqueueNotification _ MessageQuota {} = pure ()
enqueueNotification NtfCreds {notifierId = nId, rcvNtfDhSecret} Message {msgId, msgTs} = do
-- stats <- asks serverStats
ns <- asks ntfStore
ntf <- mkMessageNotification msgId msgTs rcvNtfDhSecret
liftIO $ storeNtf ns nId ntf
incStat . ntfCount =<< asks serverStats
mkMessageNotification :: ByteString -> SystemTime -> RcvNtfDhSecret -> M MsgNtf
mkMessageNotification msgId msgTs rcvNtfDhSecret = do
ntfNonce <- atomically . C.randomCbNonce =<< asks random
let msgMeta = NMsgMeta {msgId, msgTs}
encNMsgMeta = C.cbEncrypt rcvNtfDhSecret ntfNonce (smpEncode msgMeta) 128
pure $ MsgNtf {ntfMsgId = msgId, ntfTs = msgTs, ntfNonce, ntfEncMeta = fromRight "" encNMsgMeta}
processForwardedCommand :: EncFwdTransmission -> M BrokerMsg
processForwardedCommand (EncFwdTransmission s) = fmap (either ERR RRES) . runExceptT $ do
THAuthServer {serverPrivKey, sessSecret'} <- maybe (throwE $ transportErr TENoServerAuth) pure (thAuth thParams')
sessSecret <- maybe (throwE $ transportErr TENoServerAuth) pure sessSecret'
let proxyNonce = C.cbNonce $ bs corrId
s' <- liftEitherWith (const CRYPTO) $ C.cbDecryptNoPad sessSecret proxyNonce s
FwdTransmission {fwdCorrId, fwdVersion, fwdKey, fwdTransmission = EncTransmission et} <- liftEitherWith (const $ CMD SYNTAX) $ smpDecode s'
let clientSecret = C.dh' fwdKey serverPrivKey
clientNonce = C.cbNonce $ bs fwdCorrId
b <- liftEitherWith (const CRYPTO) $ C.cbDecrypt clientSecret clientNonce et
let clntTHParams = smpTHParamsSetVersion fwdVersion thParams'
-- only allowing single forwarded transactions
t' <- case tParse clntTHParams b of
t :| [] -> pure $ tDecodeParseValidate clntTHParams t
_ -> throwE BLOCK
let clntThAuth = Just $ THAuthServer {serverPrivKey, sessSecret' = Just clientSecret}
-- process forwarded command
r <-
lift (rejectOrVerify clntThAuth t') >>= \case
Left r -> pure r
-- rejectOrVerify filters allowed commands, no need to repeat it here.
-- INTERNAL is used because processCommand never returns Nothing for sender commands (could be extracted for better types).
Right t''@(_, (corrId', entId', _)) -> fromMaybe (corrId', entId', ERR INTERNAL) <$> lift (processCommand fwdVersion t'')
-- encode response
r' <- case batchTransmissions (batch clntTHParams) (blockSize clntTHParams) [Right (Nothing, encodeTransmission clntTHParams r)] of
[] -> throwE INTERNAL -- at least 1 item is guaranteed from NonEmpty/Right
TBError _ _ : _ -> throwE BLOCK
TBTransmission b' _ : _ -> pure b'
TBTransmissions b' _ _ : _ -> pure b'
-- encrypt to client
r2 <- liftEitherWith (const BLOCK) $ EncResponse <$> C.cbEncrypt clientSecret (C.reverseNonce clientNonce) r' paddedProxiedTLength
-- encrypt to proxy
let fr = FwdResponse {fwdCorrId, fwdResponse = r2}
r3 = EncFwdResponse $ C.cbEncryptNoPad sessSecret (C.reverseNonce proxyNonce) (smpEncode fr)
stats <- asks serverStats
incStat $ pMsgFwdsRecv stats
pure r3
where
rejectOrVerify :: Maybe (THandleAuth 'TServer) -> SignedTransmission ErrorType Cmd -> M (Either (Transmission BrokerMsg) (Maybe (StoreQueue s, QueueRec), Transmission Cmd))
rejectOrVerify clntThAuth (tAuth, authorized, (corrId', entId', cmdOrError)) =
case cmdOrError of
Left e -> pure $ Left (corrId', entId', ERR e)
Right cmd'
| allowed -> verified <$> verifyTransmission ms ((,C.cbNonce (bs corrId')) <$> clntThAuth) tAuth authorized entId' cmd'
| otherwise -> pure $ Left (corrId', entId', ERR $ CMD PROHIBITED)
where
allowed = case cmd' of
Cmd SSender SEND {} -> True
Cmd SSender (SKEY _) -> True
Cmd SSenderLink (LKEY _) -> True
Cmd SSenderLink LGET -> True
_ -> False
verified = \case
VRVerified q -> Right (q, (corrId', entId', cmd'))
VRFailed -> Left (corrId', entId', ERR AUTH)
deliverMessage :: T.Text -> QueueRec -> RecipientId -> Sub -> Maybe Message -> IO (Transmission BrokerMsg)
deliverMessage name qr rId s@Sub {subThread} msg_ = time (name <> " deliver") . atomically $
case subThread of
ProhibitSub -> pure resp
_ -> case msg_ of
Just msg ->
let encMsg = encryptMsg qr msg
in setDelivered s msg $> (corrId, rId, MSG encMsg)
_ -> pure resp
where
resp = (corrId, rId, OK)
time :: MonadIO m => T.Text -> m a -> m a
time name = timed name entId
encryptMsg :: QueueRec -> Message -> RcvMessage
encryptMsg qr msg = encrypt . encodeRcvMsgBody $ case msg of
Message {msgFlags, msgBody} -> RcvMsgBody {msgTs = msgTs', msgFlags, msgBody}
MessageQuota {} -> RcvMsgQuota msgTs'
where
encrypt :: KnownNat i => C.MaxLenBS i -> RcvMessage
encrypt body = RcvMessage msgId' . EncRcvMsgBody $ C.cbEncryptMaxLenBS (rcvDhSecret qr) (C.cbNonce msgId') body
msgId' = messageId msg
msgTs' = messageTs msg
setDelivered :: Sub -> Message -> STM Bool
setDelivered s msg = tryPutTMVar (delivered s) $! messageId msg
delQueueAndMsgs :: (StoreQueue s, QueueRec) -> M (Transmission BrokerMsg)
delQueueAndMsgs (q, _) = do
liftIO (deleteQueue ms q) >>= \case
Right qr -> do
-- Possibly, the same should be done if the queue is suspended, but currently we do not use it
atomically $ do
writeTQueue (subQ subscribers) (entId, clientId, False)
-- queue is usually deleted by the same client that is currently subscribed,
-- we delete subscription here, so the client with no subscriptions can be disconnected.
TM.delete entId subscriptions
forM_ (notifierId <$> notifier qr) $ \nId -> do
-- queue is deleted by a different client from the one subscribed to notifications,
-- so we don't need to remove subscription from the current client.
stats <- asks serverStats
deleted <- asks ntfStore >>= liftIO . (`deleteNtfs` nId)
when (deleted > 0) $ liftIO $ atomicModifyIORef'_ (ntfCount stats) (subtract deleted)
atomically $ writeTQueue (subQ ntfSubscribers) (nId, clientId, False)
updateDeletedStats qr
pure ok
Left e -> pure $ err e
getQueueInfo :: StoreQueue s -> QueueRec -> M BrokerMsg
getQueueInfo q QueueRec {senderKey, notifier} = do
fmap (either ERR INFO) $ liftIO $ runExceptT $ do
qiSub <- liftIO $ TM.lookupIO entId subscriptions >>= mapM mkQSub
qiSize <- getQueueSize ms q
qiMsg <- toMsgInfo <$$> tryPeekMsg ms q
let info = QueueInfo {qiSnd = isJust senderKey, qiNtf = isJust notifier, qiSub, qiSize, qiMsg}
pure info
where
mkQSub Sub {subThread, delivered} = do
qSubThread <- case subThread of
ServerSub t -> do
st <- readTVarIO t
pure $ case st of
NoSub -> QNoSub
SubPending -> QSubPending
SubThread _ -> QSubThread
ProhibitSub -> pure QProhibitSub
qDelivered <- atomically $ decodeLatin1 . encode <$$> tryReadTMVar delivered
pure QSub {qSubThread, qDelivered}
ok :: Transmission BrokerMsg
ok = (corrId, entId, OK)
err :: ErrorType -> Transmission BrokerMsg
err e = (corrId, entId, ERR e)
updateDeletedStats :: QueueRec -> M ()
updateDeletedStats q = do
stats <- asks serverStats
let delSel = if isNothing (senderKey q) then qDeletedNew else qDeletedSecured
incStat $ delSel stats
incStat $ qDeletedAll stats
liftIO $ atomicModifyIORef'_ (qCount stats) (subtract 1)
incStat :: MonadIO m => IORef Int -> m ()
incStat r = liftIO $ atomicModifyIORef'_ r (+ 1)
{-# INLINE incStat #-}
timed :: MonadIO m => T.Text -> RecipientId -> m a -> m a
timed name (EntityId qId) a = do
t <- liftIO getSystemTime
r <- a
t' <- liftIO getSystemTime
let int = diff t t'
when (int > sec) . logDebug $ T.unwords [name, tshow $ encode qId, tshow int]
pure r
where
diff t t' = (systemSeconds t' - systemSeconds t) * sec + fromIntegral (systemNanoseconds t' - systemNanoseconds t)
sec = 1000_000000
randomId' :: Int -> M ByteString
randomId' n = atomically . C.randomBytes n =<< asks random
randomId :: Int -> M EntityId
randomId = fmap EntityId . randomId'
{-# INLINE randomId #-}
saveServerMessages :: Bool -> AMsgStore -> IO ()
saveServerMessages drainMsgs = \case
AMS SQSMemory SMSMemory ms@STMMsgStore {storeConfig = STMStoreConfig {storePath}} -> case storePath of
Just f -> exportMessages False ms f drainMsgs
Nothing -> logNote "undelivered messages are not saved"
AMS _ SMSJournal _ -> logNote "closed journal message storage"
exportMessages :: MsgStoreClass s => Bool -> s -> FilePath -> Bool -> IO ()
exportMessages tty ms f drainMsgs = do
logNote $ "saving messages to file " <> T.pack f
liftIO $ withFile f WriteMode $ \h ->
tryAny (unsafeWithAllMsgQueues tty True ms $ saveQueueMsgs h) >>= \case
Right (Sum total) -> logNote $ "messages saved: " <> tshow total
Left e -> do
logError $ "error exporting messages: " <> tshow e
exitFailure
where
saveQueueMsgs h q = do
msgs <-
unsafeRunStore q "saveQueueMsgs" $
getQueueMessages_ drainMsgs q =<< getMsgQueue ms q False
BLD.hPutBuilder h $ encodeMessages (recipientId q) msgs
pure $ Sum $ length msgs
encodeMessages rId = mconcat . map (\msg -> BLD.byteString (strEncode $ MLRv3 rId msg) <> BLD.char8 '\n')
processServerMessages :: StartOptions -> M (Maybe MessageStats)
processServerMessages StartOptions {skipWarnings} = do
old_ <- asks (messageExpiration . config) $>>= (liftIO . fmap Just . expireBeforeEpoch)
expire <- asks $ expireMessagesOnStart . config
asks msgStore >>= liftIO . processMessages old_ expire
where
processMessages :: Maybe Int64 -> Bool -> AMsgStore -> IO (Maybe MessageStats)
processMessages old_ expire = \case
AMS SQSMemory SMSMemory ms@STMMsgStore {storeConfig = STMStoreConfig {storePath}} -> case storePath of
Just f -> ifM (doesFileExist f) (Just <$> importMessages False ms f old_ skipWarnings) (pure Nothing)
Nothing -> pure Nothing
AMS _ SMSJournal ms -> processJournalMessages old_ expire ms
processJournalMessages :: forall s. Maybe Int64 -> Bool -> JournalMsgStore s -> IO (Maybe MessageStats)
processJournalMessages old_ expire ms
| expire = Just <$> case old_ of
Just old -> do
logNote "expiring journal store messages..."
run $ processExpireQueue old
Nothing -> do
logNote "validating journal store messages..."
run processValidateQueue
| otherwise = logWarn "skipping message expiration" $> Nothing
where
run a = unsafeWithAllMsgQueues False False ms a `catchAny` \_ -> exitFailure
processExpireQueue :: Int64 -> JournalQueue s -> IO MessageStats
processExpireQueue old q = unsafeRunStore q "processExpireQueue" $ do
mq <- getMsgQueue ms q False
expiredMsgsCount <- deleteExpireMsgs_ old q mq
storedMsgsCount <- getQueueSize_ mq
pure MessageStats {storedMsgsCount, expiredMsgsCount, storedQueues = 1}
processValidateQueue :: JournalQueue s -> IO MessageStats
processValidateQueue q = unsafeRunStore q "processValidateQueue" $ do
storedMsgsCount <- getQueueSize_ =<< getMsgQueue ms q False
pure newMessageStats {storedMsgsCount, storedQueues = 1}
importMessages :: forall s. MsgStoreClass s => Bool -> s -> FilePath -> Maybe Int64 -> Bool -> IO MessageStats
importMessages tty ms f old_ skipWarnings = do
logNote $ "restoring messages from file " <> T.pack f
(_, (storedMsgsCount, expiredMsgsCount, overQuota)) <-
foldLogLines tty f restoreMsg (Nothing, (0, 0, M.empty))
renameFile f $ f <> ".bak"
mapM_ setOverQuota_ overQuota
logQueueStates ms
QueueCounts {queueCount} <- liftIO $ queueCounts @(StoreQueue s) $ queueStore ms
pure MessageStats {storedMsgsCount, expiredMsgsCount, storedQueues = queueCount}
where
restoreMsg :: (Maybe (RecipientId, StoreQueue s), (Int, Int, M.Map RecipientId (StoreQueue s))) -> Bool -> ByteString -> IO (Maybe (RecipientId, StoreQueue s), (Int, Int, M.Map RecipientId (StoreQueue s)))
restoreMsg (q_, counts@(!stored, !expired, !overQuota)) eof s = case strDecode s of
Right (MLRv3 rId msg) -> runExceptT (addToMsgQueue rId msg) >>= either (exitErr . tshow) pure
Left e
| eof -> warnOrExit (parsingErr e) $> (q_, counts)
| otherwise -> exitErr $ parsingErr e
where
exitErr e = do
when tty $ putStrLn ""
logError $ "error restoring messages: " <> e
liftIO exitFailure
parsingErr :: String -> Text
parsingErr e = "parsing error (" <> T.pack e <> "): " <> safeDecodeUtf8 (B.take 100 s)
addToMsgQueue rId msg = do
qOrErr <- case q_ of
-- to avoid lookup when restoring the next message to the same queue
Just (rId', q') | rId' == rId -> pure $ Right q'
_ -> liftIO $ getQueue ms SRecipient rId
case qOrErr of
Right q -> addToQueue_ q rId msg
Left AUTH -> liftIO $ do
when tty $ putStrLn ""
warnOrExit $ "queue " <> safeDecodeUtf8 (encode $ unEntityId rId) <> " does not exist"
pure (Nothing, counts)
Left e -> throwE e
addToQueue_ q rId msg =
(Just (rId, q),) <$> case msg of
Message {msgTs}
| maybe True (systemSeconds msgTs >=) old_ -> do
writeMsg ms q False msg >>= \case
Just _ -> pure (stored + 1, expired, overQuota)
Nothing -> liftIO $ do
when tty $ putStrLn ""
logError $ decodeLatin1 $ "message queue " <> strEncode rId <> " is full, message not restored: " <> strEncode (messageId msg)
pure counts
| otherwise -> pure (stored, expired + 1, overQuota)
MessageQuota {} ->
-- queue was over quota at some point,
-- it will be set as over quota once fully imported
mergeQuotaMsgs >> writeMsg ms q False msg $> (stored, expired, M.insert rId q overQuota)
where
-- if the first message in queue head is "quota", remove it.
mergeQuotaMsgs =
withPeekMsgQueue ms q "mergeQuotaMsgs" $ maybe (pure ()) $ \case
(mq, MessageQuota {}) -> tryDeleteMsg_ q mq False
_ -> pure ()
warnOrExit e
| skipWarnings = logWarn e'
| otherwise = do
logWarn $ e' <> ", start with --skip-warnings option to ignore this error"
exitFailure
where
e' = "warning restoring messages: " <> e
printMessageStats :: T.Text -> MessageStats -> IO ()
printMessageStats name MessageStats {storedMsgsCount, expiredMsgsCount, storedQueues} =
logNote $ name <> " stored: " <> tshow storedMsgsCount <> ", expired: " <> tshow expiredMsgsCount <> ", queues: " <> tshow storedQueues
saveServerNtfs :: M ()
saveServerNtfs = asks (storeNtfsFile . config) >>= mapM_ saveNtfs
where
saveNtfs f = do
logNote $ "saving notifications to file " <> T.pack f
NtfStore ns <- asks ntfStore
liftIO . withFile f WriteMode $ \h ->
readTVarIO ns >>= mapM_ (saveQueueNtfs h) . M.assocs
logNote "notifications saved"
where
-- reverse on save, to save notifications in order, will become reversed again when restoring.
saveQueueNtfs h (nId, v) = BLD.hPutBuilder h . encodeNtfs nId . reverse =<< readTVarIO v
encodeNtfs nId = mconcat . map (\ntf -> BLD.byteString (strEncode $ NLRv1 nId ntf) <> BLD.char8 '\n')
restoreServerNtfs :: M MessageStats
restoreServerNtfs =
asks (storeNtfsFile . config) >>= \case
Just f -> ifM (doesFileExist f) (restoreNtfs f) (pure newMessageStats)
Nothing -> pure newMessageStats
where
restoreNtfs f = do
logNote $ "restoring notifications from file " <> T.pack f
ns <- asks ntfStore
old <- asks (notificationExpiration . config) >>= liftIO . expireBeforeEpoch
liftIO $
LB.readFile f >>= runExceptT . foldM (restoreNtf ns old) (0, 0, 0) . LB.lines >>= \case
Left e -> do
logError . T.pack $ "error restoring notifications: " <> e
liftIO exitFailure
Right (lineCount, storedMsgsCount, expiredMsgsCount) -> do
renameFile f $ f <> ".bak"
let NtfStore ns' = ns
storedQueues <- M.size <$> readTVarIO ns'
logNote $ "notifications restored, " <> tshow lineCount <> " lines processed"
pure MessageStats {storedMsgsCount, expiredMsgsCount, storedQueues}
where
restoreNtf :: NtfStore -> Int64 -> (Int, Int, Int) -> LB.ByteString -> ExceptT String IO (Int, Int, Int)
restoreNtf ns old (!lineCount, !stored, !expired) s' = do
NLRv1 nId ntf <- liftEither . first (ntfErr "parsing") $ strDecode s
liftIO $ addToNtfs nId ntf
where
s = LB.toStrict s'
addToNtfs nId ntf@MsgNtf {ntfTs}
| systemSeconds ntfTs < old = pure (lineCount + 1, stored, expired + 1)
| otherwise = storeNtf ns nId ntf $> (lineCount + 1, stored + 1, expired)
ntfErr :: Show e => String -> e -> String
ntfErr op e = op <> " error (" <> show e <> "): " <> B.unpack (B.take 100 s)
saveServerStats :: M ()
saveServerStats =
asks (serverStatsBackupFile . config)
>>= mapM_ (\f -> asks serverStats >>= liftIO . getServerStatsData >>= liftIO . saveStats f)
where
saveStats f stats = do
logNote $ "saving server stats to file " <> T.pack f
B.writeFile f $ strEncode stats
logNote "server stats saved"
restoreServerStats :: Maybe MessageStats -> MessageStats -> M ()
restoreServerStats msgStats_ ntfStats = asks (serverStatsBackupFile . config) >>= mapM_ restoreStats
where
restoreStats f = whenM (doesFileExist f) $ do
logNote $ "restoring server stats from file " <> T.pack f
liftIO (strDecode <$> B.readFile f) >>= \case
Right d@ServerStatsData {_qCount = statsQCount, _msgCount = statsMsgCount, _ntfCount = statsNtfCount} -> do
s <- asks serverStats
AMS _ _ (st :: s) <- asks msgStore
QueueCounts {queueCount = _qCount} <- liftIO $ queueCounts @(StoreQueue s) $ queueStore st
let _msgCount = maybe statsMsgCount storedMsgsCount msgStats_
_ntfCount = storedMsgsCount ntfStats
_msgExpired' = _msgExpired d + maybe 0 expiredMsgsCount msgStats_
_msgNtfExpired' = _msgNtfExpired d + expiredMsgsCount ntfStats
liftIO $ setServerStats s d {_qCount, _msgCount, _ntfCount, _msgExpired = _msgExpired', _msgNtfExpired = _msgNtfExpired'}
renameFile f $ f <> ".bak"
logNote "server stats restored"
compareCounts "Queue" statsQCount _qCount
compareCounts "Message" statsMsgCount _msgCount
compareCounts "Notification" statsNtfCount _ntfCount
Left e -> do
logNote $ "error restoring server stats: " <> T.pack e
liftIO exitFailure
compareCounts name statsCnt storeCnt =
when (statsCnt /= storeCnt) $ logWarn $ name <> " count differs: stats: " <> tshow statsCnt <> ", store: " <> tshow storeCnt