Files
simplexmq/src/Simplex/Messaging/Server.hs
Evgeny Poberezkin 4fae7dcaee server: control port (#804)
* server: control port

* do not remove messages when saving via control port

* remove unused record fields

* fix tests
2023-07-15 13:33:00 +01:00

863 lines
38 KiB
Haskell

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
-- |
-- Module : Simplex.Messaging.Server
-- Copyright : (c) simplex.chat
-- License : AGPL-3
--
-- Maintainer : chat@simplex.chat
-- Stability : experimental
-- Portability : non-portable
--
-- This module defines SMP protocol server with in-memory persistence
-- and optional append only log of SMP queue records.
--
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md
module Simplex.Messaging.Server
( runSMPServer,
runSMPServerBlocking,
disconnectTransport,
verifyCmdSignature,
dummyVerifyCmd,
)
where
import Control.Logger.Simple
import Control.Monad
import Control.Monad.Except
import Control.Monad.IO.Unlift
import Control.Monad.Reader
import Crypto.Random
import Data.Bifunctor (first)
import Data.ByteString.Base64 (encode)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Either (fromRight, partitionEithers)
import Data.Functor (($>))
import Data.Int (Int64)
import Data.List (intercalate)
import qualified Data.List.NonEmpty as L
import qualified Data.Map.Strict as M
import Data.Maybe (isNothing)
import qualified Data.Text as T
import Data.Text.Encoding (decodeLatin1)
import Data.Time.Clock (UTCTime (..), diffTimeToPicoseconds, getCurrentTime)
import Data.Time.Clock.System (SystemTime (..), getSystemTime)
import Data.Time.Format.ISO8601 (iso8601Show)
import Data.Type.Equality
import GHC.TypeLits (KnownNat)
import Network.Socket (ServiceName, Socket, socketToHandle)
import Simplex.Messaging.Agent.Lock
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding (Encoding (smpEncode))
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Protocol
import Simplex.Messaging.Server.Control
import Simplex.Messaging.Server.Env.STM
import Simplex.Messaging.Server.Expiration
import Simplex.Messaging.Server.MsgStore
import Simplex.Messaging.Server.MsgStore.STM
import Simplex.Messaging.Server.QueueStore
import Simplex.Messaging.Server.QueueStore.STM
import Simplex.Messaging.Server.Stats
import Simplex.Messaging.Server.StoreLog
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport
import Simplex.Messaging.Transport.Buffer (trimCR)
import Simplex.Messaging.Transport.Server
import Simplex.Messaging.Util
import System.Exit (exitFailure)
import System.IO (hPutStrLn, hSetNewlineMode, universalNewlineMode)
import System.Mem.Weak (deRefWeak)
import UnliftIO.Concurrent
import UnliftIO.Directory (doesFileExist, renameFile)
import UnliftIO.Exception
import UnliftIO.IO
import UnliftIO.STM
-- | Runs an SMP server using passed configuration.
--
-- See a full server here: https://github.com/simplex-chat/simplexmq/blob/master/apps/smp-server/Main.hs
runSMPServer :: ServerConfig -> IO ()
runSMPServer cfg = do
started <- newEmptyTMVarIO
runSMPServerBlocking started cfg
-- | Runs an SMP server using passed configuration with signalling.
--
-- This function uses passed TMVar to signal when the server is ready to accept TCP requests (True)
-- and when it is disconnected from the TCP socket once the server thread is killed (False).
runSMPServerBlocking :: TMVar Bool -> ServerConfig -> IO ()
runSMPServerBlocking started cfg = newEnv cfg >>= runReaderT (smpServer started cfg)
type M a = ReaderT Env IO a
smpServer :: TMVar Bool -> ServerConfig -> M ()
smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do
s <- asks server
restoreServerMessages
restoreServerStats
raceAny_
( serverThread s subscribedQ subscribers subscriptions cancelSub :
serverThread s ntfSubscribedQ notifiers ntfSubscriptions (\_ -> pure ()) :
map runServer transports <> expireMessagesThread_ cfg <> serverStatsThread_ cfg <> controlPortThread_ cfg
)
`finally` withLock (savingLock s) "final" (saveServer False)
where
runServer :: (ServiceName, ATransport) -> M ()
runServer (tcpPort, ATransport t) = do
serverParams <- asks tlsServerParams
runTransportServer started tcpPort serverParams tCfg (runClient t)
saveServer :: Bool -> M ()
saveServer keepMsgs = withLog closeStoreLog >> saveServerMessages keepMsgs >> saveServerStats
serverThread ::
forall s.
Server ->
(Server -> TQueue (QueueId, Client)) ->
(Server -> TMap QueueId Client) ->
(Client -> TMap QueueId s) ->
(s -> IO ()) ->
M ()
serverThread s subQ subs clientSubs unsub = forever $ do
atomically updateSubscribers
$>>= endPreviousSubscriptions
>>= liftIO . mapM_ unsub
where
updateSubscribers :: STM (Maybe (QueueId, Client))
updateSubscribers = do
(qId, clnt) <- readTQueue $ subQ s
let clientToBeNotified = \c' ->
if sameClientSession clnt c'
then pure Nothing
else do
yes <- readTVar $ connected c'
pure $ if yes then Just (qId, c') else Nothing
TM.lookupInsert qId clnt (subs s) $>>= clientToBeNotified
endPreviousSubscriptions :: (QueueId, Client) -> M (Maybe s)
endPreviousSubscriptions (qId, c) = do
void . forkIO . atomically $
writeTBQueue (sndQ c) [(CorrId "", qId, END)]
atomically $ TM.lookupDelete qId (clientSubs c)
expireMessagesThread_ :: ServerConfig -> [M ()]
expireMessagesThread_ ServerConfig {messageExpiration = Just msgExp} = [expireMessages msgExp]
expireMessagesThread_ _ = []
expireMessages :: ExpirationConfig -> M ()
expireMessages expCfg = do
ms <- asks msgStore
quota <- asks $ msgQueueQuota . config
let interval = checkInterval expCfg * 1000000
forever $ do
liftIO $ threadDelay' interval
old <- liftIO $ expireBeforeEpoch expCfg
rIds <- M.keysSet <$> readTVarIO ms
forM_ rIds $ \rId ->
atomically (getMsgQueue ms rId quota)
>>= atomically . (`deleteExpiredMsgs` old)
serverStatsThread_ :: ServerConfig -> [M ()]
serverStatsThread_ ServerConfig {logStatsInterval = Just interval, logStatsStartTime, serverStatsLogFile} =
[logServerStats logStatsStartTime interval serverStatsLogFile]
serverStatsThread_ _ = []
logServerStats :: Int64 -> Int64 -> FilePath -> M ()
logServerStats startAt logInterval statsFilePath = do
initialDelay <- (startAt -) . fromIntegral . (`div` 1000000_000000) . diffTimeToPicoseconds . utctDayTime <$> liftIO getCurrentTime
liftIO $ putStrLn $ "server stats log enabled: " <> statsFilePath
liftIO $ threadDelay' $ 1000000 * (initialDelay + if initialDelay < 0 then 86400 else 0)
ServerStats {fromTime, qCreated, qSecured, qDeleted, msgSent, msgRecv, activeQueues, msgSentNtf, msgRecvNtf, activeQueuesNtf, qCount, msgCount} <- asks serverStats
let interval = 1000000 * logInterval
withFile statsFilePath AppendMode $ \h -> liftIO $ do
hSetBuffering h LineBuffering
forever $ do
ts <- getCurrentTime
fromTime' <- atomically $ swapTVar fromTime ts
qCreated' <- atomically $ swapTVar qCreated 0
qSecured' <- atomically $ swapTVar qSecured 0
qDeleted' <- atomically $ swapTVar qDeleted 0
msgSent' <- atomically $ swapTVar msgSent 0
msgRecv' <- atomically $ swapTVar msgRecv 0
ps <- atomically $ periodStatCounts activeQueues ts
msgSentNtf' <- atomically $ swapTVar msgSentNtf 0
msgRecvNtf' <- atomically $ swapTVar msgRecvNtf 0
psNtf <- atomically $ periodStatCounts activeQueuesNtf ts
qCount' <- readTVarIO qCount
msgCount' <- readTVarIO msgCount
hPutStrLn h $
intercalate
","
[ iso8601Show $ utctDay fromTime',
show qCreated',
show qSecured',
show qDeleted',
show msgSent',
show msgRecv',
dayCount ps,
weekCount ps,
monthCount ps,
show msgSentNtf',
show msgRecvNtf',
dayCount psNtf,
weekCount psNtf,
monthCount psNtf,
show qCount',
show msgCount'
]
threadDelay' interval
runClient :: Transport c => TProxy c -> c -> M ()
runClient _ h = do
kh <- asks serverIdentity
smpVRange <- asks $ smpServerVRange . config
liftIO (runExceptT $ smpServerHandshake h kh smpVRange) >>= \case
Right th -> runClientTransport th
Left _ -> pure ()
controlPortThread_ :: ServerConfig -> [M ()]
controlPortThread_ ServerConfig {controlPort = Just port} = [runCPServer port]
controlPortThread_ _ = []
runCPServer :: ServiceName -> M ()
runCPServer port = do
srv <- asks server
cpStarted <- newEmptyTMVarIO
u <- askUnliftIO
liftIO $ runTCPServer cpStarted port $ runCPClient u srv
where
runCPClient :: UnliftIO (ReaderT Env IO) -> Server -> Socket -> IO ()
runCPClient u srv sock = do
h <- socketToHandle sock ReadWriteMode
hSetBuffering h LineBuffering
hSetNewlineMode h universalNewlineMode
hPutStrLn h "SMP server control port\n'help' for supported commands"
cpLoop h
where
cpLoop h = do
s <- B.hGetLine h
case strDecode $ trimCR s of
Right CPQuit -> hClose h
Right cmd -> processCP h cmd >> cpLoop h
Left err -> hPutStrLn h ("error: " <> err) >> cpLoop h
processCP h = \case
CPSuspend -> hPutStrLn h "suspend not implemented"
CPResume -> hPutStrLn h "resume not implemented"
CPClients -> hPutStrLn h "clients not implemented"
CPStats -> do
ServerStats {fromTime, qCreated, qSecured, qDeleted, msgSent, msgRecv, msgSentNtf, msgRecvNtf, qCount, msgCount} <- unliftIO u $ asks serverStats
putStat "fromTime" fromTime
putStat "qCreated" qCreated
putStat "qSecured" qSecured
putStat "qDeleted" qDeleted
putStat "msgSent" msgSent
putStat "msgRecv" msgRecv
putStat "msgSentNtf" msgSentNtf
putStat "msgRecvNtf" msgRecvNtf
putStat "qCount" qCount
putStat "msgCount" msgCount
where
putStat :: Show a => String -> TVar a -> IO ()
putStat label var = readTVarIO var >>= \v -> hPutStrLn h $ label <> ": " <> show v
CPSave -> withLock (savingLock srv) "control" $ do
hPutStrLn h "saving server state..."
unliftIO u $ saveServer True
hPutStrLn h "server state saved!"
CPHelp -> hPutStrLn h "commands: stats, save, help, quit"
CPQuit -> pure ()
runClientTransport :: Transport c => THandle c -> M ()
runClientTransport th@THandle {thVersion, sessionId} = do
q <- asks $ tbqSize . config
ts <- liftIO getSystemTime
c <- atomically $ newClient q thVersion sessionId ts
s <- asks server
expCfg <- asks $ inactiveClientExpiration . config
raceAny_ ([liftIO $ send th c, client c s, receive th c] <> disconnectThread_ c expCfg)
`finally` clientDisconnected c
where
disconnectThread_ c (Just expCfg) = [liftIO $ disconnectTransport th c activeAt expCfg]
disconnectThread_ _ _ = []
clientDisconnected :: Client -> M ()
clientDisconnected c@Client {subscriptions, connected} = do
atomically $ writeTVar connected False
subs <- readTVarIO subscriptions
liftIO $ mapM_ cancelSub subs
atomically $ writeTVar subscriptions M.empty
cs <- asks $ subscribers . server
atomically . mapM_ (\rId -> TM.update deleteCurrentClient rId cs) $ M.keys subs
where
deleteCurrentClient :: Client -> Maybe Client
deleteCurrentClient c'
| sameClientSession c c' = Nothing
| otherwise = Just c'
sameClientSession :: Client -> Client -> Bool
sameClientSession Client {sessionId} Client {sessionId = s'} = sessionId == s'
cancelSub :: TVar Sub -> IO ()
cancelSub sub =
readTVarIO sub >>= \case
Sub {subThread = SubThread t} -> liftIO $ deRefWeak t >>= mapM_ killThread
_ -> return ()
receive :: Transport c => THandle c -> Client -> M ()
receive th Client {rcvQ, sndQ, activeAt} = forever $ do
ts <- L.toList <$> liftIO (tGet th)
atomically . writeTVar activeAt =<< liftIO getSystemTime
as <- partitionEithers <$> mapM cmdAction ts
write sndQ $ fst as
write rcvQ $ snd as
where
cmdAction :: SignedTransmission ErrorType Cmd -> M (Either (Transmission BrokerMsg) (Maybe QueueRec, Transmission Cmd))
cmdAction (sig, signed, (corrId, queueId, cmdOrError)) =
case cmdOrError of
Left e -> pure $ Left (corrId, queueId, ERR e)
Right cmd -> verified <$> verifyTransmission sig signed queueId cmd
where
verified = \case
VRVerified qr -> Right (qr, (corrId, queueId, cmd))
VRFailed -> Left (corrId, queueId, ERR AUTH)
write q = mapM_ (atomically . writeTBQueue q) . L.nonEmpty
send :: Transport c => THandle c -> Client -> IO ()
send h@THandle {thVersion = v} Client {sndQ, sessionId, activeAt} = forever $ do
ts <- atomically $ L.sortWith tOrder <$> readTBQueue sndQ
void . liftIO . tPut h Nothing $ L.map ((Nothing,) . encodeTransmission v sessionId) ts
atomically . writeTVar activeAt =<< liftIO getSystemTime
where
tOrder :: Transmission BrokerMsg -> Int
tOrder (_, _, cmd) = case cmd of
MSG {} -> 0
NMSG {} -> 0
_ -> 1
disconnectTransport :: Transport c => THandle c -> client -> (client -> TVar SystemTime) -> ExpirationConfig -> IO ()
disconnectTransport THandle {connection} c activeAt expCfg = do
let interval = checkInterval expCfg * 1000000
forever . liftIO $ do
threadDelay' interval
old <- expireBeforeEpoch expCfg
ts <- readTVarIO $ activeAt c
when (systemSeconds ts < old) $ closeConnection connection
data VerificationResult = VRVerified (Maybe QueueRec) | VRFailed
verifyTransmission :: Maybe C.ASignature -> ByteString -> QueueId -> Cmd -> M VerificationResult
verifyTransmission sig_ signed queueId cmd =
case cmd of
Cmd SRecipient (NEW k _ _) -> pure $ Nothing `verified` verifyCmdSignature sig_ signed k
Cmd SRecipient _ -> verifyCmd SRecipient $ verifyCmdSignature sig_ signed . recipientKey
Cmd SSender SEND {} -> verifyCmd SSender $ verifyMaybe . senderKey
Cmd SSender PING -> pure $ VRVerified Nothing
Cmd SNotifier NSUB -> verifyCmd SNotifier $ verifyMaybe . fmap notifierKey . notifier
where
verifyCmd :: SParty p -> (QueueRec -> Bool) -> M VerificationResult
verifyCmd party f = do
st <- asks queueStore
q_ <- atomically $ getQueue st party queueId
pure $ case q_ of
Right q -> Just q `verified` f q
_ -> maybe False (dummyVerifyCmd signed) sig_ `seq` VRFailed
verifyMaybe :: Maybe C.APublicVerifyKey -> Bool
verifyMaybe = maybe (isNothing sig_) $ verifyCmdSignature sig_ signed
verified q cond = if cond then VRVerified q else VRFailed
verifyCmdSignature :: Maybe C.ASignature -> ByteString -> C.APublicVerifyKey -> Bool
verifyCmdSignature sig_ signed key = maybe False (verify key) sig_
where
verify :: C.APublicVerifyKey -> C.ASignature -> Bool
verify (C.APublicVerifyKey a k) sig@(C.ASignature a' s) =
case (testEquality a a', C.signatureSize k == C.signatureSize s) of
(Just Refl, True) -> C.verify' k s signed
_ -> dummyVerifyCmd signed sig `seq` False
dummyVerifyCmd :: ByteString -> C.ASignature -> Bool
dummyVerifyCmd signed (C.ASignature _ s) = C.verify' (dummyPublicKey s) s signed
-- These dummy keys are used with `dummyVerify` function to mitigate timing attacks
-- by having the same time of the response whether a queue exists or nor, for all valid key/signature sizes
dummyPublicKey :: C.Signature a -> C.PublicKey a
dummyPublicKey = \case
C.SignatureEd25519 _ -> dummyKeyEd25519
C.SignatureEd448 _ -> dummyKeyEd448
dummyKeyEd25519 :: C.PublicKey 'C.Ed25519
dummyKeyEd25519 = "MCowBQYDK2VwAyEA139Oqs4QgpqbAmB0o7rZf6T19ryl7E65k4AYe0kE3Qs="
dummyKeyEd448 :: C.PublicKey 'C.Ed448
dummyKeyEd448 = "MEMwBQYDK2VxAzoA6ibQc9XpkSLtwrf7PLvp81qW/etiumckVFImCMRdftcG/XopbOSaq9qyLhrgJWKOLyNrQPNVvpMA"
client :: forall m. (MonadUnliftIO m, MonadReader Env m) => Client -> Server -> m ()
client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscribedQ, ntfSubscribedQ, notifiers} =
forever $
atomically (readTBQueue rcvQ)
>>= mapM processCommand
>>= atomically . writeTBQueue sndQ
where
processCommand :: (Maybe QueueRec, Transmission Cmd) -> m (Transmission BrokerMsg)
processCommand (qr_, (corrId, queueId, cmd)) = do
st <- asks queueStore
case cmd of
Cmd SSender command ->
case command of
SEND flags msgBody -> withQueue $ \qr -> sendMessage qr flags msgBody
PING -> pure (corrId, "", PONG)
Cmd SNotifier NSUB -> subscribeNotifications
Cmd SRecipient command ->
case command of
NEW rKey dhKey auth ->
ifM
allowNew
(createQueue st rKey dhKey)
(pure (corrId, queueId, ERR AUTH))
where
allowNew = do
ServerConfig {allowNewQueues, newQueueBasicAuth} <- asks config
pure $ allowNewQueues && maybe True ((== auth) . Just) newQueueBasicAuth
SUB -> withQueue (`subscribeQueue` queueId)
GET -> withQueue getMessage
ACK msgId -> withQueue (`acknowledgeMsg` msgId)
KEY sKey -> secureQueue_ st sKey
NKEY nKey dhKey -> addQueueNotifier_ st nKey dhKey
NDEL -> deleteQueueNotifier_ st
OFF -> suspendQueue_ st
DEL -> delQueueAndMsgs st
where
createQueue :: QueueStore -> RcvPublicVerifyKey -> RcvPublicDhKey -> m (Transmission BrokerMsg)
createQueue st recipientKey dhKey = time "NEW" $ do
(rcvPublicDhKey, privDhKey) <- liftIO C.generateKeyPair'
let rcvDhSecret = C.dh' dhKey privDhKey
qik (rcvId, sndId) = QIK {rcvId, sndId, rcvPublicDhKey}
qRec (recipientId, senderId) =
QueueRec
{ recipientId,
senderId,
recipientKey,
rcvDhSecret,
senderKey = Nothing,
notifier = Nothing,
status = QueueActive
}
(corrId,queueId,) <$> addQueueRetry 3 qik qRec
where
addQueueRetry ::
Int -> ((RecipientId, SenderId) -> QueueIdsKeys) -> ((RecipientId, SenderId) -> QueueRec) -> m BrokerMsg
addQueueRetry 0 _ _ = pure $ ERR INTERNAL
addQueueRetry n qik qRec = do
ids@(rId, _) <- getIds
-- create QueueRec record with these ids and keys
let qr = qRec ids
atomically (addQueue st qr) >>= \case
Left DUPLICATE_ -> addQueueRetry (n - 1) qik qRec
Left e -> pure $ ERR e
Right _ -> do
withLog (`logCreateById` rId)
stats <- asks serverStats
atomically $ modifyTVar' (qCreated stats) (+ 1)
atomically $ modifyTVar' (qCount stats) (+ 1)
subscribeQueue qr rId $> IDS (qik ids)
logCreateById :: StoreLog 'WriteMode -> RecipientId -> IO ()
logCreateById s rId =
atomically (getQueue st SRecipient rId) >>= \case
Right q -> logCreateQueue s q
_ -> pure ()
getIds :: m (RecipientId, SenderId)
getIds = do
n <- asks $ queueIdBytes . config
liftM2 (,) (randomId n) (randomId n)
secureQueue_ :: QueueStore -> SndPublicVerifyKey -> m (Transmission BrokerMsg)
secureQueue_ st sKey = time "KEY" $ do
withLog $ \s -> logSecureQueue s queueId sKey
stats <- asks serverStats
atomically $ modifyTVar' (qSecured stats) (+ 1)
atomically $ (corrId,queueId,) . either ERR (const OK) <$> secureQueue st queueId sKey
addQueueNotifier_ :: QueueStore -> NtfPublicVerifyKey -> RcvNtfPublicDhKey -> m (Transmission BrokerMsg)
addQueueNotifier_ st notifierKey dhKey = time "NKEY" $ do
(rcvPublicDhKey, privDhKey) <- liftIO C.generateKeyPair'
let rcvNtfDhSecret = C.dh' dhKey privDhKey
(corrId,queueId,) <$> addNotifierRetry 3 rcvPublicDhKey rcvNtfDhSecret
where
addNotifierRetry :: Int -> RcvNtfPublicDhKey -> RcvNtfDhSecret -> m BrokerMsg
addNotifierRetry 0 _ _ = pure $ ERR INTERNAL
addNotifierRetry n rcvPublicDhKey rcvNtfDhSecret = do
notifierId <- randomId =<< asks (queueIdBytes . config)
let ntfCreds = NtfCreds {notifierId, notifierKey, rcvNtfDhSecret}
atomically (addQueueNotifier st queueId ntfCreds) >>= \case
Left DUPLICATE_ -> addNotifierRetry (n - 1) rcvPublicDhKey rcvNtfDhSecret
Left e -> pure $ ERR e
Right _ -> do
withLog $ \s -> logAddNotifier s queueId ntfCreds
pure $ NID notifierId rcvPublicDhKey
deleteQueueNotifier_ :: QueueStore -> m (Transmission BrokerMsg)
deleteQueueNotifier_ st = do
withLog (`logDeleteNotifier` queueId)
okResp <$> atomically (deleteQueueNotifier st queueId)
suspendQueue_ :: QueueStore -> m (Transmission BrokerMsg)
suspendQueue_ st = do
withLog (`logSuspendQueue` queueId)
okResp <$> atomically (suspendQueue st queueId)
subscribeQueue :: QueueRec -> RecipientId -> m (Transmission BrokerMsg)
subscribeQueue qr rId = do
atomically (TM.lookup rId subscriptions) >>= \case
Nothing ->
newSub >>= deliver
Just sub ->
readTVarIO sub >>= \case
Sub {subThread = ProhibitSub} ->
-- cannot use SUB in the same connection where GET was used
pure (corrId, rId, ERR $ CMD PROHIBITED)
s ->
atomically (tryTakeTMVar $ delivered s) >> deliver sub
where
newSub :: m (TVar Sub)
newSub = time "SUB newSub" . atomically $ do
writeTQueue subscribedQ (rId, clnt)
sub <- newTVar =<< newSubscription NoSub
TM.insert rId sub subscriptions
pure sub
deliver :: TVar Sub -> m (Transmission BrokerMsg)
deliver sub = do
q <- getStoreMsgQueue "SUB" rId
msg_ <- atomically $ tryPeekMsg q
deliverMessage "SUB" qr rId sub q msg_
getMessage :: QueueRec -> m (Transmission BrokerMsg)
getMessage qr = time "GET" $ do
atomically (TM.lookup queueId subscriptions) >>= \case
Nothing ->
atomically newSub >>= getMessage_
Just sub ->
readTVarIO sub >>= \case
s@Sub {subThread = ProhibitSub} ->
atomically (tryTakeTMVar $ delivered s)
>> getMessage_ s
-- cannot use GET in the same connection where there is an active subscription
_ -> pure (corrId, queueId, ERR $ CMD PROHIBITED)
where
newSub :: STM Sub
newSub = do
s <- newSubscription ProhibitSub
sub <- newTVar s
TM.insert queueId sub subscriptions
pure s
getMessage_ :: Sub -> m (Transmission BrokerMsg)
getMessage_ s = do
q <- getStoreMsgQueue "GET" queueId
atomically $
tryPeekMsg q >>= \case
Just msg ->
let encMsg = encryptMsg qr msg
in setDelivered s msg $> (corrId, queueId, MSG encMsg)
_ -> pure (corrId, queueId, OK)
withQueue :: (QueueRec -> m (Transmission BrokerMsg)) -> m (Transmission BrokerMsg)
withQueue action = maybe (pure $ err AUTH) action qr_
subscribeNotifications :: m (Transmission BrokerMsg)
subscribeNotifications = time "NSUB" . atomically $ do
unlessM (TM.member queueId ntfSubscriptions) $ do
writeTQueue ntfSubscribedQ (queueId, clnt)
TM.insert queueId () ntfSubscriptions
pure ok
acknowledgeMsg :: QueueRec -> MsgId -> m (Transmission BrokerMsg)
acknowledgeMsg qr msgId = time "ACK" $ do
atomically (TM.lookup queueId subscriptions) >>= \case
Nothing -> pure $ err NO_MSG
Just sub ->
atomically (getDelivered sub) >>= \case
Just s -> do
q <- getStoreMsgQueue "ACK" queueId
case s of
Sub {subThread = ProhibitSub} -> do
deletedMsg_ <- atomically $ tryDelMsg q msgId
mapM_ updateStats deletedMsg_
pure ok
_ -> do
(deletedMsg_, msg_) <- atomically $ tryDelPeekMsg q msgId
mapM_ updateStats deletedMsg_
deliverMessage "ACK" qr queueId sub q msg_
_ -> pure $ err NO_MSG
where
getDelivered :: TVar Sub -> STM (Maybe Sub)
getDelivered sub = do
s@Sub {delivered} <- readTVar sub
tryTakeTMVar delivered $>>= \msgId' ->
if msgId == msgId' || B.null msgId
then pure $ Just s
else putTMVar delivered msgId' $> Nothing
updateStats :: Message -> m ()
updateStats = \case
MessageQuota {} -> pure ()
Message {msgFlags} -> do
stats <- asks serverStats
atomically $ modifyTVar' (msgRecv stats) (+ 1)
atomically $ modifyTVar' (msgCount stats) (+ 1)
atomically $ updatePeriodStats (activeQueues stats) queueId
when (notification msgFlags) $ do
atomically $ modifyTVar' (msgRecvNtf stats) (+ 1)
atomically $ updatePeriodStats (activeQueuesNtf stats) queueId
sendMessage :: QueueRec -> MsgFlags -> MsgBody -> m (Transmission BrokerMsg)
sendMessage qr msgFlags msgBody
| B.length msgBody > maxMessageLength = pure $ err LARGE_MSG
| otherwise = case status qr of
QueueOff -> return $ err AUTH
QueueActive ->
case C.maxLenBS msgBody of
Left _ -> pure $ err LARGE_MSG
Right body -> do
msg_ <- time "SEND" $ do
q <- getStoreMsgQueue "SEND" $ recipientId qr
expireMessages q
atomically . writeMsg q =<< mkMessage body
case msg_ of
Nothing -> pure $ err QUOTA
Just msg -> time "SEND ok" $ do
stats <- asks serverStats
when (notification msgFlags) $ do
atomically . trySendNotification msg =<< asks idsDrg
atomically $ modifyTVar' (msgSentNtf stats) (+ 1)
atomically $ updatePeriodStats (activeQueuesNtf stats) (recipientId qr)
atomically $ modifyTVar' (msgSent stats) (+ 1)
atomically $ modifyTVar' (msgCount stats) (subtract 1)
atomically $ updatePeriodStats (activeQueues stats) (recipientId qr)
pure ok
where
mkMessage :: C.MaxLenBS MaxMessageLen -> m Message
mkMessage body = do
msgId <- randomId =<< asks (msgIdBytes . config)
msgTs <- liftIO getSystemTime
pure $ Message msgId msgTs msgFlags body
expireMessages :: MsgQueue -> m ()
expireMessages q = do
msgExp <- asks $ messageExpiration . config
old <- liftIO $ mapM expireBeforeEpoch msgExp
atomically $ mapM_ (deleteExpiredMsgs q) old
trySendNotification :: Message -> TVar ChaChaDRG -> STM ()
trySendNotification msg ntfNonceDrg =
forM_ (notifier qr) $ \NtfCreds {notifierId, rcvNtfDhSecret} ->
mapM_ (writeNtf notifierId msg rcvNtfDhSecret ntfNonceDrg) =<< TM.lookup notifierId notifiers
writeNtf :: NotifierId -> Message -> RcvNtfDhSecret -> TVar ChaChaDRG -> Client -> STM ()
writeNtf nId msg rcvNtfDhSecret ntfNonceDrg Client {sndQ = q} =
unlessM (isFullTBQueue q) $ case msg of
Message {msgId, msgTs} -> do
(nmsgNonce, encNMsgMeta) <- mkMessageNotification msgId msgTs rcvNtfDhSecret ntfNonceDrg
writeTBQueue q [(CorrId "", nId, NMSG nmsgNonce encNMsgMeta)]
_ -> pure ()
mkMessageNotification :: ByteString -> SystemTime -> RcvNtfDhSecret -> TVar ChaChaDRG -> STM (C.CbNonce, EncNMsgMeta)
mkMessageNotification msgId msgTs rcvNtfDhSecret ntfNonceDrg = do
cbNonce <- C.pseudoRandomCbNonce ntfNonceDrg
let msgMeta = NMsgMeta {msgId, msgTs}
encNMsgMeta = C.cbEncrypt rcvNtfDhSecret cbNonce (smpEncode msgMeta) 128
pure . (cbNonce,) $ fromRight "" encNMsgMeta
deliverMessage :: T.Text -> QueueRec -> RecipientId -> TVar Sub -> MsgQueue -> Maybe Message -> m (Transmission BrokerMsg)
deliverMessage name qr rId sub q msg_ = time (name <> " deliver") $ do
readTVarIO sub >>= \case
s@Sub {subThread = NoSub} ->
case msg_ of
Just msg ->
let encMsg = encryptMsg qr msg
in atomically (setDelivered s msg) $> (corrId, rId, MSG encMsg)
_ -> forkSub $> ok
_ -> pure ok
where
forkSub :: m ()
forkSub = do
atomically . modifyTVar' sub $ \s -> s {subThread = SubPending}
t <- mkWeakThreadId =<< forkIO subscriber
atomically . modifyTVar' sub $ \case
s@Sub {subThread = SubPending} -> s {subThread = SubThread t}
s -> s
where
subscriber = do
msg <- atomically $ peekMsg q
time "subscriber" . atomically $ do
let encMsg = encryptMsg qr msg
writeTBQueue sndQ [(CorrId "", rId, MSG encMsg)]
s <- readTVar sub
void $ setDelivered s msg
writeTVar sub $! s {subThread = NoSub}
time :: T.Text -> m a -> m a
time name = timed name queueId
encryptMsg :: QueueRec -> Message -> RcvMessage
encryptMsg qr msg = case msg of
Message {msgFlags, msgBody}
| thVersion == 1 || thVersion == 2 -> encrypt msgFlags msgBody
| otherwise -> encrypt msgFlags $ encodeRcvMsgBody RcvMsgBody {msgTs = msgTs', msgFlags, msgBody}
MessageQuota {} ->
encrypt noMsgFlags $ encodeRcvMsgBody (RcvMsgQuota msgTs')
where
encrypt :: KnownNat i => MsgFlags -> C.MaxLenBS i -> RcvMessage
encrypt msgFlags body =
let encBody = EncRcvMsgBody $ C.cbEncryptMaxLenBS (rcvDhSecret qr) (C.cbNonce msgId') body
in RcvMessage msgId' msgTs' msgFlags encBody
msgId' = msgId (msg :: Message)
msgTs' = msgTs (msg :: Message)
setDelivered :: Sub -> Message -> STM Bool
setDelivered s msg = tryPutTMVar (delivered s) $ msgId (msg :: Message)
getStoreMsgQueue :: T.Text -> RecipientId -> m MsgQueue
getStoreMsgQueue name rId = time (name <> " getMsgQueue") $ do
ms <- asks msgStore
quota <- asks $ msgQueueQuota . config
atomically $ getMsgQueue ms rId quota
delQueueAndMsgs :: QueueStore -> m (Transmission BrokerMsg)
delQueueAndMsgs st = do
withLog (`logDeleteQueue` queueId)
ms <- asks msgStore
stats <- asks serverStats
atomically $ modifyTVar' (qDeleted stats) (+ 1)
atomically $ modifyTVar' (qCount stats) (subtract 1)
atomically $
deleteQueue st queueId >>= \case
Left e -> pure $ err e
Right _ -> delMsgQueue ms queueId $> ok
ok :: Transmission BrokerMsg
ok = (corrId, queueId, OK)
err :: ErrorType -> Transmission BrokerMsg
err e = (corrId, queueId, ERR e)
okResp :: Either ErrorType () -> Transmission BrokerMsg
okResp = either err $ const ok
withLog :: (MonadUnliftIO m, MonadReader Env m) => (StoreLog 'WriteMode -> IO a) -> m ()
withLog action = do
env <- ask
liftIO . mapM_ action $ storeLog (env :: Env)
timed :: MonadUnliftIO m => T.Text -> RecipientId -> m a -> m a
timed name qId a = do
t <- liftIO getSystemTime
r <- a
t' <- liftIO getSystemTime
let int = diff t t'
when (int > sec) . logDebug $ T.unwords [name, tshow $ encode qId, tshow int]
pure r
where
diff t t' = (systemSeconds t' - systemSeconds t) * sec + fromIntegral (systemNanoseconds t' - systemNanoseconds t)
sec = 1000_000000
randomId :: (MonadUnliftIO m, MonadReader Env m) => Int -> m ByteString
randomId n = do
gVar <- asks idsDrg
atomically (C.pseudoRandomBytes n gVar)
saveServerMessages :: (MonadUnliftIO m, MonadReader Env m) => Bool -> m ()
saveServerMessages keepMsgs = asks (storeMsgsFile . config) >>= mapM_ saveMessages
where
saveMessages f = do
logInfo $ "saving messages to file " <> T.pack f
ms <- asks msgStore
liftIO . withFile f WriteMode $ \h ->
readTVarIO ms >>= mapM_ (saveQueueMsgs ms h) . M.keys
logInfo "messages saved"
where
getMessages = if keepMsgs then snapshotMsgQueue else flushMsgQueue
saveQueueMsgs ms h rId =
atomically (getMessages ms rId)
>>= mapM_ (B.hPutStrLn h . strEncode . MLRv3 rId)
restoreServerMessages :: forall m. (MonadUnliftIO m, MonadReader Env m) => m ()
restoreServerMessages = asks (storeMsgsFile . config) >>= mapM_ restoreMessages
where
restoreMessages f = whenM (doesFileExist f) $ do
logInfo $ "restoring messages from file " <> T.pack f
st <- asks queueStore
ms <- asks msgStore
quota <- asks $ msgQueueQuota . config
old_ <- asks (messageExpiration . config) $>>= (liftIO . fmap Just . expireBeforeEpoch)
runExceptT (liftIO (B.readFile f) >>= mapM_ (restoreMsg st ms quota old_) . B.lines) >>= \case
Left e -> do
logError . T.pack $ "error restoring messages: " <> e
liftIO exitFailure
_ -> do
renameFile f $ f <> ".bak"
logInfo "messages restored"
where
restoreMsg st ms quota old_ s = do
r <- liftEither . first (msgErr "parsing") $ strDecode s
case r of
MLRv3 rId msg -> addToMsgQueue rId msg
MLRv1 rId encMsg -> do
qr <- liftEitherError (msgErr "queue unknown") . atomically $ getQueue st SRecipient rId
msg' <- updateMsgV1toV3 qr encMsg
addToMsgQueue rId msg'
where
addToMsgQueue rId msg = do
logFull <- atomically $ do
q <- getMsgQueue ms rId quota
case msg of
Message {msgTs}
| maybe True (systemSeconds msgTs >=) old_ -> isNothing <$> writeMsg q msg
| otherwise -> pure False
MessageQuota {} -> writeMsg q msg $> False
when logFull . logError . decodeLatin1 $ "message queue " <> strEncode rId <> " is full, message not restored: " <> strEncode (msgId (msg :: Message))
updateMsgV1toV3 QueueRec {rcvDhSecret} RcvMessage {msgId, msgTs, msgFlags, msgBody = EncRcvMsgBody body} = do
let nonce = C.cbNonce msgId
msgBody <- liftEither . first (msgErr "v1 message decryption") $ C.maxLenBS =<< C.cbDecrypt rcvDhSecret nonce body
pure Message {msgId, msgTs, msgFlags, msgBody}
msgErr :: Show e => String -> e -> String
msgErr op e = op <> " error (" <> show e <> "): " <> B.unpack (B.take 100 s)
saveServerStats :: (MonadUnliftIO m, MonadReader Env m) => m ()
saveServerStats =
asks (serverStatsBackupFile . config)
>>= mapM_ (\f -> asks serverStats >>= atomically . getServerStatsData >>= liftIO . saveStats f)
where
saveStats f stats = do
logInfo $ "saving server stats to file " <> T.pack f
B.writeFile f $ strEncode stats
logInfo "server stats saved"
restoreServerStats :: (MonadUnliftIO m, MonadReader Env m) => m ()
restoreServerStats = asks (serverStatsBackupFile . config) >>= mapM_ restoreStats
where
restoreStats f = whenM (doesFileExist f) $ do
logInfo $ "restoring server stats from file " <> T.pack f
liftIO (strDecode <$> B.readFile f) >>= \case
Right d -> do
s <- asks serverStats
_qCount <- fmap (length . M.keys) . readTVarIO . queues =<< asks queueStore
_msgCount <- foldM (\n q -> (n +) <$> readTVarIO (size q)) 0 =<< readTVarIO =<< asks msgStore
atomically $ setServerStats s d {_qCount, _msgCount}
renameFile f $ f <> ".bak"
logInfo "server stats restored"
Left e -> do
logInfo $ "error restoring server stats: " <> T.pack e
liftIO exitFailure