{-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE NumericUnderscores #-} {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} -- | -- Module : Simplex.Messaging.Server -- Copyright : (c) simplex.chat -- License : AGPL-3 -- -- Maintainer : chat@simplex.chat -- Stability : experimental -- Portability : non-portable -- -- This module defines SMP protocol server with in-memory persistence -- and optional append only log of SMP queue records. -- -- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md module Simplex.Messaging.Server ( runSMPServer, runSMPServerBlocking, disconnectTransport, verifyCmdSignature, dummyVerifyCmd, ) where import Control.Logger.Simple import Control.Monad import Control.Monad.Except import Control.Monad.IO.Unlift import Control.Monad.Reader import Crypto.Random import Data.Bifunctor (first) import Data.ByteString.Base64 (encode) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Either (fromRight, partitionEithers) import Data.Functor (($>)) import Data.List (intercalate) import qualified Data.List.NonEmpty as L import qualified Data.Map.Strict as M import Data.Maybe (isNothing) import qualified Data.Text as T import Data.Text.Encoding (decodeLatin1) import Data.Time.Clock (UTCTime (..), diffTimeToPicoseconds, getCurrentTime) import Data.Time.Clock.System (SystemTime (..), getSystemTime) import Data.Time.Format.ISO8601 (iso8601Show) import Data.Type.Equality import GHC.TypeLits (KnownNat) import Network.Socket (ServiceName) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding (Encoding (smpEncode)) import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol import Simplex.Messaging.Server.Env.STM import Simplex.Messaging.Server.Expiration import Simplex.Messaging.Server.MsgStore import Simplex.Messaging.Server.MsgStore.STM (MsgQueue) import Simplex.Messaging.Server.QueueStore import Simplex.Messaging.Server.QueueStore.STM (QueueStore) import Simplex.Messaging.Server.Stats import Simplex.Messaging.Server.StoreLog import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport import Simplex.Messaging.Transport.Server import Simplex.Messaging.Util import System.Exit (exitFailure) import System.IO (hPutStrLn) import System.Mem.Weak (deRefWeak) import UnliftIO.Concurrent import UnliftIO.Directory (doesFileExist, renameFile) import UnliftIO.Exception import UnliftIO.IO import UnliftIO.STM -- | Runs an SMP server using passed configuration. -- -- See a full server here: https://github.com/simplex-chat/simplexmq/blob/master/apps/smp-server/Main.hs runSMPServer :: (MonadRandom m, MonadUnliftIO m) => ServerConfig -> m () runSMPServer cfg = do started <- newEmptyTMVarIO runSMPServerBlocking started cfg -- | Runs an SMP server using passed configuration with signalling. -- -- This function uses passed TMVar to signal when the server is ready to accept TCP requests (True) -- and when it is disconnected from the TCP socket once the server thread is killed (False). runSMPServerBlocking :: (MonadRandom m, MonadUnliftIO m) => TMVar Bool -> ServerConfig -> m () runSMPServerBlocking started cfg = newEnv cfg >>= runReaderT (smpServer started) smpServer :: forall m. (MonadUnliftIO m, MonadReader Env m) => TMVar Bool -> m () smpServer started = do s <- asks server cfg@ServerConfig {transports} <- asks config restoreServerStats restoreServerMessages raceAny_ ( serverThread s subscribedQ subscribers subscriptions cancelSub : serverThread s ntfSubscribedQ notifiers ntfSubscriptions (\_ -> pure ()) : map runServer transports <> expireMessagesThread_ cfg <> serverStatsThread_ cfg ) `finally` (withLog closeStoreLog >> saveServerMessages >> saveServerStats) where runServer :: (ServiceName, ATransport) -> m () runServer (tcpPort, ATransport t) = do serverParams <- asks tlsServerParams runTransportServer started tcpPort serverParams (runClient t) serverThread :: forall s. Server -> (Server -> TBQueue (QueueId, Client)) -> (Server -> TMap QueueId Client) -> (Client -> TMap QueueId s) -> (s -> m ()) -> m () serverThread s subQ subs clientSubs unsub = forever $ do atomically updateSubscribers $>>= endPreviousSubscriptions >>= mapM_ unsub where updateSubscribers :: STM (Maybe (QueueId, Client)) updateSubscribers = do (qId, clnt) <- readTBQueue $ subQ s let clientToBeNotified = \c' -> if sameClientSession clnt c' then pure Nothing else do yes <- readTVar $ connected c' pure $ if yes then Just (qId, c') else Nothing TM.lookupInsert qId clnt (subs s) $>>= clientToBeNotified endPreviousSubscriptions :: (QueueId, Client) -> m (Maybe s) endPreviousSubscriptions (qId, c) = do void . forkIO . atomically $ writeTBQueue (sndQ c) [(CorrId "", qId, END)] atomically $ TM.lookupDelete qId (clientSubs c) expireMessagesThread_ :: ServerConfig -> [m ()] expireMessagesThread_ ServerConfig {messageExpiration = Just msgExp} = [expireMessages msgExp] expireMessagesThread_ _ = [] expireMessages :: ExpirationConfig -> m () expireMessages expCfg = do ms <- asks msgStore quota <- asks $ msgQueueQuota . config let interval = checkInterval expCfg * 1000000 forever $ do threadDelay interval old <- liftIO $ expireBeforeEpoch expCfg rIds <- M.keysSet <$> readTVarIO ms forM_ rIds $ \rId -> atomically (getMsgQueue ms rId quota) >>= atomically . (`deleteExpiredMsgs` old) serverStatsThread_ :: ServerConfig -> [m ()] serverStatsThread_ ServerConfig {logStatsInterval = Just interval, logStatsStartTime, serverStatsLogFile} = [logServerStats logStatsStartTime interval serverStatsLogFile] serverStatsThread_ _ = [] logServerStats :: Int -> Int -> FilePath -> m () logServerStats startAt logInterval statsFilePath = do initialDelay <- (startAt -) . fromIntegral . (`div` 1000000_000000) . diffTimeToPicoseconds . utctDayTime <$> liftIO getCurrentTime liftIO $ putStrLn $ "server stats log enabled: " <> statsFilePath threadDelay $ 1000000 * (initialDelay + if initialDelay < 0 then 86400 else 0) ServerStats {fromTime, qCreated, qSecured, qDeleted, msgSent, msgRecv, activeQueues} <- asks serverStats let interval = 1000000 * logInterval withFile statsFilePath AppendMode $ \h -> liftIO $ do hSetBuffering h LineBuffering forever $ do ts <- getCurrentTime fromTime' <- atomically $ swapTVar fromTime ts qCreated' <- atomically $ swapTVar qCreated 0 qSecured' <- atomically $ swapTVar qSecured 0 qDeleted' <- atomically $ swapTVar qDeleted 0 msgSent' <- atomically $ swapTVar msgSent 0 msgRecv' <- atomically $ swapTVar msgRecv 0 ps <- atomically $ periodStatCounts activeQueues ts hPutStrLn h $ intercalate "," [iso8601Show $ utctDay fromTime', show qCreated', show qSecured', show qDeleted', show msgSent', show msgRecv', dayCount ps, weekCount ps, monthCount ps] threadDelay interval runClient :: Transport c => TProxy c -> c -> m () runClient _ h = do kh <- asks serverIdentity smpVRange <- asks $ smpServerVRange . config liftIO (runExceptT $ smpServerHandshake h kh smpVRange) >>= \case Right th -> runClientTransport th Left _ -> pure () runClientTransport :: (Transport c, MonadUnliftIO m, MonadReader Env m) => THandle c -> m () runClientTransport th@THandle {thVersion, sessionId} = do q <- asks $ tbqSize . config ts <- liftIO getSystemTime c <- atomically $ newClient q thVersion sessionId ts s <- asks server expCfg <- asks $ inactiveClientExpiration . config raceAny_ ([send th c, client c s, receive th c] <> disconnectThread_ c expCfg) `finally` clientDisconnected c where disconnectThread_ c (Just expCfg) = [disconnectTransport th c activeAt expCfg] disconnectThread_ _ _ = [] clientDisconnected :: (MonadUnliftIO m, MonadReader Env m) => Client -> m () clientDisconnected c@Client {subscriptions, connected} = do atomically $ writeTVar connected False subs <- readTVarIO subscriptions mapM_ cancelSub subs atomically $ writeTVar subscriptions M.empty cs <- asks $ subscribers . server atomically . mapM_ (\rId -> TM.update deleteCurrentClient rId cs) $ M.keys subs where deleteCurrentClient :: Client -> Maybe Client deleteCurrentClient c' | sameClientSession c c' = Nothing | otherwise = Just c' sameClientSession :: Client -> Client -> Bool sameClientSession Client {sessionId} Client {sessionId = s'} = sessionId == s' cancelSub :: MonadUnliftIO m => TVar Sub -> m () cancelSub sub = readTVarIO sub >>= \case Sub {subThread = SubThread t} -> liftIO $ deRefWeak t >>= mapM_ killThread _ -> return () receive :: forall c m. (Transport c, MonadUnliftIO m, MonadReader Env m) => THandle c -> Client -> m () receive th Client {rcvQ, sndQ, activeAt} = forever $ do ts <- L.toList <$> tGet th atomically . writeTVar activeAt =<< liftIO getSystemTime as <- partitionEithers <$> mapM cmdAction ts write sndQ $ fst as write rcvQ $ snd as where cmdAction :: SignedTransmission Cmd -> m (Either (Transmission BrokerMsg) (Maybe QueueRec, Transmission Cmd)) cmdAction (sig, signed, (corrId, queueId, cmdOrError)) = case cmdOrError of Left e -> pure $ Left (corrId, queueId, ERR e) Right cmd -> verified <$> verifyTransmission sig signed queueId cmd where verified = \case VRVerified qr -> Right (qr, (corrId, queueId, cmd)) VRFailed -> Left (corrId, queueId, ERR AUTH) write q = mapM_ (atomically . writeTBQueue q) . L.nonEmpty send :: (Transport c, MonadUnliftIO m) => THandle c -> Client -> m () send h@THandle {thVersion = v} Client {sndQ, sessionId, activeAt} = forever $ do ts <- atomically $ L.sortWith tOrder <$> readTBQueue sndQ -- TODO the line below can return Lefts, but we ignore it and do not disconnect the client void . liftIO . tPut h $ L.map ((Nothing,) . encodeTransmission v sessionId) ts atomically . writeTVar activeAt =<< liftIO getSystemTime where tOrder :: Transmission BrokerMsg -> Int tOrder (_, _, cmd) = case cmd of MSG {} -> 0 _ -> 1 disconnectTransport :: (Transport c, MonadUnliftIO m) => THandle c -> client -> (client -> TVar SystemTime) -> ExpirationConfig -> m () disconnectTransport THandle {connection} c activeAt expCfg = do let interval = checkInterval expCfg * 1000000 forever . liftIO $ do threadDelay interval old <- expireBeforeEpoch expCfg ts <- readTVarIO $ activeAt c when (systemSeconds ts < old) $ closeConnection connection data VerificationResult = VRVerified (Maybe QueueRec) | VRFailed verifyTransmission :: forall m. (MonadUnliftIO m, MonadReader Env m) => Maybe C.ASignature -> ByteString -> QueueId -> Cmd -> m VerificationResult verifyTransmission sig_ signed queueId cmd = do case cmd of Cmd SRecipient (NEW k _) -> pure $ Nothing `verified` verifyCmdSignature sig_ signed k Cmd SRecipient _ -> verifyCmd SRecipient $ verifyCmdSignature sig_ signed . recipientKey Cmd SSender SEND {} -> verifyCmd SSender $ verifyMaybe . senderKey Cmd SSender PING -> pure $ VRVerified Nothing Cmd SNotifier NSUB -> verifyCmd SNotifier $ verifyMaybe . fmap notifierKey . notifier where verifyCmd :: SParty p -> (QueueRec -> Bool) -> m VerificationResult verifyCmd party f = do st <- asks queueStore q_ <- atomically (getQueue st party queueId) pure $ case q_ of Right q -> Just q `verified` f q _ -> maybe False (dummyVerifyCmd signed) sig_ `seq` VRFailed verifyMaybe :: Maybe C.APublicVerifyKey -> Bool verifyMaybe = maybe (isNothing sig_) $ verifyCmdSignature sig_ signed verified q cond = if cond then VRVerified q else VRFailed verifyCmdSignature :: Maybe C.ASignature -> ByteString -> C.APublicVerifyKey -> Bool verifyCmdSignature sig_ signed key = maybe False (verify key) sig_ where verify :: C.APublicVerifyKey -> C.ASignature -> Bool verify (C.APublicVerifyKey a k) sig@(C.ASignature a' s) = case (testEquality a a', C.signatureSize k == C.signatureSize s) of (Just Refl, True) -> C.verify' k s signed _ -> dummyVerifyCmd signed sig `seq` False dummyVerifyCmd :: ByteString -> C.ASignature -> Bool dummyVerifyCmd signed (C.ASignature _ s) = C.verify' (dummyPublicKey s) s signed -- These dummy keys are used with `dummyVerify` function to mitigate timing attacks -- by having the same time of the response whether a queue exists or nor, for all valid key/signature sizes dummyPublicKey :: C.Signature a -> C.PublicKey a dummyPublicKey = \case C.SignatureEd25519 _ -> dummyKeyEd25519 C.SignatureEd448 _ -> dummyKeyEd448 dummyKeyEd25519 :: C.PublicKey 'C.Ed25519 dummyKeyEd25519 = "MCowBQYDK2VwAyEA139Oqs4QgpqbAmB0o7rZf6T19ryl7E65k4AYe0kE3Qs=" dummyKeyEd448 :: C.PublicKey 'C.Ed448 dummyKeyEd448 = "MEMwBQYDK2VxAzoA6ibQc9XpkSLtwrf7PLvp81qW/etiumckVFImCMRdftcG/XopbOSaq9qyLhrgJWKOLyNrQPNVvpMA" client :: forall m. (MonadUnliftIO m, MonadReader Env m) => Client -> Server -> m () client clnt@Client {thVersion, sessionId, subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscribedQ, ntfSubscribedQ, notifiers} = forever $ atomically (readTBQueue rcvQ) >>= mapM processCommand >>= atomically . writeTBQueue sndQ where processCommand :: (Maybe QueueRec, Transmission Cmd) -> m (Transmission BrokerMsg) processCommand (qr_, (corrId, queueId, cmd)) = do st <- asks queueStore case cmd of Cmd SSender command -> case command of SEND flags msgBody -> withQueue $ \qr -> sendMessage qr flags msgBody PING -> pure (corrId, "", PONG) Cmd SNotifier NSUB -> subscribeNotifications Cmd SRecipient command -> case command of NEW rKey dhKey -> ifM (asks $ allowNewQueues . config) (createQueue st rKey dhKey) (pure (corrId, queueId, ERR AUTH)) SUB -> withQueue (`subscribeQueue` queueId) GET -> withQueue getMessage ACK msgId -> withQueue (`acknowledgeMsg` msgId) KEY sKey -> secureQueue_ st sKey NKEY nKey dhKey -> addQueueNotifier_ st nKey dhKey NDEL -> deleteQueueNotifier_ st OFF -> suspendQueue_ st DEL -> delQueueAndMsgs st where createQueue :: QueueStore -> RcvPublicVerifyKey -> RcvPublicDhKey -> m (Transmission BrokerMsg) createQueue st recipientKey dhKey = time "NEW" $ do (rcvPublicDhKey, privDhKey) <- liftIO C.generateKeyPair' let rcvDhSecret = C.dh' dhKey privDhKey qik (rcvId, sndId) = QIK {rcvId, sndId, rcvPublicDhKey} qRec (recipientId, senderId) = QueueRec { recipientId, senderId, recipientKey, rcvDhSecret, senderKey = Nothing, notifier = Nothing, status = QueueActive } (corrId,queueId,) <$> addQueueRetry 3 qik qRec where addQueueRetry :: Int -> ((RecipientId, SenderId) -> QueueIdsKeys) -> ((RecipientId, SenderId) -> QueueRec) -> m BrokerMsg addQueueRetry 0 _ _ = pure $ ERR INTERNAL addQueueRetry n qik qRec = do ids@(rId, _) <- getIds -- create QueueRec record with these ids and keys let qr = qRec ids atomically (addQueue st qr) >>= \case Left DUPLICATE_ -> addQueueRetry (n - 1) qik qRec Left e -> pure $ ERR e Right _ -> do withLog (`logCreateById` rId) stats <- asks serverStats atomically $ modifyTVar (qCreated stats) (+ 1) subscribeQueue qr rId $> IDS (qik ids) logCreateById :: StoreLog 'WriteMode -> RecipientId -> IO () logCreateById s rId = atomically (getQueue st SRecipient rId) >>= \case Right q -> logCreateQueue s q _ -> pure () getIds :: m (RecipientId, SenderId) getIds = do n <- asks $ queueIdBytes . config liftM2 (,) (randomId n) (randomId n) secureQueue_ :: QueueStore -> SndPublicVerifyKey -> m (Transmission BrokerMsg) secureQueue_ st sKey = time "KEY" $ do withLog $ \s -> logSecureQueue s queueId sKey stats <- asks serverStats atomically $ modifyTVar (qSecured stats) (+ 1) atomically $ (corrId,queueId,) . either ERR (const OK) <$> secureQueue st queueId sKey addQueueNotifier_ :: QueueStore -> NtfPublicVerifyKey -> RcvNtfPublicDhKey -> m (Transmission BrokerMsg) addQueueNotifier_ st notifierKey dhKey = time "NKEY" $ do (rcvPublicDhKey, privDhKey) <- liftIO C.generateKeyPair' let rcvNtfDhSecret = C.dh' dhKey privDhKey (corrId,queueId,) <$> addNotifierRetry 3 rcvPublicDhKey rcvNtfDhSecret where addNotifierRetry :: Int -> RcvNtfPublicDhKey -> RcvNtfDhSecret -> m BrokerMsg addNotifierRetry 0 _ _ = pure $ ERR INTERNAL addNotifierRetry n rcvPublicDhKey rcvNtfDhSecret = do notifierId <- randomId =<< asks (queueIdBytes . config) let ntfCreds = NtfCreds {notifierId, notifierKey, rcvNtfDhSecret} atomically (addQueueNotifier st queueId ntfCreds) >>= \case Left DUPLICATE_ -> addNotifierRetry (n - 1) rcvPublicDhKey rcvNtfDhSecret Left e -> pure $ ERR e Right _ -> do withLog $ \s -> logAddNotifier s queueId ntfCreds pure $ NID notifierId rcvPublicDhKey deleteQueueNotifier_ :: QueueStore -> m (Transmission BrokerMsg) deleteQueueNotifier_ st = do withLog (`logDeleteNotifier` queueId) okResp <$> atomically (deleteQueueNotifier st queueId) suspendQueue_ :: QueueStore -> m (Transmission BrokerMsg) suspendQueue_ st = do withLog (`logSuspendQueue` queueId) okResp <$> atomically (suspendQueue st queueId) subscribeQueue :: QueueRec -> RecipientId -> m (Transmission BrokerMsg) subscribeQueue qr rId = do atomically (TM.lookup rId subscriptions) >>= \case Nothing -> newSub >>= deliver Just sub -> readTVarIO sub >>= \case Sub {subThread = ProhibitSub} -> -- cannot use SUB in the same connection where GET was used pure (corrId, rId, ERR $ CMD PROHIBITED) s -> atomically (tryTakeTMVar $ delivered s) >> deliver sub where newSub :: m (TVar Sub) newSub = time "SUB newSub" . atomically $ do writeTBQueue subscribedQ (rId, clnt) sub <- newTVar =<< newSubscription NoSub TM.insert rId sub subscriptions pure sub deliver :: TVar Sub -> m (Transmission BrokerMsg) deliver sub = do q <- getStoreMsgQueue "SUB" rId msg_ <- atomically $ tryPeekMsg q deliverMessage "SUB" qr rId sub q msg_ getMessage :: QueueRec -> m (Transmission BrokerMsg) getMessage qr = time "GET" $ do atomically (TM.lookup queueId subscriptions) >>= \case Nothing -> atomically newSub >>= getMessage_ Just sub -> readTVarIO sub >>= \case s@Sub {subThread = ProhibitSub} -> atomically (tryTakeTMVar $ delivered s) >> getMessage_ s -- cannot use GET in the same connection where there is an active subscription _ -> pure (corrId, queueId, ERR $ CMD PROHIBITED) where newSub :: STM Sub newSub = do s <- newSubscription ProhibitSub sub <- newTVar s TM.insert queueId sub subscriptions pure s getMessage_ :: Sub -> m (Transmission BrokerMsg) getMessage_ s = do q <- getStoreMsgQueue "GET" queueId atomically $ tryPeekMsg q >>= \case Just msg -> let encMsg = encryptMsg qr msg in setDelivered s msg $> (corrId, queueId, MSG encMsg) _ -> pure (corrId, queueId, OK) withQueue :: (QueueRec -> m (Transmission BrokerMsg)) -> m (Transmission BrokerMsg) withQueue action = maybe (pure $ err AUTH) action qr_ subscribeNotifications :: m (Transmission BrokerMsg) subscribeNotifications = time "NSUB" . atomically $ do unlessM (TM.member queueId ntfSubscriptions) $ do writeTBQueue ntfSubscribedQ (queueId, clnt) TM.insert queueId () ntfSubscriptions pure ok acknowledgeMsg :: QueueRec -> MsgId -> m (Transmission BrokerMsg) acknowledgeMsg qr msgId = time "ACK" $ do atomically (TM.lookup queueId subscriptions) >>= \case Nothing -> pure $ err NO_MSG Just sub -> atomically (getDelivered sub) >>= \case Just s -> do q <- getStoreMsgQueue "ACK" queueId case s of Sub {subThread = ProhibitSub} -> do msgDeleted <- atomically $ tryDelMsg q msgId when msgDeleted updateStats pure ok _ -> do (msgDeleted, msg_) <- atomically $ tryDelPeekMsg q msgId when msgDeleted updateStats deliverMessage "ACK" qr queueId sub q msg_ _ -> pure $ err NO_MSG where getDelivered :: TVar Sub -> STM (Maybe Sub) getDelivered sub = do s@Sub {delivered} <- readTVar sub tryTakeTMVar delivered $>>= \msgId' -> if msgId == msgId' || B.null msgId then pure $ Just s else putTMVar delivered msgId' $> Nothing updateStats :: m () updateStats = do stats <- asks serverStats atomically $ modifyTVar (msgRecv stats) (+ 1) atomically $ updatePeriodStats (activeQueues stats) queueId sendMessage :: QueueRec -> MsgFlags -> MsgBody -> m (Transmission BrokerMsg) sendMessage qr msgFlags msgBody | B.length msgBody > maxMessageLength = pure $ err LARGE_MSG | otherwise = case status qr of QueueOff -> return $ err AUTH QueueActive -> mapM mkMessage (C.maxLenBS msgBody) >>= \case Left _ -> pure $ err LARGE_MSG Right msg -> do resp@(_, _, sent) <- time "SEND" $ do q <- getStoreMsgQueue "SEND" $ recipientId qr expireMessages q atomically $ ifM (isFull q) (pure $ err QUOTA) (writeMsg q msg $> ok) when (sent == OK) . time "SEND ok" $ do when (notification msgFlags) $ atomically . trySendNotification msg =<< asks idsDrg stats <- asks serverStats atomically $ modifyTVar (msgSent stats) (+ 1) atomically $ updatePeriodStats (activeQueues stats) (recipientId qr) pure resp where mkMessage :: C.MaxLenBS MaxMessageLen -> m Message mkMessage body = do msgId <- randomId =<< asks (msgIdBytes . config) msgTs <- liftIO getSystemTime pure $ Message msgId msgTs msgFlags body expireMessages :: MsgQueue -> m () expireMessages q = do msgExp <- asks $ messageExpiration . config old <- liftIO $ mapM expireBeforeEpoch msgExp atomically $ mapM_ (deleteExpiredMsgs q) old trySendNotification :: Message -> TVar ChaChaDRG -> STM () trySendNotification msg ntfNonceDrg = forM_ (notifier qr) $ \NtfCreds {notifierId, rcvNtfDhSecret} -> mapM_ (writeNtf notifierId msg rcvNtfDhSecret ntfNonceDrg) =<< TM.lookup notifierId notifiers writeNtf :: NotifierId -> Message -> RcvNtfDhSecret -> TVar ChaChaDRG -> Client -> STM () writeNtf nId msg rcvNtfDhSecret ntfNonceDrg Client {sndQ = q} = unlessM (isFullTBQueue q) $ do (nmsgNonce, encNMsgMeta) <- mkMessageNotification msg rcvNtfDhSecret ntfNonceDrg writeTBQueue q [(CorrId "", nId, NMSG nmsgNonce encNMsgMeta)] mkMessageNotification :: Message -> RcvNtfDhSecret -> TVar ChaChaDRG -> STM (C.CbNonce, EncNMsgMeta) mkMessageNotification Message {msgId, msgTs} rcvNtfDhSecret ntfNonceDrg = do cbNonce <- C.pseudoRandomCbNonce ntfNonceDrg let msgMeta = NMsgMeta {msgId, msgTs} encNMsgMeta = C.cbEncrypt rcvNtfDhSecret cbNonce (smpEncode msgMeta) 128 pure . (cbNonce,) $ fromRight "" encNMsgMeta deliverMessage :: T.Text -> QueueRec -> RecipientId -> TVar Sub -> MsgQueue -> Maybe Message -> m (Transmission BrokerMsg) deliverMessage name qr rId sub q msg_ = time (name <> " deliver") $ do readTVarIO sub >>= \case s@Sub {subThread = NoSub} -> case msg_ of Just msg -> let encMsg = encryptMsg qr msg in atomically (setDelivered s msg) $> (corrId, rId, MSG encMsg) _ -> forkSub $> ok _ -> pure ok where forkSub :: m () forkSub = do atomically . modifyTVar sub $ \s -> s {subThread = SubPending} t <- mkWeakThreadId =<< forkIO subscriber atomically . modifyTVar sub $ \case s@Sub {subThread = SubPending} -> s {subThread = SubThread t} s -> s where subscriber = do msg <- atomically $ peekMsg q time "subscriber" . atomically $ do let encMsg = encryptMsg qr msg writeTBQueue sndQ [(CorrId "", rId, MSG encMsg)] s <- readTVar sub void $ setDelivered s msg writeTVar sub s {subThread = NoSub} time :: T.Text -> m a -> m a time name = timed name sessionId queueId encryptMsg :: QueueRec -> Message -> RcvMessage encryptMsg qr Message {msgId, msgTs, msgFlags, msgBody} | thVersion == 1 || thVersion == 2 = encrypt msgBody | otherwise = encrypt $ encodeRcvMsgBody RcvMsgBody {msgTs, msgFlags, msgBody} where encrypt :: KnownNat i => C.MaxLenBS i -> RcvMessage encrypt body = let encBody = EncRcvMsgBody $ C.cbEncryptMaxLenBS (rcvDhSecret qr) (C.cbNonce msgId) body in RcvMessage msgId msgTs msgFlags encBody setDelivered :: Sub -> Message -> STM Bool setDelivered s Message {msgId} = tryPutTMVar (delivered s) msgId getStoreMsgQueue :: T.Text -> RecipientId -> m MsgQueue getStoreMsgQueue name rId = time (name <> " getMsgQueue") $ do ms <- asks msgStore quota <- asks $ msgQueueQuota . config atomically $ getMsgQueue ms rId quota delQueueAndMsgs :: QueueStore -> m (Transmission BrokerMsg) delQueueAndMsgs st = do withLog (`logDeleteQueue` queueId) ms <- asks msgStore stats <- asks serverStats atomically $ modifyTVar (qDeleted stats) (+ 1) atomically $ deleteQueue st queueId >>= \case Left e -> pure $ err e Right _ -> delMsgQueue ms queueId $> ok ok :: Transmission BrokerMsg ok = (corrId, queueId, OK) err :: ErrorType -> Transmission BrokerMsg err e = (corrId, queueId, ERR e) okResp :: Either ErrorType () -> Transmission BrokerMsg okResp = either err $ const ok withLog :: (MonadUnliftIO m, MonadReader Env m) => (StoreLog 'WriteMode -> IO a) -> m () withLog action = do env <- ask liftIO . mapM_ action $ storeLog (env :: Env) timed :: MonadUnliftIO m => T.Text -> ByteString -> RecipientId -> m a -> m a timed name sessId qId a = do t <- liftIO getSystemTime r <- a t' <- liftIO getSystemTime let int = diff t t' when (int > sec) . logDebug $ T.unwords [name, tshow $ encode sessId, tshow $ encode qId, tshow int] pure r where diff t t' = (systemSeconds t' - systemSeconds t) * sec + fromIntegral (systemNanoseconds t' - systemNanoseconds t) sec = 1000_000000 randomId :: (MonadUnliftIO m, MonadReader Env m) => Int -> m ByteString randomId n = do gVar <- asks idsDrg atomically (C.pseudoRandomBytes n gVar) saveServerMessages :: (MonadUnliftIO m, MonadReader Env m) => m () saveServerMessages = asks (storeMsgsFile . config) >>= mapM_ saveMessages where saveMessages f = do logInfo $ "saving messages to file " <> T.pack f ms <- asks msgStore liftIO . withFile f WriteMode $ \h -> readTVarIO ms >>= mapM_ (saveQueueMsgs ms h) . M.keys logInfo "messages saved" where saveQueueMsgs ms h rId = atomically (flushMsgQueue ms rId) >>= mapM_ (B.hPutStrLn h . strEncode . MLRv3 rId) restoreServerMessages :: forall m. (MonadUnliftIO m, MonadReader Env m) => m () restoreServerMessages = asks (storeMsgsFile . config) >>= mapM_ restoreMessages where restoreMessages f = whenM (doesFileExist f) $ do logInfo $ "restoring messages from file " <> T.pack f st <- asks queueStore ms <- asks msgStore quota <- asks $ msgQueueQuota . config runExceptT (liftIO (B.readFile f) >>= mapM_ (restoreMsg st ms quota) . B.lines) >>= \case Left e -> do logError . T.pack $ "error restoring messages: " <> e liftIO exitFailure _ -> do renameFile f $ f <> ".bak" logInfo "messages restored" where restoreMsg st ms quota s = do r <- liftEither . first (msgErr "parsing") $ strDecode s case r of MLRv3 rId msg -> addToMsgQueue rId msg MLRv1 rId encMsg -> do qr <- liftEitherError (msgErr "queue unknown") . atomically $ getQueue st SRecipient rId msg' <- updateMsgV1toV3 qr encMsg addToMsgQueue rId msg' where addToMsgQueue rId msg = do full <- atomically $ do q <- getMsgQueue ms rId quota ifM (isFull q) (pure True) (writeMsg q msg $> False) when full . logError . decodeLatin1 $ "message queue " <> strEncode rId <> " is full, message not restored: " <> strEncode (msgId (msg :: Message)) updateMsgV1toV3 QueueRec {rcvDhSecret} RcvMessage {msgId, msgTs, msgFlags, msgBody = EncRcvMsgBody body} = do let nonce = C.cbNonce msgId msgBody <- liftEither . first (msgErr "v1 message decryption") $ C.maxLenBS =<< C.cbDecrypt rcvDhSecret nonce body pure Message {msgId, msgTs, msgFlags, msgBody} msgErr :: Show e => String -> e -> String msgErr op e = op <> " error (" <> show e <> "): " <> B.unpack (B.take 100 s) saveServerStats :: (MonadUnliftIO m, MonadReader Env m) => m () saveServerStats = asks (serverStatsBackupFile . config) >>= mapM_ (\f -> asks serverStats >>= atomically . getServerStatsData >>= liftIO . saveStats f) where saveStats f stats = do logInfo $ "saving server stats to file " <> T.pack f B.writeFile f $ strEncode stats logInfo "server stats saved" restoreServerStats :: (MonadUnliftIO m, MonadReader Env m) => m () restoreServerStats = asks (serverStatsBackupFile . config) >>= mapM_ restoreStats where restoreStats f = whenM (doesFileExist f) $ do logInfo $ "restoring server stats from file " <> T.pack f liftIO (strDecode <$> B.readFile f) >>= \case Right d -> do s <- asks serverStats atomically $ setServerStats s d renameFile f $ f <> ".bak" logInfo "server stats restored" Left e -> do logInfo $ "error restoring server stats: " <> T.pack e liftIO exitFailure