mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-03-29 10:10:06 +00:00
2293 lines
124 KiB
Haskell
2293 lines
124 KiB
Haskell
{-# LANGUAGE BangPatterns #-}
|
|
{-# LANGUAGE CPP #-}
|
|
{-# LANGUAGE DataKinds #-}
|
|
{-# LANGUAGE DuplicateRecordFields #-}
|
|
{-# LANGUAGE FlexibleContexts #-}
|
|
{-# LANGUAGE GADTs #-}
|
|
{-# LANGUAGE KindSignatures #-}
|
|
{-# LANGUAGE LambdaCase #-}
|
|
{-# LANGUAGE NamedFieldPuns #-}
|
|
{-# LANGUAGE NumericUnderscores #-}
|
|
{-# LANGUAGE OverloadedLists #-}
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
{-# LANGUAGE RankNTypes #-}
|
|
{-# LANGUAGE ScopedTypeVariables #-}
|
|
{-# LANGUAGE TupleSections #-}
|
|
{-# LANGUAGE TypeApplications #-}
|
|
|
|
-- |
|
|
-- Module : Simplex.Messaging.Server
|
|
-- Copyright : (c) simplex.chat
|
|
-- License : AGPL-3
|
|
--
|
|
-- Maintainer : chat@simplex.chat
|
|
-- Stability : experimental
|
|
-- Portability : non-portable
|
|
--
|
|
-- This module defines SMP protocol server with in-memory persistence
|
|
-- and optional append only log of SMP queue records.
|
|
--
|
|
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md
|
|
module Simplex.Messaging.Server
|
|
( runSMPServer,
|
|
runSMPServerBlocking,
|
|
importMessages,
|
|
exportMessages,
|
|
printMessageStats,
|
|
disconnectTransport,
|
|
verifyCmdAuthorization,
|
|
dummyVerifyCmd,
|
|
randomId,
|
|
AttachHTTP,
|
|
MessageStats (..),
|
|
)
|
|
where
|
|
|
|
import Control.Concurrent.STM (throwSTM)
|
|
import Control.Logger.Simple
|
|
import Control.Monad
|
|
import Control.Monad.Except
|
|
import Control.Monad.IO.Unlift
|
|
import Control.Monad.Reader
|
|
import Control.Monad.Trans.Except
|
|
import Control.Monad.STM (retry)
|
|
import Crypto.Random (ChaChaDRG)
|
|
import Data.Bifunctor (first, second)
|
|
import Data.ByteString.Base64 (encode)
|
|
import qualified Data.ByteString.Builder as BLD
|
|
import Data.ByteString.Char8 (ByteString)
|
|
import qualified Data.ByteString.Char8 as B
|
|
import qualified Data.ByteString.Lazy.Char8 as LB
|
|
import Data.Constraint (Dict (..))
|
|
import Data.Dynamic (toDyn)
|
|
import Data.Either (fromRight, partitionEithers)
|
|
import Data.Functor (($>))
|
|
import Data.IORef
|
|
import Data.Int (Int64)
|
|
import qualified Data.IntMap.Strict as IM
|
|
import qualified Data.IntSet as IS
|
|
import Data.List (foldl', intercalate, mapAccumR)
|
|
import Data.List.NonEmpty (NonEmpty (..), (<|))
|
|
import qualified Data.List.NonEmpty as L
|
|
import Data.Map.Strict (Map)
|
|
import qualified Data.Map.Strict as M
|
|
import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing)
|
|
import Data.Semigroup (Sum (..))
|
|
import qualified Data.Set as S
|
|
import Data.Text (Text)
|
|
import qualified Data.Text as T
|
|
import Data.Text.Encoding (decodeLatin1)
|
|
import qualified Data.Text.IO as T
|
|
import Data.Time.Clock (UTCTime (..), diffTimeToPicoseconds, getCurrentTime)
|
|
import Data.Time.Clock.System (SystemTime (..), getSystemTime)
|
|
import Data.Time.Format.ISO8601 (iso8601Show)
|
|
import Data.Type.Equality
|
|
import Data.Typeable (cast)
|
|
import qualified Data.X509 as X
|
|
import qualified Data.X509.Validation as XV
|
|
import GHC.Conc.Signal
|
|
import GHC.IORef (atomicSwapIORef)
|
|
import GHC.Stats (getRTSStats)
|
|
import GHC.TypeLits (KnownNat)
|
|
import Network.Socket (ServiceName, Socket, socketToHandle)
|
|
import qualified Network.TLS as TLS
|
|
import Numeric.Natural (Natural)
|
|
import Simplex.Messaging.Agent.Lock
|
|
import Simplex.Messaging.Client (ProtocolClient (thParams), ProtocolClientError (..), SMPClient, SMPClientError, forwardSMPTransmission, smpProxyError, temporaryClientError)
|
|
import Simplex.Messaging.Client.Agent (OwnServer, SMPClientAgent (..), SMPClientAgentEvent (..), closeSMPClientAgent, getSMPServerClient'', isOwnServer, lookupSMPServerClient, getConnectedSMPServerClient)
|
|
import qualified Simplex.Messaging.Crypto as C
|
|
import Simplex.Messaging.Encoding
|
|
import Simplex.Messaging.Encoding.String
|
|
import Simplex.Messaging.Protocol
|
|
import Simplex.Messaging.Server.Control
|
|
import Simplex.Messaging.Server.Env.STM as Env
|
|
import Simplex.Messaging.Server.Expiration
|
|
import Simplex.Messaging.Server.MsgStore
|
|
import Simplex.Messaging.Server.MsgStore.Journal (JournalMsgStore, JournalQueue)
|
|
import Simplex.Messaging.Server.MsgStore.STM
|
|
import Simplex.Messaging.Server.MsgStore.Types
|
|
import Simplex.Messaging.Server.NtfStore
|
|
import Simplex.Messaging.Server.Prometheus
|
|
import Simplex.Messaging.Server.QueueStore
|
|
import Simplex.Messaging.Server.QueueStore.QueueInfo
|
|
import Simplex.Messaging.Server.QueueStore.Types
|
|
import Simplex.Messaging.Server.Stats
|
|
import Simplex.Messaging.Server.StoreLog (foldLogLines)
|
|
import Simplex.Messaging.TMap (TMap)
|
|
import qualified Simplex.Messaging.TMap as TM
|
|
import Simplex.Messaging.Transport
|
|
import Simplex.Messaging.Transport.Buffer (trimCR)
|
|
import Simplex.Messaging.Transport.Server
|
|
import Simplex.Messaging.Util
|
|
import Simplex.Messaging.Version
|
|
import System.Environment (lookupEnv)
|
|
import System.Exit (exitFailure, exitSuccess)
|
|
import System.IO (hPrint, hPutStrLn, hSetNewlineMode, universalNewlineMode)
|
|
import System.Mem.Weak (deRefWeak)
|
|
import UnliftIO (timeout)
|
|
import UnliftIO.Concurrent
|
|
import UnliftIO.Directory (doesFileExist, renameFile)
|
|
import UnliftIO.Exception
|
|
import UnliftIO.IO
|
|
import UnliftIO.STM
|
|
#if MIN_VERSION_base(4,18,0)
|
|
import Data.List (sort)
|
|
import GHC.Conc (listThreads, threadStatus)
|
|
import GHC.Conc.Sync (threadLabel)
|
|
#endif
|
|
|
|
-- | Runs an SMP server using passed configuration.
|
|
--
|
|
-- See a full server here: https://github.com/simplex-chat/simplexmq/blob/master/apps/smp-server/Main.hs
|
|
runSMPServer :: MsgStoreClass s => ServerConfig s -> Maybe AttachHTTP -> IO ()
|
|
runSMPServer cfg attachHTTP_ = do
|
|
started <- newEmptyTMVarIO
|
|
runSMPServerBlocking started cfg attachHTTP_
|
|
|
|
-- | Runs an SMP server using passed configuration with signalling.
|
|
--
|
|
-- This function uses passed TMVar to signal when the server is ready to accept TCP requests (True)
|
|
-- and when it is disconnected from the TCP socket once the server thread is killed (False).
|
|
runSMPServerBlocking :: MsgStoreClass s => TMVar Bool -> ServerConfig s -> Maybe AttachHTTP -> IO ()
|
|
runSMPServerBlocking started cfg attachHTTP_ = newEnv cfg >>= runReaderT (smpServer started cfg attachHTTP_)
|
|
|
|
type M s a = ReaderT (Env s) IO a
|
|
type AttachHTTP = Socket -> TLS.Context -> IO ()
|
|
|
|
-- actions used in serverThread to reduce STM transaction scope
|
|
data ClientSubAction
|
|
= CSAEndSub QueueId -- end single direct queue subscription
|
|
| CSAEndServiceSub -- end service subscription to one queue
|
|
| CSADecreaseSubs Int64 -- reduce service subscriptions when cancelling. Fixed number is used to correctly handle race conditions when service resubscribes
|
|
|
|
type PrevClientSub s = (Client s, ClientSubAction, (EntityId, BrokerMsg))
|
|
|
|
smpServer :: forall s. MsgStoreClass s => TMVar Bool -> ServerConfig s -> Maybe AttachHTTP -> M s ()
|
|
smpServer started cfg@ServerConfig {transports, transportConfig = tCfg, startOptions} attachHTTP_ = do
|
|
s <- asks server
|
|
pa <- asks proxyAgent
|
|
msgStats_ <- processServerMessages startOptions
|
|
ntfStats <- restoreServerNtfs
|
|
liftIO $ mapM_ (printMessageStats "messages") msgStats_
|
|
liftIO $ printMessageStats "notifications" ntfStats
|
|
restoreServerStats msgStats_ ntfStats
|
|
when (maintenance startOptions) $ do
|
|
liftIO $ putStrLn "Server started in 'maintenance' mode, exiting"
|
|
stopServer s
|
|
liftIO $ exitSuccess
|
|
raceAny_
|
|
( serverThread "server subscribers" s subscribers subscriptions serviceSubsCount (Just cancelSub)
|
|
: serverThread "server ntfSubscribers" s ntfSubscribers ntfSubscriptions ntfServiceSubsCount Nothing
|
|
: deliverNtfsThread s
|
|
: sendPendingEvtsThread s
|
|
: receiveFromProxyAgent pa
|
|
: expireNtfsThread cfg
|
|
: sigIntHandlerThread
|
|
: map runServer transports
|
|
<> expireMessagesThread_ cfg
|
|
<> serverStatsThread_ cfg
|
|
<> prometheusMetricsThread_ cfg
|
|
<> controlPortThread_ cfg
|
|
)
|
|
`finally` stopServer s
|
|
where
|
|
runServer :: (ServiceName, ASrvTransport, AddHTTP) -> M s ()
|
|
runServer (tcpPort, ATransport t, addHTTP) = do
|
|
smpCreds@(srvCert, srvKey) <- asks tlsServerCreds
|
|
httpCreds_ <- asks httpServerCreds
|
|
ss <- liftIO newSocketState
|
|
asks sockets >>= atomically . (`modifyTVar'` ((tcpPort, ss) :))
|
|
srvSignKey <- either fail pure $ C.x509ToPrivate' srvKey
|
|
env <- ask
|
|
liftIO $ case (httpCreds_, attachHTTP_) of
|
|
(Just httpCreds, Just attachHTTP) | addHTTP ->
|
|
runTransportServerState_ ss started tcpPort defaultSupportedParamsHTTPS combinedCreds tCfg {serverALPN = Just combinedALPNs} $ \s (sniUsed, h) ->
|
|
case cast h of
|
|
Just (TLS {tlsContext} :: TLS 'TServer) | sniUsed -> labelMyThread "https client" >> attachHTTP s tlsContext
|
|
_ -> runClient srvCert srvSignKey t h `runReaderT` env
|
|
where
|
|
combinedCreds = TLSServerCredential {credential = smpCreds, sniCredential = Just httpCreds}
|
|
combinedALPNs = alpnSupportedSMPHandshakes <> httpALPN
|
|
httpALPN :: [ALPN]
|
|
httpALPN = ["h2", "http/1.1"]
|
|
_ ->
|
|
runTransportServerState ss started tcpPort defaultSupportedParams smpCreds tCfg $ \h -> runClient srvCert srvSignKey t h `runReaderT` env
|
|
|
|
sigIntHandlerThread :: M s ()
|
|
sigIntHandlerThread = do
|
|
flagINT <- newEmptyTMVarIO
|
|
let sigINT = 2 -- CONST_SIGINT value
|
|
sigIntAction = \_ptr -> atomically $ void $ tryPutTMVar flagINT ()
|
|
sigIntHandler = Just (sigIntAction, toDyn ())
|
|
void $ liftIO $ setHandler sigINT sigIntHandler
|
|
atomically $ readTMVar flagINT
|
|
logNote "Received SIGINT, stopping server..."
|
|
|
|
stopServer :: Server s -> M s ()
|
|
stopServer s = do
|
|
asks serverActive >>= atomically . (`writeTVar` False)
|
|
logNote "Saving server state..."
|
|
withLock' (savingLock s) "final" $ saveServer True >> closeServer
|
|
logNote "Server stopped"
|
|
|
|
saveServer :: Bool -> M s ()
|
|
saveServer drainMsgs = do
|
|
ms <- asks msgStore_
|
|
liftIO $ saveServerMessages drainMsgs ms >> closeMsgStore (fromMsgStore ms)
|
|
saveServerNtfs
|
|
saveServerStats
|
|
|
|
closeServer :: M s ()
|
|
closeServer = asks (smpAgent . proxyAgent) >>= liftIO . closeSMPClientAgent
|
|
|
|
serverThread ::
|
|
forall sub. String ->
|
|
Server s ->
|
|
(Server s -> ServerSubscribers s) ->
|
|
(Client s -> TMap QueueId sub) ->
|
|
(Client s -> TVar Int64) ->
|
|
Maybe (sub -> IO ()) ->
|
|
M s ()
|
|
serverThread label srv srvSubscribers clientSubs clientServiceSubs unsub_ = do
|
|
labelMyThread label
|
|
liftIO . forever $ do
|
|
-- Reading clients outside of `updateSubscribers` transaction to avoid transaction re-evaluation on each new connected client.
|
|
-- In case client disconnects during the transaction (its `connected` property is read),
|
|
-- the transaction will still be re-evaluated, and the client won't be stored as subscribed.
|
|
sub@(_, clntId) <- atomically $ readTQueue subQ
|
|
c_ <- getServerClient clntId srv
|
|
atomically (updateSubscribers c_ sub)
|
|
>>= endPreviousSubscriptions
|
|
where
|
|
ServerSubscribers {subQ, queueSubscribers, serviceSubscribers, totalServiceSubs, subClients, pendingEvents} = srvSubscribers srv
|
|
updateSubscribers :: Maybe (Client s) -> (ClientSub, ClientId) -> STM [PrevClientSub s]
|
|
updateSubscribers c_ (clntSub, clntId) = case c_ of
|
|
Just c@Client {connected} -> ifM (readTVar connected) (updateSubConnected c) updateSubDisconnected
|
|
Nothing -> updateSubDisconnected
|
|
where
|
|
updateSubConnected c = case clntSub of
|
|
CSClient qId prevServiceId serviceId_ -> do
|
|
modifyTVar' subClients $ IS.insert clntId -- add ID to server's subscribed cients
|
|
as'' <- if prevServiceId == serviceId_ then pure [] else endServiceSub prevServiceId qId END
|
|
case serviceId_ of
|
|
Just serviceId -> do
|
|
as <- endQueueSub qId END
|
|
as' <- cancelServiceSubs serviceId =<< upsertSubscribedClient serviceId c serviceSubscribers
|
|
pure $ as ++ as' ++ as''
|
|
Nothing -> do
|
|
as <- prevSub qId END (CSAEndSub qId) =<< upsertSubscribedClient qId c queueSubscribers
|
|
pure $ as ++ as''
|
|
CSDeleted qId serviceId -> do
|
|
removeWhenNoSubs c
|
|
as <- endQueueSub qId DELD
|
|
as' <- endServiceSub serviceId qId DELD
|
|
pure $ as ++ as'
|
|
CSService serviceId -> do
|
|
modifyTVar' subClients $ IS.insert clntId -- add ID to server's subscribed cients
|
|
cancelServiceSubs serviceId =<< upsertSubscribedClient serviceId c serviceSubscribers
|
|
updateSubDisconnected = case clntSub of
|
|
-- do not insert client if it is already disconnected, but send END/DELD to any other client subscribed to this queue or service
|
|
CSClient qId prevServiceId serviceId -> do
|
|
as <- endQueueSub qId END
|
|
as' <- endServiceSub serviceId qId END
|
|
as'' <- if prevServiceId == serviceId then pure [] else endServiceSub prevServiceId qId END
|
|
pure $ as ++ as' ++ as''
|
|
CSDeleted qId serviceId -> do
|
|
as <- endQueueSub qId DELD
|
|
as' <- endServiceSub serviceId qId DELD
|
|
pure $ as ++ as'
|
|
CSService serviceId -> cancelServiceSubs serviceId =<< lookupSubscribedClient serviceId serviceSubscribers
|
|
endQueueSub :: QueueId -> BrokerMsg -> STM [PrevClientSub s]
|
|
endQueueSub qId msg = prevSub qId msg (CSAEndSub qId) =<< lookupDeleteSubscribedClient qId queueSubscribers
|
|
endServiceSub :: Maybe ServiceId -> QueueId -> BrokerMsg -> STM [PrevClientSub s]
|
|
endServiceSub Nothing _ _ = pure []
|
|
endServiceSub (Just serviceId) qId msg = prevSub qId msg CSAEndServiceSub =<< lookupSubscribedClient serviceId serviceSubscribers
|
|
prevSub :: QueueId -> BrokerMsg -> ClientSubAction -> Maybe (Client s) -> STM [PrevClientSub s]
|
|
prevSub qId msg action =
|
|
checkAnotherClient $ \c -> pure [(c, action, (qId, msg))]
|
|
cancelServiceSubs :: ServiceId -> Maybe (Client s) -> STM [PrevClientSub s]
|
|
cancelServiceSubs serviceId =
|
|
checkAnotherClient $ \c -> do
|
|
n <- swapTVar (clientServiceSubs c) 0
|
|
pure [(c, CSADecreaseSubs n, (serviceId, ENDS n))]
|
|
checkAnotherClient :: (Client s -> STM [PrevClientSub s]) -> Maybe (Client s) -> STM [PrevClientSub s]
|
|
checkAnotherClient mkSub = \case
|
|
Just c@Client {clientId, connected} | clntId /= clientId ->
|
|
ifM (readTVar connected) (mkSub c) (pure [])
|
|
_ -> pure []
|
|
|
|
endPreviousSubscriptions :: [PrevClientSub s] -> IO ()
|
|
endPreviousSubscriptions = mapM_ $ \(c, subAction, evt) -> do
|
|
atomically $ modifyTVar' pendingEvents $ IM.alter (Just . maybe [evt] (evt <|)) (clientId c)
|
|
case subAction of
|
|
CSAEndSub qId -> atomically (endSub c qId) >>= a unsub_
|
|
where
|
|
a (Just unsub) (Just s) = unsub s
|
|
a _ _ = pure ()
|
|
CSAEndServiceSub -> atomically $ do
|
|
modifyTVar' (clientServiceSubs c) decrease
|
|
modifyTVar' totalServiceSubs decrease
|
|
where
|
|
decrease n = max 0 (n - 1)
|
|
-- TODO [certs rcv] for SMP subscriptions CSADecreaseSubs should also remove all delivery threads of the passed client
|
|
CSADecreaseSubs n' -> atomically $ modifyTVar' totalServiceSubs $ \n -> max 0 (n - n')
|
|
where
|
|
endSub :: Client s -> QueueId -> STM (Maybe sub)
|
|
endSub c qId = TM.lookupDelete qId (clientSubs c) >>= (removeWhenNoSubs c $>)
|
|
-- remove client from server's subscribed cients
|
|
removeWhenNoSubs c = do
|
|
noClientSubs <- null <$> readTVar (clientSubs c)
|
|
noServiceSubs <- (0 ==) <$> readTVar (clientServiceSubs c)
|
|
when (noClientSubs && noServiceSubs) $ modifyTVar' subClients $ IS.delete (clientId c)
|
|
|
|
deliverNtfsThread :: Server s -> M s ()
|
|
deliverNtfsThread srv@Server {ntfSubscribers = ServerSubscribers {subClients, serviceSubscribers}} = do
|
|
ntfInt <- asks $ ntfDeliveryInterval . config
|
|
ms <- asks msgStore
|
|
ns' <- asks ntfStore
|
|
stats <- asks serverStats
|
|
liftIO $ forever $ do
|
|
threadDelay ntfInt
|
|
runDeliverNtfs ms ns' stats
|
|
where
|
|
runDeliverNtfs :: s -> NtfStore -> ServerStats -> IO ()
|
|
runDeliverNtfs ms (NtfStore ns) stats = do
|
|
ntfs <- M.assocs <$> readTVarIO ns
|
|
unless (null ntfs) $
|
|
getQueueNtfServices @(StoreQueue s) (queueStore ms) ntfs >>= \case
|
|
Left e -> logError $ "NOTIFICATIONS: getQueueNtfServices error " <> tshow e
|
|
Right (sNtfs, deleted) -> do
|
|
forM_ sNtfs $ \(serviceId_, ntfs') -> case serviceId_ of
|
|
Just sId -> getSubscribedClient sId serviceSubscribers >>= mapM_ (deliverServiceNtfs ntfs')
|
|
Nothing -> do -- legacy code that does almost the same as before for non-service subscribers
|
|
cIds <- IS.toList <$> readTVarIO subClients
|
|
forM_ cIds $ \cId -> getServerClient cId srv >>= mapM_ (deliverQueueNtfs ntfs')
|
|
atomically $ modifyTVar' ns (`M.withoutKeys` S.fromList (map fst deleted))
|
|
where
|
|
deliverQueueNtfs ntfs' c@Client {ntfSubscriptions} =
|
|
whenM (currentClient readTVarIO c) $ do
|
|
subs <- readTVarIO ntfSubscriptions
|
|
unless (M.null subs) $ do
|
|
let ntfs'' = filter (\(nId, _) -> M.member nId subs) ntfs'
|
|
tryAny (atomically $ flushSubscribedNtfs ntfs'' c) >>= updateNtfStats c
|
|
deliverServiceNtfs ntfs' cv = readTVarIO cv >>= mapM_ deliver
|
|
where
|
|
deliver c = tryAny (atomically $ withSubscribed $ flushSubscribedNtfs ntfs') >>= updateNtfStats c
|
|
withSubscribed :: (Client s -> STM Int) -> STM Int
|
|
withSubscribed a = readTVar cv >>= maybe (throwSTM $ userError "service unsubscribed") a
|
|
flushSubscribedNtfs :: [(NotifierId, TVar [MsgNtf])] -> Client s' -> STM Int
|
|
flushSubscribedNtfs ntfs c@Client {sndQ} = do
|
|
ts_ <- foldM addNtfs [] ntfs
|
|
forM_ (L.nonEmpty ts_) $ \ts -> do
|
|
let cancelNtfs s = throwSTM $ userError $ s <> ", " <> show (length ts_) <> " ntfs kept"
|
|
unlessM (currentClient readTVar c) $ cancelNtfs "not current client"
|
|
whenM (isFullTBQueue sndQ) $ cancelNtfs "sending queue full"
|
|
writeTBQueue sndQ ts
|
|
pure $ length ts_
|
|
currentClient :: Monad m => (forall a. TVar a -> m a) -> Client s' -> m Bool
|
|
currentClient rd Client {clientId, connected} = (&&) <$> rd connected <*> (IS.member clientId <$> rd subClients)
|
|
addNtfs :: [Transmission BrokerMsg] -> (NotifierId, TVar [MsgNtf]) -> STM [Transmission BrokerMsg]
|
|
addNtfs acc (nId, v) =
|
|
readTVar v >>= \case
|
|
[] -> pure acc
|
|
ntfs -> do
|
|
writeTVar v []
|
|
pure $ foldl' (\acc' ntf -> nmsg nId ntf : acc') acc ntfs -- reverses, to order by time
|
|
nmsg nId MsgNtf {ntfNonce, ntfEncMeta} = (CorrId "", nId, NMSG ntfNonce ntfEncMeta)
|
|
updateNtfStats :: Client s' -> Either SomeException Int -> IO ()
|
|
updateNtfStats Client {clientId} = \case
|
|
Right 0 -> pure ()
|
|
Right len -> do
|
|
atomicModifyIORef'_ (ntfCount stats) (subtract len)
|
|
atomicModifyIORef'_ (msgNtfs stats) (+ len)
|
|
atomicModifyIORef'_ (msgNtfsB stats) (+ (len `div` 80 + 1)) -- up to 80 NMSG in the batch
|
|
Left e -> logNote $ "NOTIFICATIONS: cancelled for client #" <> tshow clientId <> ", reason: " <> tshow e
|
|
|
|
sendPendingEvtsThread :: Server s -> M s ()
|
|
sendPendingEvtsThread srv@Server {subscribers, ntfSubscribers} = do
|
|
endInt <- asks $ pendingENDInterval . config
|
|
stats <- asks serverStats
|
|
liftIO $ forever $ do
|
|
threadDelay endInt
|
|
sendPending subscribers stats
|
|
sendPending ntfSubscribers stats
|
|
where
|
|
sendPending ServerSubscribers {pendingEvents} stats = do
|
|
pending <- atomically $ swapTVar pendingEvents IM.empty
|
|
unless (null pending) $ forM_ (IM.assocs pending) $ \(cId, evts) ->
|
|
getServerClient cId srv >>= mapM_ (enqueueEvts evts)
|
|
where
|
|
enqueueEvts evts c@Client {connected, sndQ} =
|
|
whenM (readTVarIO connected) $ do
|
|
sent <- atomically $ tryWriteTBQueue sndQ ts
|
|
if sent
|
|
then updateEndStats
|
|
else -- if queue is full it can block
|
|
forkClient c ("sendPendingEvtsThread.queueEvts") $
|
|
atomically (writeTBQueue sndQ ts) >> updateEndStats
|
|
where
|
|
ts = L.map (\(entId, evt) -> (CorrId "", entId, evt)) evts
|
|
-- this accounts for both END and DELD events
|
|
updateEndStats = do
|
|
let len = L.length evts
|
|
when (len > 0) $ do
|
|
atomicModifyIORef'_ (qSubEnd stats) (+ len)
|
|
atomicModifyIORef'_ (qSubEndB stats) (+ (len `div` 255 + 1)) -- up to 255 ENDs or DELDs in the batch
|
|
|
|
receiveFromProxyAgent :: ProxyAgent -> M s ()
|
|
receiveFromProxyAgent ProxyAgent {smpAgent = SMPClientAgent {agentQ}} =
|
|
forever $
|
|
atomically (readTBQueue agentQ) >>= \case
|
|
CAConnected srv _service_ -> logInfo $ "SMP server connected " <> showServer' srv
|
|
CADisconnected srv qIds -> logError $ "SMP server disconnected " <> showServer' srv <> " / subscriptions: " <> tshow (length qIds)
|
|
-- the errors below should never happen - messaging proxy does not make any subscriptions
|
|
CASubscribed srv serviceId qIds -> logError $ "SMP server subscribed queues " <> asService <> showServer' srv <> " / subscriptions: " <> tshow (length qIds)
|
|
where
|
|
asService = if isJust serviceId then "as service " else ""
|
|
CASubError srv errs -> logError $ "SMP server subscription errors " <> showServer' srv <> " / errors: " <> tshow (length errs)
|
|
CAServiceDisconnected {} -> logError "CAServiceDisconnected"
|
|
CAServiceSubscribed {} -> logError "CAServiceSubscribed"
|
|
CAServiceSubError {} -> logError "CAServiceSubError"
|
|
CAServiceUnavailable {} -> logError "CAServiceUnavailable"
|
|
where
|
|
showServer' = decodeLatin1 . strEncode . host
|
|
|
|
expireMessagesThread_ :: ServerConfig s -> [M s ()]
|
|
expireMessagesThread_ ServerConfig {messageExpiration = Just msgExp} = [expireMessagesThread msgExp]
|
|
expireMessagesThread_ _ = []
|
|
|
|
expireMessagesThread :: ExpirationConfig -> M s ()
|
|
expireMessagesThread ExpirationConfig {checkInterval, ttl} = do
|
|
ms <- asks msgStore
|
|
let interval = checkInterval * 1000000
|
|
stats <- asks serverStats
|
|
labelMyThread "expireMessagesThread"
|
|
liftIO $ forever $ expire ms stats interval
|
|
where
|
|
expire :: s -> ServerStats -> Int64 -> IO ()
|
|
expire ms stats interval = do
|
|
threadDelay' interval
|
|
logNote "Started expiring messages..."
|
|
n <- compactQueues @(StoreQueue s) $ queueStore ms
|
|
when (n > 0) $ logNote $ "Removed " <> tshow n <> " old deleted queues from the database."
|
|
now <- systemSeconds <$> getSystemTime
|
|
tryAny (expireOldMessages False ms now ttl) >>= \case
|
|
Right msgStats@MessageStats {storedMsgsCount = stored, expiredMsgsCount = expired} -> do
|
|
atomicWriteIORef (msgCount stats) stored
|
|
atomicModifyIORef'_ (msgExpired stats) (+ expired)
|
|
printMessageStats "STORE: messages" msgStats
|
|
Left e -> logError $ "STORE: withAllMsgQueues, error expiring messages, " <> tshow e
|
|
|
|
expireNtfsThread :: ServerConfig s -> M s ()
|
|
expireNtfsThread ServerConfig {notificationExpiration = expCfg} = do
|
|
ns <- asks ntfStore
|
|
let interval = checkInterval expCfg * 1000000
|
|
stats <- asks serverStats
|
|
labelMyThread "expireNtfsThread"
|
|
liftIO $ forever $ do
|
|
threadDelay' interval
|
|
old <- expireBeforeEpoch expCfg
|
|
expired <- deleteExpiredNtfs ns old
|
|
when (expired > 0) $ do
|
|
atomicModifyIORef'_ (msgNtfExpired stats) (+ expired)
|
|
atomicModifyIORef'_ (ntfCount stats) (subtract expired)
|
|
|
|
serverStatsThread_ :: ServerConfig s -> [M s ()]
|
|
serverStatsThread_ ServerConfig {logStatsInterval = Just interval, logStatsStartTime, serverStatsLogFile} =
|
|
[logServerStats logStatsStartTime interval serverStatsLogFile]
|
|
serverStatsThread_ _ = []
|
|
|
|
logServerStats :: Int64 -> Int64 -> FilePath -> M s ()
|
|
logServerStats startAt logInterval statsFilePath = do
|
|
labelMyThread "logServerStats"
|
|
initialDelay <- (startAt -) . fromIntegral . (`div` 1000000_000000) . diffTimeToPicoseconds . utctDayTime <$> liftIO getCurrentTime
|
|
liftIO $ putStrLn $ "server stats log enabled: " <> statsFilePath
|
|
liftIO $ threadDelay' $ 1000000 * (initialDelay + if initialDelay < 0 then 86400 else 0)
|
|
ss@ServerStats {fromTime, qCreated, qSecured, qDeletedAll, qDeletedAllB, qDeletedNew, qDeletedSecured, qSub, qSubAllB, qSubAuth, qSubDuplicate, qSubProhibited, qSubEnd, qSubEndB, ntfCreated, ntfDeleted, ntfDeletedB, ntfSub, ntfSubB, ntfSubAuth, ntfSubDuplicate, msgSent, msgSentAuth, msgSentQuota, msgSentLarge, msgRecv, msgRecvGet, msgGet, msgGetNoMsg, msgGetAuth, msgGetDuplicate, msgGetProhibited, msgExpired, activeQueues, msgSentNtf, msgRecvNtf, activeQueuesNtf, qCount, msgCount, ntfCount, pRelays, pRelaysOwn, pMsgFwds, pMsgFwdsOwn, pMsgFwdsRecv, rcvServices, ntfServices}
|
|
<- asks serverStats
|
|
st <- asks msgStore
|
|
EntityCounts {queueCount, notifierCount, rcvServiceCount, ntfServiceCount, rcvServiceQueuesCount, ntfServiceQueuesCount} <-
|
|
liftIO $ getEntityCounts @(StoreQueue s) $ queueStore st
|
|
let interval = 1000000 * logInterval
|
|
forever $ do
|
|
withFile statsFilePath AppendMode $ \h -> liftIO $ do
|
|
hSetBuffering h LineBuffering
|
|
ts <- getCurrentTime
|
|
fromTime' <- atomicSwapIORef fromTime ts
|
|
qCreated' <- atomicSwapIORef qCreated 0
|
|
qSecured' <- atomicSwapIORef qSecured 0
|
|
qDeletedAll' <- atomicSwapIORef qDeletedAll 0
|
|
qDeletedAllB' <- atomicSwapIORef qDeletedAllB 0
|
|
qDeletedNew' <- atomicSwapIORef qDeletedNew 0
|
|
qDeletedSecured' <- atomicSwapIORef qDeletedSecured 0
|
|
qSub' <- atomicSwapIORef qSub 0
|
|
qSubAllB' <- atomicSwapIORef qSubAllB 0
|
|
qSubAuth' <- atomicSwapIORef qSubAuth 0
|
|
qSubDuplicate' <- atomicSwapIORef qSubDuplicate 0
|
|
qSubProhibited' <- atomicSwapIORef qSubProhibited 0
|
|
qSubEnd' <- atomicSwapIORef qSubEnd 0
|
|
qSubEndB' <- atomicSwapIORef qSubEndB 0
|
|
ntfCreated' <- atomicSwapIORef ntfCreated 0
|
|
ntfDeleted' <- atomicSwapIORef ntfDeleted 0
|
|
ntfDeletedB' <- atomicSwapIORef ntfDeletedB 0
|
|
ntfSub' <- atomicSwapIORef ntfSub 0
|
|
ntfSubB' <- atomicSwapIORef ntfSubB 0
|
|
ntfSubAuth' <- atomicSwapIORef ntfSubAuth 0
|
|
ntfSubDuplicate' <- atomicSwapIORef ntfSubDuplicate 0
|
|
msgSent' <- atomicSwapIORef msgSent 0
|
|
msgSentAuth' <- atomicSwapIORef msgSentAuth 0
|
|
msgSentQuota' <- atomicSwapIORef msgSentQuota 0
|
|
msgSentLarge' <- atomicSwapIORef msgSentLarge 0
|
|
msgRecv' <- atomicSwapIORef msgRecv 0
|
|
msgRecvGet' <- atomicSwapIORef msgRecvGet 0
|
|
msgGet' <- atomicSwapIORef msgGet 0
|
|
msgGetNoMsg' <- atomicSwapIORef msgGetNoMsg 0
|
|
msgGetAuth' <- atomicSwapIORef msgGetAuth 0
|
|
msgGetDuplicate' <- atomicSwapIORef msgGetDuplicate 0
|
|
msgGetProhibited' <- atomicSwapIORef msgGetProhibited 0
|
|
msgExpired' <- atomicSwapIORef msgExpired 0
|
|
ps <- liftIO $ periodStatCounts activeQueues ts
|
|
msgSentNtf' <- atomicSwapIORef msgSentNtf 0
|
|
msgRecvNtf' <- atomicSwapIORef msgRecvNtf 0
|
|
psNtf <- liftIO $ periodStatCounts activeQueuesNtf ts
|
|
msgNtfs' <- atomicSwapIORef (msgNtfs ss) 0
|
|
msgNtfsB' <- atomicSwapIORef (msgNtfsB ss) 0
|
|
msgNtfNoSub' <- atomicSwapIORef (msgNtfNoSub ss) 0
|
|
msgNtfLost' <- atomicSwapIORef (msgNtfLost ss) 0
|
|
msgNtfExpired' <- atomicSwapIORef (msgNtfExpired ss) 0
|
|
pRelays' <- getResetProxyStatsData pRelays
|
|
pRelaysOwn' <- getResetProxyStatsData pRelaysOwn
|
|
pMsgFwds' <- getResetProxyStatsData pMsgFwds
|
|
pMsgFwdsOwn' <- getResetProxyStatsData pMsgFwdsOwn
|
|
pMsgFwdsRecv' <- atomicSwapIORef pMsgFwdsRecv 0
|
|
rcvServices' <- getServiceStatsData rcvServices
|
|
ntfServices' <- getServiceStatsData ntfServices
|
|
qCount' <- readIORef qCount
|
|
msgCount' <- readIORef msgCount
|
|
ntfCount' <- readIORef ntfCount
|
|
T.hPutStrLn h $
|
|
T.intercalate
|
|
","
|
|
( [ T.pack $ iso8601Show $ utctDay fromTime',
|
|
tshow qCreated',
|
|
tshow qSecured',
|
|
tshow qDeletedAll',
|
|
tshow msgSent',
|
|
tshow msgRecv',
|
|
dayCount ps,
|
|
weekCount ps,
|
|
monthCount ps,
|
|
tshow msgSentNtf',
|
|
tshow msgRecvNtf',
|
|
dayCount psNtf,
|
|
weekCount psNtf,
|
|
monthCount psNtf,
|
|
tshow qCount',
|
|
tshow msgCount',
|
|
tshow msgExpired',
|
|
tshow qDeletedNew',
|
|
tshow qDeletedSecured'
|
|
]
|
|
<> showProxyStats pRelays'
|
|
<> showProxyStats pRelaysOwn'
|
|
<> showProxyStats pMsgFwds'
|
|
<> showProxyStats pMsgFwdsOwn'
|
|
<> [ tshow pMsgFwdsRecv',
|
|
tshow qSub',
|
|
tshow qSubAuth',
|
|
tshow qSubDuplicate',
|
|
tshow qSubProhibited',
|
|
tshow msgSentAuth',
|
|
tshow msgSentQuota',
|
|
tshow msgSentLarge',
|
|
tshow msgNtfs',
|
|
tshow msgNtfNoSub',
|
|
tshow msgNtfLost',
|
|
"0", -- qSubNoMsg' is removed for performance.
|
|
-- Use qSubAllB for the approximate number of all subscriptions.
|
|
-- Average observed batch size is 25-30 subscriptions.
|
|
tshow msgRecvGet',
|
|
tshow msgGet',
|
|
tshow msgGetNoMsg',
|
|
tshow msgGetAuth',
|
|
tshow msgGetDuplicate',
|
|
tshow msgGetProhibited',
|
|
"0", -- dayCount psSub; psSub is removed to reduce memory usage
|
|
"0", -- weekCount psSub
|
|
"0", -- monthCount psSub
|
|
tshow queueCount,
|
|
tshow ntfCreated',
|
|
tshow ntfDeleted',
|
|
tshow ntfSub',
|
|
tshow ntfSubAuth',
|
|
tshow ntfSubDuplicate',
|
|
tshow notifierCount,
|
|
tshow qDeletedAllB',
|
|
tshow qSubAllB',
|
|
tshow qSubEnd',
|
|
tshow qSubEndB',
|
|
tshow ntfDeletedB',
|
|
tshow ntfSubB',
|
|
tshow msgNtfsB',
|
|
tshow msgNtfExpired',
|
|
tshow ntfCount',
|
|
tshow rcvServiceCount,
|
|
tshow ntfServiceCount,
|
|
tshow rcvServiceQueuesCount,
|
|
tshow ntfServiceQueuesCount
|
|
]
|
|
<> showServiceStats rcvServices'
|
|
<> showServiceStats ntfServices'
|
|
)
|
|
liftIO $ threadDelay' interval
|
|
where
|
|
showProxyStats ProxyStatsData {_pRequests, _pSuccesses, _pErrorsConnect, _pErrorsCompat, _pErrorsOther} =
|
|
map tshow [_pRequests, _pSuccesses, _pErrorsConnect, _pErrorsCompat, _pErrorsOther]
|
|
showServiceStats ServiceStatsData {_srvAssocNew, _srvAssocDuplicate, _srvAssocUpdated, _srvAssocRemoved, _srvSubCount, _srvSubDuplicate, _srvSubQueues, _srvSubEnd} =
|
|
map tshow [_srvAssocNew, _srvAssocDuplicate, _srvAssocUpdated, _srvAssocRemoved, _srvSubCount, _srvSubDuplicate, _srvSubQueues, _srvSubEnd]
|
|
|
|
prometheusMetricsThread_ :: ServerConfig s -> [M s ()]
|
|
prometheusMetricsThread_ ServerConfig {prometheusInterval = Just interval, prometheusMetricsFile} =
|
|
[savePrometheusMetrics interval prometheusMetricsFile]
|
|
prometheusMetricsThread_ _ = []
|
|
|
|
savePrometheusMetrics :: Int -> FilePath -> M s ()
|
|
savePrometheusMetrics saveInterval metricsFile = do
|
|
labelMyThread "savePrometheusMetrics"
|
|
liftIO $ putStrLn $ "Prometheus metrics saved every " <> show saveInterval <> " seconds to " <> metricsFile
|
|
st <- asks msgStore
|
|
ss <- asks serverStats
|
|
env <- ask
|
|
rtsOpts <- liftIO $ maybe ("set " <> rtsOptionsEnv) T.pack <$> lookupEnv (T.unpack rtsOptionsEnv)
|
|
let interval = 1000000 * saveInterval
|
|
liftIO $ forever $ do
|
|
threadDelay interval
|
|
ts <- getCurrentTime
|
|
sm <- getServerMetrics st ss rtsOpts
|
|
rtm <- getRealTimeMetrics env
|
|
T.writeFile metricsFile $ prometheusMetrics sm rtm ts
|
|
|
|
getServerMetrics :: s -> ServerStats -> Text -> IO ServerMetrics
|
|
getServerMetrics st ss rtsOptions = do
|
|
d <- getServerStatsData ss
|
|
let ps = periodStatDataCounts $ _activeQueues d
|
|
psNtf = periodStatDataCounts $ _activeQueuesNtf d
|
|
entityCounts <- getEntityCounts @(StoreQueue s) $ queueStore st
|
|
pure ServerMetrics {statsData = d, activeQueueCounts = ps, activeNtfCounts = psNtf, entityCounts, rtsOptions}
|
|
|
|
getRealTimeMetrics :: Env s -> IO RealTimeMetrics
|
|
getRealTimeMetrics Env {sockets, msgStore_ = ms, server = srv@Server {subscribers, ntfSubscribers}} = do
|
|
socketStats <- mapM (traverse getSocketStats) =<< readTVarIO sockets
|
|
#if MIN_VERSION_base(4,18,0)
|
|
threadsCount <- length <$> listThreads
|
|
#else
|
|
let threadsCount = 0
|
|
#endif
|
|
clientsCount <- IM.size <$> getServerClients srv
|
|
(deliveredSubs, deliveredTimes) <- getDeliveredMetrics =<< getSystemSeconds
|
|
smpSubs <- getSubscribersMetrics subscribers
|
|
ntfSubs <- getSubscribersMetrics ntfSubscribers
|
|
loadedCounts <- loadedQueueCounts $ fromMsgStore ms
|
|
pure RealTimeMetrics {socketStats, threadsCount, clientsCount, deliveredSubs, deliveredTimes, smpSubs, ntfSubs, loadedCounts}
|
|
where
|
|
getSubscribersMetrics ServerSubscribers {queueSubscribers, serviceSubscribers, subClients} = do
|
|
subsCount <- M.size <$> getSubscribedClients queueSubscribers
|
|
subClientsCount <- IS.size <$> readTVarIO subClients
|
|
subServicesCount <- M.size <$> getSubscribedClients serviceSubscribers
|
|
pure RTSubscriberMetrics {subsCount, subClientsCount, subServicesCount}
|
|
getDeliveredMetrics ts' = foldM countClnt (RTSubscriberMetrics 0 0 0, emptyTimeBuckets) =<< getServerClients srv
|
|
where
|
|
countClnt acc@(metrics, times) Client {subscriptions} = do
|
|
(cnt, times') <- foldM countSubs (0, times) =<< readTVarIO subscriptions
|
|
pure $ if cnt > 0
|
|
then (metrics {subsCount = subsCount metrics + cnt, subClientsCount = subClientsCount metrics + 1}, times')
|
|
else acc
|
|
countSubs acc@(!cnt, times) Sub {delivered} = do
|
|
delivered_ <- atomically $ tryReadTMVar delivered
|
|
pure $ case delivered_ of
|
|
Nothing -> acc
|
|
Just (_, ts) -> (cnt + 1, updateTimeBuckets ts ts' times)
|
|
|
|
runClient :: Transport c => X.CertificateChain -> C.APrivateSignKey -> TProxy c 'TServer -> c 'TServer -> M s ()
|
|
runClient srvCert srvSignKey tp h = do
|
|
ms <- asks msgStore
|
|
g <- asks random
|
|
idSize <- asks $ queueIdBytes . config
|
|
kh <- asks serverIdentity
|
|
ks <- atomically . C.generateKeyPair =<< asks random
|
|
ServerConfig {smpServerVRange, smpHandshakeTimeout} <- asks config
|
|
labelMyThread $ "smp handshake for " <> transportName tp
|
|
liftIO (timeout smpHandshakeTimeout . runExceptT $ smpServerHandshake srvCert srvSignKey h ks kh smpServerVRange $ getClientService ms g idSize) >>= \case
|
|
Just (Right th) -> runClientTransport th
|
|
_ -> pure ()
|
|
|
|
getClientService :: s -> TVar ChaChaDRG -> Int -> SMPServiceRole -> X.CertificateChain -> XV.Fingerprint -> ExceptT TransportError IO ServiceId
|
|
getClientService ms g idSize role cert fp = do
|
|
newServiceId <- EntityId <$> atomically (C.randomBytes idSize g)
|
|
ts <- liftIO getSystemDate
|
|
let sr = ServiceRec {serviceId = newServiceId, serviceRole = role, serviceCert = cert, serviceCertHash = fp, serviceCreatedAt = ts}
|
|
withExceptT (const $ TEHandshake BAD_SERVICE) $ ExceptT $
|
|
getCreateService @(StoreQueue s) (queueStore ms) sr
|
|
|
|
controlPortThread_ :: ServerConfig s -> [M s ()]
|
|
controlPortThread_ ServerConfig {controlPort = Just port} = [runCPServer port]
|
|
controlPortThread_ _ = []
|
|
|
|
runCPServer :: ServiceName -> M s ()
|
|
runCPServer port = do
|
|
srv <- asks server
|
|
cpStarted <- newEmptyTMVarIO
|
|
u <- askUnliftIO
|
|
liftIO $ do
|
|
labelMyThread "control port server"
|
|
runLocalTCPServer cpStarted port $ runCPClient u srv
|
|
where
|
|
runCPClient :: UnliftIO (ReaderT (Env s) IO) -> Server s -> Socket -> IO ()
|
|
runCPClient u srv sock = do
|
|
labelMyThread "control port client"
|
|
h <- socketToHandle sock ReadWriteMode
|
|
hSetBuffering h LineBuffering
|
|
hSetNewlineMode h universalNewlineMode
|
|
hPutStrLn h "SMP server control port\n'help' for supported commands"
|
|
role <- newTVarIO CPRNone
|
|
cpLoop h role
|
|
where
|
|
cpLoop h role = do
|
|
s <- trimCR <$> B.hGetLine h
|
|
case strDecode s of
|
|
Right CPQuit -> hClose h
|
|
Right cmd -> logCmd s cmd >> processCP h role cmd >> cpLoop h role
|
|
Left err -> hPutStrLn h ("error: " <> err) >> cpLoop h role
|
|
logCmd s cmd = when shouldLog $ logWarn $ "ControlPort: " <> tshow s
|
|
where
|
|
shouldLog = case cmd of
|
|
CPAuth _ -> False
|
|
CPHelp -> False
|
|
CPQuit -> False
|
|
CPSkip -> False
|
|
_ -> True
|
|
processCP h role = \case
|
|
CPAuth auth -> atomically $ writeTVar role $! newRole cfg
|
|
where
|
|
newRole ServerConfig {controlPortUserAuth = user, controlPortAdminAuth = admin}
|
|
| Just auth == admin = CPRAdmin
|
|
| Just auth == user = CPRUser
|
|
| otherwise = CPRNone
|
|
CPSuspend -> withAdminRole $ hPutStrLn h "suspend not implemented"
|
|
CPResume -> withAdminRole $ hPutStrLn h "resume not implemented"
|
|
CPClients -> withAdminRole $ do
|
|
cls <- getServerClients srv
|
|
hPutStrLn h "clientId,sessionId,connected,createdAt,rcvActiveAt,sndActiveAt,age,subscriptions"
|
|
forM_ (IM.toList cls) $ \(cid, Client {clientTHParams = THandleParams {sessionId}, connected, createdAt, rcvActiveAt, sndActiveAt, subscriptions}) -> do
|
|
connected' <- bshow <$> readTVarIO connected
|
|
rcvActiveAt' <- strEncode <$> readTVarIO rcvActiveAt
|
|
sndActiveAt' <- strEncode <$> readTVarIO sndActiveAt
|
|
now <- liftIO getSystemTime
|
|
let age = systemSeconds now - systemSeconds createdAt
|
|
subscriptions' <- bshow . M.size <$> readTVarIO subscriptions
|
|
hPutStrLn h . B.unpack $ B.intercalate "," [bshow cid, encode sessionId, connected', strEncode createdAt, rcvActiveAt', sndActiveAt', bshow age, subscriptions']
|
|
CPStats -> withUserRole $ do
|
|
ss <- unliftIO u $ asks serverStats
|
|
st <- unliftIO u $ asks msgStore
|
|
EntityCounts {queueCount, notifierCount, rcvServiceCount, ntfServiceCount, rcvServiceQueuesCount, ntfServiceQueuesCount} <-
|
|
getEntityCounts @(StoreQueue s) $ queueStore st
|
|
let getStat :: (ServerStats -> IORef a) -> IO a
|
|
getStat var = readIORef (var ss)
|
|
putStat :: Show a => String -> (ServerStats -> IORef a) -> IO ()
|
|
putStat label var = getStat var >>= \v -> hPutStrLn h $ label <> ": " <> show v
|
|
putProxyStat :: String -> (ServerStats -> ProxyStats) -> IO ()
|
|
putProxyStat label var = do
|
|
ProxyStatsData {_pRequests, _pSuccesses, _pErrorsConnect, _pErrorsCompat, _pErrorsOther} <- getProxyStatsData $ var ss
|
|
hPutStrLn h $ label <> ": requests=" <> show _pRequests <> ", successes=" <> show _pSuccesses <> ", errorsConnect=" <> show _pErrorsConnect <> ", errorsCompat=" <> show _pErrorsCompat <> ", errorsOther=" <> show _pErrorsOther
|
|
putStat "fromTime" fromTime
|
|
putStat "qCreated" qCreated
|
|
putStat "qSecured" qSecured
|
|
putStat "qDeletedAll" qDeletedAll
|
|
putStat "qDeletedAllB" qDeletedAllB
|
|
putStat "qDeletedNew" qDeletedNew
|
|
putStat "qDeletedSecured" qDeletedSecured
|
|
getStat (day . activeQueues) >>= \v -> hPutStrLn h $ "daily active queues: " <> show (IS.size v)
|
|
-- removed to reduce memory usage
|
|
-- getStat (day . subscribedQueues) >>= \v -> hPutStrLn h $ "daily subscribed queues: " <> show (S.size v)
|
|
putStat "qSub" qSub
|
|
putStat "qSubAllB" qSubAllB
|
|
putStat "qSubEnd" qSubEnd
|
|
putStat "qSubEndB" qSubEndB
|
|
subs <- (,,) <$> getStat qSubAuth <*> getStat qSubDuplicate <*> getStat qSubProhibited
|
|
hPutStrLn h $ "other SUB events (auth, duplicate, prohibited): " <> show subs
|
|
putStat "msgSent" msgSent
|
|
putStat "msgRecv" msgRecv
|
|
putStat "msgRecvGet" msgRecvGet
|
|
putStat "msgGet" msgGet
|
|
putStat "msgGetNoMsg" msgGetNoMsg
|
|
gets <- (,,) <$> getStat msgGetAuth <*> getStat msgGetDuplicate <*> getStat msgGetProhibited
|
|
hPutStrLn h $ "other GET events (auth, duplicate, prohibited): " <> show gets
|
|
putStat "msgSentNtf" msgSentNtf
|
|
putStat "msgRecvNtf" msgRecvNtf
|
|
putStat "msgNtfs" msgNtfs
|
|
putStat "msgNtfsB" msgNtfsB
|
|
putStat "msgNtfExpired" msgNtfExpired
|
|
putStat "qCount" qCount
|
|
hPutStrLn h $ "qCount 2: " <> show queueCount
|
|
hPutStrLn h $ "notifiers: " <> show notifierCount
|
|
putStat "msgCount" msgCount
|
|
putStat "ntfCount" ntfCount
|
|
readTVarIO role >>= \case
|
|
CPRAdmin -> do
|
|
NtfStore ns <- unliftIO u $ asks ntfStore
|
|
ntfCount2 <- liftIO . foldM (\(!n) q -> (n +) . length <$> readTVarIO q) 0 =<< readTVarIO ns
|
|
hPutStrLn h $ "ntfCount 2: " <> show ntfCount2
|
|
_ -> pure ()
|
|
putProxyStat "pRelays" pRelays
|
|
putProxyStat "pRelaysOwn" pRelaysOwn
|
|
putProxyStat "pMsgFwds" pMsgFwds
|
|
putProxyStat "pMsgFwdsOwn" pMsgFwdsOwn
|
|
putStat "pMsgFwdsRecv" pMsgFwdsRecv
|
|
hPutStrLn h $ "rcvServiceCount: " <> show rcvServiceCount
|
|
hPutStrLn h $ "ntfServiceCount: " <> show ntfServiceCount
|
|
hPutStrLn h $ "rcvServiceQueuesCount: " <> show rcvServiceQueuesCount
|
|
hPutStrLn h $ "ntfServiceQueuesCount: " <> show ntfServiceQueuesCount
|
|
CPStatsRTS -> getRTSStats >>= hPrint h
|
|
CPThreads -> withAdminRole $ do
|
|
#if MIN_VERSION_base(4,18,0)
|
|
threads <- liftIO listThreads
|
|
hPutStrLn h $ "Threads: " <> show (length threads)
|
|
forM_ (sort threads) $ \tid -> do
|
|
label <- threadLabel tid
|
|
status <- threadStatus tid
|
|
hPutStrLn h $ show tid <> " (" <> show status <> ") " <> fromMaybe "" label
|
|
#else
|
|
hPutStrLn h "Not available on GHC 8.10"
|
|
#endif
|
|
CPSockets -> withUserRole $ unliftIO u (asks sockets) >>= readTVarIO >>= mapM_ putSockets
|
|
where
|
|
putSockets (tcpPort, socketsState) = do
|
|
ss <- getSocketStats socketsState
|
|
hPutStrLn h $ "Sockets for port " <> tcpPort <> ":"
|
|
hPutStrLn h $ "accepted: " <> show (socketsAccepted ss)
|
|
hPutStrLn h $ "closed: " <> show (socketsClosed ss)
|
|
hPutStrLn h $ "active: " <> show (socketsActive ss)
|
|
hPutStrLn h $ "leaked: " <> show (socketsLeaked ss)
|
|
CPSocketThreads -> withAdminRole $ do
|
|
#if MIN_VERSION_base(4,18,0)
|
|
unliftIO u (asks sockets) >>= readTVarIO >>= mapM_ putSocketThreads
|
|
where
|
|
putSocketThreads (tcpPort, (_, _, active')) = do
|
|
active <- readTVarIO active'
|
|
forM_ (IM.toList active) $ \(sid, tid') ->
|
|
deRefWeak tid' >>= \case
|
|
Nothing -> hPutStrLn h $ intercalate "," [tcpPort, show sid, "", "gone", ""]
|
|
Just tid -> do
|
|
label <- threadLabel tid
|
|
status <- threadStatus tid
|
|
hPutStrLn h $ intercalate "," [tcpPort, show sid, show tid, show status, fromMaybe "" label]
|
|
#else
|
|
hPutStrLn h "Not available on GHC 8.10"
|
|
#endif
|
|
CPServerInfo -> readTVarIO role >>= \case
|
|
CPRNone -> do
|
|
logError "Unauthorized control port command"
|
|
hPutStrLn h "AUTH"
|
|
r -> do
|
|
#if MIN_VERSION_base(4,18,0)
|
|
threads <- liftIO listThreads
|
|
hPutStrLn h $ "Threads: " <> show (length threads)
|
|
#else
|
|
hPutStrLn h "Threads: not available on GHC 8.10"
|
|
#endif
|
|
let Server {subscribers, ntfSubscribers} = srv
|
|
activeClients <- getServerClients srv
|
|
hPutStrLn h $ "Clients: " <> show (IM.size activeClients)
|
|
when (r == CPRAdmin) $ do
|
|
clQs <- clientTBQueueLengths' activeClients
|
|
hPutStrLn h $ "Client queues (rcvQ, sndQ, msgQ): " <> show clQs
|
|
(smpSubCnt, smpSubCntByGroup, smpClCnt, smpClQs) <- countClientSubs subscriptions (Just countSMPSubs) activeClients
|
|
hPutStrLn h $ "SMP subscriptions (via clients): " <> show smpSubCnt
|
|
hPutStrLn h $ "SMP subscriptions (by group: NoSub, SubPending, SubThread, ProhibitSub): " <> show smpSubCntByGroup
|
|
hPutStrLn h $ "SMP subscribed clients (via clients): " <> show smpClCnt
|
|
hPutStrLn h $ "SMP subscribed clients queues (via clients, rcvQ, sndQ, msgQ): " <> show smpClQs
|
|
(ntfSubCnt, _, ntfClCnt, ntfClQs) <- countClientSubs ntfSubscriptions Nothing activeClients
|
|
hPutStrLn h $ "Ntf subscriptions (via clients): " <> show ntfSubCnt
|
|
hPutStrLn h $ "Ntf subscribed clients (via clients): " <> show ntfClCnt
|
|
hPutStrLn h $ "Ntf subscribed clients queues (via clients, rcvQ, sndQ, msgQ): " <> show ntfClQs
|
|
putSubscribersInfo "SMP" subscribers False
|
|
putSubscribersInfo "Ntf" ntfSubscribers True
|
|
where
|
|
putSubscribersInfo :: String -> ServerSubscribers s -> Bool -> IO ()
|
|
putSubscribersInfo protoName ServerSubscribers {queueSubscribers, subClients} showIds = do
|
|
activeSubs <- getSubscribedClients queueSubscribers
|
|
hPutStrLn h $ protoName <> " subscriptions: " <> show (M.size activeSubs)
|
|
-- TODO [certs] service subscriptions
|
|
clnts <- countSubClients activeSubs
|
|
hPutStrLn h $ protoName <> " subscribed clients: " <> show (IS.size clnts) <> (if showIds then " " <> show (IS.toList clnts) else "")
|
|
clnts' <- readTVarIO subClients
|
|
hPutStrLn h $ protoName <> " subscribed clients count 2: " <> show (IS.size clnts') <> (if showIds then " " <> show clnts' else "")
|
|
where
|
|
countSubClients :: Map QueueId (TVar (Maybe (Client s))) -> IO IS.IntSet
|
|
countSubClients = foldM (\ !s c -> maybe s ((`IS.insert` s) . clientId) <$> readTVarIO c) IS.empty
|
|
countClientSubs :: (Client s -> TMap QueueId a) -> Maybe (Map QueueId a -> IO (Int, Int, Int, Int)) -> IM.IntMap (Client s) -> IO (Int, (Int, Int, Int, Int), Int, (Natural, Natural, Natural))
|
|
countClientSubs subSel countSubs_ = foldM addSubs (0, (0, 0, 0, 0), 0, (0, 0, 0))
|
|
where
|
|
addSubs :: (Int, (Int, Int, Int, Int), Int, (Natural, Natural, Natural)) -> Client s -> IO (Int, (Int, Int, Int, Int), Int, (Natural, Natural, Natural))
|
|
addSubs (!subCnt, cnts@(!c1, !c2, !c3, !c4), !clCnt, !qs) cl = do
|
|
subs <- readTVarIO $ subSel cl
|
|
cnts' <- case countSubs_ of
|
|
Nothing -> pure cnts
|
|
Just countSubs -> do
|
|
(c1', c2', c3', c4') <- countSubs subs
|
|
pure (c1 + c1', c2 + c2', c3 + c3', c4 + c4')
|
|
let cnt = M.size subs
|
|
clCnt' = if cnt == 0 then clCnt else clCnt + 1
|
|
qs' <- if cnt == 0 then pure qs else addQueueLengths qs cl
|
|
pure (subCnt + cnt, cnts', clCnt', qs')
|
|
clientTBQueueLengths' :: Foldable t => t (Client s) -> IO (Natural, Natural, Natural)
|
|
clientTBQueueLengths' = foldM addQueueLengths (0, 0, 0)
|
|
addQueueLengths (!rl, !sl, !ml) cl = do
|
|
(rl', sl', ml') <- queueLengths cl
|
|
pure (rl + rl', sl + sl', ml + ml')
|
|
queueLengths Client {rcvQ, sndQ, msgQ} = do
|
|
rl <- atomically $ lengthTBQueue rcvQ
|
|
sl <- atomically $ lengthTBQueue sndQ
|
|
ml <- atomically $ lengthTBQueue msgQ
|
|
pure (rl, sl, ml)
|
|
countSMPSubs :: Map QueueId Sub -> IO (Int, Int, Int, Int)
|
|
countSMPSubs = foldM countSubs (0, 0, 0, 0)
|
|
where
|
|
countSubs (c1, c2, c3, c4) Sub {subThread} = case subThread of
|
|
ServerSub t -> do
|
|
st <- readTVarIO t
|
|
pure $ case st of
|
|
NoSub -> (c1 + 1, c2, c3, c4)
|
|
SubPending -> (c1, c2 + 1, c3, c4)
|
|
SubThread _ -> (c1, c2, c3 + 1, c4)
|
|
ProhibitSub -> pure (c1, c2, c3, c4 + 1)
|
|
CPDelete sId -> withUserRole $ unliftIO u $ do
|
|
st <- asks msgStore
|
|
r <- liftIO $ runExceptT $ do
|
|
q <- ExceptT $ getQueue st SSender sId
|
|
ExceptT $ deleteQueueSize st q
|
|
case r of
|
|
Left e -> liftIO $ hPutStrLn h $ "error: " <> show e
|
|
Right (qr, numDeleted) -> do
|
|
updateDeletedStats qr
|
|
liftIO $ hPutStrLn h $ "ok, " <> show numDeleted <> " messages deleted"
|
|
CPStatus sId -> withUserRole $ unliftIO u $ do
|
|
st <- asks msgStore
|
|
q <- liftIO $ getQueueRec st SSender sId
|
|
liftIO $ hPutStrLn h $ case q of
|
|
Left e -> "error: " <> show e
|
|
Right (_, QueueRec {queueMode, status, updatedAt}) ->
|
|
"status: " <> show status <> ", updatedAt: " <> show updatedAt <> ", queueMode: " <> show queueMode
|
|
CPBlock sId info -> withUserRole $ unliftIO u $ do
|
|
st <- asks msgStore
|
|
r <- liftIO $ runExceptT $ do
|
|
q <- ExceptT $ getQueue st SSender sId
|
|
ExceptT $ blockQueue (queueStore st) q info
|
|
case r of
|
|
Left e -> liftIO $ hPutStrLn h $ "error: " <> show e
|
|
Right () -> do
|
|
incStat . qBlocked =<< asks serverStats
|
|
liftIO $ hPutStrLn h "ok"
|
|
CPUnblock sId -> withUserRole $ unliftIO u $ do
|
|
st <- asks msgStore
|
|
r <- liftIO $ runExceptT $ do
|
|
q <- ExceptT $ getQueue st SSender sId
|
|
ExceptT $ unblockQueue (queueStore st) q
|
|
liftIO $ hPutStrLn h $ case r of
|
|
Left e -> "error: " <> show e
|
|
Right () -> "ok"
|
|
CPSave -> withAdminRole $ withLock' (savingLock srv) "control" $ do
|
|
hPutStrLn h "saving server state..."
|
|
unliftIO u $ saveServer False
|
|
hPutStrLn h "server state saved!"
|
|
CPHelp -> hPutStrLn h "commands: stats, stats-rts, clients, sockets, socket-threads, threads, server-info, delete, save, help, quit"
|
|
CPQuit -> pure ()
|
|
CPSkip -> pure ()
|
|
where
|
|
withUserRole action = readTVarIO role >>= \case
|
|
CPRAdmin -> action
|
|
CPRUser -> action
|
|
_ -> do
|
|
logError "Unauthorized control port command"
|
|
hPutStrLn h "AUTH"
|
|
withAdminRole action = readTVarIO role >>= \case
|
|
CPRAdmin -> action
|
|
_ -> do
|
|
logError "Unauthorized control port command"
|
|
hPutStrLn h "AUTH"
|
|
|
|
runClientTransport :: forall c s. (Transport c, MsgStoreClass s) => THandleSMP c 'TServer -> M s ()
|
|
runClientTransport h@THandle {params = thParams@THandleParams {sessionId}} = do
|
|
q <- asks $ tbqSize . config
|
|
ts <- liftIO getSystemTime
|
|
nextClientId <- asks clientSeq
|
|
clientId <- atomically $ stateTVar nextClientId $ \next -> (next, next + 1)
|
|
c <- liftIO $ newClient clientId q thParams ts
|
|
runClientThreads c `finally` clientDisconnected c
|
|
where
|
|
runClientThreads :: Client s -> M s ()
|
|
runClientThreads c = do
|
|
s <- asks server
|
|
ms <- asks msgStore
|
|
whenM (liftIO $ insertServerClient c s) $ do
|
|
expCfg <- asks $ inactiveClientExpiration . config
|
|
th <- newMVar h -- put TH under a fair lock to interleave messages and command responses
|
|
labelMyThread . B.unpack $ "client $" <> encode sessionId
|
|
raceAny_ $ [liftIO $ send th c, liftIO $ sendMsg th c, client s ms c, receive h ms c] <> disconnectThread_ c s expCfg
|
|
disconnectThread_ :: Client s -> Server s -> Maybe ExpirationConfig -> [M s ()]
|
|
disconnectThread_ c s (Just expCfg) = [liftIO $ disconnectTransport h (rcvActiveAt c) (sndActiveAt c) expCfg (noSubscriptions c s)]
|
|
disconnectThread_ _ _ _ = []
|
|
noSubscriptions Client {clientId} s =
|
|
not <$> anyM [hasSubs (subscribers s), hasSubs (ntfSubscribers s)]
|
|
where
|
|
hasSubs ServerSubscribers {subClients} = IS.member clientId <$> readTVarIO subClients
|
|
|
|
clientDisconnected :: forall s. Client s -> M s ()
|
|
clientDisconnected c@Client {clientId, subscriptions, ntfSubscriptions, serviceSubsCount, ntfServiceSubsCount, connected, clientTHParams = THandleParams {sessionId, thAuth}, endThreads} = do
|
|
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " disc"
|
|
-- these can be in separate transactions,
|
|
-- because the client already disconnected and they won't change
|
|
atomically $ writeTVar connected False
|
|
subs <- atomically $ swapTVar subscriptions M.empty
|
|
ntfSubs <- atomically $ swapTVar ntfSubscriptions M.empty
|
|
liftIO $ mapM_ cancelSub subs
|
|
whenM (asks serverActive >>= readTVarIO) $ do
|
|
srv@Server {subscribers, ntfSubscribers} <- asks server
|
|
liftIO $ do
|
|
deleteServerClient clientId srv
|
|
updateSubscribers subs subscribers
|
|
updateSubscribers ntfSubs ntfSubscribers
|
|
case peerClientService =<< thAuth of
|
|
Just THClientService {serviceId, serviceRole}
|
|
| serviceRole == SRMessaging -> updateServiceSubs serviceId serviceSubsCount subscribers
|
|
| serviceRole == SRNotifier -> updateServiceSubs serviceId ntfServiceSubsCount ntfSubscribers
|
|
_ -> pure ()
|
|
tIds <- atomically $ swapTVar endThreads IM.empty
|
|
liftIO $ mapM_ (mapM_ killThread <=< deRefWeak) tIds
|
|
where
|
|
updateSubscribers :: Map QueueId a -> ServerSubscribers s -> IO ()
|
|
updateSubscribers subs ServerSubscribers {queueSubscribers, subClients} = do
|
|
mapM_ (\qId -> deleteSubcribedClient qId c queueSubscribers) (M.keys subs)
|
|
atomically $ modifyTVar' subClients $ IS.delete clientId
|
|
updateServiceSubs :: ServiceId -> TVar Int64 -> ServerSubscribers s -> IO ()
|
|
updateServiceSubs serviceId subsCount ServerSubscribers {totalServiceSubs, serviceSubscribers} = do
|
|
deleteSubcribedClient serviceId c serviceSubscribers
|
|
atomically . modifyTVar' totalServiceSubs . subtract =<< readTVarIO subsCount
|
|
|
|
cancelSub :: Sub -> IO ()
|
|
cancelSub s = case subThread s of
|
|
ServerSub st ->
|
|
readTVarIO st >>= \case
|
|
SubThread t -> liftIO $ deRefWeak t >>= mapM_ killThread
|
|
_ -> pure ()
|
|
ProhibitSub -> pure ()
|
|
|
|
type VerifiedTransmissionOrError s = Either (Transmission BrokerMsg) (VerifiedTransmission s)
|
|
|
|
receive :: forall c s. (Transport c, MsgStoreClass s) => THandleSMP c 'TServer -> s -> Client s -> M s ()
|
|
receive h@THandle {params = THandleParams {thAuth, sessionId}} ms Client {rcvQ, sndQ, rcvActiveAt} = do
|
|
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " receive"
|
|
sa <- asks serverActive
|
|
stats <- asks serverStats
|
|
liftIO $ forever $ do
|
|
ts <- tGetServer h
|
|
unlessM (readTVarIO sa) $ throwIO $ userError "server stopped"
|
|
atomically . (writeTVar rcvActiveAt $!) =<< getSystemTime
|
|
let (es, ts') = partitionEithers $ L.toList ts
|
|
errs = map (second ERR) es
|
|
case ts' of
|
|
(_, _, (_, _, Cmd p cmd)) : rest -> do
|
|
let service = peerClientService =<< thAuth
|
|
(errs', cmds) <- partitionEithers <$> case batchParty p of
|
|
Just Dict | not (null rest) && all (sameParty p) ts'-> do
|
|
updateBatchStats stats cmd -- even if nothing is verified
|
|
let queueId (_, _, (_, qId, _)) = qId
|
|
qs <- getQueueRecs ms p $ map queueId ts'
|
|
zipWithM (\t -> verified stats t . verifyLoadedQueue service thAuth t) ts' qs
|
|
_ -> mapM (\t -> verified stats t =<< verifyTransmission ms service thAuth t) ts'
|
|
write rcvQ cmds
|
|
write sndQ $ errs ++ errs'
|
|
[] -> write sndQ errs
|
|
where
|
|
sameParty :: SParty p -> SignedTransmission Cmd -> Bool
|
|
sameParty p (_, _, (_, _, Cmd p' _)) = isJust $ testEquality p p'
|
|
updateBatchStats :: ServerStats -> Command p -> IO ()
|
|
updateBatchStats stats = \case
|
|
SUB -> incStat $ qSubAllB stats
|
|
DEL -> incStat $ qDeletedAllB stats
|
|
NDEL -> incStat $ ntfDeletedB stats
|
|
NSUB -> incStat $ ntfSubB stats
|
|
_ -> pure ()
|
|
verified :: ServerStats -> SignedTransmission Cmd -> VerificationResult s -> IO (VerifiedTransmissionOrError s)
|
|
verified stats (_, _, t@(corrId, entId, Cmd _ command)) = \case
|
|
VRVerified q -> pure $ Right (q, t)
|
|
VRFailed e -> Left (corrId, entId, ERR e) <$ when (e == AUTH) incAuthStat
|
|
where
|
|
incAuthStat = case command of
|
|
SEND {} -> incStat $ msgSentAuth stats
|
|
SUB -> incStat $ qSubAuth stats
|
|
NSUB -> incStat $ ntfSubAuth stats
|
|
GET -> incStat $ msgGetAuth stats
|
|
_ -> pure ()
|
|
write q = mapM_ (atomically . writeTBQueue q) . L.nonEmpty
|
|
|
|
send :: Transport c => MVar (THandleSMP c 'TServer) -> Client s -> IO ()
|
|
send th c@Client {sndQ, msgQ, clientTHParams = THandleParams {sessionId}} = do
|
|
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " send"
|
|
forever $ atomically (readTBQueue sndQ) >>= sendTransmissions
|
|
where
|
|
sendTransmissions :: NonEmpty (Transmission BrokerMsg) -> IO ()
|
|
sendTransmissions ts
|
|
| L.length ts <= 2 = tSend th c ts
|
|
| otherwise = do
|
|
let (msgs_, ts') = mapAccumR splitMessages [] ts
|
|
-- If the request had batched subscriptions and L.length ts > 2
|
|
-- this will reply OK to all SUBs in the first batched transmission,
|
|
-- to reduce client timeouts.
|
|
tSend th c ts'
|
|
-- After that all messages will be sent in separate transmissions,
|
|
-- without any client response timeouts, and allowing them to interleave
|
|
-- with other requests responses.
|
|
mapM_ (atomically . writeTBQueue msgQ) $ L.nonEmpty msgs_
|
|
where
|
|
splitMessages :: [Transmission BrokerMsg] -> Transmission BrokerMsg -> ([Transmission BrokerMsg], Transmission BrokerMsg)
|
|
splitMessages msgs t@(corrId, entId, cmd) = case cmd of
|
|
-- replace MSG response with OK, accumulating MSG in a separate list.
|
|
MSG {} -> ((CorrId "", entId, cmd) : msgs, (corrId, entId, OK))
|
|
_ -> (msgs, t)
|
|
|
|
sendMsg :: Transport c => MVar (THandleSMP c 'TServer) -> Client s -> IO ()
|
|
sendMsg th c@Client {msgQ, clientTHParams = THandleParams {sessionId}} = do
|
|
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " sendMsg"
|
|
forever $ atomically (readTBQueue msgQ) >>= mapM_ (\t -> tSend th c [t])
|
|
|
|
tSend :: Transport c => MVar (THandleSMP c 'TServer) -> Client s -> NonEmpty (Transmission BrokerMsg) -> IO ()
|
|
tSend th Client {sndActiveAt} ts = do
|
|
withMVar th $ \h@THandle {params} ->
|
|
void . tPut h $ L.map (\t -> Right (Nothing, encodeTransmission params t)) ts
|
|
atomically . (writeTVar sndActiveAt $!) =<< liftIO getSystemTime
|
|
|
|
disconnectTransport :: Transport c => THandle v c 'TServer -> TVar SystemTime -> TVar SystemTime -> ExpirationConfig -> IO Bool -> IO ()
|
|
disconnectTransport THandle {connection, params = THandleParams {sessionId}} rcvActiveAt sndActiveAt expCfg noSubscriptions = do
|
|
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " disconnectTransport"
|
|
loop
|
|
where
|
|
loop = do
|
|
threadDelay' $ checkInterval expCfg * 1000000
|
|
ifM noSubscriptions checkExpired loop
|
|
checkExpired = do
|
|
old <- expireBeforeEpoch expCfg
|
|
ts <- max <$> readTVarIO rcvActiveAt <*> readTVarIO sndActiveAt
|
|
if systemSeconds ts < old then closeConnection connection else loop
|
|
|
|
data VerificationResult s = VRVerified (Maybe (StoreQueue s, QueueRec)) | VRFailed ErrorType
|
|
|
|
-- This function verifies queue command authorization, with the objective to have constant time between the three AUTH error scenarios:
|
|
-- - the queue and party key exist, and the provided authorization has type matching queue key, but it is made with the different key.
|
|
-- - the queue and party key exist, but the provided authorization has incorrect type.
|
|
-- - the queue or party key do not exist.
|
|
-- In all cases, the time of the verification should depend only on the provided authorization type,
|
|
-- a dummy key is used to run verification in the last two cases, and failure is returned irrespective of the result.
|
|
verifyTransmission :: forall s. MsgStoreClass s => s -> Maybe THPeerClientService -> Maybe (THandleAuth 'TServer) -> SignedTransmission Cmd -> IO (VerificationResult s)
|
|
verifyTransmission ms service thAuth t@(_, _, (_, queueId, Cmd p _)) = case queueParty p of
|
|
Just Dict -> verifyLoadedQueue service thAuth t <$> getQueueRec ms p queueId
|
|
Nothing -> pure $ verifyQueueTransmission service thAuth t Nothing
|
|
|
|
verifyLoadedQueue :: Maybe THPeerClientService -> Maybe (THandleAuth 'TServer) -> SignedTransmission Cmd -> Either ErrorType (StoreQueue s, QueueRec) -> VerificationResult s
|
|
verifyLoadedQueue service thAuth t@(tAuth, authorized, (corrId, _, _)) = \case
|
|
Right q -> verifyQueueTransmission service thAuth t (Just q)
|
|
Left AUTH -> dummyVerifyCmd thAuth tAuth authorized corrId `seq` VRFailed AUTH
|
|
Left e -> VRFailed e
|
|
|
|
verifyQueueTransmission :: forall s. Maybe THPeerClientService -> Maybe (THandleAuth 'TServer) -> SignedTransmission Cmd -> Maybe (StoreQueue s, QueueRec) -> VerificationResult s
|
|
verifyQueueTransmission service thAuth (tAuth, authorized, (corrId, _, command@(Cmd p cmd))) q_
|
|
| not checkRole = VRFailed $ CMD PROHIBITED
|
|
| not verifyServiceSig = VRFailed SERVICE
|
|
| otherwise = vc p cmd
|
|
where
|
|
vc :: SParty p -> Command p -> VerificationResult s -- this pattern match works with ghc8.10.7, flat case sees it as non-exhastive.
|
|
vc SCreator (NEW NewQueueReq {rcvAuthKey = k}) = verifiedWith k
|
|
vc SRecipient SUB = verifyQueue $ \q -> verifiedWithKeys $ recipientKeys (snd q)
|
|
vc SRecipient _ = verifyQueue $ \q -> verifiedWithKeys $ recipientKeys (snd q)
|
|
vc SRecipientService SUBS = verifyServiceCmd
|
|
vc SSender (SKEY k) = verifySecure k
|
|
-- SEND will be accepted without authorization before the queue is secured with KEY, SKEY or LSKEY command
|
|
vc SSender SEND {} = verifyQueue $ \q -> if maybe (isNothing tAuth) verify (senderKey $ snd q) then VRVerified q_ else VRFailed AUTH
|
|
vc SIdleClient PING = VRVerified Nothing
|
|
vc SSenderLink (LKEY k) = verifySecure k
|
|
vc SSenderLink LGET = verifyQueue $ \q -> if isContactQueue (snd q) then VRVerified q_ else VRFailed AUTH
|
|
vc SNotifier NSUB = verifyQueue $ \q -> maybe dummyVerify (\n -> verifiedWith $ notifierKey n) (notifier $ snd q)
|
|
vc SNotifierService NSUBS = verifyServiceCmd
|
|
vc SProxiedClient _ = VRVerified Nothing
|
|
vc SProxyService (RFWD _) = VRVerified Nothing
|
|
checkRole = case (service, partyClientRole p) of
|
|
(Just THClientService {serviceRole}, Just role) -> serviceRole == role
|
|
_ -> True
|
|
verify = verifyCmdAuthorization thAuth tAuth authorized' corrId
|
|
verifyServiceCmd :: VerificationResult s
|
|
verifyServiceCmd = case (service, tAuth) of
|
|
(Just THClientService {serviceKey = k}, Just (TASignature (C.ASignature C.SEd25519 s), Nothing))
|
|
| C.verify' k s authorized -> VRVerified Nothing
|
|
_ -> VRFailed SERVICE
|
|
-- this function verify service signature for commands that use it in service sessions
|
|
verifyServiceSig
|
|
| useServiceAuth command = case (service, serviceSig) of
|
|
(Just THClientService {serviceKey = k}, Just s) -> C.verify' k s authorized
|
|
(Nothing, Nothing) -> True
|
|
_ -> False
|
|
| otherwise = isNothing serviceSig
|
|
serviceSig = snd =<< tAuth
|
|
authorized' = case (service, serviceSig) of
|
|
(Just THClientService {serviceCertHash = XV.Fingerprint fp}, Just _) -> fp <> authorized
|
|
_ -> authorized
|
|
dummyVerify :: VerificationResult s
|
|
dummyVerify = dummyVerifyCmd thAuth tAuth authorized corrId `seq` VRFailed AUTH
|
|
-- That a specific command requires queue signature verification is determined by `queueParty`,
|
|
-- it should be coordinated with the case in this function (`verifyQueueTransmission`)
|
|
verifyQueue :: ((StoreQueue s, QueueRec) -> VerificationResult s) -> VerificationResult s
|
|
verifyQueue v = maybe (VRFailed INTERNAL) v q_
|
|
verifySecure :: SndPublicAuthKey -> VerificationResult s
|
|
verifySecure k = verifyQueue $ \q -> if k `allowedKey` snd q then verifiedWith k else dummyVerify
|
|
verifiedWith :: C.APublicAuthKey -> VerificationResult s
|
|
verifiedWith k = if verify k then VRVerified q_ else VRFailed AUTH
|
|
verifiedWithKeys :: NonEmpty C.APublicAuthKey -> VerificationResult s
|
|
verifiedWithKeys ks = if any verify ks then VRVerified q_ else VRFailed AUTH
|
|
allowedKey k = \case
|
|
QueueRec {queueMode = Just QMMessaging, senderKey} -> maybe True (k ==) senderKey
|
|
_ -> False
|
|
|
|
isContactQueue :: QueueRec -> Bool
|
|
isContactQueue QueueRec {queueMode, senderKey} = case queueMode of
|
|
Just QMMessaging -> False
|
|
Just QMContact -> True
|
|
Nothing -> isNothing senderKey -- for backward compatibility with pre-SKEY contact addresses
|
|
|
|
isSecuredMsgQueue :: QueueRec -> Bool
|
|
isSecuredMsgQueue QueueRec {queueMode, senderKey} = case queueMode of
|
|
Just QMContact -> False
|
|
_ -> isJust senderKey
|
|
|
|
-- Random correlation ID is used as a nonce in case crypto_box authenticator is used to authorize transmission
|
|
verifyCmdAuthorization :: Maybe (THandleAuth 'TServer) -> Maybe TAuthorizations -> ByteString -> CorrId -> C.APublicAuthKey -> Bool
|
|
verifyCmdAuthorization thAuth tAuth authorized corrId key = maybe False (verify key) tAuth
|
|
where
|
|
verify :: C.APublicAuthKey -> TAuthorizations -> Bool
|
|
verify (C.APublicAuthKey a k) = \case
|
|
(TASignature (C.ASignature a' s), _) -> case testEquality a a' of
|
|
Just Refl -> C.verify' k s authorized
|
|
_ -> C.verify' (dummySignKey a') s authorized `seq` False
|
|
(TAAuthenticator s, _) -> case a of
|
|
C.SX25519 -> verifyCmdAuth thAuth k s authorized corrId
|
|
_ -> verifyCmdAuth thAuth dummyKeyX25519 s authorized corrId `seq` False
|
|
|
|
verifyCmdAuth :: Maybe (THandleAuth 'TServer) -> C.PublicKeyX25519 -> C.CbAuthenticator -> ByteString -> CorrId -> Bool
|
|
verifyCmdAuth thAuth k authenticator authorized (CorrId corrId) = case thAuth of
|
|
Just THAuthServer {serverPrivKey = pk} -> C.cbVerify k pk (C.cbNonce corrId) authenticator authorized
|
|
Nothing -> False
|
|
|
|
dummyVerifyCmd :: Maybe (THandleAuth 'TServer) -> Maybe TAuthorizations -> ByteString -> CorrId -> Bool
|
|
dummyVerifyCmd thAuth tAuth authorized corrId = maybe False verify tAuth
|
|
where
|
|
verify = \case
|
|
(TASignature (C.ASignature a s), _) -> C.verify' (dummySignKey a) s authorized
|
|
(TAAuthenticator s, _) -> verifyCmdAuth thAuth dummyKeyX25519 s authorized corrId
|
|
|
|
-- These dummy keys are used with `dummyVerify` function to mitigate timing attacks
|
|
-- by having the same time of the response whether a queue exists or nor, for all valid key/signature sizes
|
|
dummySignKey :: C.SignatureAlgorithm a => C.SAlgorithm a -> C.PublicKey a
|
|
dummySignKey = \case
|
|
C.SEd25519 -> dummyKeyEd25519
|
|
C.SEd448 -> dummyKeyEd448
|
|
|
|
dummyKeyEd25519 :: C.PublicKey 'C.Ed25519
|
|
dummyKeyEd25519 = "MCowBQYDK2VwAyEA139Oqs4QgpqbAmB0o7rZf6T19ryl7E65k4AYe0kE3Qs="
|
|
|
|
dummyKeyEd448 :: C.PublicKey 'C.Ed448
|
|
dummyKeyEd448 = "MEMwBQYDK2VxAzoA6ibQc9XpkSLtwrf7PLvp81qW/etiumckVFImCMRdftcG/XopbOSaq9qyLhrgJWKOLyNrQPNVvpMA"
|
|
|
|
dummyKeyX25519 :: C.PublicKey 'C.X25519
|
|
dummyKeyX25519 = "MCowBQYDK2VuAyEA4JGSMYht18H4mas/jHeBwfcM7jLwNYJNOAhi2/g4RXg="
|
|
|
|
forkClient :: MonadUnliftIO m => Client s -> String -> m () -> m ()
|
|
forkClient Client {endThreads, endThreadSeq} label action = do
|
|
tId <- atomically $ stateTVar endThreadSeq $ \next -> (next, next + 1)
|
|
t <- forkIO $ do
|
|
labelMyThread label
|
|
action `finally` atomically (modifyTVar' endThreads $ IM.delete tId)
|
|
mkWeakThreadId t >>= atomically . modifyTVar' endThreads . IM.insert tId
|
|
|
|
client :: forall s. MsgStoreClass s => Server s -> s -> Client s -> M s ()
|
|
client
|
|
-- TODO [certs rcv] rcv subscriptions
|
|
Server {subscribers, ntfSubscribers}
|
|
ms
|
|
clnt@Client {clientId, subscriptions, ntfSubscriptions, ntfServiceSubscribed, serviceSubsCount = _todo', ntfServiceSubsCount, rcvQ, sndQ, clientTHParams = thParams'@THandleParams {sessionId}, procThreads} = do
|
|
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " commands"
|
|
let THandleParams {thVersion} = thParams'
|
|
service = peerClientService =<< thAuth thParams'
|
|
forever $
|
|
atomically (readTBQueue rcvQ)
|
|
>>= mapM (processCommand service thVersion)
|
|
>>= mapM_ reply . L.nonEmpty . catMaybes . L.toList
|
|
where
|
|
reply :: MonadIO m => NonEmpty (Transmission BrokerMsg) -> m ()
|
|
reply = atomically . writeTBQueue sndQ
|
|
processProxiedCmd :: Transmission (Command 'ProxiedClient) -> M s (Maybe (Transmission BrokerMsg))
|
|
processProxiedCmd (corrId, EntityId sessId, command) = (corrId,EntityId sessId,) <$$> case command of
|
|
PRXY srv auth -> ifM allowProxy getRelay (pure $ Just $ ERR $ PROXY BASIC_AUTH)
|
|
where
|
|
allowProxy = do
|
|
ServerConfig {allowSMPProxy, newQueueBasicAuth} <- asks config
|
|
pure $ allowSMPProxy && maybe True ((== auth) . Just) newQueueBasicAuth
|
|
getRelay = do
|
|
ProxyAgent {smpAgent = a} <- asks proxyAgent
|
|
liftIO (getConnectedSMPServerClient a srv) >>= \case
|
|
Just r -> Just <$> proxyServerResponse a r
|
|
Nothing ->
|
|
forkProxiedCmd $
|
|
liftIO (runExceptT (getSMPServerClient'' a srv) `catch` (pure . Left . PCEIOError))
|
|
>>= proxyServerResponse a
|
|
proxyServerResponse :: SMPClientAgent 'Sender -> Either SMPClientError (OwnServer, SMPClient) -> M s BrokerMsg
|
|
proxyServerResponse a smp_ = do
|
|
ServerStats {pRelays, pRelaysOwn} <- asks serverStats
|
|
let inc = mkIncProxyStats pRelays pRelaysOwn
|
|
case smp_ of
|
|
Right (own, smp) -> do
|
|
inc own pRequests
|
|
case proxyResp smp of
|
|
r@PKEY {} -> r <$ inc own pSuccesses
|
|
r -> r <$ inc own pErrorsCompat
|
|
Left e -> do
|
|
let own = isOwnServer a srv
|
|
inc own pRequests
|
|
inc own $ if temporaryClientError e then pErrorsConnect else pErrorsOther
|
|
logWarn $ "Error connecting: " <> decodeLatin1 (strEncode $ host srv) <> " " <> tshow e
|
|
pure . ERR $ smpProxyError e
|
|
where
|
|
proxyResp smp =
|
|
let THandleParams {sessionId = srvSessId, thVersion, thServerVRange, thAuth} = thParams smp
|
|
in case compatibleVRange thServerVRange proxiedSMPRelayVRange of
|
|
-- Cap the destination relay version range to prevent client version fingerprinting.
|
|
-- See comment for proxiedSMPRelayVersion.
|
|
Just (Compatible vr) | thVersion >= sendingProxySMPVersion -> case thAuth of
|
|
Just THAuthClient {peerServerCertKey} -> PKEY srvSessId vr peerServerCertKey
|
|
Nothing -> ERR $ transportErr TENoServerAuth
|
|
_ -> ERR $ transportErr TEVersion
|
|
PFWD fwdV pubKey encBlock -> do
|
|
ProxyAgent {smpAgent = a} <- asks proxyAgent
|
|
ServerStats {pMsgFwds, pMsgFwdsOwn} <- asks serverStats
|
|
let inc = mkIncProxyStats pMsgFwds pMsgFwdsOwn
|
|
liftIO (lookupSMPServerClient a sessId) >>= \case
|
|
Just (own, smp) -> do
|
|
inc own pRequests
|
|
if v >= sendingProxySMPVersion
|
|
then forkProxiedCmd $ do
|
|
liftIO (runExceptT (forwardSMPTransmission smp corrId fwdV pubKey encBlock) `catch` (pure . Left . PCEIOError)) >>= \case
|
|
Right r -> PRES r <$ inc own pSuccesses
|
|
Left e -> ERR (smpProxyError e) <$ case e of
|
|
PCEProtocolError {} -> inc own pSuccesses
|
|
_ -> inc own pErrorsOther
|
|
else Just (ERR $ transportErr TEVersion) <$ inc own pErrorsCompat
|
|
where
|
|
THandleParams {thVersion = v} = thParams smp
|
|
Nothing -> inc False pRequests >> inc False pErrorsConnect $> Just (ERR $ PROXY NO_SESSION)
|
|
where
|
|
forkProxiedCmd :: M s BrokerMsg -> M s (Maybe BrokerMsg)
|
|
forkProxiedCmd cmdAction = do
|
|
bracket_ wait signal . forkClient clnt (B.unpack $ "client $" <> encode sessionId <> " proxy") $ do
|
|
-- commands MUST be processed under a reasonable timeout or the client would halt
|
|
cmdAction >>= \t -> reply [(corrId, EntityId sessId, t)]
|
|
pure Nothing
|
|
where
|
|
wait = do
|
|
ServerConfig {serverClientConcurrency} <- asks config
|
|
atomically $ do
|
|
used <- readTVar procThreads
|
|
when (used >= serverClientConcurrency) retry
|
|
writeTVar procThreads $! used + 1
|
|
signal = atomically $ modifyTVar' procThreads (\t -> t - 1)
|
|
transportErr :: TransportError -> ErrorType
|
|
transportErr = PROXY . BROKER . TRANSPORT
|
|
mkIncProxyStats :: MonadIO m => ProxyStats -> ProxyStats -> OwnServer -> (ProxyStats -> IORef Int) -> m ()
|
|
mkIncProxyStats ps psOwn own sel = do
|
|
incStat $ sel ps
|
|
when own $ incStat $ sel psOwn
|
|
processCommand :: Maybe THPeerClientService -> VersionSMP -> VerifiedTransmission s -> M s (Maybe (Transmission BrokerMsg))
|
|
processCommand service clntVersion (q_, (corrId, entId, cmd)) = case cmd of
|
|
Cmd SProxiedClient command -> processProxiedCmd (corrId, entId, command)
|
|
Cmd SSender command -> Just <$> case command of
|
|
SKEY k -> withQueue $ \q qr -> checkMode QMMessaging qr $ secureQueue_ q k
|
|
SEND flags msgBody -> withQueue_ False $ sendMessage flags msgBody
|
|
Cmd SIdleClient PING -> pure $ Just (corrId, NoEntity, PONG)
|
|
Cmd SProxyService (RFWD encBlock) -> Just . (corrId,NoEntity,) <$> processForwardedCommand encBlock
|
|
Cmd SSenderLink command -> Just <$> case command of
|
|
LKEY k -> withQueue $ \q qr -> checkMode QMMessaging qr $ secureQueue_ q k $>> getQueueLink_ q qr
|
|
LGET -> withQueue $ \q qr -> checkContact qr $ getQueueLink_ q qr
|
|
Cmd SNotifier NSUB -> Just . (corrId,entId,) <$> case q_ of
|
|
Just (q, QueueRec {notifier = Just ntfCreds}) -> subscribeNotifications q ntfCreds
|
|
_ -> pure $ ERR INTERNAL
|
|
Cmd SNotifierService NSUBS -> Just . (corrId,entId,) <$> case service of
|
|
Just s -> subscribeServiceNotifications s
|
|
Nothing -> pure $ ERR INTERNAL
|
|
Cmd SCreator (NEW nqr@NewQueueReq {auth_}) ->
|
|
Just <$> ifM allowNew (createQueue nqr) (pure (corrId, entId, ERR AUTH))
|
|
where
|
|
allowNew = do
|
|
ServerConfig {allowNewQueues, newQueueBasicAuth} <- asks config
|
|
pure $ allowNewQueues && maybe True ((== auth_) . Just) newQueueBasicAuth
|
|
Cmd SRecipient command ->
|
|
Just <$> case command of
|
|
SUB -> withQueue subscribeQueue
|
|
GET -> withQueue getMessage
|
|
ACK msgId -> withQueue $ acknowledgeMsg msgId
|
|
KEY sKey -> withQueue $ \q _ -> either err (corrId,entId,) <$> secureQueue_ q sKey
|
|
RKEY rKeys -> withQueue $ \q qr -> checkMode QMContact qr $ OK <$$ liftIO (updateKeys (queueStore ms) q rKeys)
|
|
LSET lnkId d ->
|
|
withQueue $ \q qr -> case queueData qr of
|
|
_ | isSecuredMsgQueue qr -> pure $ err AUTH
|
|
Just (lnkId', _) | lnkId' /= lnkId -> pure $ err AUTH -- can't change link ID
|
|
_ -> liftIO $ either err (const ok) <$> addQueueLinkData (queueStore ms) q lnkId d
|
|
LDEL ->
|
|
withQueue $ \q qr -> case queueData qr of
|
|
Just _ -> liftIO $ either err (const ok) <$> deleteQueueLinkData (queueStore ms) q
|
|
Nothing -> pure ok
|
|
NKEY nKey dhKey -> withQueue $ \q _ -> addQueueNotifier_ q nKey dhKey
|
|
NDEL -> withQueue $ \q _ -> deleteQueueNotifier_ q
|
|
OFF -> maybe (pure $ err INTERNAL) suspendQueue_ q_
|
|
DEL -> maybe (pure $ err INTERNAL) delQueueAndMsgs q_
|
|
QUE -> withQueue $ \q qr -> (corrId,entId,) <$> getQueueInfo q qr
|
|
Cmd SRecipientService SUBS -> pure $ Just $ err (CMD PROHIBITED) -- "TODO [certs rcv]"
|
|
where
|
|
createQueue :: NewQueueReq -> M s (Transmission BrokerMsg)
|
|
createQueue NewQueueReq {rcvAuthKey, rcvDhKey, subMode, queueReqData}
|
|
| isJust service && subMode == SMOnlyCreate = pure (corrId, entId, ERR $ CMD PROHIBITED)
|
|
| otherwise = time "NEW" $ do
|
|
g <- asks random
|
|
idSize <- asks $ queueIdBytes . config
|
|
updatedAt <- Just <$> liftIO getSystemDate
|
|
(rcvPublicDhKey, privDhKey) <- atomically $ C.generateKeyPair g
|
|
-- TODO [notifications]
|
|
-- ntfKeys_ <- forM ntfCreds $ \(NewNtfCreds notifierKey dhKey) -> do
|
|
-- (ntfPubDhKey, ntfPrivDhKey) <- atomically $ C.generateKeyPair g
|
|
-- pure (notifierKey, C.dh' dhKey ntfPrivDhKey, ntfPubDhKey)
|
|
let randId = EntityId <$> atomically (C.randomBytes idSize g)
|
|
-- TODO [notifications] the remaining 24 bytes are reserver for notifier ID
|
|
sndId' = B.take 24 $ C.sha3_384 (bs corrId)
|
|
tryCreate 0 = pure $ ERR INTERNAL
|
|
tryCreate n = do
|
|
(sndId, clntIds, queueData) <- case queueReqData of
|
|
Just (QRMessaging (Just (sId, d))) -> (\linkId -> (sId, True, Just (linkId, d))) <$> randId
|
|
Just (QRContact (Just (linkId, (sId, d)))) -> pure (sId, True, Just (linkId, d))
|
|
_ -> (,False,Nothing) <$> randId
|
|
-- The condition that client-provided sender ID must match hash of correlation ID
|
|
-- prevents "ID oracle" attack, when creating queue with supplied ID can be used to check
|
|
-- if queue with this ID still exists.
|
|
if clntIds && unEntityId sndId /= sndId'
|
|
then pure $ ERR $ CMD PROHIBITED
|
|
else do
|
|
rcvId <- randId
|
|
-- TODO [notifications]
|
|
-- ntf <- forM ntfKeys_ $ \(notifierKey, rcvNtfDhSecret, rcvPubDhKey) -> do
|
|
-- notifierId <- randId
|
|
-- pure (NtfCreds {notifierId, notifierKey, rcvNtfDhSecret}, ServerNtfCreds notifierId rcvPubDhKey)
|
|
let queueMode = queueReqMode <$> queueReqData
|
|
rcvServiceId = (\THClientService {serviceId} -> serviceId) <$> service
|
|
qr =
|
|
QueueRec
|
|
{ senderId = sndId,
|
|
recipientKeys = [rcvAuthKey],
|
|
rcvDhSecret = C.dh' rcvDhKey privDhKey,
|
|
senderKey = Nothing,
|
|
queueMode,
|
|
queueData,
|
|
-- TODO [notifications]
|
|
notifier = Nothing, -- fst <$> ntf,
|
|
status = EntityActive,
|
|
updatedAt,
|
|
rcvServiceId
|
|
}
|
|
liftIO (addQueue ms rcvId qr) >>= \case
|
|
Left DUPLICATE_ -- TODO [short links] possibly, we somehow need to understand which IDs caused collision to retry if it's not client-supplied?
|
|
| clntIds -> pure $ ERR AUTH -- no retry on collision if sender ID is client-supplied
|
|
| otherwise -> tryCreate (n - 1)
|
|
Left e -> pure $ ERR e
|
|
Right _q -> do
|
|
stats <- asks serverStats
|
|
incStat $ qCreated stats
|
|
incStat $ qCount stats
|
|
-- TODO [notifications]
|
|
-- when (isJust ntf) $ incStat $ ntfCreated stats
|
|
case subMode of
|
|
SMOnlyCreate -> pure ()
|
|
SMSubscribe -> void $ subscribeNewQueue rcvId qr -- no need to check if message is available, it's a new queue
|
|
pure $ IDS QIK {rcvId, sndId, rcvPublicDhKey, queueMode, linkId = fst <$> queueData, serviceId = rcvServiceId} -- , serverNtfCreds = snd <$> ntf
|
|
(corrId,entId,) <$> tryCreate (3 :: Int)
|
|
|
|
-- this check allows to support contact queues created prior to SKEY,
|
|
-- using `queueMode == Just QMContact` would prevent it, as they have queueMode `Nothing`.
|
|
checkContact :: QueueRec -> M s (Either ErrorType BrokerMsg) -> M s (Transmission BrokerMsg)
|
|
checkContact qr a =
|
|
either err (corrId,entId,)
|
|
<$> if isContactQueue qr then a else pure $ Left AUTH
|
|
|
|
checkMode :: QueueMode -> QueueRec -> M s (Either ErrorType BrokerMsg) -> M s (Transmission BrokerMsg)
|
|
checkMode qm QueueRec {queueMode} a =
|
|
either err (corrId,entId,)
|
|
<$> if queueMode == Just qm then a else pure $ Left AUTH
|
|
|
|
secureQueue_ :: StoreQueue s -> SndPublicAuthKey -> M s (Either ErrorType BrokerMsg)
|
|
secureQueue_ q sKey = do
|
|
liftIO (secureQueue (queueStore ms) q sKey)
|
|
$>> (asks serverStats >>= incStat . qSecured) $> Right OK
|
|
|
|
getQueueLink_ :: StoreQueue s -> QueueRec -> M s (Either ErrorType BrokerMsg)
|
|
getQueueLink_ q qr = liftIO $ LNK (senderId qr) <$$> getQueueLinkData (queueStore ms) q entId
|
|
|
|
addQueueNotifier_ :: StoreQueue s -> NtfPublicAuthKey -> RcvNtfPublicDhKey -> M s (Transmission BrokerMsg)
|
|
addQueueNotifier_ q notifierKey dhKey = time "NKEY" $ do
|
|
(rcvPublicDhKey, privDhKey) <- atomically . C.generateKeyPair =<< asks random
|
|
let rcvNtfDhSecret = C.dh' dhKey privDhKey
|
|
(corrId,entId,) <$> addNotifierRetry 3 rcvPublicDhKey rcvNtfDhSecret
|
|
where
|
|
addNotifierRetry :: Int -> RcvNtfPublicDhKey -> RcvNtfDhSecret -> M s BrokerMsg
|
|
addNotifierRetry 0 _ _ = pure $ ERR INTERNAL
|
|
addNotifierRetry n rcvPublicDhKey rcvNtfDhSecret = do
|
|
notifierId <- randomId =<< asks (queueIdBytes . config)
|
|
let ntfCreds = NtfCreds {notifierId, notifierKey, rcvNtfDhSecret, ntfServiceId = Nothing}
|
|
liftIO (addQueueNotifier (queueStore ms) q ntfCreds) >>= \case
|
|
Left DUPLICATE_ -> addNotifierRetry (n - 1) rcvPublicDhKey rcvNtfDhSecret
|
|
Left e -> pure $ ERR e
|
|
Right nc_ -> do
|
|
incStat . ntfCreated =<< asks serverStats
|
|
forM_ nc_ $ \NtfCreds {notifierId = nId, ntfServiceId} ->
|
|
atomically $ writeTQueue (subQ ntfSubscribers) (CSDeleted nId ntfServiceId, clientId)
|
|
pure $ NID notifierId rcvPublicDhKey
|
|
|
|
deleteQueueNotifier_ :: StoreQueue s -> M s (Transmission BrokerMsg)
|
|
deleteQueueNotifier_ q =
|
|
liftIO (deleteQueueNotifier (queueStore ms) q) >>= \case
|
|
Right (Just NtfCreds {notifierId = nId, ntfServiceId}) -> do
|
|
-- Possibly, the same should be done if the queue is suspended, but currently we do not use it
|
|
stats <- asks serverStats
|
|
deleted <- asks ntfStore >>= liftIO . (`deleteNtfs` nId)
|
|
when (deleted > 0) $ liftIO $ atomicModifyIORef'_ (ntfCount stats) (subtract deleted)
|
|
atomically $ writeTQueue (subQ ntfSubscribers) (CSDeleted nId ntfServiceId, clientId)
|
|
incStat $ ntfDeleted stats
|
|
pure ok
|
|
Right Nothing -> pure ok
|
|
Left e -> pure $ err e
|
|
|
|
suspendQueue_ :: (StoreQueue s, QueueRec) -> M s (Transmission BrokerMsg)
|
|
suspendQueue_ (q, _) = liftIO $ either err (const ok) <$> suspendQueue (queueStore ms) q
|
|
|
|
-- TODO [certs rcv] if serviceId is passed, associate with the service and respond with SOK
|
|
subscribeQueue :: StoreQueue s -> QueueRec -> M s (Transmission BrokerMsg)
|
|
subscribeQueue q qr =
|
|
liftIO (TM.lookupIO rId subscriptions) >>= \case
|
|
Nothing -> subscribeNewQueue rId qr >>= deliver True
|
|
Just s@Sub {subThread} -> do
|
|
stats <- asks serverStats
|
|
case subThread of
|
|
ProhibitSub -> do
|
|
-- cannot use SUB in the same connection where GET was used
|
|
incStat $ qSubProhibited stats
|
|
pure (corrId, rId, ERR $ CMD PROHIBITED)
|
|
_ -> do
|
|
incStat $ qSubDuplicate stats
|
|
atomically (tryTakeTMVar $ delivered s) >> deliver False s
|
|
where
|
|
rId = recipientId q
|
|
deliver :: Bool -> Sub -> M s (Transmission BrokerMsg)
|
|
deliver inc sub = do
|
|
stats <- asks serverStats
|
|
fmap (either (\e -> (corrId, rId, ERR e)) id) $ liftIO $ runExceptT $ do
|
|
msg_ <- tryPeekMsg ms q
|
|
liftIO $ when (inc && isJust msg_) $ incStat (qSub stats)
|
|
liftIO $ deliverMessage "SUB" qr rId sub msg_
|
|
|
|
subscribeNewQueue :: RecipientId -> QueueRec -> M s Sub
|
|
subscribeNewQueue rId QueueRec {rcvServiceId} = time "SUB newSub" . atomically $ do
|
|
writeTQueue (subQ subscribers) (CSClient rId rcvServiceId Nothing, clientId)
|
|
sub <- newSubscription NoSub
|
|
TM.insert rId sub subscriptions
|
|
pure sub
|
|
|
|
-- clients that use GET are not added to server subscribers
|
|
getMessage :: StoreQueue s -> QueueRec -> M s (Transmission BrokerMsg)
|
|
getMessage q qr = time "GET" $ do
|
|
atomically (TM.lookup entId subscriptions) >>= \case
|
|
Nothing ->
|
|
atomically newSub >>= (`getMessage_` Nothing)
|
|
Just s@Sub {subThread} ->
|
|
case subThread of
|
|
ProhibitSub ->
|
|
atomically (tryTakeTMVar $ delivered s)
|
|
>>= getMessage_ s
|
|
-- cannot use GET in the same connection where there is an active subscription
|
|
_ -> do
|
|
stats <- asks serverStats
|
|
incStat $ msgGetProhibited stats
|
|
pure $ err $ CMD PROHIBITED
|
|
where
|
|
newSub :: STM Sub
|
|
newSub = do
|
|
s <- newProhibitedSub
|
|
TM.insert entId s subscriptions
|
|
-- Here we don't account for this client as subscribed in the server
|
|
-- and don't notify other subscribed clients.
|
|
-- This is tracked as "subscription" in the client to prevent these
|
|
-- clients from being able to subscribe.
|
|
pure s
|
|
getMessage_ :: Sub -> Maybe (MsgId, RoundedSystemTime) -> M s (Transmission BrokerMsg)
|
|
getMessage_ s delivered_ = do
|
|
stats <- asks serverStats
|
|
fmap (either err id) $ liftIO $ runExceptT $
|
|
tryPeekMsg ms q >>= \case
|
|
Just msg -> do
|
|
let encMsg = encryptMsg qr msg
|
|
incStat $ (if isJust delivered_ then msgGetDuplicate else msgGet) stats
|
|
ts <- liftIO getSystemSeconds
|
|
atomically $ setDelivered s msg ts $> (corrId, entId, MSG encMsg)
|
|
Nothing -> incStat (msgGetNoMsg stats) $> ok
|
|
|
|
withQueue :: (StoreQueue s -> QueueRec -> M s (Transmission BrokerMsg)) -> M s (Transmission BrokerMsg)
|
|
withQueue = withQueue_ True
|
|
|
|
-- SEND passes queueNotBlocked False here to update time, but it fails anyway on blocked queues (see code for SEND).
|
|
withQueue_ :: Bool -> (StoreQueue s -> QueueRec -> M s (Transmission BrokerMsg)) -> M s (Transmission BrokerMsg)
|
|
withQueue_ queueNotBlocked action = case q_ of
|
|
Nothing -> pure $ err INTERNAL
|
|
Just (q, qr@QueueRec {status, updatedAt}) -> case status of
|
|
EntityBlocked info | queueNotBlocked -> pure $ err $ BLOCKED info
|
|
_ -> do
|
|
t <- liftIO getSystemDate
|
|
if updatedAt == Just t
|
|
then action q qr
|
|
else liftIO (updateQueueTime (queueStore ms) q t) >>= either (pure . err) (action q)
|
|
|
|
subscribeNotifications :: StoreQueue s -> NtfCreds -> M s BrokerMsg
|
|
subscribeNotifications q NtfCreds {ntfServiceId} = do
|
|
stats <- asks serverStats
|
|
let incNtfSrvStat sel = incStat $ sel $ ntfServices stats
|
|
case service of
|
|
Just THClientService {serviceId}
|
|
| ntfServiceId == Just serviceId -> do
|
|
-- duplicate queue-service association - can only happen in case of response error/timeout
|
|
hasSub <- atomically $ ifM hasServiceSub (pure True) (False <$ newServiceQueueSub)
|
|
unless hasSub $ do
|
|
incNtfSrvStat srvSubCount
|
|
incNtfSrvStat srvSubQueues
|
|
incNtfSrvStat srvAssocDuplicate
|
|
pure $ SOK $ Just serviceId
|
|
| otherwise ->
|
|
-- new or updated queue-service association
|
|
liftIO (setQueueService (queueStore ms) q SNotifierService (Just serviceId)) >>= \case
|
|
Left e -> pure $ ERR e
|
|
Right () -> do
|
|
hasSub <- atomically $ (<$ newServiceQueueSub) =<< hasServiceSub
|
|
unless hasSub $ incNtfSrvStat srvSubCount
|
|
incNtfSrvStat srvSubQueues
|
|
incNtfSrvStat $ maybe srvAssocNew (const srvAssocUpdated) ntfServiceId
|
|
pure $ SOK $ Just serviceId
|
|
where
|
|
hasServiceSub = (0 /=) <$> readTVar ntfServiceSubsCount
|
|
-- This function is used when queue is associated with the service.
|
|
newServiceQueueSub = do
|
|
writeTQueue (subQ ntfSubscribers) (CSClient entId ntfServiceId (Just serviceId), clientId)
|
|
modifyTVar' ntfServiceSubsCount (+ 1) -- service count
|
|
modifyTVar' (totalServiceSubs ntfSubscribers) (+ 1) -- server count for all services
|
|
Nothing -> case ntfServiceId of
|
|
Just _ ->
|
|
liftIO (setQueueService (queueStore ms) q SNotifierService Nothing) >>= \case
|
|
Left e -> pure $ ERR e
|
|
Right () -> do
|
|
-- hasSubscription should never be True in this branch, because queue was associated with service.
|
|
-- So unless storage and session states diverge, this check is redundant.
|
|
hasSub <- atomically $ hasSubscription >>= newSub
|
|
incNtfSrvStat srvAssocRemoved
|
|
sok hasSub
|
|
Nothing -> do
|
|
hasSub <- atomically $ ifM hasSubscription (pure True) (newSub False)
|
|
sok hasSub
|
|
where
|
|
hasSubscription = TM.member entId ntfSubscriptions
|
|
newSub hasSub = do
|
|
writeTQueue (subQ ntfSubscribers) (CSClient entId ntfServiceId Nothing, clientId)
|
|
unless (hasSub) $ TM.insert entId () ntfSubscriptions
|
|
pure hasSub
|
|
sok hasSub = do
|
|
incStat $ if hasSub then ntfSubDuplicate stats else ntfSub stats
|
|
pure $ SOK Nothing
|
|
|
|
subscribeServiceNotifications :: THPeerClientService -> M s BrokerMsg
|
|
subscribeServiceNotifications THClientService {serviceId} = do
|
|
subscribed <- readTVarIO ntfServiceSubscribed
|
|
if subscribed
|
|
then SOKS <$> readTVarIO ntfServiceSubsCount
|
|
else
|
|
liftIO (getNtfServiceQueueCount @(StoreQueue s) (queueStore ms) serviceId) >>= \case
|
|
Left e -> pure $ ERR e
|
|
Right !count' -> do
|
|
atomically $ do
|
|
writeTVar ntfServiceSubscribed True
|
|
count <- swapTVar ntfServiceSubsCount count'
|
|
modifyTVar' (totalServiceSubs ntfSubscribers) (+ (count' - count)) -- server count for all services
|
|
atomically $ writeTQueue (subQ ntfSubscribers) (CSService serviceId, clientId)
|
|
pure $ SOKS count'
|
|
|
|
acknowledgeMsg :: MsgId -> StoreQueue s -> QueueRec -> M s (Transmission BrokerMsg)
|
|
acknowledgeMsg msgId q qr = time "ACK" $ do
|
|
liftIO (TM.lookupIO entId subscriptions) >>= \case
|
|
Nothing -> pure $ err NO_MSG
|
|
Just sub ->
|
|
atomically (getDelivered sub) >>= \case
|
|
Just (st, ts) -> do
|
|
stats <- asks serverStats
|
|
fmap (either err id) $ liftIO $ runExceptT $ do
|
|
case st of
|
|
ProhibitSub -> do
|
|
deletedMsg_ <- tryDelMsg ms q msgId
|
|
liftIO $ mapM_ (updateStats stats True ts) deletedMsg_
|
|
pure ok
|
|
_ -> do
|
|
(deletedMsg_, msg_) <- tryDelPeekMsg ms q msgId
|
|
liftIO $ mapM_ (updateStats stats False ts) deletedMsg_
|
|
liftIO $ deliverMessage "ACK" qr entId sub msg_
|
|
_ -> pure $ err NO_MSG
|
|
where
|
|
getDelivered :: Sub -> STM (Maybe (ServerSub, RoundedSystemTime))
|
|
getDelivered Sub {delivered, subThread} = do
|
|
tryTakeTMVar delivered $>>= \v@(msgId', ts) ->
|
|
if msgId == msgId' || B.null msgId
|
|
then pure $ Just (subThread, ts)
|
|
else putTMVar delivered v $> Nothing
|
|
updateStats :: ServerStats -> Bool -> RoundedSystemTime -> Message -> IO ()
|
|
updateStats stats isGet deliveryTime = \case
|
|
MessageQuota {} -> pure ()
|
|
Message {msgFlags} -> do
|
|
incStat $ msgRecv stats
|
|
if isGet
|
|
then incStat $ msgRecvGet stats
|
|
else pure () -- TODO skip notification delivery for delivered message
|
|
-- skipping delivery fails tests, it should be counted in msgNtfSkipped
|
|
-- forM_ (notifierId <$> notifier qr) $ \nId -> do
|
|
-- ns <- asks ntfStore
|
|
-- atomically $ TM.lookup nId ns >>=
|
|
-- mapM_ (\MsgNtf {ntfMsgId} -> when (msgId == msgId') $ TM.delete nId ns)
|
|
atomicModifyIORef'_ (msgCount stats) (subtract 1)
|
|
updatePeriodStats (activeQueues stats) entId
|
|
when (notification msgFlags) $ do
|
|
incStat $ msgRecvNtf stats
|
|
updatePeriodStats (activeQueuesNtf stats) entId
|
|
currTime <- getSystemSeconds
|
|
atomicModifyIORef'_ (msgRecvAckTimes stats) $ updateTimeBuckets deliveryTime currTime
|
|
|
|
sendMessage :: MsgFlags -> MsgBody -> StoreQueue s -> QueueRec -> M s (Transmission BrokerMsg)
|
|
sendMessage msgFlags msgBody q qr
|
|
| B.length msgBody > maxMessageLength clntVersion = do
|
|
stats <- asks serverStats
|
|
incStat $ msgSentLarge stats
|
|
pure $ err LARGE_MSG
|
|
| otherwise = do
|
|
stats <- asks serverStats
|
|
case status qr of
|
|
EntityOff -> do
|
|
incStat $ msgSentAuth stats
|
|
pure $ err AUTH
|
|
EntityBlocked info -> do
|
|
incStat $ msgSentBlock stats
|
|
pure $ err $ BLOCKED info
|
|
EntityActive ->
|
|
case C.maxLenBS msgBody of
|
|
Left _ -> pure $ err LARGE_MSG
|
|
Right body -> do
|
|
when (isJust (queueData qr) && isSecuredMsgQueue qr) $ void $ liftIO $
|
|
deleteQueueLinkData (queueStore ms) q
|
|
ServerConfig {messageExpiration, msgIdBytes} <- asks config
|
|
msgId <- randomId' msgIdBytes
|
|
msg_ <- liftIO $ time "SEND" $ runExceptT $ do
|
|
expireMessages messageExpiration stats
|
|
msg <- liftIO $ mkMessage msgId body
|
|
writeMsg ms q True msg
|
|
case msg_ of
|
|
Left e -> pure $ err e
|
|
Right Nothing -> do
|
|
incStat $ msgSentQuota stats
|
|
pure $ err QUOTA
|
|
Right (Just (msg, wasEmpty)) -> time "SEND ok" $ do
|
|
when wasEmpty $ liftIO $ tryDeliverMessage msg
|
|
when (notification msgFlags) $ do
|
|
mapM_ (`enqueueNotification` msg) (notifier qr)
|
|
incStat $ msgSentNtf stats
|
|
liftIO $ updatePeriodStats (activeQueuesNtf stats) (recipientId q)
|
|
incStat $ msgSent stats
|
|
incStat $ msgCount stats
|
|
liftIO $ updatePeriodStats (activeQueues stats) (recipientId q)
|
|
pure ok
|
|
where
|
|
mkMessage :: MsgId -> C.MaxLenBS MaxMessageLen -> IO Message
|
|
mkMessage msgId body = do
|
|
msgTs <- getSystemTime
|
|
pure $! Message msgId msgTs msgFlags body
|
|
|
|
expireMessages :: Maybe ExpirationConfig -> ServerStats -> ExceptT ErrorType IO ()
|
|
expireMessages msgExp stats = do
|
|
deleted <- maybe (pure 0) (deleteExpiredMsgs ms q <=< liftIO . expireBeforeEpoch) msgExp
|
|
liftIO $ when (deleted > 0) $ atomicModifyIORef'_ (msgExpired stats) (+ deleted)
|
|
|
|
-- The condition for delivery of the message is:
|
|
-- - the queue was empty when the message was sent,
|
|
-- - there is subscribed recipient,
|
|
-- - no message was "delivered" that was not acknowledged.
|
|
-- If the send queue of the subscribed client is not full the message is put there in the same transaction.
|
|
-- If the queue is not full, then the thread is created where these checks are made:
|
|
-- - it is the same subscribed client (in case it was reconnected it would receive message via SUB command)
|
|
-- - nothing was delivered to this subscription (to avoid race conditions with the recipient).
|
|
tryDeliverMessage :: Message -> IO ()
|
|
tryDeliverMessage msg =
|
|
-- the subscribed client var is read outside of STM to avoid transaction cost
|
|
-- in case no client is subscribed.
|
|
getSubscribedClient rId (queueSubscribers subscribers)
|
|
$>>= deliverToSub
|
|
>>= mapM_ forkDeliver
|
|
where
|
|
rId = recipientId q
|
|
deliverToSub rcv = do
|
|
ts <- getSystemSeconds
|
|
atomically $
|
|
-- reading client TVar in the same transaction,
|
|
-- so that if subscription ends, it re-evalutates
|
|
-- and delivery is cancelled -
|
|
-- the new client will receive message in response to SUB.
|
|
readTVar rcv
|
|
$>>= \rc@Client {subscriptions = subs, sndQ = sndQ'} -> TM.lookup rId subs
|
|
$>>= \s@Sub {subThread, delivered} -> case subThread of
|
|
ProhibitSub -> pure Nothing
|
|
ServerSub st -> readTVar st >>= \case
|
|
NoSub ->
|
|
tryTakeTMVar delivered >>= \case
|
|
Just _ -> pure Nothing -- if a message was already delivered, should not deliver more
|
|
Nothing ->
|
|
ifM
|
|
(isFullTBQueue sndQ')
|
|
(writeTVar st SubPending $> Just (rc, s, st))
|
|
(deliver sndQ' s ts $> Nothing)
|
|
_ -> pure Nothing
|
|
deliver sndQ' s ts = do
|
|
let encMsg = encryptMsg qr msg
|
|
writeTBQueue sndQ' [(CorrId "", rId, MSG encMsg)]
|
|
void $ setDelivered s msg ts
|
|
forkDeliver (rc@Client {sndQ = sndQ'}, s@Sub {delivered}, st) = do
|
|
t <- mkWeakThreadId =<< forkIO deliverThread
|
|
atomically $ modifyTVar' st $ \case
|
|
-- this case is needed because deliverThread can exit before it
|
|
SubPending -> SubThread t
|
|
st' -> st'
|
|
where
|
|
deliverThread = do
|
|
labelMyThread $ B.unpack ("client $" <> encode sessionId) <> " deliver/SEND"
|
|
-- lookup can be outside of STM transaction,
|
|
-- as long as the check that it is the same client is inside.
|
|
getSubscribedClient rId (queueSubscribers subscribers) >>= mapM_ deliverIfSame
|
|
deliverIfSame rcv = time "deliver" $ do
|
|
ts <- getSystemSeconds
|
|
atomically $ whenM (sameClient rc rcv) $
|
|
tryTakeTMVar delivered >>= \case
|
|
Just _ -> pure () -- if a message was already delivered, should not deliver more
|
|
Nothing -> do
|
|
-- a separate thread is needed because it blocks when client sndQ is full.
|
|
deliver sndQ' s ts
|
|
writeTVar st NoSub
|
|
|
|
enqueueNotification :: NtfCreds -> Message -> M s ()
|
|
enqueueNotification _ MessageQuota {} = pure ()
|
|
enqueueNotification NtfCreds {notifierId = nId, rcvNtfDhSecret} Message {msgId, msgTs} = do
|
|
ns <- asks ntfStore
|
|
ntf <- mkMessageNotification msgId msgTs rcvNtfDhSecret
|
|
liftIO $ storeNtf ns nId ntf
|
|
incStat . ntfCount =<< asks serverStats
|
|
|
|
mkMessageNotification :: ByteString -> SystemTime -> RcvNtfDhSecret -> M s MsgNtf
|
|
mkMessageNotification msgId msgTs rcvNtfDhSecret = do
|
|
ntfNonce <- atomically . C.randomCbNonce =<< asks random
|
|
let msgMeta = NMsgMeta {msgId, msgTs}
|
|
encNMsgMeta = C.cbEncrypt rcvNtfDhSecret ntfNonce (smpEncode msgMeta) 128
|
|
pure $ MsgNtf {ntfMsgId = msgId, ntfTs = msgTs, ntfNonce, ntfEncMeta = fromRight "" encNMsgMeta}
|
|
|
|
processForwardedCommand :: EncFwdTransmission -> M s BrokerMsg
|
|
processForwardedCommand (EncFwdTransmission s) = fmap (either ERR RRES) . runExceptT $ do
|
|
THAuthServer {serverPrivKey, sessSecret'} <- maybe (throwE $ transportErr TENoServerAuth) pure (thAuth thParams')
|
|
sessSecret <- maybe (throwE $ transportErr TENoServerAuth) pure sessSecret'
|
|
let proxyNonce = C.cbNonce $ bs corrId
|
|
s' <- liftEitherWith (const CRYPTO) $ C.cbDecryptNoPad sessSecret proxyNonce s
|
|
FwdTransmission {fwdCorrId, fwdVersion, fwdKey, fwdTransmission = EncTransmission et} <- liftEitherWith (const $ CMD SYNTAX) $ smpDecode s'
|
|
let clientSecret = C.dh' fwdKey serverPrivKey
|
|
clientNonce = C.cbNonce $ bs fwdCorrId
|
|
b <- liftEitherWith (const CRYPTO) $ C.cbDecrypt clientSecret clientNonce et
|
|
let clntTHParams = smpTHParamsSetVersion fwdVersion thParams'
|
|
-- only allowing single forwarded transactions
|
|
t' <- case tParse clntTHParams b of
|
|
t :| [] -> pure $ tDecodeServer clntTHParams t
|
|
_ -> throwE BLOCK
|
|
let clntThAuth = Just $ THAuthServer {serverPrivKey, peerClientService = Nothing, sessSecret' = Just clientSecret}
|
|
-- process forwarded command
|
|
r <-
|
|
lift (rejectOrVerify clntThAuth t') >>= \case
|
|
Left r -> pure r
|
|
-- rejectOrVerify filters allowed commands, no need to repeat it here.
|
|
-- INTERNAL is used because processCommand never returns Nothing for sender commands (could be extracted for better types).
|
|
Right t''@(_, (corrId', entId', _)) -> fromMaybe (corrId', entId', ERR INTERNAL) <$> lift (processCommand Nothing fwdVersion t'')
|
|
-- encode response
|
|
r' <- case batchTransmissions clntTHParams [Right (Nothing, encodeTransmission clntTHParams r)] of
|
|
[] -> throwE INTERNAL -- at least 1 item is guaranteed from NonEmpty/Right
|
|
TBError _ _ : _ -> throwE BLOCK
|
|
TBTransmission b' _ : _ -> pure b'
|
|
TBTransmissions b' _ _ : _ -> pure b'
|
|
-- encrypt to client
|
|
r2 <- liftEitherWith (const BLOCK) $ EncResponse <$> C.cbEncrypt clientSecret (C.reverseNonce clientNonce) r' paddedProxiedTLength
|
|
-- encrypt to proxy
|
|
let fr = FwdResponse {fwdCorrId, fwdResponse = r2}
|
|
r3 = EncFwdResponse $ C.cbEncryptNoPad sessSecret (C.reverseNonce proxyNonce) (smpEncode fr)
|
|
stats <- asks serverStats
|
|
incStat $ pMsgFwdsRecv stats
|
|
pure r3
|
|
where
|
|
rejectOrVerify :: Maybe (THandleAuth 'TServer) -> SignedTransmissionOrError ErrorType Cmd -> M s (VerifiedTransmissionOrError s)
|
|
rejectOrVerify clntThAuth = \case
|
|
Left (corrId', entId', e) -> pure $ Left (corrId', entId', ERR e)
|
|
Right t'@(_, _, t''@(corrId', entId', cmd'))
|
|
| allowed -> liftIO $ verified <$> verifyTransmission ms Nothing clntThAuth t'
|
|
| otherwise -> pure $ Left (corrId', entId', ERR $ CMD PROHIBITED)
|
|
where
|
|
allowed = case cmd' of
|
|
Cmd SSender SEND {} -> True
|
|
Cmd SSender (SKEY _) -> True
|
|
Cmd SSenderLink (LKEY _) -> True
|
|
Cmd SSenderLink LGET -> True
|
|
_ -> False
|
|
verified = \case
|
|
VRVerified q -> Right (q, t'')
|
|
VRFailed e -> Left (corrId', entId', ERR e)
|
|
|
|
deliverMessage :: T.Text -> QueueRec -> RecipientId -> Sub -> Maybe Message -> IO (Transmission BrokerMsg)
|
|
deliverMessage name qr rId s@Sub {subThread} msg_ = time (name <> " deliver") $
|
|
case subThread of
|
|
ProhibitSub -> pure resp
|
|
_ -> case msg_ of
|
|
Just msg -> do
|
|
ts <- getSystemSeconds
|
|
let encMsg = encryptMsg qr msg
|
|
atomically (setDelivered s msg ts) $> (corrId, rId, MSG encMsg)
|
|
_ -> pure resp
|
|
where
|
|
resp = (corrId, rId, OK)
|
|
|
|
time :: MonadIO m => T.Text -> m a -> m a
|
|
time name = timed name entId
|
|
|
|
encryptMsg :: QueueRec -> Message -> RcvMessage
|
|
encryptMsg qr msg = encrypt . encodeRcvMsgBody $ case msg of
|
|
Message {msgFlags, msgBody} -> RcvMsgBody {msgTs = msgTs', msgFlags, msgBody}
|
|
MessageQuota {} -> RcvMsgQuota msgTs'
|
|
where
|
|
encrypt :: KnownNat i => C.MaxLenBS i -> RcvMessage
|
|
encrypt body = RcvMessage msgId' . EncRcvMsgBody $ C.cbEncryptMaxLenBS (rcvDhSecret qr) (C.cbNonce msgId') body
|
|
msgId' = messageId msg
|
|
msgTs' = messageTs msg
|
|
|
|
setDelivered :: Sub -> Message -> RoundedSystemTime -> STM Bool
|
|
setDelivered Sub {delivered} msg !ts = do
|
|
let !msgId = messageId msg
|
|
tryPutTMVar delivered (msgId, ts)
|
|
|
|
delQueueAndMsgs :: (StoreQueue s, QueueRec) -> M s (Transmission BrokerMsg)
|
|
delQueueAndMsgs (q, QueueRec {rcvServiceId}) = do
|
|
liftIO (deleteQueue ms q) >>= \case
|
|
Right qr -> do
|
|
-- Possibly, the same should be done if the queue is suspended, but currently we do not use it
|
|
atomically $ do
|
|
writeTQueue (subQ subscribers) (CSDeleted entId rcvServiceId, clientId)
|
|
-- queue is usually deleted by the same client that is currently subscribed,
|
|
-- we delete subscription here, so the client with no subscriptions can be disconnected.
|
|
TM.delete entId subscriptions
|
|
forM_ (notifier qr) $ \NtfCreds {notifierId = nId, ntfServiceId} -> do
|
|
-- queue is deleted by a different client from the one subscribed to notifications,
|
|
-- so we don't need to remove subscription from the current client.
|
|
stats <- asks serverStats
|
|
deleted <- asks ntfStore >>= liftIO . (`deleteNtfs` nId)
|
|
when (deleted > 0) $ liftIO $ atomicModifyIORef'_ (ntfCount stats) (subtract deleted)
|
|
atomically $ writeTQueue (subQ ntfSubscribers) (CSDeleted nId ntfServiceId, clientId)
|
|
updateDeletedStats qr
|
|
pure ok
|
|
Left e -> pure $ err e
|
|
|
|
getQueueInfo :: StoreQueue s -> QueueRec -> M s BrokerMsg
|
|
getQueueInfo q QueueRec {senderKey, notifier} = do
|
|
fmap (either ERR INFO) $ liftIO $ runExceptT $ do
|
|
qiSub <- liftIO $ TM.lookupIO entId subscriptions >>= mapM mkQSub
|
|
qiSize <- getQueueSize ms q
|
|
qiMsg <- toMsgInfo <$$> tryPeekMsg ms q
|
|
let info = QueueInfo {qiSnd = isJust senderKey, qiNtf = isJust notifier, qiSub, qiSize, qiMsg}
|
|
pure info
|
|
where
|
|
mkQSub Sub {subThread, delivered} = do
|
|
qSubThread <- case subThread of
|
|
ServerSub t -> do
|
|
st <- readTVarIO t
|
|
pure $ case st of
|
|
NoSub -> QNoSub
|
|
SubPending -> QSubPending
|
|
SubThread _ -> QSubThread
|
|
ProhibitSub -> pure QProhibitSub
|
|
qDelivered <- atomically $ decodeLatin1 . encode . fst <$$> tryReadTMVar delivered
|
|
pure QSub {qSubThread, qDelivered}
|
|
|
|
ok :: Transmission BrokerMsg
|
|
ok = (corrId, entId, OK)
|
|
|
|
err :: ErrorType -> Transmission BrokerMsg
|
|
err e = (corrId, entId, ERR e)
|
|
|
|
updateDeletedStats :: QueueRec -> M s ()
|
|
updateDeletedStats q = do
|
|
stats <- asks serverStats
|
|
let delSel = if isNothing (senderKey q) then qDeletedNew else qDeletedSecured
|
|
incStat $ delSel stats
|
|
incStat $ qDeletedAll stats
|
|
liftIO $ atomicModifyIORef'_ (qCount stats) (subtract 1)
|
|
|
|
incStat :: MonadIO m => IORef Int -> m ()
|
|
incStat r = liftIO $ atomicModifyIORef'_ r (+ 1)
|
|
{-# INLINE incStat #-}
|
|
|
|
timed :: MonadIO m => T.Text -> RecipientId -> m a -> m a
|
|
timed name (EntityId qId) a = do
|
|
t <- liftIO getSystemTime
|
|
r <- a
|
|
t' <- liftIO getSystemTime
|
|
let int = diff t t'
|
|
when (int > sec) . logDebug $ T.unwords [name, tshow $ encode qId, tshow int]
|
|
pure r
|
|
where
|
|
diff t t' = (systemSeconds t' - systemSeconds t) * sec + fromIntegral (systemNanoseconds t' - systemNanoseconds t)
|
|
sec = 1000_000000
|
|
|
|
randomId' :: Int -> M s ByteString
|
|
randomId' n = atomically . C.randomBytes n =<< asks random
|
|
|
|
randomId :: Int -> M s EntityId
|
|
randomId = fmap EntityId . randomId'
|
|
{-# INLINE randomId #-}
|
|
|
|
saveServerMessages :: Bool -> MsgStore s -> IO ()
|
|
saveServerMessages drainMsgs = \case
|
|
StoreMemory ms@STMMsgStore {storeConfig = STMStoreConfig {storePath}} -> case storePath of
|
|
Just f -> exportMessages False ms f drainMsgs
|
|
Nothing -> logNote "undelivered messages are not saved"
|
|
StoreJournal _ -> logNote "closed journal message storage"
|
|
|
|
exportMessages :: MsgStoreClass s => Bool -> s -> FilePath -> Bool -> IO ()
|
|
exportMessages tty ms f drainMsgs = do
|
|
logNote $ "saving messages to file " <> T.pack f
|
|
liftIO $ withFile f WriteMode $ \h ->
|
|
tryAny (unsafeWithAllMsgQueues tty True ms $ saveQueueMsgs h) >>= \case
|
|
Right (Sum total) -> logNote $ "messages saved: " <> tshow total
|
|
Left e -> do
|
|
logError $ "error exporting messages: " <> tshow e
|
|
exitFailure
|
|
where
|
|
saveQueueMsgs h q = do
|
|
msgs <-
|
|
unsafeRunStore q "saveQueueMsgs" $
|
|
getQueueMessages_ drainMsgs q =<< getMsgQueue ms q False
|
|
BLD.hPutBuilder h $ encodeMessages (recipientId q) msgs
|
|
pure $ Sum $ length msgs
|
|
encodeMessages rId = mconcat . map (\msg -> BLD.byteString (strEncode $ MLRv3 rId msg) <> BLD.char8 '\n')
|
|
|
|
processServerMessages :: forall s'. StartOptions -> M s' (Maybe MessageStats)
|
|
processServerMessages StartOptions {skipWarnings} = do
|
|
old_ <- asks (messageExpiration . config) $>>= (liftIO . fmap Just . expireBeforeEpoch)
|
|
expire <- asks $ expireMessagesOnStart . config
|
|
asks msgStore_ >>= liftIO . processMessages old_ expire
|
|
where
|
|
processMessages :: Maybe Int64 -> Bool -> MsgStore s' -> IO (Maybe MessageStats)
|
|
processMessages old_ expire = \case
|
|
StoreMemory ms@STMMsgStore {storeConfig = STMStoreConfig {storePath}} -> case storePath of
|
|
Just f -> ifM (doesFileExist f) (Just <$> importMessages False ms f old_ skipWarnings) (pure Nothing)
|
|
Nothing -> pure Nothing
|
|
StoreJournal ms -> processJournalMessages old_ expire ms
|
|
processJournalMessages :: forall s. Maybe Int64 -> Bool -> JournalMsgStore s -> IO (Maybe MessageStats)
|
|
processJournalMessages old_ expire ms
|
|
| expire = Just <$> case old_ of
|
|
Just old -> do
|
|
logNote "expiring journal store messages..."
|
|
run $ processExpireQueue old
|
|
Nothing -> do
|
|
logNote "validating journal store messages..."
|
|
run processValidateQueue
|
|
| otherwise = logWarn "skipping message expiration" $> Nothing
|
|
where
|
|
run a = unsafeWithAllMsgQueues False False ms a `catchAny` \_ -> exitFailure
|
|
processExpireQueue :: Int64 -> JournalQueue s -> IO MessageStats
|
|
processExpireQueue old q = unsafeRunStore q "processExpireQueue" $ do
|
|
mq <- getMsgQueue ms q False
|
|
expiredMsgsCount <- deleteExpireMsgs_ old q mq
|
|
storedMsgsCount <- getQueueSize_ mq
|
|
pure MessageStats {storedMsgsCount, expiredMsgsCount, storedQueues = 1}
|
|
processValidateQueue :: JournalQueue s -> IO MessageStats
|
|
processValidateQueue q = unsafeRunStore q "processValidateQueue" $ do
|
|
storedMsgsCount <- getQueueSize_ =<< getMsgQueue ms q False
|
|
pure newMessageStats {storedMsgsCount, storedQueues = 1}
|
|
|
|
importMessages :: forall s. MsgStoreClass s => Bool -> s -> FilePath -> Maybe Int64 -> Bool -> IO MessageStats
|
|
importMessages tty ms f old_ skipWarnings = do
|
|
logNote $ "restoring messages from file " <> T.pack f
|
|
(_, (storedMsgsCount, expiredMsgsCount, overQuota)) <-
|
|
foldLogLines tty f restoreMsg (Nothing, (0, 0, M.empty))
|
|
renameFile f $ f <> ".bak"
|
|
mapM_ setOverQuota_ overQuota
|
|
logQueueStates ms
|
|
EntityCounts {queueCount} <- liftIO $ getEntityCounts @(StoreQueue s) $ queueStore ms
|
|
pure MessageStats {storedMsgsCount, expiredMsgsCount, storedQueues = queueCount}
|
|
where
|
|
restoreMsg :: (Maybe (RecipientId, StoreQueue s), (Int, Int, Map RecipientId (StoreQueue s))) -> Bool -> ByteString -> IO (Maybe (RecipientId, StoreQueue s), (Int, Int, Map RecipientId (StoreQueue s)))
|
|
restoreMsg (q_, counts@(!stored, !expired, !overQuota)) eof s = case strDecode s of
|
|
Right (MLRv3 rId msg) -> runExceptT (addToMsgQueue rId msg) >>= either (exitErr . tshow) pure
|
|
Left e
|
|
| eof -> warnOrExit (parsingErr e) $> (q_, counts)
|
|
| otherwise -> exitErr $ parsingErr e
|
|
where
|
|
exitErr e = do
|
|
when tty $ putStrLn ""
|
|
logError $ "error restoring messages: " <> e
|
|
liftIO exitFailure
|
|
parsingErr :: String -> Text
|
|
parsingErr e = "parsing error (" <> T.pack e <> "): " <> safeDecodeUtf8 (B.take 100 s)
|
|
addToMsgQueue rId msg = do
|
|
qOrErr <- case q_ of
|
|
-- to avoid lookup when restoring the next message to the same queue
|
|
Just (rId', q') | rId' == rId -> pure $ Right q'
|
|
_ -> liftIO $ getQueue ms SRecipient rId
|
|
case qOrErr of
|
|
Right q -> addToQueue_ q rId msg
|
|
Left AUTH -> liftIO $ do
|
|
when tty $ putStrLn ""
|
|
warnOrExit $ "queue " <> safeDecodeUtf8 (encode $ unEntityId rId) <> " does not exist"
|
|
pure (Nothing, counts)
|
|
Left e -> throwE e
|
|
addToQueue_ q rId msg =
|
|
(Just (rId, q),) <$> case msg of
|
|
Message {msgTs}
|
|
| maybe True (systemSeconds msgTs >=) old_ -> do
|
|
writeMsg ms q False msg >>= \case
|
|
Just _ -> pure (stored + 1, expired, overQuota)
|
|
Nothing -> liftIO $ do
|
|
when tty $ putStrLn ""
|
|
logError $ decodeLatin1 $ "message queue " <> strEncode rId <> " is full, message not restored: " <> strEncode (messageId msg)
|
|
pure counts
|
|
| otherwise -> pure (stored, expired + 1, overQuota)
|
|
MessageQuota {} ->
|
|
-- queue was over quota at some point,
|
|
-- it will be set as over quota once fully imported
|
|
mergeQuotaMsgs >> writeMsg ms q False msg $> (stored, expired, M.insert rId q overQuota)
|
|
where
|
|
-- if the first message in queue head is "quota", remove it.
|
|
mergeQuotaMsgs =
|
|
withPeekMsgQueue ms q "mergeQuotaMsgs" $ maybe (pure ()) $ \case
|
|
(mq, MessageQuota {}) -> tryDeleteMsg_ q mq False
|
|
_ -> pure ()
|
|
warnOrExit e
|
|
| skipWarnings = logWarn e'
|
|
| otherwise = do
|
|
logWarn $ e' <> ", start with --skip-warnings option to ignore this error"
|
|
exitFailure
|
|
where
|
|
e' = "warning restoring messages: " <> e
|
|
|
|
printMessageStats :: T.Text -> MessageStats -> IO ()
|
|
printMessageStats name MessageStats {storedMsgsCount, expiredMsgsCount, storedQueues} =
|
|
logNote $ name <> " stored: " <> tshow storedMsgsCount <> ", expired: " <> tshow expiredMsgsCount <> ", queues: " <> tshow storedQueues
|
|
|
|
saveServerNtfs :: M s ()
|
|
saveServerNtfs = asks (storeNtfsFile . config) >>= mapM_ saveNtfs
|
|
where
|
|
saveNtfs f = do
|
|
logNote $ "saving notifications to file " <> T.pack f
|
|
NtfStore ns <- asks ntfStore
|
|
liftIO . withFile f WriteMode $ \h ->
|
|
readTVarIO ns >>= mapM_ (saveQueueNtfs h) . M.assocs
|
|
logNote "notifications saved"
|
|
where
|
|
-- reverse on save, to save notifications in order, will become reversed again when restoring.
|
|
saveQueueNtfs h (nId, v) = BLD.hPutBuilder h . encodeNtfs nId . reverse =<< readTVarIO v
|
|
encodeNtfs nId = mconcat . map (\ntf -> BLD.byteString (strEncode $ NLRv1 nId ntf) <> BLD.char8 '\n')
|
|
|
|
restoreServerNtfs :: M s MessageStats
|
|
restoreServerNtfs =
|
|
asks (storeNtfsFile . config) >>= \case
|
|
Just f -> ifM (doesFileExist f) (restoreNtfs f) (pure newMessageStats)
|
|
Nothing -> pure newMessageStats
|
|
where
|
|
restoreNtfs f = do
|
|
logNote $ "restoring notifications from file " <> T.pack f
|
|
ns <- asks ntfStore
|
|
old <- asks (notificationExpiration . config) >>= liftIO . expireBeforeEpoch
|
|
liftIO $
|
|
LB.readFile f >>= runExceptT . foldM (restoreNtf ns old) (0, 0, 0) . LB.lines >>= \case
|
|
Left e -> do
|
|
logError . T.pack $ "error restoring notifications: " <> e
|
|
liftIO exitFailure
|
|
Right (lineCount, storedMsgsCount, expiredMsgsCount) -> do
|
|
renameFile f $ f <> ".bak"
|
|
let NtfStore ns' = ns
|
|
storedQueues <- M.size <$> readTVarIO ns'
|
|
logNote $ "notifications restored, " <> tshow lineCount <> " lines processed"
|
|
pure MessageStats {storedMsgsCount, expiredMsgsCount, storedQueues}
|
|
where
|
|
restoreNtf :: NtfStore -> Int64 -> (Int, Int, Int) -> LB.ByteString -> ExceptT String IO (Int, Int, Int)
|
|
restoreNtf ns old (!lineCount, !stored, !expired) s' = do
|
|
NLRv1 nId ntf <- liftEither . first (ntfErr "parsing") $ strDecode s
|
|
liftIO $ addToNtfs nId ntf
|
|
where
|
|
s = LB.toStrict s'
|
|
addToNtfs nId ntf@MsgNtf {ntfTs}
|
|
| systemSeconds ntfTs < old = pure (lineCount + 1, stored, expired + 1)
|
|
| otherwise = storeNtf ns nId ntf $> (lineCount + 1, stored + 1, expired)
|
|
ntfErr :: Show e => String -> e -> String
|
|
ntfErr op e = op <> " error (" <> show e <> "): " <> B.unpack (B.take 100 s)
|
|
|
|
saveServerStats :: M s ()
|
|
saveServerStats =
|
|
asks (serverStatsBackupFile . config)
|
|
>>= mapM_ (\f -> asks serverStats >>= liftIO . getServerStatsData >>= liftIO . saveStats f)
|
|
where
|
|
saveStats f stats = do
|
|
logNote $ "saving server stats to file " <> T.pack f
|
|
B.writeFile f $ strEncode stats
|
|
logNote "server stats saved"
|
|
|
|
restoreServerStats :: forall s. MsgStoreClass s => Maybe MessageStats -> MessageStats -> M s ()
|
|
restoreServerStats msgStats_ ntfStats = asks (serverStatsBackupFile . config) >>= mapM_ restoreStats
|
|
where
|
|
restoreStats f = whenM (doesFileExist f) $ do
|
|
logNote $ "restoring server stats from file " <> T.pack f
|
|
liftIO (strDecode <$> B.readFile f) >>= \case
|
|
Right d@ServerStatsData {_qCount = statsQCount, _msgCount = statsMsgCount, _ntfCount = statsNtfCount} -> do
|
|
s <- asks serverStats
|
|
st <- asks msgStore
|
|
EntityCounts {queueCount = _qCount} <- liftIO $ getEntityCounts @(StoreQueue s) $ queueStore st
|
|
let _msgCount = maybe statsMsgCount storedMsgsCount msgStats_
|
|
_ntfCount = storedMsgsCount ntfStats
|
|
_msgExpired' = _msgExpired d + maybe 0 expiredMsgsCount msgStats_
|
|
_msgNtfExpired' = _msgNtfExpired d + expiredMsgsCount ntfStats
|
|
liftIO $ setServerStats s d {_qCount, _msgCount, _ntfCount, _msgExpired = _msgExpired', _msgNtfExpired = _msgNtfExpired'}
|
|
renameFile f $ f <> ".bak"
|
|
logNote "server stats restored"
|
|
compareCounts "Queue" statsQCount _qCount
|
|
compareCounts "Message" statsMsgCount _msgCount
|
|
compareCounts "Notification" statsNtfCount _ntfCount
|
|
Left e -> do
|
|
logNote $ "error restoring server stats: " <> T.pack e
|
|
liftIO exitFailure
|
|
compareCounts name statsCnt storeCnt =
|
|
when (statsCnt /= storeCnt) $ logWarn $ name <> " count differs: stats: " <> tshow statsCnt <> ", store: " <> tshow storeCnt |