mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-03-30 18:35:59 +00:00
1283 lines
62 KiB
Haskell
1283 lines
62 KiB
Haskell
{-# LANGUAGE BangPatterns #-}
|
|
{-# LANGUAGE CPP #-}
|
|
{-# LANGUAGE DataKinds #-}
|
|
{-# LANGUAGE DuplicateRecordFields #-}
|
|
{-# LANGUAGE FlexibleContexts #-}
|
|
{-# LANGUAGE GADTs #-}
|
|
{-# LANGUAGE KindSignatures #-}
|
|
{-# LANGUAGE LambdaCase #-}
|
|
{-# LANGUAGE MultiWayIf #-}
|
|
{-# LANGUAGE NamedFieldPuns #-}
|
|
{-# LANGUAGE NumericUnderscores #-}
|
|
{-# LANGUAGE OverloadedLists #-}
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
{-# LANGUAGE RankNTypes #-}
|
|
{-# LANGUAGE ScopedTypeVariables #-}
|
|
{-# LANGUAGE TupleSections #-}
|
|
|
|
-- |
|
|
-- Module : Simplex.Messaging.Server
|
|
-- Copyright : (c) simplex.chat
|
|
-- License : AGPL-3
|
|
--
|
|
-- Maintainer : chat@simplex.chat
|
|
-- Stability : experimental
|
|
-- Portability : non-portable
|
|
--
|
|
-- This module defines SMP protocol server with in-memory persistence
|
|
-- and optional append only log of SMP queue records.
|
|
--
|
|
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md
|
|
module Simplex.Messaging.Server
|
|
( runSMPServer,
|
|
runSMPServerBlocking,
|
|
disconnectTransport,
|
|
verifyCmdAuthorization,
|
|
dummyVerifyCmd,
|
|
randomId,
|
|
)
|
|
where
|
|
|
|
import Control.Logger.Simple
|
|
import Control.Monad
|
|
import Control.Monad.Except
|
|
import Control.Monad.IO.Unlift
|
|
import Control.Monad.Reader
|
|
import Control.Monad.Trans.Except
|
|
import Crypto.Random
|
|
import Control.Monad.STM (retry)
|
|
import Data.Bifunctor (first)
|
|
import Data.ByteString.Base64 (encode)
|
|
import Data.ByteString.Char8 (ByteString)
|
|
import qualified Data.ByteString.Char8 as B
|
|
import qualified Data.ByteString.Lazy.Char8 as LB
|
|
import Data.Either (fromRight, partitionEithers)
|
|
import Data.Functor (($>))
|
|
import Data.Int (Int64)
|
|
import qualified Data.IntMap.Strict as IM
|
|
import Data.List (intercalate, mapAccumR)
|
|
import Data.List.NonEmpty (NonEmpty (..))
|
|
import qualified Data.List.NonEmpty as L
|
|
import qualified Data.Map.Strict as M
|
|
import Data.Maybe (catMaybes, fromMaybe, isNothing)
|
|
import qualified Data.Text as T
|
|
import Data.Text.Encoding (decodeLatin1)
|
|
import Data.Time.Clock (UTCTime (..), diffTimeToPicoseconds, getCurrentTime)
|
|
import Data.Time.Clock.System (SystemTime (..), getSystemTime)
|
|
import Data.Time.Format.ISO8601 (iso8601Show)
|
|
import Data.Type.Equality
|
|
import GHC.Stats (getRTSStats)
|
|
import GHC.TypeLits (KnownNat)
|
|
import Network.Socket (ServiceName, Socket, socketToHandle)
|
|
import Simplex.Messaging.Agent.Lock
|
|
import Simplex.Messaging.Client (ProtocolClient (thParams), ProtocolClientError (..), SMPClient, SMPClientError, forwardSMPMessage, smpProxyError, temporaryClientError)
|
|
import Simplex.Messaging.Client.Agent (OwnServer, SMPClientAgent (..), SMPClientAgentEvent (..), closeSMPClientAgent, getSMPServerClient'', isOwnServer, lookupSMPServerClient, getConnectedSMPServerClient)
|
|
import qualified Simplex.Messaging.Crypto as C
|
|
import Simplex.Messaging.Encoding
|
|
import Simplex.Messaging.Encoding.String
|
|
import Simplex.Messaging.Protocol
|
|
import Simplex.Messaging.Server.Control
|
|
import Simplex.Messaging.Server.Env.STM as Env
|
|
import Simplex.Messaging.Server.Expiration
|
|
import Simplex.Messaging.Server.MsgStore
|
|
import Simplex.Messaging.Server.MsgStore.STM
|
|
import Simplex.Messaging.Server.QueueStore
|
|
import Simplex.Messaging.Server.QueueStore.STM as QS
|
|
import Simplex.Messaging.Server.Stats
|
|
import Simplex.Messaging.Server.StoreLog
|
|
import Simplex.Messaging.TMap (TMap)
|
|
import qualified Simplex.Messaging.TMap as TM
|
|
import Simplex.Messaging.Transport
|
|
import Simplex.Messaging.Transport.Buffer (trimCR)
|
|
import Simplex.Messaging.Transport.Server
|
|
import Simplex.Messaging.Util
|
|
import Simplex.Messaging.Version
|
|
import System.Exit (exitFailure)
|
|
import System.IO (hPrint, hPutStrLn, hSetNewlineMode, universalNewlineMode)
|
|
import System.Mem.Weak (deRefWeak)
|
|
import UnliftIO (timeout)
|
|
import UnliftIO.Concurrent
|
|
import UnliftIO.Directory (doesFileExist, renameFile)
|
|
import UnliftIO.Exception
|
|
import UnliftIO.IO
|
|
import UnliftIO.STM
|
|
#if MIN_VERSION_base(4,18,0)
|
|
import Data.List (sort)
|
|
import GHC.Conc (listThreads, threadStatus)
|
|
import GHC.Conc.Sync (threadLabel)
|
|
#endif
|
|
|
|
-- | Runs an SMP server using passed configuration.
|
|
--
|
|
-- See a full server here: https://github.com/simplex-chat/simplexmq/blob/master/apps/smp-server/Main.hs
|
|
runSMPServer :: ServerConfig -> IO ()
|
|
runSMPServer cfg = do
|
|
started <- newEmptyTMVarIO
|
|
runSMPServerBlocking started cfg
|
|
|
|
-- | Runs an SMP server using passed configuration with signalling.
|
|
--
|
|
-- This function uses passed TMVar to signal when the server is ready to accept TCP requests (True)
|
|
-- and when it is disconnected from the TCP socket once the server thread is killed (False).
|
|
runSMPServerBlocking :: TMVar Bool -> ServerConfig -> IO ()
|
|
runSMPServerBlocking started cfg = newEnv cfg >>= runReaderT (smpServer started cfg)
|
|
|
|
type M a = ReaderT Env IO a
|
|
|
|
smpServer :: TMVar Bool -> ServerConfig -> M ()
|
|
smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do
|
|
s <- asks server
|
|
pa <- asks proxyAgent
|
|
expired <- restoreServerMessages
|
|
restoreServerStats expired
|
|
raceAny_
|
|
( serverThread s "server subscribedQ" subscribedQ subscribers subscriptions cancelSub
|
|
: serverThread s "server ntfSubscribedQ" ntfSubscribedQ Env.notifiers ntfSubscriptions (\_ -> pure ())
|
|
: receiveFromProxyAgent pa
|
|
: map runServer transports <> expireMessagesThread_ cfg <> serverStatsThread_ cfg <> controlPortThread_ cfg
|
|
)
|
|
`finally` withLock' (savingLock s) "final" (saveServer False >> closeServer)
|
|
where
|
|
runServer :: (ServiceName, ATransport) -> M ()
|
|
runServer (tcpPort, ATransport t) = do
|
|
serverParams <- asks tlsServerParams
|
|
ss <- asks sockets
|
|
serverSignKey <- either fail pure . fromTLSCredentials $ tlsServerCredentials serverParams
|
|
env <- ask
|
|
liftIO $ runTransportServerState ss started tcpPort serverParams tCfg $ \h -> runClient serverSignKey t h `runReaderT` env
|
|
fromTLSCredentials (_, pk) = C.x509ToPrivate (pk, []) >>= C.privKey
|
|
|
|
saveServer :: Bool -> M ()
|
|
saveServer keepMsgs = withLog closeStoreLog >> saveServerMessages keepMsgs >> saveServerStats
|
|
|
|
closeServer :: M ()
|
|
closeServer = asks (smpAgent . proxyAgent) >>= liftIO . closeSMPClientAgent
|
|
|
|
serverThread ::
|
|
forall s.
|
|
Server ->
|
|
String ->
|
|
(Server -> TQueue (QueueId, Client)) ->
|
|
(Server -> TMap QueueId Client) ->
|
|
(Client -> TMap QueueId s) ->
|
|
(s -> IO ()) ->
|
|
M ()
|
|
serverThread s label subQ subs clientSubs unsub = do
|
|
labelMyThread label
|
|
forever $
|
|
atomically updateSubscribers
|
|
$>>= endPreviousSubscriptions
|
|
>>= liftIO . mapM_ unsub
|
|
where
|
|
updateSubscribers :: STM (Maybe (QueueId, Client))
|
|
updateSubscribers = do
|
|
(qId, clnt) <- readTQueue $ subQ s
|
|
let clientToBeNotified c' =
|
|
if sameClientId clnt c'
|
|
then pure Nothing
|
|
else do
|
|
yes <- readTVar $ connected c'
|
|
pure $ if yes then Just (qId, c') else Nothing
|
|
TM.lookupInsert qId clnt (subs s) $>>= clientToBeNotified
|
|
endPreviousSubscriptions :: (QueueId, Client) -> M (Maybe s)
|
|
endPreviousSubscriptions (qId, c) = do
|
|
forkClient c (label <> ".endPreviousSubscriptions") $
|
|
atomically $ writeTBQueue (sndQ c) [(CorrId "", qId, END)]
|
|
atomically $ TM.lookupDelete qId (clientSubs c)
|
|
|
|
receiveFromProxyAgent :: ProxyAgent -> M ()
|
|
receiveFromProxyAgent ProxyAgent {smpAgent = SMPClientAgent {agentQ}} =
|
|
forever $
|
|
atomically (readTBQueue agentQ) >>= \case
|
|
CAConnected srv -> logInfo $ "SMP server connected " <> showServer' srv
|
|
CADisconnected srv [] -> logInfo $ "SMP server disconnected " <> showServer' srv
|
|
CADisconnected srv subs -> logError $ "SMP server disconnected " <> showServer' srv <> " / subscriptions: " <> tshow (length subs)
|
|
CAResubscribed srv subs -> logError $ "SMP server resubscribed " <> showServer' srv <> " / subscriptions: " <> tshow (length subs)
|
|
CASubError srv errs -> logError $ "SMP server subscription errors " <> showServer' srv <> " / errors: " <> tshow (length errs)
|
|
where
|
|
showServer' = decodeLatin1 . strEncode . host
|
|
|
|
expireMessagesThread_ :: ServerConfig -> [M ()]
|
|
expireMessagesThread_ ServerConfig {messageExpiration = Just msgExp} = [expireMessages msgExp]
|
|
expireMessagesThread_ _ = []
|
|
|
|
expireMessages :: ExpirationConfig -> M ()
|
|
expireMessages expCfg = do
|
|
ms <- asks msgStore
|
|
quota <- asks $ msgQueueQuota . config
|
|
let interval = checkInterval expCfg * 1000000
|
|
stats <- asks serverStats
|
|
labelMyThread "expireMessages"
|
|
forever $ do
|
|
liftIO $ threadDelay' interval
|
|
old <- liftIO $ expireBeforeEpoch expCfg
|
|
rIds <- M.keysSet <$> readTVarIO ms
|
|
forM_ rIds $ \rId -> do
|
|
q <- atomically (getMsgQueue ms rId quota)
|
|
deleted <- atomically $ deleteExpiredMsgs q old
|
|
atomically $ modifyTVar' (msgExpired stats) (+ deleted)
|
|
|
|
serverStatsThread_ :: ServerConfig -> [M ()]
|
|
serverStatsThread_ ServerConfig {logStatsInterval = Just interval, logStatsStartTime, serverStatsLogFile} =
|
|
[logServerStats logStatsStartTime interval serverStatsLogFile]
|
|
serverStatsThread_ _ = []
|
|
|
|
logServerStats :: Int64 -> Int64 -> FilePath -> M ()
|
|
logServerStats startAt logInterval statsFilePath = do
|
|
labelMyThread "logServerStats"
|
|
initialDelay <- (startAt -) . fromIntegral . (`div` 1000000_000000) . diffTimeToPicoseconds . utctDayTime <$> liftIO getCurrentTime
|
|
liftIO $ putStrLn $ "server stats log enabled: " <> statsFilePath
|
|
liftIO $ threadDelay' $ 1000000 * (initialDelay + if initialDelay < 0 then 86400 else 0)
|
|
ServerStats {fromTime, qCreated, qSecured, qDeletedAll, qDeletedNew, qDeletedSecured, qSub, qSubAuth, qSubDuplicate, qSubProhibited, msgSent, msgSentAuth, msgSentQuota, msgSentLarge, msgRecv, msgExpired, activeQueues, msgSentNtf, msgRecvNtf, activeQueuesNtf, qCount, msgCount, pRelays, pRelaysOwn, pMsgFwds, pMsgFwdsOwn, pMsgFwdsRecv} <- asks serverStats
|
|
let interval = 1000000 * logInterval
|
|
forever $ do
|
|
withFile statsFilePath AppendMode $ \h -> liftIO $ do
|
|
hSetBuffering h LineBuffering
|
|
ts <- getCurrentTime
|
|
fromTime' <- atomically $ swapTVar fromTime ts
|
|
qCreated' <- atomically $ swapTVar qCreated 0
|
|
qSecured' <- atomically $ swapTVar qSecured 0
|
|
qDeletedAll' <- atomically $ swapTVar qDeletedAll 0
|
|
qDeletedNew' <- atomically $ swapTVar qDeletedNew 0
|
|
qDeletedSecured' <- atomically $ swapTVar qDeletedSecured 0
|
|
qSub' <- atomically $ swapTVar qSub 0
|
|
qSubAuth' <- atomically $ swapTVar qSubAuth 0
|
|
qSubDuplicate' <- atomically $ swapTVar qSubDuplicate 0
|
|
qSubProhibited' <- atomically $ swapTVar qSubProhibited 0
|
|
msgSent' <- atomically $ swapTVar msgSent 0
|
|
msgSentAuth' <- atomically $ swapTVar msgSentAuth 0
|
|
msgSentQuota' <- atomically $ swapTVar msgSentQuota 0
|
|
msgSentLarge' <- atomically $ swapTVar msgSentLarge 0
|
|
msgRecv' <- atomically $ swapTVar msgRecv 0
|
|
msgExpired' <- atomically $ swapTVar msgExpired 0
|
|
ps <- atomically $ periodStatCounts activeQueues ts
|
|
msgSentNtf' <- atomically $ swapTVar msgSentNtf 0
|
|
msgRecvNtf' <- atomically $ swapTVar msgRecvNtf 0
|
|
psNtf <- atomically $ periodStatCounts activeQueuesNtf ts
|
|
pRelays' <- atomically $ getResetProxyStatsData pRelays
|
|
pRelaysOwn' <- atomically $ getResetProxyStatsData pRelaysOwn
|
|
pMsgFwds' <- atomically $ getResetProxyStatsData pMsgFwds
|
|
pMsgFwdsOwn' <- atomically $ getResetProxyStatsData pMsgFwdsOwn
|
|
pMsgFwdsRecv' <- atomically $ swapTVar pMsgFwdsRecv 0
|
|
qCount' <- readTVarIO qCount
|
|
msgCount' <- readTVarIO msgCount
|
|
hPutStrLn h $
|
|
intercalate
|
|
","
|
|
( [ iso8601Show $ utctDay fromTime',
|
|
show qCreated',
|
|
show qSecured',
|
|
show qDeletedAll',
|
|
show msgSent',
|
|
show msgRecv',
|
|
dayCount ps,
|
|
weekCount ps,
|
|
monthCount ps,
|
|
show msgSentNtf',
|
|
show msgRecvNtf',
|
|
dayCount psNtf,
|
|
weekCount psNtf,
|
|
monthCount psNtf,
|
|
show qCount',
|
|
show msgCount',
|
|
show msgExpired',
|
|
show qDeletedNew',
|
|
show qDeletedSecured'
|
|
]
|
|
<> showProxyStats pRelays'
|
|
<> showProxyStats pRelaysOwn'
|
|
<> showProxyStats pMsgFwds'
|
|
<> showProxyStats pMsgFwdsOwn'
|
|
<> [ show pMsgFwdsRecv',
|
|
show qSub',
|
|
show qSubAuth',
|
|
show qSubDuplicate',
|
|
show qSubProhibited',
|
|
show msgSentAuth',
|
|
show msgSentQuota',
|
|
show msgSentLarge'
|
|
]
|
|
)
|
|
liftIO $ threadDelay' interval
|
|
where
|
|
showProxyStats ProxyStatsData {_pRequests, _pSuccesses, _pErrorsConnect, _pErrorsCompat, _pErrorsOther} =
|
|
[show _pRequests, show _pSuccesses, show _pErrorsConnect, show _pErrorsCompat, show _pErrorsOther]
|
|
|
|
runClient :: Transport c => C.APrivateSignKey -> TProxy c -> c -> M ()
|
|
runClient signKey tp h = do
|
|
kh <- asks serverIdentity
|
|
ks <- atomically . C.generateKeyPair =<< asks random
|
|
ServerConfig {smpServerVRange, smpHandshakeTimeout} <- asks config
|
|
labelMyThread $ "smp handshake for " <> transportName tp
|
|
liftIO (timeout smpHandshakeTimeout . runExceptT $ smpServerHandshake signKey h ks kh smpServerVRange) >>= \case
|
|
Just (Right th) -> runClientTransport th
|
|
_ -> pure ()
|
|
|
|
controlPortThread_ :: ServerConfig -> [M ()]
|
|
controlPortThread_ ServerConfig {controlPort = Just port} = [runCPServer port]
|
|
controlPortThread_ _ = []
|
|
|
|
runCPServer :: ServiceName -> M ()
|
|
runCPServer port = do
|
|
srv <- asks server
|
|
cpStarted <- newEmptyTMVarIO
|
|
u <- askUnliftIO
|
|
liftIO $ do
|
|
labelMyThread "control port server"
|
|
runTCPServer cpStarted port $ runCPClient u srv
|
|
where
|
|
runCPClient :: UnliftIO (ReaderT Env IO) -> Server -> Socket -> IO ()
|
|
runCPClient u srv sock = do
|
|
labelMyThread "control port client"
|
|
h <- socketToHandle sock ReadWriteMode
|
|
hSetBuffering h LineBuffering
|
|
hSetNewlineMode h universalNewlineMode
|
|
hPutStrLn h "SMP server control port\n'help' for supported commands"
|
|
role <- newTVarIO CPRNone
|
|
cpLoop h role
|
|
where
|
|
cpLoop h role = do
|
|
s <- trimCR <$> B.hGetLine h
|
|
case strDecode s of
|
|
Right CPQuit -> hClose h
|
|
Right cmd -> logCmd s cmd >> processCP h role cmd >> cpLoop h role
|
|
Left err -> hPutStrLn h ("error: " <> err) >> cpLoop h role
|
|
logCmd s cmd = when shouldLog $ logWarn $ "ControlPort: " <> tshow s
|
|
where
|
|
shouldLog = case cmd of
|
|
CPAuth _ -> False
|
|
CPHelp -> False
|
|
CPQuit -> False
|
|
CPSkip -> False
|
|
_ -> True
|
|
processCP h role = \case
|
|
CPAuth auth -> atomically $ writeTVar role $! newRole cfg
|
|
where
|
|
newRole ServerConfig {controlPortUserAuth = user, controlPortAdminAuth = admin}
|
|
| Just auth == admin = CPRAdmin
|
|
| Just auth == user = CPRUser
|
|
| otherwise = CPRNone
|
|
CPSuspend -> withAdminRole $ hPutStrLn h "suspend not implemented"
|
|
CPResume -> withAdminRole $ hPutStrLn h "resume not implemented"
|
|
CPClients -> withAdminRole $ do
|
|
active <- unliftIO u (asks clients) >>= readTVarIO
|
|
hPutStrLn h "clientId,sessionId,connected,createdAt,rcvActiveAt,sndActiveAt,age,subscriptions"
|
|
forM_ (IM.toList active) $ \(cid, Client {sessionId, connected, createdAt, rcvActiveAt, sndActiveAt, subscriptions}) -> do
|
|
connected' <- bshow <$> readTVarIO connected
|
|
rcvActiveAt' <- strEncode <$> readTVarIO rcvActiveAt
|
|
sndActiveAt' <- strEncode <$> readTVarIO sndActiveAt
|
|
now <- liftIO getSystemTime
|
|
let age = systemSeconds now - systemSeconds createdAt
|
|
subscriptions' <- bshow . M.size <$> readTVarIO subscriptions
|
|
hPutStrLn h . B.unpack $ B.intercalate "," [bshow cid, encode sessionId, connected', strEncode createdAt, rcvActiveAt', sndActiveAt', bshow age, subscriptions']
|
|
CPStats -> withAdminRole $ do
|
|
ss <- unliftIO u $ asks serverStats
|
|
let putStat :: Show a => ByteString -> (ServerStats -> TVar a) -> IO ()
|
|
putStat label var = readTVarIO (var ss) >>= \v -> B.hPutStr h $ label <> ": " <> bshow v <> "\n"
|
|
putProxyStat :: ByteString -> (ServerStats -> ProxyStats) -> IO ()
|
|
putProxyStat label var = do
|
|
ProxyStatsData {_pRequests, _pSuccesses, _pErrorsConnect, _pErrorsCompat, _pErrorsOther} <- atomically $ getProxyStatsData $ var ss
|
|
B.hPutStr h $ label <> ": requests=" <> bshow _pRequests <> ", successes=" <> bshow _pSuccesses <> ", errorsConnect=" <> bshow _pErrorsConnect <> ", errorsCompat=" <> bshow _pErrorsCompat <> ", errorsOther=" <> bshow _pErrorsOther <> "\n"
|
|
putStat "fromTime" fromTime
|
|
putStat "qCreated" qCreated
|
|
putStat "qSecured" qSecured
|
|
putStat "qDeletedAll" qDeletedAll
|
|
putStat "qDeletedNew" qDeletedNew
|
|
putStat "qDeletedSecured" qDeletedSecured
|
|
putStat "msgSent" msgSent
|
|
putStat "msgRecv" msgRecv
|
|
putStat "msgSentNtf" msgSentNtf
|
|
putStat "msgRecvNtf" msgRecvNtf
|
|
putStat "qCount" qCount
|
|
putStat "msgCount" msgCount
|
|
putProxyStat "pRelays" pRelays
|
|
putProxyStat "pRelaysOwn" pRelaysOwn
|
|
putProxyStat "pMsgFwds" pMsgFwds
|
|
putProxyStat "pMsgFwdsOwn" pMsgFwdsOwn
|
|
putStat "pMsgFwdsRecv" pMsgFwdsRecv
|
|
CPStatsRTS -> getRTSStats >>= hPrint h
|
|
CPThreads -> withAdminRole $ do
|
|
#if MIN_VERSION_base(4,18,0)
|
|
threads <- liftIO listThreads
|
|
hPutStrLn h $ "Threads: " <> show (length threads)
|
|
forM_ (sort threads) $ \tid -> do
|
|
label <- threadLabel tid
|
|
status <- threadStatus tid
|
|
hPutStrLn h $ show tid <> " (" <> show status <> ") " <> fromMaybe "" label
|
|
#else
|
|
hPutStrLn h "Not available on GHC 8.10"
|
|
#endif
|
|
CPSockets -> withAdminRole $ do
|
|
(accepted', closed', active') <- unliftIO u $ asks sockets
|
|
(accepted, closed, active) <- atomically $ (,,) <$> readTVar accepted' <*> readTVar closed' <*> readTVar active'
|
|
hPutStrLn h "Sockets: "
|
|
hPutStrLn h $ "accepted: " <> show accepted
|
|
hPutStrLn h $ "closed: " <> show closed
|
|
hPutStrLn h $ "active: " <> show (IM.size active)
|
|
hPutStrLn h $ "leaked: " <> show (accepted - closed - IM.size active)
|
|
CPSocketThreads -> withAdminRole $ do
|
|
#if MIN_VERSION_base(4,18,0)
|
|
(_, _, active') <- unliftIO u $ asks sockets
|
|
active <- readTVarIO active'
|
|
forM_ (IM.toList active) $ \(sid, tid') ->
|
|
deRefWeak tid' >>= \case
|
|
Nothing -> hPutStrLn h $ intercalate "," [show sid, "", "gone", ""]
|
|
Just tid -> do
|
|
label <- threadLabel tid
|
|
status <- threadStatus tid
|
|
hPutStrLn h $ intercalate "," [show sid, show tid, show status, fromMaybe "" label]
|
|
#else
|
|
hPutStrLn h "Not available on GHC 8.10"
|
|
#endif
|
|
CPDelete queueId' -> withUserRole $ unliftIO u $ do
|
|
st <- asks queueStore
|
|
ms <- asks msgStore
|
|
queueId <- atomically (getQueue st SSender queueId') >>= \case
|
|
Left _ -> pure queueId' -- fallback to using as recipientId directly
|
|
Right QueueRec {recipientId} -> pure recipientId
|
|
r <- atomically $
|
|
deleteQueue st queueId $>>= \q ->
|
|
Right . (q,) <$> delMsgQueueSize ms queueId
|
|
case r of
|
|
Left e -> liftIO . hPutStrLn h $ "error: " <> show e
|
|
Right (q, numDeleted) -> do
|
|
withLog (`logDeleteQueue` queueId)
|
|
updateDeletedStats q
|
|
liftIO . hPutStrLn h $ "ok, " <> show numDeleted <> " messages deleted"
|
|
CPSave -> withAdminRole $ withLock' (savingLock srv) "control" $ do
|
|
hPutStrLn h "saving server state..."
|
|
unliftIO u $ saveServer True
|
|
hPutStrLn h "server state saved!"
|
|
CPHelp -> hPutStrLn h "commands: stats, stats-rts, clients, sockets, socket-threads, threads, delete, save, help, quit"
|
|
CPQuit -> pure ()
|
|
CPSkip -> pure ()
|
|
where
|
|
withUserRole action = readTVarIO role >>= \case
|
|
CPRAdmin -> action
|
|
CPRUser -> action
|
|
_ -> do
|
|
logError "Unauthorized control port command"
|
|
hPutStrLn h "AUTH"
|
|
withAdminRole action = readTVarIO role >>= \case
|
|
CPRAdmin -> action
|
|
_ -> do
|
|
logError "Unauthorized control port command"
|
|
hPutStrLn h "AUTH"
|
|
|
|
runClientTransport :: Transport c => THandleSMP c 'TServer -> M ()
|
|
runClientTransport h@THandle {params = thParams@THandleParams {thVersion, sessionId}} = do
|
|
q <- asks $ tbqSize . config
|
|
ts <- liftIO getSystemTime
|
|
active <- asks clients
|
|
nextClientId <- asks clientSeq
|
|
c <- atomically $ do
|
|
new@Client {clientId} <- newClient nextClientId q thVersion sessionId ts
|
|
modifyTVar' active $ IM.insert clientId new
|
|
pure new
|
|
s <- asks server
|
|
expCfg <- asks $ inactiveClientExpiration . config
|
|
th <- newMVar h -- put TH under a fair lock to interleave messages and command responses
|
|
labelMyThread . B.unpack $ "client $" <> encode sessionId
|
|
raceAny_ ([liftIO $ send th c, liftIO $ sendMsg th c, client thParams c s, receive h c] <> disconnectThread_ c expCfg)
|
|
`finally` clientDisconnected c
|
|
where
|
|
disconnectThread_ c (Just expCfg) = [liftIO $ disconnectTransport h (rcvActiveAt c) (sndActiveAt c) expCfg (noSubscriptions c)]
|
|
disconnectThread_ _ _ = []
|
|
noSubscriptions c = atomically $ (&&) <$> TM.null (subscriptions c) <*> TM.null (ntfSubscriptions c)
|
|
|
|
clientDisconnected :: Client -> M ()
|
|
clientDisconnected c@Client {clientId, subscriptions, connected, sessionId, endThreads} = do
|
|
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " disc"
|
|
subs <- atomically $ do
|
|
writeTVar connected False
|
|
swapTVar subscriptions M.empty
|
|
liftIO $ mapM_ cancelSub subs
|
|
srvSubs <- asks $ subscribers . server
|
|
atomically $ modifyTVar' srvSubs $ \cs ->
|
|
M.foldrWithKey (\sub _ -> M.update deleteCurrentClient sub) cs subs
|
|
asks clients >>= atomically . (`modifyTVar'` IM.delete clientId)
|
|
tIds <- atomically $ swapTVar endThreads IM.empty
|
|
liftIO $ mapM_ (mapM_ killThread <=< deRefWeak) tIds
|
|
where
|
|
deleteCurrentClient :: Client -> Maybe Client
|
|
deleteCurrentClient c'
|
|
| sameClientId c c' = Nothing
|
|
| otherwise = Just c'
|
|
|
|
sameClientId :: Client -> Client -> Bool
|
|
sameClientId Client {clientId} Client {clientId = cId'} = clientId == cId'
|
|
|
|
cancelSub :: TVar Sub -> IO ()
|
|
cancelSub sub =
|
|
readTVarIO sub >>= \case
|
|
Sub {subThread = SubThread t} -> liftIO $ deRefWeak t >>= mapM_ killThread
|
|
_ -> return ()
|
|
|
|
receive :: Transport c => THandleSMP c 'TServer -> Client -> M ()
|
|
receive h@THandle {params = THandleParams {thAuth}} Client {rcvQ, sndQ, rcvActiveAt, sessionId} = do
|
|
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " receive"
|
|
forever $ do
|
|
ts <- L.toList <$> liftIO (tGet h)
|
|
atomically . writeTVar rcvActiveAt =<< liftIO getSystemTime
|
|
stats <- asks serverStats
|
|
(errs, cmds) <- partitionEithers <$> mapM (cmdAction stats) ts
|
|
write sndQ errs
|
|
write rcvQ cmds
|
|
where
|
|
cmdAction :: ServerStats -> SignedTransmission ErrorType Cmd -> M (Either (Transmission BrokerMsg) (Maybe QueueRec, Transmission Cmd))
|
|
cmdAction stats (tAuth, authorized, (corrId, entId, cmdOrError)) =
|
|
case cmdOrError of
|
|
Left e -> pure $ Left (corrId, entId, ERR e)
|
|
Right cmd -> verified =<< verifyTransmission ((,C.cbNonce (bs corrId)) <$> thAuth) tAuth authorized entId cmd
|
|
where
|
|
verified = \case
|
|
VRVerified qr -> pure $ Right (qr, (corrId, entId, cmd))
|
|
VRFailed -> do
|
|
case cmd of
|
|
Cmd _ SEND {} -> atomically $ modifyTVar' (msgSentAuth stats) (+ 1)
|
|
Cmd _ SUB -> atomically $ modifyTVar' (qSubAuth stats) (+ 1)
|
|
_ -> pure ()
|
|
pure $ Left (corrId, entId, ERR AUTH)
|
|
write q = mapM_ (atomically . writeTBQueue q) . L.nonEmpty
|
|
|
|
send :: Transport c => MVar (THandleSMP c 'TServer) -> Client -> IO ()
|
|
send th c@Client {sndQ, msgQ, sessionId} = do
|
|
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " send"
|
|
forever $ atomically (readTBQueue sndQ) >>= sendTransmissions
|
|
where
|
|
sendTransmissions :: NonEmpty (Transmission BrokerMsg) -> IO ()
|
|
sendTransmissions ts
|
|
| L.length ts <= 2 = tSend th c ts
|
|
| otherwise = do
|
|
let (msgs_, ts') = mapAccumR splitMessages [] ts
|
|
-- If the request had batched subscriptions (L.length ts > 2)
|
|
-- this will reply OK to all SUBs in the first batched transmission,
|
|
-- to reduce client timeouts.
|
|
tSend th c ts'
|
|
-- After that all messages will be sent in separate transmissions,
|
|
-- without any client response timeouts, and allowing them to interleave
|
|
-- with other requests responses.
|
|
mapM_ (atomically . writeTBQueue msgQ) $ L.nonEmpty msgs_
|
|
where
|
|
splitMessages :: [Transmission BrokerMsg] -> Transmission BrokerMsg -> ([Transmission BrokerMsg], Transmission BrokerMsg)
|
|
splitMessages msgs t@(corrId, entId, cmd) = case cmd of
|
|
-- replace MSG response with OK, accumulating MSG in a separate list.
|
|
MSG {} -> ((CorrId "", entId, cmd) : msgs, (corrId, entId, OK))
|
|
_ -> (msgs, t)
|
|
|
|
sendMsg :: Transport c => MVar (THandleSMP c 'TServer) -> Client -> IO ()
|
|
sendMsg th c@Client {msgQ, sessionId} = do
|
|
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " sendMsg"
|
|
forever $ atomically (readTBQueue msgQ) >>= mapM_ (\t -> tSend th c [t])
|
|
|
|
tSend :: Transport c => MVar (THandleSMP c 'TServer) -> Client -> NonEmpty (Transmission BrokerMsg) -> IO ()
|
|
tSend th Client {sndActiveAt} ts = do
|
|
withMVar th $ \h@THandle {params} ->
|
|
void . tPut h $ L.map (\t -> Right (Nothing, encodeTransmission params t)) ts
|
|
atomically . writeTVar sndActiveAt =<< liftIO getSystemTime
|
|
|
|
disconnectTransport :: Transport c => THandle v c 'TServer -> TVar SystemTime -> TVar SystemTime -> ExpirationConfig -> IO Bool -> IO ()
|
|
disconnectTransport THandle {connection, params = THandleParams {sessionId}} rcvActiveAt sndActiveAt expCfg noSubscriptions = do
|
|
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " disconnectTransport"
|
|
loop
|
|
where
|
|
loop = do
|
|
threadDelay' $ checkInterval expCfg * 1000000
|
|
ifM noSubscriptions checkExpired loop
|
|
checkExpired = do
|
|
old <- expireBeforeEpoch expCfg
|
|
ts <- max <$> readTVarIO rcvActiveAt <*> readTVarIO sndActiveAt
|
|
if systemSeconds ts < old then closeConnection connection else loop
|
|
|
|
data VerificationResult = VRVerified (Maybe QueueRec) | VRFailed
|
|
|
|
-- This function verifies queue command authorization, with the objective to have constant time between the three AUTH error scenarios:
|
|
-- - the queue and party key exist, and the provided authorization has type matching queue key, but it is made with the different key.
|
|
-- - the queue and party key exist, but the provided authorization has incorrect type.
|
|
-- - the queue or party key do not exist.
|
|
-- In all cases, the time of the verification should depend only on the provided authorization type,
|
|
-- a dummy key is used to run verification in the last two cases, and failure is returned irrespective of the result.
|
|
verifyTransmission :: Maybe (THandleAuth 'TServer, C.CbNonce) -> Maybe TransmissionAuth -> ByteString -> QueueId -> Cmd -> M VerificationResult
|
|
verifyTransmission auth_ tAuth authorized queueId cmd =
|
|
case cmd of
|
|
Cmd SRecipient (NEW k _ _ _) -> pure $ Nothing `verifiedWith` k
|
|
Cmd SRecipient _ -> verifyQueue (\q -> Just q `verifiedWith` recipientKey q) <$> get SRecipient
|
|
-- SEND will be accepted without authorization before the queue is secured with KEY command
|
|
Cmd SSender SEND {} -> verifyQueue (\q -> Just q `verified` maybe (isNothing tAuth) verify (senderKey q)) <$> get SSender
|
|
Cmd SSender PING -> pure $ VRVerified Nothing
|
|
Cmd SSender RFWD {} -> pure $ VRVerified Nothing
|
|
-- NSUB will not be accepted without authorization
|
|
Cmd SNotifier NSUB -> verifyQueue (\q -> maybe dummyVerify (\n -> Just q `verifiedWith` notifierKey n) (notifier q)) <$> get SNotifier
|
|
Cmd SProxiedClient _ -> pure $ VRVerified Nothing
|
|
where
|
|
verify = verifyCmdAuthorization auth_ tAuth authorized
|
|
dummyVerify = verify (dummyAuthKey tAuth) `seq` VRFailed
|
|
verifyQueue :: (QueueRec -> VerificationResult) -> Either ErrorType QueueRec -> VerificationResult
|
|
verifyQueue = either (const dummyVerify)
|
|
verified q cond = if cond then VRVerified q else VRFailed
|
|
verifiedWith q k = q `verified` verify k
|
|
get :: DirectParty p => SParty p -> M (Either ErrorType QueueRec)
|
|
get party = do
|
|
st <- asks queueStore
|
|
atomically $ getQueue st party queueId
|
|
|
|
verifyCmdAuthorization :: Maybe (THandleAuth 'TServer, C.CbNonce) -> Maybe TransmissionAuth -> ByteString -> C.APublicAuthKey -> Bool
|
|
verifyCmdAuthorization auth_ tAuth authorized key = maybe False (verify key) tAuth
|
|
where
|
|
verify :: C.APublicAuthKey -> TransmissionAuth -> Bool
|
|
verify (C.APublicAuthKey a k) = \case
|
|
TASignature (C.ASignature a' s) -> case testEquality a a' of
|
|
Just Refl -> C.verify' k s authorized
|
|
_ -> C.verify' (dummySignKey a') s authorized `seq` False
|
|
TAAuthenticator s -> case a of
|
|
C.SX25519 -> verifyCmdAuth auth_ k s authorized
|
|
_ -> verifyCmdAuth auth_ dummyKeyX25519 s authorized `seq` False
|
|
|
|
verifyCmdAuth :: Maybe (THandleAuth 'TServer, C.CbNonce) -> C.PublicKeyX25519 -> C.CbAuthenticator -> ByteString -> Bool
|
|
verifyCmdAuth auth_ k authenticator authorized = case auth_ of
|
|
Just (THAuthServer {serverPrivKey = pk}, nonce) -> C.cbVerify k pk nonce authenticator authorized
|
|
Nothing -> False
|
|
|
|
dummyVerifyCmd :: Maybe (THandleAuth 'TServer, C.CbNonce) -> ByteString -> TransmissionAuth -> Bool
|
|
dummyVerifyCmd auth_ authorized = \case
|
|
TASignature (C.ASignature a s) -> C.verify' (dummySignKey a) s authorized
|
|
TAAuthenticator s -> verifyCmdAuth auth_ dummyKeyX25519 s authorized
|
|
|
|
-- These dummy keys are used with `dummyVerify` function to mitigate timing attacks
|
|
-- by having the same time of the response whether a queue exists or nor, for all valid key/signature sizes
|
|
dummySignKey :: C.SignatureAlgorithm a => C.SAlgorithm a -> C.PublicKey a
|
|
dummySignKey = \case
|
|
C.SEd25519 -> dummyKeyEd25519
|
|
C.SEd448 -> dummyKeyEd448
|
|
|
|
dummyAuthKey :: Maybe TransmissionAuth -> C.APublicAuthKey
|
|
dummyAuthKey = \case
|
|
Just (TASignature (C.ASignature a _)) -> case a of
|
|
C.SEd25519 -> C.APublicAuthKey C.SEd25519 dummyKeyEd25519
|
|
C.SEd448 -> C.APublicAuthKey C.SEd448 dummyKeyEd448
|
|
_ -> C.APublicAuthKey C.SX25519 dummyKeyX25519
|
|
|
|
dummyKeyEd25519 :: C.PublicKey 'C.Ed25519
|
|
dummyKeyEd25519 = "MCowBQYDK2VwAyEA139Oqs4QgpqbAmB0o7rZf6T19ryl7E65k4AYe0kE3Qs="
|
|
|
|
dummyKeyEd448 :: C.PublicKey 'C.Ed448
|
|
dummyKeyEd448 = "MEMwBQYDK2VxAzoA6ibQc9XpkSLtwrf7PLvp81qW/etiumckVFImCMRdftcG/XopbOSaq9qyLhrgJWKOLyNrQPNVvpMA"
|
|
|
|
dummyKeyX25519 :: C.PublicKey 'C.X25519
|
|
dummyKeyX25519 = "MCowBQYDK2VuAyEA4JGSMYht18H4mas/jHeBwfcM7jLwNYJNOAhi2/g4RXg="
|
|
|
|
forkClient :: Client -> String -> M () -> M ()
|
|
forkClient Client {endThreads, endThreadSeq} label action = do
|
|
tId <- atomically $ stateTVar endThreadSeq $ \next -> (next, next + 1)
|
|
t <- forkIO $ do
|
|
labelMyThread label
|
|
action `finally` atomically (modifyTVar' endThreads $ IM.delete tId)
|
|
mkWeakThreadId t >>= atomically . modifyTVar' endThreads . IM.insert tId
|
|
|
|
client :: THandleParams SMPVersion 'TServer -> Client -> Server -> M ()
|
|
client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId, procThreads} Server {subscribedQ, ntfSubscribedQ, notifiers} = do
|
|
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " commands"
|
|
forever $
|
|
atomically (readTBQueue rcvQ)
|
|
>>= mapM processCommand
|
|
>>= mapM_ reply . L.nonEmpty . catMaybes . L.toList
|
|
where
|
|
reply :: MonadIO m => NonEmpty (Transmission BrokerMsg) -> m ()
|
|
reply = atomically . writeTBQueue sndQ
|
|
processProxiedCmd :: Transmission (Command 'ProxiedClient) -> M (Maybe (Transmission BrokerMsg))
|
|
processProxiedCmd (corrId, sessId, command) = (corrId,sessId,) <$$> case command of
|
|
PRXY srv auth -> ifM allowProxy getRelay (pure $ Just $ ERR $ PROXY BASIC_AUTH)
|
|
where
|
|
allowProxy = do
|
|
ServerConfig {allowSMPProxy, newQueueBasicAuth} <- asks config
|
|
pure $ allowSMPProxy && maybe True ((== auth) . Just) newQueueBasicAuth
|
|
getRelay = do
|
|
ProxyAgent {smpAgent = a} <- asks proxyAgent
|
|
liftIO (getConnectedSMPServerClient a srv) >>= \case
|
|
Just r -> Just <$> proxyServerResponse a r
|
|
Nothing ->
|
|
forkProxiedCmd $
|
|
liftIO (runExceptT (getSMPServerClient'' a srv) `catch` (pure . Left . PCEIOError))
|
|
>>= proxyServerResponse a
|
|
proxyServerResponse :: SMPClientAgent -> Either SMPClientError (OwnServer, SMPClient) -> M BrokerMsg
|
|
proxyServerResponse a smp_ = do
|
|
ServerStats {pRelays, pRelaysOwn} <- asks serverStats
|
|
let inc = mkIncProxyStats pRelays pRelaysOwn
|
|
case smp_ of
|
|
Right (own, smp) -> do
|
|
inc own pRequests
|
|
case proxyResp smp of
|
|
r@PKEY {} -> r <$ inc own pSuccesses
|
|
r -> r <$ inc own pErrorsCompat
|
|
Left e -> do
|
|
let own = isOwnServer a srv
|
|
inc own pRequests
|
|
inc own $ if temporaryClientError e then pErrorsConnect else pErrorsOther
|
|
pure . ERR $ smpProxyError e
|
|
where
|
|
proxyResp smp =
|
|
let THandleParams {sessionId = srvSessId, thVersion, thServerVRange, thAuth} = thParams smp
|
|
in case compatibleVRange thServerVRange proxiedSMPRelayVRange of
|
|
-- Cap the destination relay version range to prevent client version fingerprinting.
|
|
-- See comment for proxiedSMPRelayVersion.
|
|
Just (Compatible vr) | thVersion >= sendingProxySMPVersion -> case thAuth of
|
|
Just THAuthClient {serverCertKey} -> PKEY srvSessId vr serverCertKey
|
|
Nothing -> ERR $ transportErr TENoServerAuth
|
|
_ -> ERR $ transportErr TEVersion
|
|
PFWD fwdV pubKey encBlock -> do
|
|
ProxyAgent {smpAgent = a} <- asks proxyAgent
|
|
ServerStats {pMsgFwds, pMsgFwdsOwn} <- asks serverStats
|
|
let inc = mkIncProxyStats pMsgFwds pMsgFwdsOwn
|
|
atomically (lookupSMPServerClient a sessId) >>= \case
|
|
Just (own, smp) -> do
|
|
inc own pRequests
|
|
if v >= sendingProxySMPVersion
|
|
then forkProxiedCmd $ do
|
|
liftIO (runExceptT (forwardSMPMessage smp corrId fwdV pubKey encBlock) `catch` (pure . Left . PCEIOError)) >>= \case
|
|
Right r -> PRES r <$ inc own pSuccesses
|
|
Left e -> ERR (smpProxyError e) <$ case e of
|
|
PCEProtocolError {} -> inc own pSuccesses
|
|
_ -> inc own pErrorsOther
|
|
else Just (ERR $ transportErr TEVersion) <$ inc own pErrorsCompat
|
|
where
|
|
THandleParams {thVersion = v} = thParams smp
|
|
Nothing -> inc False pRequests >> inc False pErrorsConnect $> Just (ERR $ PROXY NO_SESSION)
|
|
where
|
|
forkProxiedCmd :: M BrokerMsg -> M (Maybe BrokerMsg)
|
|
forkProxiedCmd cmdAction = do
|
|
bracket_ wait signal . forkClient clnt (B.unpack $ "client $" <> encode sessionId <> " proxy") $ do
|
|
-- commands MUST be processed under a reasonable timeout or the client would halt
|
|
cmdAction >>= \t -> reply [(corrId, sessId, t)]
|
|
pure Nothing
|
|
where
|
|
wait = do
|
|
ServerConfig {serverClientConcurrency} <- asks config
|
|
atomically $ do
|
|
used <- readTVar procThreads
|
|
when (used >= serverClientConcurrency) retry
|
|
writeTVar procThreads $! used + 1
|
|
signal = atomically $ modifyTVar' procThreads (\t -> t - 1)
|
|
transportErr :: TransportError -> ErrorType
|
|
transportErr = PROXY . BROKER . TRANSPORT
|
|
mkIncProxyStats :: MonadIO m => ProxyStats -> ProxyStats -> OwnServer -> (ProxyStats -> TVar Int) -> m ()
|
|
mkIncProxyStats ps psOwn = \own sel -> do
|
|
atomically $ modifyTVar' (sel ps) (+ 1)
|
|
when own $ atomically $ modifyTVar' (sel psOwn) (+ 1)
|
|
processCommand :: (Maybe QueueRec, Transmission Cmd) -> M (Maybe (Transmission BrokerMsg))
|
|
processCommand (qr_, (corrId, queueId, cmd)) = case cmd of
|
|
Cmd SProxiedClient command -> processProxiedCmd (corrId, queueId, command)
|
|
Cmd SSender command -> Just <$> case command of
|
|
SEND flags msgBody -> withQueue $ \qr -> sendMessage qr flags msgBody
|
|
PING -> pure (corrId, "", PONG)
|
|
RFWD encBlock -> (corrId, "",) <$> processForwardedCommand encBlock
|
|
Cmd SNotifier NSUB -> Just <$> subscribeNotifications
|
|
Cmd SRecipient command -> do
|
|
st <- asks queueStore
|
|
Just <$> case command of
|
|
NEW rKey dhKey auth subMode ->
|
|
ifM
|
|
allowNew
|
|
(createQueue st rKey dhKey subMode)
|
|
(pure (corrId, queueId, ERR AUTH))
|
|
where
|
|
allowNew = do
|
|
ServerConfig {allowNewQueues, newQueueBasicAuth} <- asks config
|
|
pure $ allowNewQueues && maybe True ((== auth) . Just) newQueueBasicAuth
|
|
SUB -> withQueue (`subscribeQueue` queueId)
|
|
GET -> withQueue getMessage
|
|
ACK msgId -> withQueue (`acknowledgeMsg` msgId)
|
|
KEY sKey -> secureQueue_ st sKey
|
|
NKEY nKey dhKey -> addQueueNotifier_ st nKey dhKey
|
|
NDEL -> deleteQueueNotifier_ st
|
|
OFF -> suspendQueue_ st
|
|
DEL -> delQueueAndMsgs st
|
|
where
|
|
createQueue :: QueueStore -> RcvPublicAuthKey -> RcvPublicDhKey -> SubscriptionMode -> M (Transmission BrokerMsg)
|
|
createQueue st recipientKey dhKey subMode = time "NEW" $ do
|
|
(rcvPublicDhKey, privDhKey) <- atomically . C.generateKeyPair =<< asks random
|
|
let rcvDhSecret = C.dh' dhKey privDhKey
|
|
qik (rcvId, sndId) = QIK {rcvId, sndId, rcvPublicDhKey}
|
|
qRec (recipientId, senderId) =
|
|
QueueRec
|
|
{ recipientId,
|
|
senderId,
|
|
recipientKey,
|
|
rcvDhSecret,
|
|
senderKey = Nothing,
|
|
notifier = Nothing,
|
|
status = QueueActive
|
|
}
|
|
(corrId,queueId,) <$> addQueueRetry 3 qik qRec
|
|
where
|
|
addQueueRetry ::
|
|
Int -> ((RecipientId, SenderId) -> QueueIdsKeys) -> ((RecipientId, SenderId) -> QueueRec) -> M BrokerMsg
|
|
addQueueRetry 0 _ _ = pure $ ERR INTERNAL
|
|
addQueueRetry n qik qRec = do
|
|
ids@(rId, _) <- getIds
|
|
-- create QueueRec record with these ids and keys
|
|
let qr = qRec ids
|
|
atomically (addQueue st qr) >>= \case
|
|
Left DUPLICATE_ -> addQueueRetry (n - 1) qik qRec
|
|
Left e -> pure $ ERR e
|
|
Right _ -> do
|
|
withLog (`logCreateById` rId)
|
|
stats <- asks serverStats
|
|
atomically $ modifyTVar' (qCreated stats) (+ 1)
|
|
atomically $ modifyTVar' (qCount stats) (+ 1)
|
|
case subMode of
|
|
SMOnlyCreate -> pure ()
|
|
SMSubscribe -> void $ subscribeQueue qr rId
|
|
pure $ IDS (qik ids)
|
|
|
|
logCreateById :: StoreLog 'WriteMode -> RecipientId -> IO ()
|
|
logCreateById s rId =
|
|
atomically (getQueue st SRecipient rId) >>= \case
|
|
Right q -> logCreateQueue s q
|
|
_ -> pure ()
|
|
|
|
getIds :: M (RecipientId, SenderId)
|
|
getIds = do
|
|
n <- asks $ queueIdBytes . config
|
|
liftM2 (,) (randomId n) (randomId n)
|
|
|
|
secureQueue_ :: QueueStore -> SndPublicAuthKey -> M (Transmission BrokerMsg)
|
|
secureQueue_ st sKey = time "KEY" $ do
|
|
withLog $ \s -> logSecureQueue s queueId sKey
|
|
stats <- asks serverStats
|
|
atomically $ modifyTVar' (qSecured stats) (+ 1)
|
|
atomically $ (corrId,queueId,) . either ERR (const OK) <$> secureQueue st queueId sKey
|
|
|
|
addQueueNotifier_ :: QueueStore -> NtfPublicAuthKey -> RcvNtfPublicDhKey -> M (Transmission BrokerMsg)
|
|
addQueueNotifier_ st notifierKey dhKey = time "NKEY" $ do
|
|
(rcvPublicDhKey, privDhKey) <- atomically . C.generateKeyPair =<< asks random
|
|
let rcvNtfDhSecret = C.dh' dhKey privDhKey
|
|
(corrId,queueId,) <$> addNotifierRetry 3 rcvPublicDhKey rcvNtfDhSecret
|
|
where
|
|
addNotifierRetry :: Int -> RcvNtfPublicDhKey -> RcvNtfDhSecret -> M BrokerMsg
|
|
addNotifierRetry 0 _ _ = pure $ ERR INTERNAL
|
|
addNotifierRetry n rcvPublicDhKey rcvNtfDhSecret = do
|
|
notifierId <- randomId =<< asks (queueIdBytes . config)
|
|
let ntfCreds = NtfCreds {notifierId, notifierKey, rcvNtfDhSecret}
|
|
atomically (addQueueNotifier st queueId ntfCreds) >>= \case
|
|
Left DUPLICATE_ -> addNotifierRetry (n - 1) rcvPublicDhKey rcvNtfDhSecret
|
|
Left e -> pure $ ERR e
|
|
Right _ -> do
|
|
withLog $ \s -> logAddNotifier s queueId ntfCreds
|
|
pure $ NID notifierId rcvPublicDhKey
|
|
|
|
deleteQueueNotifier_ :: QueueStore -> M (Transmission BrokerMsg)
|
|
deleteQueueNotifier_ st = do
|
|
withLog (`logDeleteNotifier` queueId)
|
|
okResp <$> atomically (deleteQueueNotifier st queueId)
|
|
|
|
suspendQueue_ :: QueueStore -> M (Transmission BrokerMsg)
|
|
suspendQueue_ st = do
|
|
withLog (`logSuspendQueue` queueId)
|
|
okResp <$> atomically (suspendQueue st queueId)
|
|
|
|
subscribeQueue :: QueueRec -> RecipientId -> M (Transmission BrokerMsg)
|
|
subscribeQueue qr rId = do
|
|
stats <- asks serverStats
|
|
atomically (TM.lookup rId subscriptions) >>= \case
|
|
Nothing -> do
|
|
atomically $ modifyTVar' (qSub stats) (+ 1)
|
|
newSub >>= deliver
|
|
Just sub ->
|
|
readTVarIO sub >>= \case
|
|
Sub {subThread = ProhibitSub} -> do
|
|
-- cannot use SUB in the same connection where GET was used
|
|
atomically $ modifyTVar' (qSubProhibited stats) (+ 1)
|
|
pure (corrId, rId, ERR $ CMD PROHIBITED)
|
|
s -> do
|
|
atomically $ modifyTVar' (qSubDuplicate stats) (+ 1)
|
|
atomically (tryTakeTMVar $ delivered s) >> deliver sub
|
|
where
|
|
newSub :: M (TVar Sub)
|
|
newSub = time "SUB newSub" . atomically $ do
|
|
writeTQueue subscribedQ (rId, clnt)
|
|
sub <- newTVar =<< newSubscription NoSub
|
|
TM.insert rId sub subscriptions
|
|
pure sub
|
|
deliver :: TVar Sub -> M (Transmission BrokerMsg)
|
|
deliver sub = do
|
|
q <- getStoreMsgQueue "SUB" rId
|
|
msg_ <- atomically $ tryPeekMsg q
|
|
deliverMessage "SUB" qr rId sub q msg_
|
|
|
|
getMessage :: QueueRec -> M (Transmission BrokerMsg)
|
|
getMessage qr = time "GET" $ do
|
|
atomically (TM.lookup queueId subscriptions) >>= \case
|
|
Nothing ->
|
|
atomically newSub >>= getMessage_
|
|
Just sub ->
|
|
readTVarIO sub >>= \case
|
|
s@Sub {subThread = ProhibitSub} ->
|
|
atomically (tryTakeTMVar $ delivered s)
|
|
>> getMessage_ s
|
|
-- cannot use GET in the same connection where there is an active subscription
|
|
_ -> pure (corrId, queueId, ERR $ CMD PROHIBITED)
|
|
where
|
|
newSub :: STM Sub
|
|
newSub = do
|
|
s <- newSubscription ProhibitSub
|
|
sub <- newTVar s
|
|
TM.insert queueId sub subscriptions
|
|
pure s
|
|
getMessage_ :: Sub -> M (Transmission BrokerMsg)
|
|
getMessage_ s = do
|
|
q <- getStoreMsgQueue "GET" queueId
|
|
atomically $
|
|
tryPeekMsg q >>= \case
|
|
Just msg ->
|
|
let encMsg = encryptMsg qr msg
|
|
in setDelivered s msg $> (corrId, queueId, MSG encMsg)
|
|
_ -> pure (corrId, queueId, OK)
|
|
|
|
withQueue :: (QueueRec -> M (Transmission BrokerMsg)) -> M (Transmission BrokerMsg)
|
|
withQueue action = maybe (pure $ err AUTH) action qr_
|
|
|
|
subscribeNotifications :: M (Transmission BrokerMsg)
|
|
subscribeNotifications = time "NSUB" . atomically $ do
|
|
unlessM (TM.member queueId ntfSubscriptions) $ do
|
|
writeTQueue ntfSubscribedQ (queueId, clnt)
|
|
TM.insert queueId () ntfSubscriptions
|
|
pure ok
|
|
|
|
acknowledgeMsg :: QueueRec -> MsgId -> M (Transmission BrokerMsg)
|
|
acknowledgeMsg qr msgId = time "ACK" $ do
|
|
atomically (TM.lookup queueId subscriptions) >>= \case
|
|
Nothing -> pure $ err NO_MSG
|
|
Just sub ->
|
|
atomically (getDelivered sub) >>= \case
|
|
Just s -> do
|
|
q <- getStoreMsgQueue "ACK" queueId
|
|
case s of
|
|
Sub {subThread = ProhibitSub} -> do
|
|
deletedMsg_ <- atomically $ tryDelMsg q msgId
|
|
mapM_ updateStats deletedMsg_
|
|
pure ok
|
|
_ -> do
|
|
(deletedMsg_, msg_) <- atomically $ tryDelPeekMsg q msgId
|
|
mapM_ updateStats deletedMsg_
|
|
deliverMessage "ACK" qr queueId sub q msg_
|
|
_ -> pure $ err NO_MSG
|
|
where
|
|
getDelivered :: TVar Sub -> STM (Maybe Sub)
|
|
getDelivered sub = do
|
|
s@Sub {delivered} <- readTVar sub
|
|
tryTakeTMVar delivered $>>= \msgId' ->
|
|
if msgId == msgId' || B.null msgId
|
|
then pure $ Just s
|
|
else putTMVar delivered msgId' $> Nothing
|
|
updateStats :: Message -> M ()
|
|
updateStats = \case
|
|
MessageQuota {} -> pure ()
|
|
Message {msgFlags} -> do
|
|
stats <- asks serverStats
|
|
atomically $ modifyTVar' (msgRecv stats) (+ 1)
|
|
atomically $ modifyTVar' (msgCount stats) (subtract 1)
|
|
atomically $ updatePeriodStats (activeQueues stats) queueId
|
|
when (notification msgFlags) $ do
|
|
atomically $ modifyTVar' (msgRecvNtf stats) (+ 1)
|
|
atomically $ updatePeriodStats (activeQueuesNtf stats) queueId
|
|
|
|
sendMessage :: QueueRec -> MsgFlags -> MsgBody -> M (Transmission BrokerMsg)
|
|
sendMessage qr msgFlags msgBody
|
|
| B.length msgBody > maxMessageLength thVersion = do
|
|
stats <- asks serverStats
|
|
atomically $ modifyTVar' (msgSentLarge stats) (+ 1)
|
|
pure $ err LARGE_MSG
|
|
| otherwise = do
|
|
stats <- asks serverStats
|
|
case status qr of
|
|
QueueOff -> do
|
|
atomically $ modifyTVar' (msgSentAuth stats) (+ 1)
|
|
pure $ err AUTH
|
|
QueueActive ->
|
|
case C.maxLenBS msgBody of
|
|
Left _ -> pure $ err LARGE_MSG
|
|
Right body -> do
|
|
msg_ <- time "SEND" $ do
|
|
q <- getStoreMsgQueue "SEND" $ recipientId qr
|
|
expireMessages q
|
|
atomically . writeMsg q =<< mkMessage body
|
|
case msg_ of
|
|
Nothing -> do
|
|
atomically $ modifyTVar' (msgSentQuota stats) (+ 1)
|
|
pure $ err QUOTA
|
|
Just msg -> time "SEND ok" $ do
|
|
when (notification msgFlags) $ do
|
|
atomically . trySendNotification msg =<< asks random
|
|
atomically $ modifyTVar' (msgSentNtf stats) (+ 1)
|
|
atomically $ updatePeriodStats (activeQueuesNtf stats) (recipientId qr)
|
|
atomically $ modifyTVar' (msgSent stats) (+ 1)
|
|
atomically $ modifyTVar' (msgCount stats) (+ 1)
|
|
atomically $ updatePeriodStats (activeQueues stats) (recipientId qr)
|
|
pure ok
|
|
where
|
|
THandleParams {thVersion} = thParams'
|
|
mkMessage :: C.MaxLenBS MaxMessageLen -> M Message
|
|
mkMessage body = do
|
|
msgId <- randomId =<< asks (msgIdBytes . config)
|
|
msgTs <- liftIO getSystemTime
|
|
pure $ Message msgId msgTs msgFlags body
|
|
|
|
expireMessages :: MsgQueue -> M ()
|
|
expireMessages q = do
|
|
msgExp <- asks $ messageExpiration . config
|
|
old <- liftIO $ mapM expireBeforeEpoch msgExp
|
|
stats <- asks serverStats
|
|
deleted <- atomically $ sum <$> mapM (deleteExpiredMsgs q) old
|
|
atomically $ modifyTVar' (msgExpired stats) (+ deleted)
|
|
|
|
trySendNotification :: Message -> TVar ChaChaDRG -> STM ()
|
|
trySendNotification msg ntfNonceDrg =
|
|
forM_ (notifier qr) $ \NtfCreds {notifierId, rcvNtfDhSecret} ->
|
|
mapM_ (writeNtf notifierId msg rcvNtfDhSecret ntfNonceDrg) =<< TM.lookup notifierId notifiers
|
|
|
|
writeNtf :: NotifierId -> Message -> RcvNtfDhSecret -> TVar ChaChaDRG -> Client -> STM ()
|
|
writeNtf nId msg rcvNtfDhSecret ntfNonceDrg Client {sndQ = q} =
|
|
unlessM (isFullTBQueue q) $ case msg of
|
|
Message {msgId, msgTs} -> do
|
|
(nmsgNonce, encNMsgMeta) <- mkMessageNotification msgId msgTs rcvNtfDhSecret ntfNonceDrg
|
|
writeTBQueue q [(CorrId "", nId, NMSG nmsgNonce encNMsgMeta)]
|
|
_ -> pure ()
|
|
|
|
mkMessageNotification :: ByteString -> SystemTime -> RcvNtfDhSecret -> TVar ChaChaDRG -> STM (C.CbNonce, EncNMsgMeta)
|
|
mkMessageNotification msgId msgTs rcvNtfDhSecret ntfNonceDrg = do
|
|
cbNonce <- C.randomCbNonce ntfNonceDrg
|
|
let msgMeta = NMsgMeta {msgId, msgTs}
|
|
encNMsgMeta = C.cbEncrypt rcvNtfDhSecret cbNonce (smpEncode msgMeta) 128
|
|
pure . (cbNonce,) $ fromRight "" encNMsgMeta
|
|
|
|
processForwardedCommand :: EncFwdTransmission -> M BrokerMsg
|
|
processForwardedCommand (EncFwdTransmission s) = fmap (either ERR id) . runExceptT $ do
|
|
THAuthServer {serverPrivKey, sessSecret'} <- maybe (throwE $ transportErr TENoServerAuth) pure (thAuth thParams')
|
|
sessSecret <- maybe (throwE $ transportErr TENoServerAuth) pure sessSecret'
|
|
let proxyNonce = C.cbNonce $ bs corrId
|
|
s' <- liftEitherWith (const CRYPTO) $ C.cbDecryptNoPad sessSecret proxyNonce s
|
|
FwdTransmission {fwdCorrId, fwdVersion, fwdKey, fwdTransmission = EncTransmission et} <- liftEitherWith (const $ CMD SYNTAX) $ smpDecode s'
|
|
let clientSecret = C.dh' fwdKey serverPrivKey
|
|
clientNonce = C.cbNonce $ bs fwdCorrId
|
|
b <- liftEitherWith (const CRYPTO) $ C.cbDecrypt clientSecret clientNonce et
|
|
let clntTHParams = smpTHParamsSetVersion fwdVersion thParams'
|
|
-- only allowing single forwarded transactions
|
|
t' <- case tParse clntTHParams b of
|
|
t :| [] -> pure $ tDecodeParseValidate clntTHParams t
|
|
_ -> throwE BLOCK
|
|
let clntThAuth = Just $ THAuthServer {serverPrivKey, sessSecret' = Just clientSecret}
|
|
-- process forwarded SEND
|
|
r <-
|
|
lift (rejectOrVerify clntThAuth t') >>= \case
|
|
Left r -> pure r
|
|
Right t''@(_, (corrId', entId', cmd')) -> case cmd' of
|
|
Cmd SSender SEND {} ->
|
|
-- Left will not be returned by processCommand, as only SEND command is allowed
|
|
fromMaybe (corrId', entId', ERR INTERNAL) <$> lift (processCommand t'')
|
|
_ ->
|
|
pure (corrId', entId', ERR $ CMD PROHIBITED)
|
|
-- encode response
|
|
r' <- case batchTransmissions (batch clntTHParams) (blockSize clntTHParams) [Right (Nothing, encodeTransmission clntTHParams r)] of
|
|
[] -> throwE INTERNAL -- at least 1 item is guaranteed from NonEmpty/Right
|
|
TBError _ _ : _ -> throwE BLOCK
|
|
TBTransmission b' _ : _ -> pure b'
|
|
TBTransmissions b' _ _ : _ -> pure b'
|
|
-- encrypt to client
|
|
r2 <- liftEitherWith (const BLOCK) $ EncResponse <$> C.cbEncrypt clientSecret (C.reverseNonce clientNonce) r' paddedProxiedMsgLength
|
|
-- encrypt to proxy
|
|
let fr = FwdResponse {fwdCorrId, fwdResponse = r2}
|
|
r3 = EncFwdResponse $ C.cbEncryptNoPad sessSecret (C.reverseNonce proxyNonce) (smpEncode fr)
|
|
stats <- asks serverStats
|
|
atomically $ modifyTVar' (pMsgFwdsRecv stats) (+ 1)
|
|
pure $ RRES r3
|
|
where
|
|
rejectOrVerify :: Maybe (THandleAuth 'TServer) -> SignedTransmission ErrorType Cmd -> M (Either (Transmission BrokerMsg) (Maybe QueueRec, Transmission Cmd))
|
|
rejectOrVerify clntThAuth (tAuth, authorized, (corrId', entId', cmdOrError)) =
|
|
case cmdOrError of
|
|
Left e -> pure $ Left (corrId', entId', ERR e)
|
|
Right cmd'@(Cmd SSender SEND {}) -> verified <$> verifyTransmission ((,C.cbNonce (bs corrId')) <$> clntThAuth) tAuth authorized entId' cmd'
|
|
where
|
|
verified = \case
|
|
VRVerified qr -> Right (qr, (corrId', entId', cmd'))
|
|
VRFailed -> Left (corrId', entId', ERR AUTH)
|
|
Right _ -> pure $ Left (corrId', entId', ERR $ CMD PROHIBITED)
|
|
|
|
deliverMessage :: T.Text -> QueueRec -> RecipientId -> TVar Sub -> MsgQueue -> Maybe Message -> M (Transmission BrokerMsg)
|
|
deliverMessage name qr rId sub q msg_ = time (name <> " deliver") $ do
|
|
readTVarIO sub >>= \case
|
|
s@Sub {subThread = NoSub} ->
|
|
case msg_ of
|
|
Just msg ->
|
|
let encMsg = encryptMsg qr msg
|
|
in atomically (setDelivered s msg) $> (corrId, rId, MSG encMsg)
|
|
_ -> forkSub $> resp
|
|
_ -> pure resp
|
|
where
|
|
resp = (corrId, rId, OK)
|
|
forkSub :: M ()
|
|
forkSub = do
|
|
atomically . modifyTVar' sub $ \s -> s {subThread = SubPending}
|
|
t <- mkWeakThreadId =<< forkIO subscriber
|
|
atomically . modifyTVar' sub $ \case
|
|
s@Sub {subThread = SubPending} -> s {subThread = SubThread t}
|
|
s -> s
|
|
where
|
|
subscriber = do
|
|
labelMyThread $ B.unpack ("client $" <> encode sessionId) <> " subscriber/" <> T.unpack name
|
|
msg <- atomically $ peekMsg q
|
|
time "subscriber" . atomically $ do
|
|
let encMsg = encryptMsg qr msg
|
|
writeTBQueue sndQ [(CorrId "", rId, MSG encMsg)]
|
|
s <- readTVar sub
|
|
void $ setDelivered s msg
|
|
writeTVar sub $! s {subThread = NoSub}
|
|
|
|
time :: T.Text -> M a -> M a
|
|
time name = timed name queueId
|
|
|
|
encryptMsg :: QueueRec -> Message -> RcvMessage
|
|
encryptMsg qr msg = encrypt . encodeRcvMsgBody $ case msg of
|
|
Message {msgFlags, msgBody} -> RcvMsgBody {msgTs = msgTs', msgFlags, msgBody}
|
|
MessageQuota {} -> RcvMsgQuota msgTs'
|
|
where
|
|
encrypt :: KnownNat i => C.MaxLenBS i -> RcvMessage
|
|
encrypt body = RcvMessage msgId' . EncRcvMsgBody $ C.cbEncryptMaxLenBS (rcvDhSecret qr) (C.cbNonce msgId') body
|
|
msgId' = messageId msg
|
|
msgTs' = messageTs msg
|
|
|
|
setDelivered :: Sub -> Message -> STM Bool
|
|
setDelivered s msg = tryPutTMVar (delivered s) (messageId msg)
|
|
|
|
getStoreMsgQueue :: T.Text -> RecipientId -> M MsgQueue
|
|
getStoreMsgQueue name rId = time (name <> " getMsgQueue") $ do
|
|
ms <- asks msgStore
|
|
quota <- asks $ msgQueueQuota . config
|
|
atomically $ getMsgQueue ms rId quota
|
|
|
|
delQueueAndMsgs :: QueueStore -> M (Transmission BrokerMsg)
|
|
delQueueAndMsgs st = do
|
|
withLog (`logDeleteQueue` queueId)
|
|
ms <- asks msgStore
|
|
atomically (deleteQueue st queueId $>>= \q -> delMsgQueue ms queueId $> Right q) >>= \case
|
|
Right q -> updateDeletedStats q $> ok
|
|
Left e -> pure $ err e
|
|
|
|
ok :: Transmission BrokerMsg
|
|
ok = (corrId, queueId, OK)
|
|
|
|
err :: ErrorType -> Transmission BrokerMsg
|
|
err e = (corrId, queueId, ERR e)
|
|
|
|
okResp :: Either ErrorType () -> Transmission BrokerMsg
|
|
okResp = either err $ const ok
|
|
|
|
updateDeletedStats :: QueueRec -> M ()
|
|
updateDeletedStats q = do
|
|
stats <- asks serverStats
|
|
let delSel = if isNothing (senderKey q) then qDeletedNew else qDeletedSecured
|
|
atomically $ modifyTVar' (delSel stats) (+ 1)
|
|
atomically $ modifyTVar' (qDeletedAll stats) (+ 1)
|
|
atomically $ modifyTVar' (qCount stats) (subtract 1)
|
|
|
|
withLog :: (StoreLog 'WriteMode -> IO a) -> M ()
|
|
withLog action = do
|
|
env <- ask
|
|
liftIO . mapM_ action $ storeLog (env :: Env)
|
|
|
|
timed :: T.Text -> RecipientId -> M a -> M a
|
|
timed name qId a = do
|
|
t <- liftIO getSystemTime
|
|
r <- a
|
|
t' <- liftIO getSystemTime
|
|
let int = diff t t'
|
|
when (int > sec) . logDebug $ T.unwords [name, tshow $ encode qId, tshow int]
|
|
pure r
|
|
where
|
|
diff t t' = (systemSeconds t' - systemSeconds t) * sec + fromIntegral (systemNanoseconds t' - systemNanoseconds t)
|
|
sec = 1000_000000
|
|
|
|
randomId :: Int -> M ByteString
|
|
randomId n = atomically . C.randomBytes n =<< asks random
|
|
|
|
saveServerMessages :: Bool -> M ()
|
|
saveServerMessages keepMsgs = asks (storeMsgsFile . config) >>= mapM_ saveMessages
|
|
where
|
|
saveMessages f = do
|
|
logInfo $ "saving messages to file " <> T.pack f
|
|
ms <- asks msgStore
|
|
liftIO . withFile f WriteMode $ \h ->
|
|
readTVarIO ms >>= mapM_ (saveQueueMsgs ms h) . M.keys
|
|
logInfo "messages saved"
|
|
where
|
|
getMessages = if keepMsgs then snapshotMsgQueue else flushMsgQueue
|
|
saveQueueMsgs ms h rId =
|
|
atomically (getMessages ms rId)
|
|
>>= mapM_ (B.hPutStrLn h . strEncode . MLRv3 rId)
|
|
|
|
restoreServerMessages :: M Int
|
|
restoreServerMessages =
|
|
asks (storeMsgsFile . config) >>= \case
|
|
Just f -> ifM (doesFileExist f) (restoreMessages f) (pure 0)
|
|
Nothing -> pure 0
|
|
where
|
|
restoreMessages f = do
|
|
logInfo $ "restoring messages from file " <> T.pack f
|
|
ms <- asks msgStore
|
|
quota <- asks $ msgQueueQuota . config
|
|
old_ <- asks (messageExpiration . config) $>>= (liftIO . fmap Just . expireBeforeEpoch)
|
|
runExceptT (liftIO (LB.readFile f) >>= foldM (\expired -> restoreMsg expired ms quota old_) 0 . LB.lines) >>= \case
|
|
Left e -> do
|
|
logError . T.pack $ "error restoring messages: " <> e
|
|
liftIO exitFailure
|
|
Right expired -> do
|
|
renameFile f $ f <> ".bak"
|
|
logInfo "messages restored"
|
|
pure expired
|
|
where
|
|
restoreMsg !expired ms quota old_ s' = do
|
|
MLRv3 rId msg <- liftEither . first (msgErr "parsing") $ strDecode s
|
|
addToMsgQueue rId msg
|
|
where
|
|
s = LB.toStrict s'
|
|
addToMsgQueue rId msg = do
|
|
(isExpired, logFull) <- atomically $ do
|
|
q <- getMsgQueue ms rId quota
|
|
case msg of
|
|
Message {msgTs}
|
|
| maybe True (systemSeconds msgTs >=) old_ -> (False,) . isNothing <$> writeMsg q msg
|
|
| otherwise -> pure (True, False)
|
|
MessageQuota {} -> writeMsg q msg $> (False, False)
|
|
when logFull . logError . decodeLatin1 $ "message queue " <> strEncode rId <> " is full, message not restored: " <> strEncode (messageId msg)
|
|
pure $ if isExpired then expired + 1 else expired
|
|
msgErr :: Show e => String -> e -> String
|
|
msgErr op e = op <> " error (" <> show e <> "): " <> B.unpack (B.take 100 s)
|
|
|
|
saveServerStats :: M ()
|
|
saveServerStats =
|
|
asks (serverStatsBackupFile . config)
|
|
>>= mapM_ (\f -> asks serverStats >>= atomically . getServerStatsData >>= liftIO . saveStats f)
|
|
where
|
|
saveStats f stats = do
|
|
logInfo $ "saving server stats to file " <> T.pack f
|
|
B.writeFile f $ strEncode stats
|
|
logInfo "server stats saved"
|
|
|
|
restoreServerStats :: Int -> M ()
|
|
restoreServerStats expiredWhileRestoring = asks (serverStatsBackupFile . config) >>= mapM_ restoreStats
|
|
where
|
|
restoreStats f = whenM (doesFileExist f) $ do
|
|
logInfo $ "restoring server stats from file " <> T.pack f
|
|
liftIO (strDecode <$> B.readFile f) >>= \case
|
|
Right d@ServerStatsData {_qCount = statsQCount} -> do
|
|
s <- asks serverStats
|
|
_qCount <- fmap M.size . readTVarIO . queues =<< asks queueStore
|
|
_msgCount <- foldM (\(!n) q -> (n +) <$> readTVarIO (size q)) 0 =<< readTVarIO =<< asks msgStore
|
|
atomically $ setServerStats s d {_qCount, _msgCount, _msgExpired = _msgExpired d + expiredWhileRestoring}
|
|
renameFile f $ f <> ".bak"
|
|
logInfo "server stats restored"
|
|
when (_qCount /= statsQCount) $ logWarn $ "Queue count differs: stats: " <> tshow statsQCount <> ", store: " <> tshow _qCount
|
|
logInfo $ "Restored " <> tshow _msgCount <> " messages in " <> tshow _qCount <> " queues"
|
|
Left e -> do
|
|
logInfo $ "error restoring server stats: " <> T.pack e
|
|
liftIO exitFailure
|