diff --git a/apps/smp-server/Main.hs b/apps/smp-server/Main.hs index 163b312c8..084603ad5 100644 --- a/apps/smp-server/Main.hs +++ b/apps/smp-server/Main.hs @@ -3,6 +3,7 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE NumericUnderscores #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TypeApplications #-} @@ -283,7 +284,10 @@ runServer IniOptions {enableStoreLog, port, enableWebsockets} = do caCertificateFile = caCrtFile, privateKeyFile = serverKeyFile, certificateFile = serverCrtFile, - storeLog + storeLog, + allowNewQueues = True, + messageTTL = Just $ 7 * 86400, -- 7 days + expireMessagesInterval = Just 21600_000000 -- microseconds, 6 hours } openStoreLog :: IO (Maybe (StoreLog 'ReadMode)) diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 938ba9aa1..560df2e2a 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -33,9 +33,10 @@ import Crypto.Random import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Functor (($>)) +import Data.Int (Int64) import qualified Data.Map.Strict as M import Data.Maybe (isNothing) -import Data.Time.Clock.System (getSystemTime) +import Data.Time.Clock.System (SystemTime (..), getSystemTime) import Data.Type.Equality import Network.Socket (ServiceName) import qualified Simplex.Messaging.Crypto as C @@ -69,34 +70,32 @@ runSMPServer cfg = do -- This function uses passed TMVar to signal when the server is ready to accept TCP requests (True) -- and when it is disconnected from the TCP socket once the server thread is killed (False). runSMPServerBlocking :: (MonadRandom m, MonadUnliftIO m) => TMVar Bool -> ServerConfig -> m () -runSMPServerBlocking started cfg@ServerConfig {transports} = do - env <- newEnv cfg - runReaderT smpServer env - where - smpServer :: (MonadUnliftIO m', MonadReader Env m') => m' () - smpServer = do - s <- asks server - raceAny_ - ( serverThread s subscribedQ subscribers subscriptions cancelSub : - serverThread s ntfSubscribedQ notifiers ntfSubscriptions (\_ -> pure ()) : - map runServer transports - ) - `finally` withLog closeStoreLog +runSMPServerBlocking started cfg = newEnv cfg >>= runReaderT (smpServer started) - runServer :: (MonadUnliftIO m', MonadReader Env m') => (ServiceName, ATransport) -> m' () +smpServer :: forall m. (MonadUnliftIO m, MonadReader Env m) => TMVar Bool -> m () +smpServer started = do + s <- asks server + cfg@ServerConfig {transports} <- asks config + raceAny_ + ( serverThread s subscribedQ subscribers subscriptions cancelSub : + serverThread s ntfSubscribedQ notifiers ntfSubscriptions (\_ -> pure ()) : + map runServer transports <> expireMessagesThread_ cfg + ) + `finally` withLog closeStoreLog + where + runServer :: (ServiceName, ATransport) -> m () runServer (tcpPort, ATransport t) = do serverParams <- asks tlsServerParams runTransportServer started tcpPort serverParams (runClient t) serverThread :: - forall m' s. - MonadUnliftIO m' => + forall s. Server -> (Server -> TBQueue (QueueId, Client)) -> (Server -> TMap QueueId Client) -> (Client -> TMap QueueId s) -> - (s -> m' ()) -> - m' () + (s -> m ()) -> + m () serverThread s subQ subs clientSubs unsub = forever $ do atomically updateSubscribers >>= fmap join . mapM endPreviousSubscriptions @@ -113,13 +112,31 @@ runSMPServerBlocking started cfg@ServerConfig {transports} = do pure $ if yes then Just (qId, c') else Nothing TM.lookupInsert qId clnt (subs s) >>= fmap join . mapM clientToBeNotified - endPreviousSubscriptions :: (QueueId, Client) -> m' (Maybe s) + endPreviousSubscriptions :: (QueueId, Client) -> m (Maybe s) endPreviousSubscriptions (qId, c) = do void . forkIO . atomically $ writeTBQueue (sndQ c) (CorrId "", qId, END) atomically $ TM.lookupDelete qId (clientSubs c) - runClient :: (Transport c, MonadUnliftIO m, MonadReader Env m) => TProxy c -> c -> m () + expireMessagesThread_ :: ServerConfig -> [m ()] + expireMessagesThread_ ServerConfig {messageTTL, expireMessagesInterval} = + case (messageTTL, expireMessagesInterval) of + (Just ttl, Just int) -> [expireMessages ttl int] + _ -> [] + + expireMessages :: Int64 -> Int -> m () + expireMessages ttl interval = do + ms <- asks msgStore + quota <- asks $ msgQueueQuota . config + forever $ do + threadDelay interval + old <- subtract ttl . systemSeconds <$> liftIO getSystemTime + rIds <- M.keysSet <$> readTVarIO ms + forM_ rIds $ \rId -> + atomically (getMsgQueue ms rId quota) + >>= atomically . (`deleteExpiredMsgs` old) + + runClient :: Transport c => TProxy c -> c -> m () runClient _ h = do kh <- asks serverIdentity liftIO (runExceptT $ serverHandshake h kh) >>= \case @@ -231,7 +248,11 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscri Cmd SNotifier NSUB -> subscribeNotifications Cmd SRecipient command -> case command of - NEW rKey dhKey -> createQueue st rKey dhKey + NEW rKey dhKey -> + ifM + (asks $ allowNewQueues . config) + (createQueue st rKey dhKey) + (pure (corrId, queueId, ERR AUTH)) SUB -> subscribeQueue queueId ACK -> acknowledgeMsg KEY sKey -> secureQueue_ st sKey @@ -350,9 +371,11 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscri Left _ -> pure $ err LARGE_MSG Right msg -> do ms <- asks msgStore - quota <- asks $ msgQueueQuota . config + ServerConfig {messageTTL, msgQueueQuota} <- asks config + old <- forM messageTTL $ \ttl -> subtract ttl . systemSeconds <$> liftIO getSystemTime atomically $ do - q <- getMsgQueue ms (recipientId qr) quota + q <- getMsgQueue ms (recipientId qr) msgQueueQuota + mapM_ (deleteExpiredMsgs q) old ifM (isFull q) (pure $ err QUOTA) $ do trySendNotification writeMsg q msg diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index 3c4599a97..8a13121b9 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -9,6 +9,7 @@ import Control.Concurrent (ThreadId) import Control.Monad.IO.Unlift import Crypto.Random import Data.ByteString.Char8 (ByteString) +import Data.Int (Int64) import Data.Map.Strict (Map) import qualified Data.Map.Strict as M import Data.X509.Validation (Fingerprint (..)) @@ -36,6 +37,12 @@ data ServerConfig = ServerConfig queueIdBytes :: Int, msgIdBytes :: Int, storeLog :: Maybe (StoreLog 'ReadMode), + -- | set to False to prohibit creating new queues + allowNewQueues :: Bool, + -- | time after which the messages can be removed from the queues, seconds + messageTTL :: Maybe Int64, + -- | interval to periodically remove expired messages (when no messages are sent to the queue), microseconds + expireMessagesInterval :: Maybe Int, -- CA certificate private key is not needed for initialization caCertificateFile :: FilePath, privateKeyFile :: FilePath, diff --git a/src/Simplex/Messaging/Server/MsgStore.hs b/src/Simplex/Messaging/Server/MsgStore.hs index 9da8492f7..a7180311f 100644 --- a/src/Simplex/Messaging/Server/MsgStore.hs +++ b/src/Simplex/Messaging/Server/MsgStore.hs @@ -2,6 +2,7 @@ module Simplex.Messaging.Server.MsgStore where +import Data.Int (Int64) import Data.Time.Clock.System (SystemTime) import Numeric.Natural import Simplex.Messaging.Protocol (MsgBody, MsgId, RecipientId) @@ -22,3 +23,4 @@ class MonadMsgQueue q m where tryPeekMsg :: q -> m (Maybe Message) -- non blocking peekMsg :: q -> m Message -- blocking tryDelPeekMsg :: q -> m (Maybe Message) -- atomic delete (== read) last and peek next message, if available + deleteExpiredMsgs :: q -> Int64 -> m () diff --git a/src/Simplex/Messaging/Server/MsgStore/STM.hs b/src/Simplex/Messaging/Server/MsgStore/STM.hs index 86d6db996..be395f918 100644 --- a/src/Simplex/Messaging/Server/MsgStore/STM.hs +++ b/src/Simplex/Messaging/Server/MsgStore/STM.hs @@ -3,9 +3,13 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE NamedFieldPuns #-} module Simplex.Messaging.Server.MsgStore.STM where +import Control.Monad (when) +import Data.Int (Int64) +import Data.Time.Clock.System (SystemTime (systemSeconds)) import Numeric.Natural import Simplex.Messaging.Protocol (RecipientId) import Simplex.Messaging.Server.MsgStore @@ -48,3 +52,11 @@ instance MonadMsgQueue MsgQueue STM where -- atomic delete (== read) last and peek next message if available tryDelPeekMsg :: MsgQueue -> STM (Maybe Message) tryDelPeekMsg (MsgQueue q) = tryReadTBQueue q >> tryPeekTBQueue q + + deleteExpiredMsgs :: MsgQueue -> Int64 -> STM () + deleteExpiredMsgs (MsgQueue q) old = loop + where + loop = tryPeekTBQueue q >>= mapM_ delOldMsg + delOldMsg Message {ts} = + when (systemSeconds ts < old) $ + tryReadTBQueue q >> loop diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index 3abf2c4d9..02c4024ff 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -61,6 +61,9 @@ cfg = queueIdBytes = 24, msgIdBytes = 24, storeLog = Nothing, + allowNewQueues = True, + messageTTL = Just $ 7 * 86400, -- seconds, 7 days + expireMessagesInterval = Just 21600_000000, -- microseconds, 6 hours caCertificateFile = "tests/fixtures/ca.crt", privateKeyFile = "tests/fixtures/server.key", certificateFile = "tests/fixtures/server.crt" @@ -69,16 +72,16 @@ cfg = withSmpServerStoreLogOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> ServiceName -> (ThreadId -> m a) -> m a withSmpServerStoreLogOn t port' client = do s <- liftIO $ openReadStoreLog testStoreLogFile + withSmpServerConfigOn t cfg {storeLog = Just s} port' client + +withSmpServerConfigOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> ServerConfig -> ServiceName -> (ThreadId -> m a) -> m a +withSmpServerConfigOn t cfg' port' = serverBracket - (\started -> runSMPServerBlocking started cfg {transports = [(port', t)], storeLog = Just s}) + (\started -> runSMPServerBlocking started cfg' {transports = [(port', t)]}) (pure ()) - client withSmpServerThreadOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> ServiceName -> (ThreadId -> m a) -> m a -withSmpServerThreadOn t port' = - serverBracket - (\started -> runSMPServerBlocking started cfg {transports = [(port', t)]}) - (pure ()) +withSmpServerThreadOn t = withSmpServerConfigOn t cfg serverBracket :: MonadUnliftIO m => (TMVar Bool -> m ()) -> m () -> (ThreadId -> m a) -> m a serverBracket process afterProcess f = do diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index 418c86a5e..1719caab9 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -9,7 +9,7 @@ module ServerTests where -import Control.Concurrent (ThreadId, killThread) +import Control.Concurrent (ThreadId, killThread, threadDelay) import Control.Concurrent.STM import Control.Exception (SomeException, try) import Control.Monad.Except (forM, forM_, runExceptT) @@ -21,6 +21,7 @@ import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol +import Simplex.Messaging.Server.Env.STM (ServerConfig (..)) import Simplex.Messaging.Transport import System.Directory (removeFile) import System.TimeIt (timeItT) @@ -29,18 +30,23 @@ import Test.HUnit import Test.Hspec serverTests :: ATransport -> Spec -serverTests t = do +serverTests t@(ATransport t') = do describe "SMP syntax" $ syntaxTests t describe "SMP queues" $ do describe "NEW and KEY commands, SEND messages" $ testCreateSecure t describe "NEW, OFF and DEL commands, SEND messages" $ testCreateDelete t describe "Stress test" $ stressTest t + describe "allowNewQueues setting" $ testAllowNewQueues t' describe "SMP messages" $ do describe "duplex communication over 2 SMP connections" $ testDuplex t describe "switch subscription to another TCP connection" $ testSwitchSub t describe "Store log" $ testWithStoreLog t describe "Timing of AUTH error" $ testTiming t describe "Message notifications" $ testMessageNotifications t + describe "Message expiration" $ do + testMsgExpireOnSend t' + testMsgExpireOnInterval t' + testMsgNOTExpireOnInterval t' pattern Resp :: CorrId -> QueueId -> BrokerMsg -> SignedTransmission BrokerMsg pattern Resp corrId queueId command <- (_, _, (corrId, queueId, Right command)) @@ -204,6 +210,16 @@ stressTest (ATransport t) = closeConnection $ connection h2 subscribeQueues h3 +testAllowNewQueues :: forall c. Transport c => TProxy c -> Spec +testAllowNewQueues t = + it "should prohibit creating new queues with allowNewQueues = False" $ do + withSmpServerConfigOn (ATransport t) cfg {allowNewQueues = False} testPort $ \_ -> + testSMPClient @c $ \h -> do + (rPub, rKey) <- C.generateSignatureKeyPair C.SEd448 + (dhPub, _ :: C.PrivateKeyX25519) <- C.generateKeyPair' + Resp "abcd" "" (ERR AUTH) <- signSendRecv h rKey ("abcd", "", NEW rPub dhPub) + pure () + testDuplex :: ATransport -> Spec testDuplex (ATransport t) = it "should create 2 simplex connections and exchange messages" $ @@ -466,6 +482,56 @@ testMessageNotifications (ATransport t) = Nothing -> return () Just _ -> error "nothing else should be delivered to the 1st notifier's TCP connection" +testMsgExpireOnSend :: forall c. Transport c => TProxy c -> Spec +testMsgExpireOnSend t = + it "should expire messages that are not received before messageTTL on SEND" $ do + (sPub, sKey) <- C.generateSignatureKeyPair C.SEd25519 + withSmpServerConfigOn (ATransport t) cfg {messageTTL = Just 1} testPort $ \_ -> + testSMPClient @c $ \sh -> do + (sId, rId, rKey, dhShared) <- testSMPClient @c $ \rh -> createAndSecureQueue rh sPub + let dec nonce = C.cbDecrypt dhShared (C.cbNonce nonce) + Resp "1" _ OK <- signSendRecv sh sKey ("1", sId, SEND "hello (should expire)") + threadDelay 2000000 + Resp "2" _ OK <- signSendRecv sh sKey ("2", sId, SEND "hello (should NOT expire)") + testSMPClient @c $ \rh -> do + Resp "3" _ (MSG mId _ msg) <- signSendRecv rh rKey ("3", rId, SUB) + (dec mId msg, Right "hello (should NOT expire)") #== "delivered" + 1000 `timeout` tGet @BrokerMsg rh >>= \case + Nothing -> return () + Just _ -> error "nothing else should be delivered" + +testMsgExpireOnInterval :: forall c. Transport c => TProxy c -> Spec +testMsgExpireOnInterval t = + it "should expire messages that are not received before messageTTL after expiry interval" $ do + (sPub, sKey) <- C.generateSignatureKeyPair C.SEd25519 + withSmpServerConfigOn (ATransport t) cfg {messageTTL = Just 1, expireMessagesInterval = Just 1000000} testPort $ \_ -> + testSMPClient @c $ \sh -> do + (sId, rId, rKey, _) <- testSMPClient @c $ \rh -> createAndSecureQueue rh sPub + Resp "1" _ OK <- signSendRecv sh sKey ("1", sId, SEND "hello (should expire)") + threadDelay 2000000 + testSMPClient @c $ \rh -> do + Resp "2" _ OK <- signSendRecv rh rKey ("2", rId, SUB) + 1000 `timeout` tGet @BrokerMsg rh >>= \case + Nothing -> return () + Just _ -> error "nothing should be delivered" + +testMsgNOTExpireOnInterval :: forall c. Transport c => TProxy c -> Spec +testMsgNOTExpireOnInterval t = + it "should NOT expire messages that are not received before messageTTL if expiry interval is not set" $ do + (sPub, sKey) <- C.generateSignatureKeyPair C.SEd25519 + withSmpServerConfigOn (ATransport t) cfg {messageTTL = Just 1, expireMessagesInterval = Nothing} testPort $ \_ -> + testSMPClient @c $ \sh -> do + (sId, rId, rKey, dhShared) <- testSMPClient @c $ \rh -> createAndSecureQueue rh sPub + let dec nonce = C.cbDecrypt dhShared (C.cbNonce nonce) + Resp "1" _ OK <- signSendRecv sh sKey ("1", sId, SEND "hello (should NOT expire)") + threadDelay 2000000 + testSMPClient @c $ \rh -> do + Resp "2" _ (MSG mId _ msg) <- signSendRecv rh rKey ("2", rId, SUB) + (dec mId msg, Right "hello (should NOT expire)") #== "delivered" + 1000 `timeout` tGet @BrokerMsg rh >>= \case + Nothing -> return () + Just _ -> error "nothing else should be delivered" + samplePubKey :: C.APublicVerifyKey samplePubKey = C.APublicVerifyKey C.SEd25519 "MCowBQYDK2VwAyEAfAOflyvbJv1fszgzkQ6buiZJVgSpQWsucXq7U6zjMgY="