diff --git a/CHANGELOG.md b/CHANGELOG.md index 90706a19e..9696a64b3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,26 @@ +# 6.0.0 + +Version 6.0.0.8 + +Agent: +- enabled fast handshake support. +- batch-send multiple messages in each connection. +- resume subscriptions as soon as agent moves to foreground or as network connection resumes. +- "known" servers to determine whether to use SMP proxy. +- retry on SMP proxy NO_SESSION error. +- fixes to notification subscriptions. +- persistent server statistics. +- better concurrency. + +SMP server: +- reduce threads usage. +- additional statistics. +- improve disabling inactive clients. +- additional control port commands for monitoring. + +Notification server: +- support onion-only SMP servers. + # 5.8.2 Agent: diff --git a/package.yaml b/package.yaml index 94c2b6db0..26cdcc51a 100644 --- a/package.yaml +++ b/package.yaml @@ -1,5 +1,5 @@ name: simplexmq -version: 5.8.2.0 +version: 6.0.0.8 synopsis: SimpleXMQ message broker description: | This package includes <./docs/Simplex-Messaging-Server.html server>, diff --git a/simplexmq.cabal b/simplexmq.cabal index d1fa32d43..d557ac509 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -5,7 +5,7 @@ cabal-version: 1.12 -- see: https://github.com/sol/hpack name: simplexmq -version: 5.8.2.0 +version: 6.0.0.8 synopsis: SimpleXMQ message broker description: This package includes <./docs/Simplex-Messaging-Server.html server>, <./docs/Simplex-Messaging-Client.html client> and @@ -133,8 +133,8 @@ library Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240223_connections_wait_delivery Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240225_ratchet_kem Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240417_rcv_files_approved_relays - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240518_servers_stats Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240624_snd_secure + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240702_servers_stats Simplex.Messaging.Agent.TRcvQueues Simplex.Messaging.Client Simplex.Messaging.Client.Agent diff --git a/src/Simplex/FileTransfer/Agent.hs b/src/Simplex/FileTransfer/Agent.hs index 90dda5cbc..d6ee75ae9 100644 --- a/src/Simplex/FileTransfer/Agent.hs +++ b/src/Simplex/FileTransfer/Agent.hs @@ -12,6 +12,7 @@ module Simplex.FileTransfer.Agent ( startXFTPWorkers, + startXFTPSndWorkers, closeXFTPAgent, toFSFilePath, -- Receiving files @@ -42,13 +43,14 @@ import Data.Either (partitionEithers, rights) import Data.Int (Int64) import Data.List (foldl', partition, sortOn) import qualified Data.List.NonEmpty as L -import Data.Map (Map) +import Data.Map.Strict (Map) import qualified Data.Map.Strict as M import Data.Maybe (mapMaybe) import qualified Data.Set as S import Data.Text (Text) import Data.Time.Clock (getCurrentTime) import Data.Time.Format (defaultTimeLocale, formatTime) +import Simplex.FileTransfer.Chunks (toKB) import Simplex.FileTransfer.Client (XFTPChunkSpec (..)) import Simplex.FileTransfer.Client.Main import Simplex.FileTransfer.Crypto @@ -81,13 +83,21 @@ import UnliftIO.Directory import qualified UnliftIO.Exception as E startXFTPWorkers :: AgentClient -> Maybe FilePath -> AM () -startXFTPWorkers c workDir = do +startXFTPWorkers = startXFTPWorkers_ True +{-# INLINE startXFTPWorkers #-} + +startXFTPSndWorkers :: AgentClient -> Maybe FilePath -> AM () +startXFTPSndWorkers = startXFTPWorkers_ False +{-# INLINE startXFTPSndWorkers #-} + +startXFTPWorkers_ :: Bool -> AgentClient -> Maybe FilePath -> AM () +startXFTPWorkers_ allWorkers c workDir = do wd <- asks $ xftpWorkDir . xftpAgent atomically $ writeTVar wd workDir cfg <- asks config - startRcvFiles cfg + when allWorkers $ startRcvFiles cfg startSndFiles cfg - startDelFiles cfg + when allWorkers $ startDelFiles cfg where startRcvFiles :: AgentConfig -> AM () startRcvFiles AgentConfig {rcvFilesTTL} = do @@ -174,7 +184,7 @@ runXFTPRcvWorker c srv Worker {doWork} = do cfg <- asks config forever $ do lift $ waitForWork doWork - atomically $ assertAgentForeground c + liftIO $ assertAgentForeground c runXFTPOperation cfg where runXFTPOperation :: AgentConfig -> AM () @@ -184,6 +194,7 @@ runXFTPRcvWorker c srv Worker {doWork} = do (fc@RcvFileChunk {userId, rcvFileId, rcvFileEntityId, digest, fileTmpPath, replicas = replica@RcvFileChunkReplica {rcvChunkReplicaId, server, delay} : _}, approvedRelays) -> do let ri' = maybe ri (\d -> ri {initialInterval = d, increaseAfter = 0}) delay withRetryIntervalLimit xftpConsecutiveRetries ri' $ \delay' loop -> do + liftIO $ waitWhileSuspended c liftIO $ waitForUserNetwork c atomically $ incXFTPServerStat c userId srv downloadAttempts downloadFileChunk fc replica approvedRelays @@ -194,7 +205,7 @@ runXFTPRcvWorker c srv Worker {doWork} = do when (serverHostError e) $ notify c rcvFileEntityId $ RFWARN e liftIO $ closeXFTPServerClient c userId server digest withStore' c $ \db -> updateRcvChunkReplicaDelay db rcvChunkReplicaId replicaDelay - atomically $ assertAgentForeground c + liftIO $ assertAgentForeground c loop retryDone e = do atomically . incXFTPServerStat c userId srv $ case e of @@ -206,10 +217,11 @@ runXFTPRcvWorker c srv Worker {doWork} = do unlessM ((approvedRelays ||) <$> ipAddressProtected') $ throwE $ FILE NOT_APPROVED fsFileTmpPath <- lift $ toFSFilePath fileTmpPath chunkPath <- uniqueCombine fsFileTmpPath $ show chunkNo - let chunkSpec = XFTPRcvChunkSpec chunkPath (unFileSize chunkSize) (unFileDigest digest) + let chSize = unFileSize chunkSize + chunkSpec = XFTPRcvChunkSpec chunkPath chSize (unFileDigest digest) relChunkPath = fileTmpPath takeFileName chunkPath agentXFTPDownloadChunk c userId digest replica chunkSpec - atomically $ waitUntilForeground c + liftIO $ waitUntilForeground c (entityId, complete, progress) <- withStore c $ \db -> runExceptT $ do liftIO $ updateRcvFileChunkReceived db (rcvChunkReplicaId replica) rcvChunkId relChunkPath RcvFile {size = FileSize currentSize, chunks, redirect} <- ExceptT $ getRcvFile db rcvFileId @@ -221,13 +233,14 @@ runXFTPRcvWorker c srv Worker {doWork} = do liftIO . when complete $ updateRcvFileStatus db rcvFileId RFSReceived pure (entityId, complete, RFPROG rcvd total) atomically $ incXFTPServerStat c userId srv downloads + atomically $ incXFTPServerSizeStat c userId srv downloadsSize (fromIntegral $ toKB chSize) notify c entityId progress when complete . lift . void $ getXFTPRcvWorker True c Nothing where ipAddressProtected' :: AM Bool ipAddressProtected' = do - cfg <- liftIO $ getNetworkConfig' c + cfg <- liftIO $ getFastNetworkConfig c pure $ ipAddressProtected cfg srv receivedSize :: [RcvFileChunk] -> Int64 receivedSize = foldl' (\sz ch -> sz + receivedChunkSize ch) 0 @@ -260,7 +273,7 @@ runXFTPRcvLocalWorker c Worker {doWork} = do cfg <- asks config forever $ do lift $ waitForWork doWork - atomically $ assertAgentForeground c + liftIO $ assertAgentForeground c runXFTPOperation cfg where runXFTPOperation :: AgentConfig -> AM () @@ -286,12 +299,12 @@ runXFTPRcvLocalWorker c Worker {doWork} = do Nothing -> do notify c rcvFileEntityId $ RFDONE fsSavePath lift $ forM_ tmpPath (removePath <=< toFSFilePath) - atomically $ waitUntilForeground c + liftIO $ waitUntilForeground c withStore' c (`updateRcvFileComplete` rcvFileId) Just RcvFileRedirect {redirectFileInfo, redirectDbId} -> do let RedirectFileInfo {size = redirectSize, digest = redirectDigest} = redirectFileInfo lift $ forM_ tmpPath (removePath <=< toFSFilePath) - atomically $ waitUntilForeground c + liftIO $ waitUntilForeground c withStore' c (`updateRcvFileComplete` rcvFileId) -- proceed with redirect yaml <- liftError (FILE . FILE_IO . show) (CF.readFile $ CryptoFile fsSavePath cfArgs) `agentFinally` (lift $ toFSFilePath fsSavePath >>= removePath) @@ -379,7 +392,7 @@ runXFTPSndPrepareWorker c Worker {doWork} = do cfg <- asks config forever $ do lift $ waitForWork doWork - atomically $ assertAgentForeground c + liftIO $ assertAgentForeground c runXFTPOperation cfg where runXFTPOperation :: AgentConfig -> AM () @@ -441,7 +454,7 @@ runXFTPSndPrepareWorker c Worker {doWork} = do SndFileChunkReplica {server} : _ -> Right server createChunk :: Int -> SndFileChunk -> AM (ProtocolServer 'PXFTP) createChunk numRecipients' ch = do - atomically $ assertAgentForeground c + liftIO $ assertAgentForeground c (replica, ProtoServerWithAuth srv _) <- tryCreate withStore' c $ \db -> createSndFileReplica db ch replica pure srv @@ -449,8 +462,9 @@ runXFTPSndPrepareWorker c Worker {doWork} = do tryCreate = do usedSrvs <- newTVarIO ([] :: [XFTPServer]) let AgentClient {xftpServers} = c - userSrvCount <- length <$> atomically (TM.lookup userId xftpServers) + userSrvCount <- liftIO $ length <$> TM.lookupIO userId xftpServers withRetryIntervalCount (riFast ri) $ \n _ loop -> do + liftIO $ waitWhileSuspended c liftIO $ waitForUserNetwork c let triedAllSrvs = n > userSrvCount createWithNextSrv usedSrvs @@ -460,7 +474,7 @@ runXFTPSndPrepareWorker c Worker {doWork} = do retryLoop loop triedAllSrvs e = do flip catchAgentError (\_ -> pure ()) $ do when (triedAllSrvs && serverHostError e) $ notify c sndFileEntityId $ SFWARN e - atomically $ assertAgentForeground c + liftIO $ assertAgentForeground c loop createWithNextSrv usedSrvs = do deleted <- withStore' c $ \db -> getSndFileDeleted db sndFileId @@ -480,7 +494,7 @@ runXFTPSndWorker c srv Worker {doWork} = do cfg <- asks config forever $ do lift $ waitForWork doWork - atomically $ assertAgentForeground c + liftIO $ assertAgentForeground c runXFTPOperation cfg where runXFTPOperation :: AgentConfig -> AM () @@ -490,6 +504,7 @@ runXFTPSndWorker c srv Worker {doWork} = do fc@SndFileChunk {userId, sndFileId, sndFileEntityId, filePrefixPath, digest, replicas = replica@SndFileChunkReplica {sndChunkReplicaId, server, delay} : _} -> do let ri' = maybe ri (\d -> ri {initialInterval = d, increaseAfter = 0}) delay withRetryIntervalLimit xftpConsecutiveRetries ri' $ \delay' loop -> do + liftIO $ waitWhileSuspended c liftIO $ waitForUserNetwork c atomically $ incXFTPServerStat c userId srv uploadAttempts uploadFileChunk cfg fc replica @@ -500,20 +515,20 @@ runXFTPSndWorker c srv Worker {doWork} = do when (serverHostError e) $ notify c sndFileEntityId $ SFWARN e liftIO $ closeXFTPServerClient c userId server digest withStore' c $ \db -> updateSndChunkReplicaDelay db sndChunkReplicaId replicaDelay - atomically $ assertAgentForeground c + liftIO $ assertAgentForeground c loop retryDone e = do atomically $ incXFTPServerStat c userId srv uploadErrs sndWorkerInternalError c sndFileId sndFileEntityId (Just filePrefixPath) e uploadFileChunk :: AgentConfig -> SndFileChunk -> SndFileChunkReplica -> AM () - uploadFileChunk AgentConfig {xftpMaxRecipientsPerRequest = maxRecipients} sndFileChunk@SndFileChunk {sndFileId, userId, chunkSpec = chunkSpec@XFTPChunkSpec {filePath}, digest = chunkDigest} replica = do + uploadFileChunk AgentConfig {xftpMaxRecipientsPerRequest = maxRecipients} sndFileChunk@SndFileChunk {sndFileId, userId, chunkSpec = chunkSpec@XFTPChunkSpec {filePath, chunkSize = chSize}, digest = chunkDigest} replica = do replica'@SndFileChunkReplica {sndChunkReplicaId} <- addRecipients sndFileChunk replica fsFilePath <- lift $ toFSFilePath filePath unlessM (doesFileExist fsFilePath) $ throwE $ FILE NO_FILE let chunkSpec' = chunkSpec {filePath = fsFilePath} :: XFTPChunkSpec - atomically $ assertAgentForeground c + liftIO $ assertAgentForeground c agentXFTPUploadChunk c userId chunkDigest replica' chunkSpec' - atomically $ waitUntilForeground c + liftIO $ waitUntilForeground c sf@SndFile {sndFileEntityId, prefixPath, chunks} <- withStore c $ \db -> do updateSndChunkReplicaStatus db sndChunkReplicaId SFRSUploaded getSndFile db sndFileId @@ -521,6 +536,7 @@ runXFTPSndWorker c srv Worker {doWork} = do total = totalSize chunks complete = all chunkUploaded chunks atomically $ incXFTPServerStat c userId srv uploads + atomically $ incXFTPServerSizeStat c userId srv uploadsSize (fromIntegral $ toKB chSize) notify c sndFileEntityId $ SFPROG uploaded total when complete $ do (sndDescr, rcvDescrs) <- sndFileToDescrs sf @@ -650,7 +666,7 @@ runXFTPDelWorker c srv Worker {doWork} = do cfg <- asks config forever $ do lift $ waitForWork doWork - atomically $ assertAgentForeground c + liftIO $ assertAgentForeground c runXFTPOperation cfg where runXFTPOperation :: AgentConfig -> AM () @@ -661,6 +677,7 @@ runXFTPDelWorker c srv Worker {doWork} = do processDeletedReplica replica@DeletedSndChunkReplica {deletedSndChunkReplicaId, userId, server, chunkDigest, delay} = do let ri' = maybe ri (\d -> ri {initialInterval = d, increaseAfter = 0}) delay withRetryIntervalLimit xftpConsecutiveRetries ri' $ \delay' loop -> do + liftIO $ waitWhileSuspended c liftIO $ waitForUserNetwork c atomically $ incXFTPServerStat c userId srv deleteAttempts deleteChunkReplica @@ -671,7 +688,7 @@ runXFTPDelWorker c srv Worker {doWork} = do when (serverHostError e) $ notify c "" $ SFWARN e liftIO $ closeXFTPServerClient c userId server chunkDigest withStore' c $ \db -> updateDeletedSndChunkReplicaDelay db deletedSndChunkReplicaId replicaDelay - atomically $ assertAgentForeground c + liftIO $ assertAgentForeground c loop retryDone e = do atomically $ incXFTPServerStat c userId srv deleteErrs @@ -686,7 +703,7 @@ delWorkerInternalError c deletedSndChunkReplicaId e = do withStore' c $ \db -> deleteDeletedSndChunkReplica db deletedSndChunkReplicaId notify c "" $ SFERR e -assertAgentForeground :: AgentClient -> STM () +assertAgentForeground :: AgentClient -> IO () assertAgentForeground c = do throwWhenInactive c waitUntilForeground c diff --git a/src/Simplex/FileTransfer/Chunks.hs b/src/Simplex/FileTransfer/Chunks.hs index 0b35649c5..d8890944d 100644 --- a/src/Simplex/FileTransfer/Chunks.hs +++ b/src/Simplex/FileTransfer/Chunks.hs @@ -26,6 +26,10 @@ kb :: Integral a => a -> a kb n = 1024 * n {-# INLINE kb #-} +toKB :: Integral a => a -> a +toKB n = n `div` 1024 +{-# INLINE toKB #-} + mb :: Integral a => a -> a mb n = 1024 * kb n {-# INLINE mb #-} diff --git a/src/Simplex/FileTransfer/Client/Agent.hs b/src/Simplex/FileTransfer/Client/Agent.hs index 86b093ee7..863a91ce1 100644 --- a/src/Simplex/FileTransfer/Client/Agent.hs +++ b/src/Simplex/FileTransfer/Client/Agent.hs @@ -53,9 +53,9 @@ defaultXFTPClientAgentConfig = data XFTPClientAgentError = XFTPClientAgentError XFTPServer XFTPClientError deriving (Show, Exception) -newXFTPAgent :: XFTPClientAgentConfig -> STM XFTPClientAgent +newXFTPAgent :: XFTPClientAgentConfig -> IO XFTPClientAgent newXFTPAgent config = do - xftpClients <- TM.empty + xftpClients <- TM.emptyIO pure XFTPClientAgent {xftpClients, config} type ME a = ExceptT XFTPClientAgentError IO a diff --git a/src/Simplex/FileTransfer/Client/Main.hs b/src/Simplex/FileTransfer/Client/Main.hs index aeac956e6..1eea6ef5a 100644 --- a/src/Simplex/FileTransfer/Client/Main.hs +++ b/src/Simplex/FileTransfer/Client/Main.hs @@ -43,7 +43,7 @@ import Data.Int (Int64) import Data.List (foldl', sortOn) import Data.List.NonEmpty (NonEmpty (..), nonEmpty) import qualified Data.List.NonEmpty as L -import Data.Map (Map) +import Data.Map.Strict (Map) import qualified Data.Map as M import Data.Maybe (fromMaybe, listToMaybe) import qualified Data.Text as T @@ -313,7 +313,7 @@ cliSendFileOpts SendOptions {filePath, outputDir, numRecipients, xftpServers, re pure (encPath, fdRcv, fdSnd, chunkSpecs, encSize) uploadFile :: TVar ChaChaDRG -> [XFTPChunkSpec] -> TVar [Int64] -> Int64 -> ExceptT CLIError IO [SentFileChunk] uploadFile g chunks uploadedChunks encSize = do - a <- atomically $ newXFTPAgent defaultXFTPClientAgentConfig + a <- liftIO $ newXFTPAgent defaultXFTPClientAgentConfig gen <- newTVarIO =<< liftIO newStdGen let xftpSrvs = fromMaybe defaultXFTPServers (nonEmpty xftpServers) srvs <- liftIO $ replicateM (length chunks) $ getXFTPServer gen xftpSrvs @@ -429,7 +429,7 @@ cliReceiveFile ReceiveOptions {fileDescription, filePath, retryCount, tempPath, receive (ValidFileDescription FileDescription {size, digest, key, nonce, chunks}) = do encPath <- getEncPath tempPath "xftp" createDirectory encPath - a <- atomically $ newXFTPAgent defaultXFTPClientAgentConfig + a <- liftIO $ newXFTPAgent defaultXFTPClientAgentConfig liftIO $ printNoNewLine "Downloading file..." downloadedChunks <- newTVarIO [] let srv FileChunk {replicas} = case replicas of @@ -494,7 +494,7 @@ cliDeleteFile DeleteOptions {fileDescription, retryCount, yes} = do where deleteFile :: ValidFileDescription 'FSender -> ExceptT CLIError IO () deleteFile (ValidFileDescription FileDescription {chunks}) = do - a <- atomically $ newXFTPAgent defaultXFTPClientAgentConfig + a <- liftIO $ newXFTPAgent defaultXFTPClientAgentConfig forM_ chunks $ deleteFileChunk a liftIO $ do printNoNewLine "File deleted!" diff --git a/src/Simplex/FileTransfer/Description.hs b/src/Simplex/FileTransfer/Description.hs index d5b5e5105..c702a177f 100644 --- a/src/Simplex/FileTransfer/Description.hs +++ b/src/Simplex/FileTransfer/Description.hs @@ -52,7 +52,7 @@ import Data.Int (Int64) import Data.List (foldl', sortOn) import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L -import Data.Map (Map) +import Data.Map.Strict (Map) import qualified Data.Map as M import Data.Maybe (fromMaybe) import Data.String diff --git a/src/Simplex/FileTransfer/Server.hs b/src/Simplex/FileTransfer/Server.hs index 24dcc5e38..819be9a81 100644 --- a/src/Simplex/FileTransfer/Server.hs +++ b/src/Simplex/FileTransfer/Server.hs @@ -112,7 +112,7 @@ xftpServer cfg@XFTPServerConfig {xftpPort, transportConfig, inactiveClientExpira Right pk' -> pure pk' Left e -> putStrLn ("servers has no valid key: " <> show e) >> exitFailure env <- ask - sessions <- atomically TM.empty + sessions <- liftIO TM.emptyIO let cleanup sessionId = atomically $ TM.delete sessionId sessions liftIO . runHTTP2Server started xftpPort defaultHTTP2BufferSize serverParams transportConfig inactiveClientExpiration cleanup $ \sessionId sessionALPN r sendResponse -> do reqBody <- getHTTP2Body r xftpBlockSize @@ -576,7 +576,7 @@ incFileStat statSel = do saveServerStats :: M () saveServerStats = asks (serverStatsBackupFile . config) - >>= mapM_ (\f -> asks serverStats >>= atomically . getFileServerStatsData >>= liftIO . saveStats f) + >>= mapM_ (\f -> asks serverStats >>= liftIO . getFileServerStatsData >>= liftIO . saveStats f) where saveStats f stats = do logInfo $ "saving server stats to file " <> T.pack f diff --git a/src/Simplex/FileTransfer/Server/Env.hs b/src/Simplex/FileTransfer/Server/Env.hs index f8a6bc996..1fa399a2a 100644 --- a/src/Simplex/FileTransfer/Server/Env.hs +++ b/src/Simplex/FileTransfer/Server/Env.hs @@ -11,7 +11,6 @@ module Simplex.FileTransfer.Server.Env where import Control.Logger.Simple import Control.Monad -import Control.Monad.IO.Unlift import Crypto.Random import Data.Int (Int64) import Data.List.NonEmpty (NonEmpty) @@ -105,17 +104,17 @@ supportedXFTPhandshakes = ["xftp/1"] newXFTPServerEnv :: XFTPServerConfig -> IO XFTPEnv newXFTPServerEnv config@XFTPServerConfig {storeLogFile, fileSizeQuota, caCertificateFile, certificateFile, privateKeyFile, transportConfig} = do - random <- liftIO C.newRandom - store <- atomically newFileStore - storeLog <- liftIO $ mapM (`readWriteFileStore` store) storeLogFile + random <- C.newRandom + store <- newFileStore + storeLog <- mapM (`readWriteFileStore` store) storeLogFile used <- countUsedStorage <$> readTVarIO (files store) atomically $ writeTVar (usedStorage store) used forM_ fileSizeQuota $ \quota -> do logInfo $ "Total / available storage: " <> tshow quota <> " / " <> tshow (quota - used) when (quota < used) $ logInfo "WARNING: storage quota is less than used storage, no files can be uploaded!" - tlsServerParams <- liftIO $ loadTLSServerParams caCertificateFile certificateFile privateKeyFile (alpn transportConfig) - Fingerprint fp <- liftIO $ loadFingerprint caCertificateFile - serverStats <- atomically . newFileServerStats =<< liftIO getCurrentTime + tlsServerParams <- loadTLSServerParams caCertificateFile certificateFile privateKeyFile (alpn transportConfig) + Fingerprint fp <- loadFingerprint caCertificateFile + serverStats <- newFileServerStats =<< getCurrentTime pure XFTPEnv {config, store, storeLog, random, tlsServerParams, serverIdentity = C.KeyHash fp, serverStats} countUsedStorage :: M.Map k FileRec -> Int64 diff --git a/src/Simplex/FileTransfer/Server/Stats.hs b/src/Simplex/FileTransfer/Server/Stats.hs index 08813dc2a..1178dd5f6 100644 --- a/src/Simplex/FileTransfer/Server/Stats.hs +++ b/src/Simplex/FileTransfer/Server/Stats.hs @@ -43,34 +43,34 @@ data FileServerStatsData = FileServerStatsData } deriving (Show) -newFileServerStats :: UTCTime -> STM FileServerStats +newFileServerStats :: UTCTime -> IO FileServerStats newFileServerStats ts = do - fromTime <- newTVar ts - filesCreated <- newTVar 0 - fileRecipients <- newTVar 0 - filesUploaded <- newTVar 0 - filesExpired <- newTVar 0 - filesDeleted <- newTVar 0 + fromTime <- newTVarIO ts + filesCreated <- newTVarIO 0 + fileRecipients <- newTVarIO 0 + filesUploaded <- newTVarIO 0 + filesExpired <- newTVarIO 0 + filesDeleted <- newTVarIO 0 filesDownloaded <- newPeriodStats - fileDownloads <- newTVar 0 - fileDownloadAcks <- newTVar 0 - filesCount <- newTVar 0 - filesSize <- newTVar 0 + fileDownloads <- newTVarIO 0 + fileDownloadAcks <- newTVarIO 0 + filesCount <- newTVarIO 0 + filesSize <- newTVarIO 0 pure FileServerStats {fromTime, filesCreated, fileRecipients, filesUploaded, filesExpired, filesDeleted, filesDownloaded, fileDownloads, fileDownloadAcks, filesCount, filesSize} -getFileServerStatsData :: FileServerStats -> STM FileServerStatsData +getFileServerStatsData :: FileServerStats -> IO FileServerStatsData getFileServerStatsData s = do - _fromTime <- readTVar $ fromTime (s :: FileServerStats) - _filesCreated <- readTVar $ filesCreated s - _fileRecipients <- readTVar $ fileRecipients s - _filesUploaded <- readTVar $ filesUploaded s - _filesExpired <- readTVar $ filesExpired s - _filesDeleted <- readTVar $ filesDeleted s + _fromTime <- readTVarIO $ fromTime (s :: FileServerStats) + _filesCreated <- readTVarIO $ filesCreated s + _fileRecipients <- readTVarIO $ fileRecipients s + _filesUploaded <- readTVarIO $ filesUploaded s + _filesExpired <- readTVarIO $ filesExpired s + _filesDeleted <- readTVarIO $ filesDeleted s _filesDownloaded <- getPeriodStatsData $ filesDownloaded s - _fileDownloads <- readTVar $ fileDownloads s - _fileDownloadAcks <- readTVar $ fileDownloadAcks s - _filesCount <- readTVar $ filesCount s - _filesSize <- readTVar $ filesSize s + _fileDownloads <- readTVarIO $ fileDownloads s + _fileDownloadAcks <- readTVarIO $ fileDownloadAcks s + _filesCount <- readTVarIO $ filesCount s + _filesSize <- readTVarIO $ filesSize s pure FileServerStatsData {_fromTime, _filesCreated, _fileRecipients, _filesUploaded, _filesExpired, _filesDeleted, _filesDownloaded, _fileDownloads, _fileDownloadAcks, _filesCount, _filesSize} setFileServerStats :: FileServerStats -> FileServerStatsData -> STM () diff --git a/src/Simplex/FileTransfer/Server/Store.hs b/src/Simplex/FileTransfer/Server/Store.hs index aa8eaa932..b56b516aa 100644 --- a/src/Simplex/FileTransfer/Server/Store.hs +++ b/src/Simplex/FileTransfer/Server/Store.hs @@ -55,11 +55,11 @@ instance StrEncoding FileRecipient where strEncode (FileRecipient rId rKey) = strEncode rId <> ":" <> strEncode rKey strP = FileRecipient <$> strP <* A.char ':' <*> strP -newFileStore :: STM FileStore +newFileStore :: IO FileStore newFileStore = do - files <- TM.empty - recipients <- TM.empty - usedStorage <- newTVar 0 + files <- TM.emptyIO + recipients <- TM.emptyIO + usedStorage <- newTVarIO 0 pure FileStore {files, recipients, usedStorage} addFile :: FileStore -> SenderId -> FileInfo -> SystemTime -> STM (Either XFTPErrorType ()) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 32bfa4198..672375aaf 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -33,6 +33,7 @@ module Simplex.Messaging.Agent AgentClient (..), AE, SubscriptionsInfo (..), + MsgReq, getSMPAgentClient, getSMPAgentClient_, disconnectAgentClient, @@ -77,6 +78,7 @@ module Simplex.Messaging.Agent getConnectionServers, getConnectionRatchetAdHash, setProtocolServers, + checkUserServers, testProtocolServer, setNtfServers, setNetworkConfig, @@ -91,6 +93,7 @@ module Simplex.Messaging.Agent getNtfTokenData, toggleConnectionNtfs, xftpStartWorkers, + xftpStartSndWorkers, xftpReceiveFile, xftpDeleteRcvFile, xftpDeleteRcvFiles, @@ -104,6 +107,7 @@ module Simplex.Messaging.Agent rcConnectHost, rcConnectCtrl, rcDiscoverCtrl, + getAgentSubsTotal, getAgentServersSummary, resetAgentServersStats, foregroundAgent, @@ -145,7 +149,7 @@ import Data.Time.Clock import Data.Time.Clock.System (systemToUTCTime) import Data.Traversable (mapAccumL) import Data.Word (Word16) -import Simplex.FileTransfer.Agent (closeXFTPAgent, deleteSndFileInternal, deleteSndFileRemote, deleteSndFilesInternal, deleteSndFilesRemote, startXFTPWorkers, toFSFilePath, xftpDeleteRcvFile', xftpDeleteRcvFiles', xftpReceiveFile', xftpSendDescription', xftpSendFile') +import Simplex.FileTransfer.Agent (closeXFTPAgent, deleteSndFileInternal, deleteSndFileRemote, deleteSndFilesInternal, deleteSndFilesRemote, startXFTPSndWorkers, startXFTPWorkers, toFSFilePath, xftpDeleteRcvFile', xftpDeleteRcvFiles', xftpReceiveFile', xftpSendDescription', xftpSendFile') import Simplex.FileTransfer.Description (ValidFileDescription) import Simplex.FileTransfer.Protocol (FileParty (..)) import Simplex.FileTransfer.Types (RcvFileId, SndFileId) @@ -172,9 +176,8 @@ import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfRegCode (NtfReg import Simplex.Messaging.Notifications.Server.Push.APNS (PNMessageData (..)) import Simplex.Messaging.Notifications.Types import Simplex.Messaging.Parsers (parse) -import Simplex.Messaging.Protocol (BrokerMsg, Cmd (..), EntityId, ErrorType (AUTH), MsgBody, MsgFlags (..), NtfServer, ProtoServerWithAuth, ProtocolTypeI (..), SMPMsgMeta, SParty (..), SProtocolType (..), SndPublicAuthKey, SubscriptionMode (..), UserProtocol, VersionSMPC, XFTPServerWithAuth, sndAuthKeySMPClientVersion) +import Simplex.Messaging.Protocol (BrokerMsg, Cmd (..), EntityId, ErrorType (AUTH), MsgBody, MsgFlags (..), NtfServer, ProtoServerWithAuth, ProtocolType (..), ProtocolTypeI (..), SMPMsgMeta, SParty (..), SProtocolType (..), SndPublicAuthKey, SubscriptionMode (..), UserProtocol, VersionSMPC, sndAuthKeySMPClientVersion) import qualified Simplex.Messaging.Protocol as SMP -import Simplex.Messaging.Server.QueueStore.QueueInfo import Simplex.Messaging.ServiceScheme (ServiceScheme (..)) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (SMPVersion, THandleParams (sessionId)) @@ -198,26 +201,29 @@ getSMPAgentClient = getSMPAgentClient_ 1 {-# INLINE getSMPAgentClient #-} getSMPAgentClient_ :: Int -> AgentConfig -> InitialAgentServers -> SQLiteStore -> Bool -> IO AgentClient -getSMPAgentClient_ clientId cfg initServers store backgroundMode = - liftIO $ newSMPAgentEnv cfg store >>= runReaderT runAgent +getSMPAgentClient_ clientId cfg initServers@InitialAgentServers {smp, xftp} store backgroundMode = + newSMPAgentEnv cfg store >>= runReaderT runAgent where runAgent = do + liftIO $ checkServers "SMP" smp >> checkServers "XFTP" xftp currentTs <- liftIO getCurrentTime - c@AgentClient {acThread} <- atomically . newAgentClient clientId initServers currentTs =<< ask + c@AgentClient {acThread} <- liftIO . newAgentClient clientId initServers currentTs =<< ask t <- runAgentThreads c `forkFinally` const (liftIO $ disconnectAgentClient c) atomically . writeTVar acThread . Just =<< mkWeakThreadId t pure c + checkServers protocol srvs = + forM_ (M.assocs srvs) $ \(userId, srvs') -> checkUserServers ("getSMPAgentClient " <> protocol <> " " <> tshow userId) srvs' runAgentThreads c | backgroundMode = run c "subscriber" $ subscriber c | otherwise = do - -- restoreServersStats c + restoreServersStats c raceAny_ [ run c "subscriber" $ subscriber c, run c "runNtfSupervisor" $ runNtfSupervisor c, - run c "cleanupManager" $ cleanupManager c - -- run c "logServersStats" $ logServersStats c + run c "cleanupManager" $ cleanupManager c, + run c "logServersStats" $ logServersStats c ] - -- `E.finally` saveServersStats c + `E.finally` saveServersStats c run AgentClient {subQ, acThread} name a = a `E.catchAny` \e -> whenM (isJust <$> readTVarIO acThread) $ do logError $ "Agent thread " <> name <> " crashed: " <> tshow e @@ -229,30 +235,30 @@ logServersStats c = do liftIO $ threadDelay' delay int <- asks (logStatsInterval . config) forever $ do + liftIO $ waitUntilActive c saveServersStats c liftIO $ threadDelay' int saveServersStats :: AgentClient -> AM' () -saveServersStats c@AgentClient {subQ, smpServersStats, xftpServersStats} = do - -- sss <- mapM (lift . getAgentSMPServerStats) =<< readTVarIO smpServersStats - -- xss <- mapM (lift . getAgentXFTPServerStats) =<< readTVarIO xftpServersStats - -- let stats = AgentPersistedServerStats {smpServersStats = sss, xftpServersStats = xss} - -- tryAgentError' (withStore' c (`updateServersStats` stats)) >>= \case - -- Left e -> atomically $ writeTBQueue subQ ("", "", AEvt SAEConn $ ERR $ INTERNAL $ show e) - -- Right () -> pure () - pure () +saveServersStats c@AgentClient {subQ, smpServersStats, xftpServersStats, ntfServersStats} = do + sss <- mapM (liftIO . getAgentSMPServerStats) =<< readTVarIO smpServersStats + xss <- mapM (liftIO . getAgentXFTPServerStats) =<< readTVarIO xftpServersStats + nss <- mapM (liftIO . getAgentNtfServerStats) =<< readTVarIO ntfServersStats + let stats = AgentPersistedServerStats {smpServersStats = sss, xftpServersStats = xss, ntfServersStats = OptionalMap nss} + tryAgentError' (withStore' c (`updateServersStats` stats)) >>= \case + Left e -> atomically $ writeTBQueue subQ ("", "", AEvt SAEConn $ ERR $ INTERNAL $ show e) + Right () -> pure () restoreServersStats :: AgentClient -> AM' () -restoreServersStats c@AgentClient {smpServersStats, xftpServersStats, srvStatsStartedAt} = do +restoreServersStats c@AgentClient {smpServersStats, xftpServersStats, ntfServersStats, srvStatsStartedAt} = do tryAgentError' (withStore c getServersStats) >>= \case Left e -> atomically $ writeTBQueue (subQ c) ("", "", AEvt SAEConn $ ERR $ INTERNAL $ show e) Right (startedAt, Nothing) -> atomically $ writeTVar srvStatsStartedAt startedAt - Right (startedAt, Just AgentPersistedServerStats {smpServersStats = sss, xftpServersStats = xss}) -> do + Right (startedAt, Just AgentPersistedServerStats {smpServersStats = sss, xftpServersStats = xss, ntfServersStats = OptionalMap nss}) -> do atomically $ writeTVar srvStatsStartedAt startedAt - sss' <- mapM (atomically . newAgentSMPServerStats') sss - atomically $ writeTVar smpServersStats sss' - xss' <- mapM (atomically . newAgentXFTPServerStats') xss - atomically $ writeTVar xftpServersStats xss' + atomically . writeTVar smpServersStats =<< mapM (atomically . newAgentSMPServerStats') sss + atomically . writeTVar xftpServersStats =<< mapM (atomically . newAgentXFTPServerStats') xss + atomically . writeTVar ntfServersStats =<< mapM (atomically . newAgentNtfServerStats') nss disconnectAgentClient :: AgentClient -> IO () disconnectAgentClient c@AgentClient {agentEnv = Env {ntfSupervisor = ns, xftpAgent = xa}} = do @@ -273,7 +279,7 @@ resumeAgentClient :: AgentClient -> IO () resumeAgentClient c = atomically $ writeTVar (active c) True {-# INLINE resumeAgentClient #-} -createUser :: AgentClient -> NonEmpty SMPServerWithAuth -> NonEmpty XFTPServerWithAuth -> AE UserId +createUser :: AgentClient -> NonEmpty (ServerCfg 'PSMP) -> NonEmpty (ServerCfg 'PXFTP) -> AE UserId createUser c = withAgentEnv c .: createUser' c {-# INLINE createUser #-} @@ -336,7 +342,7 @@ prepareConnectionToJoin :: AgentClient -> UserId -> Bool -> ConnectionRequestUri prepareConnectionToJoin c userId enableNtfs = withAgentEnv c .: newConnToJoin c userId "" enableNtfs -- | Join SMP agent connection (JOIN command). -joinConnection :: AgentClient -> UserId -> Maybe ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> AE ConnId +joinConnection :: AgentClient -> UserId -> Maybe ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> AE (ConnId, SndQueueSecured) joinConnection c userId Nothing enableNtfs = withAgentEnv c .:: joinConn c userId "" False enableNtfs joinConnection c userId (Just connId) enableNtfs = withAgentEnv c .:: joinConn c userId connId True enableNtfs {-# INLINE joinConnection #-} @@ -347,7 +353,7 @@ allowConnection c = withAgentEnv c .:. allowConnection' c {-# INLINE allowConnection #-} -- | Accept contact after REQ notification (ACPT command) -acceptContact :: AgentClient -> Bool -> ConfirmationId -> ConnInfo -> PQSupport -> SubscriptionMode -> AE ConnId +acceptContact :: AgentClient -> Bool -> ConfirmationId -> ConnInfo -> PQSupport -> SubscriptionMode -> AE (ConnId, SndQueueSecured) acceptContact c enableNtfs = withAgentEnv c .:: acceptContact' c "" enableNtfs {-# INLINE acceptContact #-} @@ -372,7 +378,7 @@ getConnectionMessage c = withAgentEnv c . getConnectionMessage' c {-# INLINE getConnectionMessage #-} -- | Get connection message for received notification -getNotificationMessage :: AgentClient -> C.CbNonce -> ByteString -> AE (NotificationInfo, [SMPMsgMeta]) +getNotificationMessage :: AgentClient -> C.CbNonce -> ByteString -> AE (NotificationInfo, Maybe SMPMsgMeta) getNotificationMessage c = withAgentEnv c .: getNotificationMessage' c {-# INLINE getNotificationMessage #-} @@ -389,6 +395,10 @@ sendMessage :: AgentClient -> ConnId -> PQEncryption -> MsgFlags -> MsgBody -> A sendMessage c = withAgentEnv c .:: sendMessage' c {-# INLINE sendMessage #-} +-- When sending multiple messages to the same connection, +-- only the first MsgReq for this connection should have non-empty ConnId. +-- All subsequent MsgReq in traversable for this connection must be empty. +-- This is done to optimize processing by grouping all messages to one connection together. type MsgReq = (ConnId, PQEncryption, MsgFlags, MsgBody) -- | Send multiple messages to different connections (SEND command) @@ -404,7 +414,7 @@ ackMessage :: AgentClient -> ConnId -> AgentMsgId -> Maybe MsgReceiptInfo -> AE ackMessage c = withAgentEnv c .:. ackMessage' c {-# INLINE ackMessage #-} -getConnectionQueueInfo :: AgentClient -> ConnId -> AE QueueInfo +getConnectionQueueInfo :: AgentClient -> ConnId -> AE ServerQueueInfo getConnectionQueueInfo c = withAgentEnv c . getConnectionQueueInfo' c {-# INLINE getConnectionQueueInfo #-} @@ -520,6 +530,10 @@ xftpStartWorkers :: AgentClient -> Maybe FilePath -> AE () xftpStartWorkers c = withAgentEnv c . startXFTPWorkers c {-# INLINE xftpStartWorkers #-} +xftpStartSndWorkers :: AgentClient -> Maybe FilePath -> AE () +xftpStartSndWorkers c = withAgentEnv c . startXFTPSndWorkers c +{-# INLINE xftpStartSndWorkers #-} + -- | Receive XFTP file xftpReceiveFile :: AgentClient -> UserId -> ValidFileDescription 'FRecipient -> Maybe CryptoFileArgs -> Bool -> AE RcvFileId xftpReceiveFile c = withAgentEnv c .:: xftpReceiveFile' c @@ -602,11 +616,13 @@ logConnection c connected = let event = if connected then "connected to" else "disconnected from" in logInfo $ T.unwords ["client", tshow (clientId c), event, "Agent"] -createUser' :: AgentClient -> NonEmpty SMPServerWithAuth -> NonEmpty XFTPServerWithAuth -> AM UserId +createUser' :: AgentClient -> NonEmpty (ServerCfg 'PSMP) -> NonEmpty (ServerCfg 'PXFTP) -> AM UserId createUser' c smp xftp = do + liftIO $ checkUserServers "createUser SMP" smp + liftIO $ checkUserServers "createUser XFTP" xftp userId <- withStore' c createUserRecord - atomically $ TM.insert userId smp $ smpServers c - atomically $ TM.insert userId xftp $ xftpServers c + atomically $ TM.insert userId (mkUserServers smp) $ smpServers c + atomically $ TM.insert userId (mkUserServers xftp) $ xftpServers c pure userId deleteUser' :: AgentClient -> UserId -> Bool -> AM () @@ -739,7 +755,7 @@ newRcvConnSrv c userId connId enableNtfs cMode clientData pqInitKeys subMode srv (SCMContact, CR.IKUsePQ) -> throwE $ CMD PROHIBITED "newRcvConnSrv" _ -> pure () AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config - let sndSecure = False -- case cMode of SCMInvitation -> True; SCMContact -> False + let sndSecure = case cMode of SCMInvitation -> True; SCMContact -> False (rq, qUri, tSess, sessId) <- newRcvQueue c userId connId srvWithAuth smpClientVRange subMode sndSecure `catchAgentError` \e -> liftIO (print e) >> throwE e atomically $ incSMPServerStat c userId srv connCreated rq' <- withStore c $ \db -> updateNewConnRcv db connId rq @@ -774,7 +790,7 @@ newConnToJoin c userId connId enableNtfs cReq pqSup = case cReq of cData = ConnData {userId, connId, connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk, pqSupport} withStore c $ \db -> createNewConn db g cData SCMInvitation -joinConn :: AgentClient -> UserId -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> AM ConnId +joinConn :: AgentClient -> UserId -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> AM (ConnId, SndQueueSecured) joinConn c userId connId hasNewConn enableNtfs cReq cInfo pqSupport subMode = do srv <- case cReq of CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _ -> @@ -833,7 +849,7 @@ versionPQSupport_ :: VersionSMPA -> Maybe CR.VersionE2E -> PQSupport versionPQSupport_ agentV e2eV_ = PQSupport $ agentV >= pqdrSMPAgentVersion && maybe True (>= CR.pqRatchetE2EEncryptVersion) e2eV_ {-# INLINE versionPQSupport_ #-} -joinConnSrv :: AgentClient -> UserId -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> SMPServerWithAuth -> AM ConnId +joinConnSrv :: AgentClient -> UserId -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> SMPServerWithAuth -> AM (ConnId, SndQueueSecured) joinConnSrv c userId connId hasNewConn enableNtfs inv@CRInvitationUri {} cInfo pqSup subMode srv = withInvLock c (strEncode inv) "joinConnSrv" $ do (cData, q, _, rc, e2eSndParams) <- startJoinInvitation userId connId Nothing enableNtfs inv pqSup @@ -850,7 +866,7 @@ joinConnSrv c userId connId hasNewConn enableNtfs inv@CRInvitationUri {} cInfo p -- otherwise we would need to manage retries here to avoid SndQueue recreated with a different key, -- similar to how joinConnAsync does that. tryError (secureConfirmQueue c cData' sq srv cInfo (Just e2eSndParams) subMode) >>= \case - Right _ -> pure connId' + Right sqSecured -> pure (connId', sqSecured) Left e -> do -- possible improvement: recovery for failure on network timeout, see rfcs/2022-04-20-smp-conf-timeout-recovery.md void $ withStore' c $ \db -> deleteConn db Nothing connId' @@ -860,10 +876,10 @@ joinConnSrv c userId connId hasNewConn enableNtfs cReqUri@CRContactUri {} cInfo Just (qInfo, vrsn) -> do (connId', cReq) <- newConnSrv c userId connId hasNewConn enableNtfs SCMInvitation Nothing (CR.IKNoPQ pqSup) subMode srv void $ sendInvitation c userId qInfo vrsn cReq cInfo - pure connId' + pure (connId', False) Nothing -> throwE $ AGENT A_VERSION -joinConnSrvAsync :: AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> SMPServerWithAuth -> AM () +joinConnSrvAsync :: AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> SMPServerWithAuth -> AM SndQueueSecured joinConnSrvAsync c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqSupport subMode srv = do SomeConn cType conn <- withStore c (`getConn` connId) case conn of @@ -871,7 +887,7 @@ joinConnSrvAsync c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqSuppo SndConnection _ sq -> doJoin $ Just sq _ -> throwE $ CMD PROHIBITED $ "joinConnSrvAsync: bad connection " <> show cType where - doJoin :: Maybe SndQueue -> AM () + doJoin :: Maybe SndQueue -> AM SndQueueSecured doJoin sq_ = do (cData, sq, _, rc, e2eSndParams) <- startJoinInvitation userId connId sq_ enableNtfs inv pqSupport sq' <- withStore c $ \db -> runExceptT $ do @@ -883,8 +899,9 @@ joinConnSrvAsync _c _userId _connId _enableNtfs (CRContactUri _) _cInfo _subMode createReplyQueue :: AgentClient -> ConnData -> SndQueue -> SubscriptionMode -> SMPServerWithAuth -> AM SMPQueueInfo createReplyQueue c ConnData {userId, connId, enableNtfs} SndQueue {smpClientVersion} subMode srv = do - let sndSecure = False -- smpClientVersion >= sndAuthKeySMPClientVersion + let sndSecure = smpClientVersion >= sndAuthKeySMPClientVersion (rq, qUri, tSess, sessId) <- newRcvQueue c userId connId srv (versionToRange smpClientVersion) subMode sndSecure + atomically $ incSMPServerStat c userId (qServer rq) connCreated let qInfo = toVersionT qUri smpClientVersion rq' <- withStore c $ \db -> upgradeSndConnToDuplex db connId rq lift . when (subMode == SMSubscribe) $ addNewQueueSubscription c rq' tSess sessId @@ -897,18 +914,14 @@ createReplyQueue c ConnData {userId, connId, enableNtfs} SndQueue {smpClientVers allowConnection' :: AgentClient -> ConnId -> ConfirmationId -> ConnInfo -> AM () allowConnection' c connId confId ownConnInfo = withConnLock c connId "allowConnection" $ do withStore c (`getConn` connId) >>= \case - SomeConn _ (RcvConnection _ rq@RcvQueue {server, rcvId, e2ePrivKey, smpClientVersion = v}) -> do - senderKey <- withStore c $ \db -> runExceptT $ do - AcceptedConfirmation {ratchetState, senderConf = SMPConfirmation {senderKey, e2ePubKey, smpClientVersion = v'}} <- ExceptT $ acceptConfirmation db confId ownConnInfo - liftIO $ createRatchet db connId ratchetState - let dhSecret = C.dh' e2ePubKey e2ePrivKey - liftIO $ setRcvQueueConfirmedE2E db rq dhSecret $ min v v' - pure senderKey + SomeConn _ (RcvConnection _ RcvQueue {server, rcvId}) -> do + AcceptedConfirmation {senderConf = SMPConfirmation {senderKey}} <- + withStore c $ \db -> acceptConfirmation db confId ownConnInfo enqueueCommand c "" connId (Just server) . AInternalCommand $ ICAllowSecure rcvId senderKey _ -> throwE $ CMD PROHIBITED "allowConnection" -- | Accept contact (ACPT command) in Reader monad -acceptContact' :: AgentClient -> ConnId -> Bool -> InvitationId -> ConnInfo -> PQSupport -> SubscriptionMode -> AM ConnId +acceptContact' :: AgentClient -> ConnId -> Bool -> InvitationId -> ConnInfo -> PQSupport -> SubscriptionMode -> AM (ConnId, SndQueueSecured) acceptContact' c connId enableNtfs invId ownConnInfo pqSupport subMode = withConnLock c connId "acceptContact" $ do Invitation {contactConnId, connReq} <- withStore c (`getInvitation` invId) withStore c (`getConn` contactConnId) >>= \case @@ -1026,7 +1039,7 @@ getConnectionMessage' c connId = do SndConnection _ _ -> throwE $ CONN SIMPLEX NewConnection _ -> throwE $ CMD PROHIBITED "getConnectionMessage: NewConnection" -getNotificationMessage' :: AgentClient -> C.CbNonce -> ByteString -> AM (NotificationInfo, [SMPMsgMeta]) +getNotificationMessage' :: AgentClient -> C.CbNonce -> ByteString -> AM (NotificationInfo, Maybe SMPMsgMeta) getNotificationMessage' c nonce encNtfInfo = do withStore' c getActiveNtfToken >>= \case Just NtfToken {ntfDhSecret = Just dhSecret} -> do @@ -1034,22 +1047,9 @@ getNotificationMessage' c nonce encNtfInfo = do PNMessageData {smpQueue, ntfTs, nmsgNonce, encNMsgMeta} <- liftEither (parse strP (INTERNAL "error parsing PNMessageData") ntfData) (ntfConnId, rcvNtfDhSecret) <- withStore c (`getNtfRcvQueue` smpQueue) ntfMsgMeta <- (eitherToMaybe . smpDecode <$> agentCbDecrypt rcvNtfDhSecret nmsgNonce encNMsgMeta) `catchAgentError` \_ -> pure Nothing - maxMsgs <- asks $ ntfMaxMessages . config - (NotificationInfo {ntfConnId, ntfTs, ntfMsgMeta},) <$> getNtfMessages ntfConnId ntfMsgMeta maxMsgs + msgMeta <- getConnectionMessage' c ntfConnId + pure (NotificationInfo {ntfConnId, ntfTs, ntfMsgMeta}, msgMeta) _ -> throwE $ CMD PROHIBITED "getNotificationMessage" - where - getNtfMessages ntfConnId nMeta = getMsg - where - getMsg 0 = pure [] - getMsg n = - getConnectionMessage' c ntfConnId >>= \case - Just m - | lastMsg m -> pure [m] - | otherwise -> (m :) <$> getMsg (n - 1) - Nothing -> pure [] - lastMsg SMP.SMPMsgMeta {msgId, msgTs, msgFlags} = case nMeta of - Just SMP.NMsgMeta {msgId = msgId', msgTs = msgTs'} -> msgId == msgId' || msgTs > msgTs' - Nothing -> SMP.notification msgFlags -- | Send message to the connection (SEND command) in Reader monad sendMessage' :: AgentClient -> ConnId -> PQEncryption -> MsgFlags -> MsgBody -> AM (AgentMsgId, PQEncryption) @@ -1063,38 +1063,49 @@ sendMessages' c = sendMessagesB' c . map Right sendMessagesB' :: forall t. Traversable t => AgentClient -> t (Either AgentErrorType MsgReq) -> AM (t (Either AgentErrorType (AgentMsgId, PQEncryption))) sendMessagesB' c reqs = do - connIds <- liftEither $ foldl' addConnId (Right S.empty) reqs + (_, connIds) <- liftEither $ foldl' addConnId (Right ("", S.empty)) reqs lift $ sendMessagesB_ c reqs connIds where - addConnId s@(Right s') (Right (connId, _, _, _)) - | B.null connId = s - | connId `S.notMember` s' = Right $ S.insert connId s' - | otherwise = Left $ INTERNAL "sendMessages: duplicate connection ID" - addConnId s _ = s + addConnId acc@(Right (prevId, s)) (Right (connId, _, _, _)) + | B.null connId = if B.null prevId then Left $ INTERNAL "sendMessages: empty first connId" else acc + | connId `S.member` s = Left $ INTERNAL "sendMessages: duplicate connId" + | otherwise = Right (connId, S.insert connId s) + addConnId acc _ = acc sendMessagesB_ :: forall t. Traversable t => AgentClient -> t (Either AgentErrorType MsgReq) -> Set ConnId -> AM' (t (Either AgentErrorType (AgentMsgId, PQEncryption))) sendMessagesB_ c reqs connIds = withConnLocks c connIds "sendMessages" $ do - reqs' <- withStoreBatch c (\db -> fmap (bindRight $ \req@(connId, _, _, _) -> bimap storeError (req,) <$> getConn db connId) reqs) + prev <- newTVarIO Nothing + reqs' <- withStoreBatch c $ \db -> fmap (bindRight $ getConn_ db prev) reqs let (toEnable, reqs'') = mapAccumL prepareConn [] reqs' - void $ withStoreBatch' c $ \db -> map (\connId -> setConnPQSupport db connId PQSupportOn) toEnable + void $ withStoreBatch' c $ \db -> map (\connId -> setConnPQSupport db connId PQSupportOn) $ S.toList toEnable enqueueMessagesB c reqs'' where - prepareConn :: [ConnId] -> Either AgentErrorType (MsgReq, SomeConn) -> ([ConnId], Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage)) - prepareConn acc (Left e) = (acc, Left e) - prepareConn acc (Right ((_, pqEnc, msgFlags, msg), SomeConn _ conn)) = case conn of + getConn_ :: DB.Connection -> TVar (Maybe (Either AgentErrorType SomeConn)) -> MsgReq -> IO (Either AgentErrorType (MsgReq, SomeConn)) + getConn_ db prev req@(connId, _, _, _) = + (req,) <$$> + if B.null connId + then fromMaybe (Left $ INTERNAL "sendMessagesB_: empty prev connId") <$> readTVarIO prev + else do + conn <- first storeError <$> getConn db connId + conn <$ atomically (writeTVar prev $ Just conn) + prepareConn :: Set ConnId -> Either AgentErrorType (MsgReq, SomeConn) -> (Set ConnId, Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage)) + prepareConn s (Left e) = (s, Left e) + prepareConn s (Right ((_, pqEnc, msgFlags, msg), SomeConn _ conn)) = case conn of DuplexConnection cData _ sqs -> prepareMsg cData sqs SndConnection cData sq -> prepareMsg cData [sq] - _ -> (acc, Left $ CONN SIMPLEX) + _ -> (s, Left $ CONN SIMPLEX) where - prepareMsg :: ConnData -> NonEmpty SndQueue -> ([ConnId], Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage)) + prepareMsg :: ConnData -> NonEmpty SndQueue -> (Set ConnId, Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage)) prepareMsg cData@ConnData {connId, pqSupport} sqs - | ratchetSyncSendProhibited cData = (acc, Left $ CMD PROHIBITED "sendMessagesB: send prohibited") + | ratchetSyncSendProhibited cData = (s, Left $ CMD PROHIBITED "sendMessagesB: send prohibited") -- connection is only updated if PQ encryption was disabled, and now it has to be enabled. -- support for PQ encryption (small message envelopes) will not be disabled when message is sent. | pqEnc == PQEncOn && pqSupport == PQSupportOff = let cData' = cData {pqSupport = PQSupportOn} :: ConnData - in (connId : acc, Right (cData', sqs, Just pqEnc, msgFlags, A_MSG msg)) - | otherwise = (acc, Right (cData, sqs, Just pqEnc, msgFlags, A_MSG msg)) + in (S.insert connId s, mkReq cData') + | otherwise = (s, mkReq cData) + where + mkReq cData' = Right (cData', sqs, Just pqEnc, msgFlags, A_MSG msg) -- / async command processing v v v @@ -1125,12 +1136,16 @@ runCommandProcessing c@AgentClient {subQ} server_ Worker {doWork} = do forever $ do atomically $ endAgentOperation c AOSndNetwork lift $ waitForWork doWork - atomically $ throwWhenInactive c + liftIO $ throwWhenInactive c atomically $ beginAgentOperation c AOSndNetwork - withWork c doWork (`getPendingServerCommand` server_) $ processCmd (riFast ri) + withWork c doWork (`getPendingServerCommand` server_) $ runProcessCmd (riFast ri) where - processCmd :: RetryInterval -> PendingCommand -> AM () - processCmd ri PendingCommand {cmdId, corrId, userId, connId, command} = case command of + runProcessCmd ri cmd = do + pending <- newTVarIO [] + processCmd ri cmd pending + mapM_ (atomically . writeTBQueue subQ) . reverse =<< readTVarIO pending + processCmd :: RetryInterval -> PendingCommand -> TVar [ATransmission] -> AM () + processCmd ri PendingCommand {cmdId, corrId, userId, connId, command} pendingCmds = case command of AClientCommand cmd -> case cmd of NEW enableNtfs (ACM cMode) pqEnc subMode -> noServer $ do usedSrvs <- newTVarIO ([] :: [SMPServer]) @@ -1141,12 +1156,12 @@ runCommandProcessing c@AgentClient {subQ} server_ Worker {doWork} = do let initUsed = [qServer q] usedSrvs <- newTVarIO initUsed tryCommand . withNextSrv c userId usedSrvs initUsed $ \srv -> do - joinConnSrvAsync c userId connId enableNtfs cReq connInfo pqEnc subMode srv - notify OK + sqSecured <- joinConnSrvAsync c userId connId enableNtfs cReq connInfo pqEnc subMode srv + notify $ JOINED sqSecured LET confId ownCInfo -> withServer' . tryCommand $ allowConnection' c connId confId ownCInfo >> notify OK ACK msgId rcptInfo_ -> withServer' . tryCommand $ ackMessage' c connId msgId rcptInfo_ >> notify OK SWCH -> - noServer . tryCommand . withConnLock c connId "switchConnection" $ + noServer . tryWithLock "switchConnection" $ withStore c (`getConn` connId) >>= \case SomeConn _ conn@(DuplexConnection _ (replaced :| _rqs) _) -> switchDuplexConnection c conn replaced >>= notify . SWITCH QDRcv SPStarted @@ -1175,7 +1190,6 @@ runCommandProcessing c@AgentClient {subQ} server_ Worker {doWork} = do ICDeleteRcvQueue rId -> withServer $ \srv -> tryWithLock "ICDeleteRcvQueue" $ do rq <- withStore c (\db -> getDeletedRcvQueue db connId srv rId) deleteQueue c rq - atomically $ incSMPServerStat c userId srv connDeleted withStore' c (`deleteConnRcvQueue` rq) ICQSecure rId senderKey -> withServer $ \srv -> tryWithLock "ICQSecure" . withDuplexConn $ \(DuplexConnection cData rqs sqs) -> @@ -1239,7 +1253,9 @@ runCommandProcessing c@AgentClient {subQ} server_ Worker {doWork} = do withStore c (`getConn` connId) >>= \case SomeConn _ conn@DuplexConnection {} -> a conn _ -> internalErr "command requires duplex connection" - tryCommand action = withRetryInterval ri $ \_ loop -> + tryCommand action = withRetryInterval ri $ \_ loop -> do + liftIO $ waitWhileSuspended c + liftIO $ waitForUserNetwork c tryError action >>= \case Left e | temporaryOrHostError e -> retrySndOp c loop @@ -1249,7 +1265,9 @@ runCommandProcessing c@AgentClient {subQ} server_ Worker {doWork} = do internalErr s = cmdError $ INTERNAL $ s <> ": " <> show (agentCommandTag command) cmdError e = notify (ERR e) >> withStore' c (`deleteCommand` cmdId) notify :: forall e. AEntityI e => AEvent e -> AM () - notify cmd = atomically $ writeTBQueue subQ (corrId, connId, AEvt (sAEntity @e) cmd) + notify cmd = + let t = (corrId, connId, AEvt (sAEntity @e) cmd) + in atomically $ ifM (isFullTBQueue subQ) (modifyTVar' pendingCmds (t :)) (writeTBQueue subQ t) -- ^ ^ ^ async command processing / enqueueMessages :: AgentClient -> ConnData -> NonEmpty SndQueue -> MsgFlags -> AMessage -> AM (AgentMsgId, PQEncryption) @@ -1345,8 +1363,8 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} ConnData {connId} sq@SndQueue {userI forever $ do atomically $ endAgentOperation c AOSndNetwork lift $ waitForWork doWork - atomically $ throwWhenInactive c - atomically $ throwWhenNoDelivery c sq + liftIO $ throwWhenInactive c + liftIO $ throwWhenNoDelivery c sq atomically $ beginAgentOperation c AOSndNetwork withWork c doWork (\db -> getPendingQueueMsg db connId sq) $ \(rq_, PendingMsgData {msgId, msgType, msgBody, pqEncryption, msgFlags, msgRetryState, internalTs}) -> do @@ -1354,6 +1372,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} ConnData {connId} sq@SndQueue {userI let mId = unId msgId ri' = maybe id updateRetryInterval2 msgRetryState ri withRetryLock2 ri' qLock $ \riState loop -> do + liftIO $ waitWhileSuspended c liftIO $ waitForUserNetwork c resp <- tryError $ case msgType of AM_CONN_INFO -> sendConfirmation c sq msgBody @@ -1425,7 +1444,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} ConnData {connId} sq@SndQueue {userI withStore' c $ \db -> setSndQueueStatus db sq Active case rq_ of -- party initiating connection (in v1) - Just RcvQueue {status} -> + Just rq@RcvQueue {status} -> -- it is unclear why subscribeQueue was needed here, -- message delivery can only be enabled for queues that were created in the current session or subscribed -- subscribeQueue c rq connId @@ -1435,7 +1454,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} ConnData {connId} sq@SndQueue {userI -- because it can be sent before HELLO is received -- With `status == Active` condition, CON is sent here only by the accepting party, that previously received HELLO when (status == Active) $ do - atomically $ incSMPServerStat c userId server connCompleted + atomically $ incSMPServerStat c userId (qServer rq) connCompleted notify $ CON pqEncryption -- this branch should never be reached as receive queue is created before the confirmation, _ -> logError "HELLO sent without receive queue" @@ -1506,7 +1525,7 @@ retrySndOp :: AgentClient -> AM () -> AM () retrySndOp c loop = do -- end... is in a separate atomically because if begin... blocks, SUSPENDED won't be sent atomically $ endAgentOperation c AOSndNetwork - atomically $ throwWhenInactive c + liftIO $ throwWhenInactive c atomically $ beginAgentOperation c AOSndNetwork loop @@ -1543,7 +1562,7 @@ ackMessage' c connId msgId rcptInfo_ = withConnLock c connId "ackMessage" $ do withStore' c $ \db -> deleteDeliveredSndMsg db connId $ InternalId sndMsgId _ -> pure () -getConnectionQueueInfo' :: AgentClient -> ConnId -> AM QueueInfo +getConnectionQueueInfo' :: AgentClient -> ConnId -> AM ServerQueueInfo getConnectionQueueInfo' c connId = do SomeConn _ conn <- withStore c (`getConn` connId) case conn of @@ -1627,10 +1646,14 @@ synchronizeRatchet' c connId pqSupport' force = withConnLock c connId "synchroni _ -> throwE $ CMD PROHIBITED "synchronizeRatchet: not duplex" ackQueueMessage :: AgentClient -> RcvQueue -> SMP.MsgId -> AM () -ackQueueMessage c rq srvMsgId = - sendAck c rq srvMsgId `catchAgentError` \case - SMP _ SMP.NO_MSG -> pure () - e -> throwE e +ackQueueMessage c rq@RcvQueue {userId, server} srvMsgId = do + atomically $ incSMPServerStat c userId server ackAttempts + tryAgentError (sendAck c rq srvMsgId) >>= \case + Right _ -> atomically $ incSMPServerStat c userId server ackMsgs + Left (SMP _ SMP.NO_MSG) -> atomically $ incSMPServerStat c userId server ackNoMsgErrs + Left e -> do + unless (temporaryOrHostError e) $ atomically $ incSMPServerStat c userId server ackOtherErrs + throwE e -- | Suspend SMP agent connection (OFF command) in Reader monad suspendConnection' :: AgentClient -> ConnId -> AM () @@ -1727,11 +1750,15 @@ deleteConnQueues c waitDelivery ntf rqs = do Int -> (RcvQueue, Either AgentErrorType ()) -> IO ((RcvQueue, Either AgentErrorType ()), Maybe (AM' ())) - deleteQueueRec db maxErrs (rq, r) = case r of + deleteQueueRec db maxErrs (rq@RcvQueue {userId, server}, r) = case r of Right _ -> deleteConnRcvQueue db rq $> ((rq, r), Just (notifyRQ rq Nothing)) Left e | temporaryOrHostError e && deleteErrors rq + 1 < maxErrs -> incRcvDeleteErrors db rq $> ((rq, r), Nothing) - | otherwise -> deleteConnRcvQueue db rq $> ((rq, Right ()), Just (notifyRQ rq (Just e))) + | otherwise -> do + deleteConnRcvQueue db rq + -- attempts and successes are counted in deleteQueues function + atomically $ incSMPServerStat c userId server connDeleted + pure ((rq, Right ()), Just (notifyRQ rq (Just e))) notifyRQ rq e_ = notify ("", qConnId rq, AEvt SAEConn $ DEL_RCVQ (qServer rq) (queueId rq) e_) notify = when ntf . atomically . writeTBQueue (subQ c) connResults :: [(RcvQueue, Either AgentErrorType ())] -> Map ConnId (Either AgentErrorType ()) @@ -1803,10 +1830,17 @@ connectionStats = \case ratchetSyncSupported = connAgentVersion >= ratchetSyncSMPAgentVersion } --- | Change servers to be used for creating new queues, in Reader monad -setProtocolServers :: (ProtocolTypeI p, UserProtocol p) => AgentClient -> UserId -> NonEmpty (ProtoServerWithAuth p) -> IO () -setProtocolServers c userId srvs = atomically $ TM.insert userId srvs (userServers c) -{-# INLINE setProtocolServers #-} +-- | Change servers to be used for creating new queues. +-- This function will set all servers as enabled in case all passed servers are disabled. +setProtocolServers :: forall p. (ProtocolTypeI p, UserProtocol p) => AgentClient -> UserId -> NonEmpty (ServerCfg p) -> IO () +setProtocolServers c userId srvs = do + checkUserServers "setProtocolServers" srvs + atomically $ TM.insert userId (mkUserServers srvs) (userServers c) + +checkUserServers :: Text -> NonEmpty (ServerCfg p) -> IO () +checkUserServers name srvs = + unless (any (\ServerCfg {enabled} -> enabled) srvs) $ + logWarn (name <> ": all passed servers are disabled, using all servers.") registerNtfToken' :: AgentClient -> DeviceToken -> NotificationsMode -> AM NtfTknStatus registerNtfToken' c suppliedDeviceToken suppliedNtfMode = @@ -1996,7 +2030,7 @@ deleteNtfSubs c deleteCmd = do sendNtfConnCommands :: AgentClient -> NtfSupervisorCommand -> AM () sendNtfConnCommands c cmd = do ns <- asks ntfSupervisor - connIds <- atomically $ getSubscriptions c + connIds <- liftIO $ getSubscriptions c forM_ connIds $ \connId -> do withStore' c (`getConnData` connId) >>= \case Just (ConnData {enableNtfs}, _) -> @@ -2009,10 +2043,12 @@ setNtfServers c = atomically . writeTVar (ntfServers c) {-# INLINE setNtfServers #-} resetAgentServersStats' :: AgentClient -> AM () -resetAgentServersStats' c@AgentClient {smpServersStats, xftpServersStats} = do +resetAgentServersStats' c@AgentClient {smpServersStats, xftpServersStats, srvStatsStartedAt} = do + startedAt <- liftIO getCurrentTime + atomically $ writeTVar srvStatsStartedAt startedAt atomically $ TM.clear smpServersStats atomically $ TM.clear xftpServersStats - withStore' c resetServersStats + withStore' c (`resetServersStats` startedAt) -- | Activate operations foregroundAgent :: AgentClient -> IO () @@ -2076,7 +2112,7 @@ cleanupManager c@AgentClient {subQ} = do liftIO $ threadDelay' delay int <- asks (cleanupInterval . config) ttl <- asks $ storedMsgDataTTL . config - forever $ do + forever $ waitActive $ do run ERR deleteConns run ERR $ withStore' c (`deleteRcvMsgHashesExpired` ttl) run ERR $ withStore' c (`deleteSndMsgsExpired` ttl) @@ -2096,7 +2132,8 @@ cleanupManager c@AgentClient {subQ} = do step <- asks $ cleanupStepInterval . config liftIO $ threadDelay step -- we are catching it to avoid CRITICAL errors in tests when this is the only remaining handle to active - waitActive a = liftIO (E.tryAny . atomically $ waitUntilActive c) >>= either (\_ -> pure ()) (\_ -> void a) + waitActive :: ReaderT Env IO a -> AM' () + waitActive a = liftIO (E.tryAny $ waitUntilActive c) >>= either (\_ -> pure ()) (\_ -> void a) deleteConns = withLock (deleteLock c) "cleanupManager" $ do void $ withStore' c getDeletedConnIds >>= deleteDeletedConns c @@ -2146,12 +2183,12 @@ data ACKd = ACKd | ACKPending -- It cannot be finally, as sometimes it needs to be ACK+DEL, -- and sometimes ACK has to be sent from the consumer. processSMPTransmissions :: AgentClient -> ServerTransmissionBatch SMPVersion ErrorType BrokerMsg -> AM' () -processSMPTransmissions c@AgentClient {subQ} (tSess@(_, srv, _), _v, sessId, ts) = do +processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId, ts) = do upConnIds <- newTVarIO [] forM_ ts $ \(entId, t) -> case t of STEvent msgOrErr -> withRcvConn entId $ \rq@RcvQueue {connId} conn -> case msgOrErr of - Right msg -> processSMP rq conn (toConnData conn) msg + Right msg -> runProcessSMP rq conn (toConnData conn) msg Left e -> lift $ notifyErr connId e STResponse (Cmd SRecipient cmd) respOrErr -> withRcvConn entId $ \rq conn -> case cmd of @@ -2159,11 +2196,11 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(_, srv, _), _v, sessId, ts) Right SMP.OK -> processSubOk rq upConnIds Right msg@SMP.MSG {} -> do processSubOk rq upConnIds -- the connection is UP even when processing this particular message fails - processSMP rq conn (toConnData conn) msg + runProcessSMP rq conn (toConnData conn) msg Right r -> processSubErr rq $ unexpectedResponse r Left e -> unless (temporaryClientError e) $ processSubErr rq e -- timeout/network was already reported SMP.ACK _ -> case respOrErr of - Right msg@SMP.MSG {} -> processSMP rq conn (toConnData conn) msg + Right msg@SMP.MSG {} -> runProcessSMP rq conn (toConnData conn) msg _ -> pure () -- TODO process OK response to ACK _ -> pure () -- TODO process expired response to DEL STResponse {} -> pure () -- TODO process expired responses to sent messages @@ -2171,7 +2208,9 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(_, srv, _), _v, sessId, ts) logServer "<--" c srv entId $ "error: " <> bshow e notifyErr "" e connIds <- readTVarIO upConnIds - unless (null connIds) $ notify' "" $ UP srv connIds + unless (null connIds) $ do + notify' "" $ UP srv connIds + atomically $ incSMPServerStat' c userId srv connSubscribed $ length connIds where withRcvConn :: SMP.RecipientId -> (forall c. RcvQueue -> Connection c -> AM ()) -> AM' () withRcvConn rId a = do @@ -2182,27 +2221,35 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(_, srv, _), _v, sessId, ts) Left e -> notify' connId (ERR e) Right () -> pure () processSubOk :: RcvQueue -> TVar [ConnId] -> AM () - processSubOk rq@RcvQueue {userId, connId} upConnIds = do + processSubOk rq@RcvQueue {connId} upConnIds = atomically . whenM (isPendingSub connId) $ do - addSubscription c rq + addSubscription c sessId rq modifyTVar' upConnIds (connId :) - atomically $ incSMPServerStat c userId srv connSubscribed processSubErr :: RcvQueue -> SMPClientError -> AM () - processSubErr rq@RcvQueue {userId, connId} e = do - atomically . whenM (isPendingSub connId) $ failSubscription c rq e - atomically $ incSMPServerStat c userId srv connSubErrs + processSubErr rq@RcvQueue {connId} e = do + atomically . whenM (isPendingSub connId) $ + failSubscription c rq e >> incSMPServerStat c userId srv connSubErrs lift $ notifyErr connId e - isPendingSub connId = (&&) <$> hasPendingSubscription c connId <*> activeClientSession c tSess sessId + isPendingSub connId = do + pending <- (&&) <$> hasPendingSubscription c connId <*> activeClientSession c tSess sessId + unless pending $ incSMPServerStat c userId srv connSubIgnored + pure pending notify' :: forall e m. (AEntityI e, MonadIO m) => ConnId -> AEvent e -> m () notify' connId msg = atomically $ writeTBQueue subQ ("", connId, AEvt (sAEntity @e) msg) notifyErr :: ConnId -> SMPClientError -> AM' () notifyErr connId = notify' connId . ERR . protocolClientError SMP (B.unpack $ strEncode srv) - processSMP :: forall c. RcvQueue -> Connection c -> ConnData -> BrokerMsg -> AM () + runProcessSMP :: RcvQueue -> Connection c -> ConnData -> BrokerMsg -> AM () + runProcessSMP rq conn cData msg = do + pending <- newTVarIO [] + processSMP rq conn cData msg pending + mapM_ (atomically . writeTBQueue subQ) . reverse =<< readTVarIO pending + processSMP :: forall c. RcvQueue -> Connection c -> ConnData -> BrokerMsg -> TVar [ATransmission] -> AM () processSMP rq@RcvQueue {rcvId = rId, sndSecure, e2ePrivKey, e2eDhSecret, status} conn - cData@ConnData {userId, connId, connAgentVersion, ratchetSyncState = rss} - smpMsg = + cData@ConnData {connId, connAgentVersion, ratchetSyncState = rss} + smpMsg + pendingMsgs = withConnLock c connId "processSMP" $ case smpMsg of SMP.MSG msg@SMP.RcvMessage {msgId = srvMsgId} -> do atomically $ incSMPServerStat c userId srv recvMsgs @@ -2211,7 +2258,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(_, srv, _), _v, sessId, ts) ack' <- handleNotifyAck $ case msg' of SMP.ClientRcvMsgBody {msgTs = srvTs, msgFlags, msgBody} -> processClientMsg srvTs msgFlags msgBody SMP.ClientRcvMsgQuota {} -> queueDrained >> ack - whenM (atomically $ hasGetLock c rq) $ + whenM (liftIO $ hasGetLock c rq) $ notify (MSGNTF $ SMP.rcvMessageMeta srvMsgId msg') pure ack' where @@ -2383,7 +2430,9 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(_, srv, _), _v, sessId, ts) r -> unexpected r where notify :: forall e m. (AEntityI e, MonadIO m) => AEvent e -> m () - notify = notify' connId + notify msg = + let t = ("", connId, AEvt (sAEntity @e) msg) + in atomically $ ifM (isFullTBQueue subQ) (modifyTVar' pendingMsgs (t :)) (writeTBQueue subQ t) prohibited :: Text -> AM () prohibited s = do @@ -2448,6 +2497,18 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(_, srv, _), _v, sessId, ts) confId <- withStore c $ \db -> do setConnAgentVersion db connId agentVersion when (pqSupport /= pqSupport') $ setConnPQSupport db connId pqSupport' + -- / + -- Starting with agent version 7 (ratchetOnConfSMPAgentVersion), + -- initiating party initializes ratchet on processing confirmation; + -- previously, it initialized ratchet on allowConnection; + -- this is to support decryption of messages that may be received before allowConnection + liftIO $ do + createRatchet db connId rc' + let RcvQueue {smpClientVersion = v, e2ePrivKey = e2ePrivKey'} = rq + SMPConfirmation {smpClientVersion = v', e2ePubKey = e2ePubKey'} = senderConf + dhSecret = C.dh' e2ePubKey' e2ePrivKey' + setRcvQueueConfirmedE2E db rq dhSecret $ min v v' + -- / createConfirmation db g newConfirmation let srvs = map qServer $ smpReplyQueues senderConf notify $ CONF confId pqSupport' srvs connInfo @@ -2731,25 +2792,27 @@ connectReplyQueues c cData@ConnData {userId, connId} ownConnInfo sq_ (qInfo :| _ Just qInfo' -> do -- in case of SKEY retry the connection is already duplex sq' <- maybe upgradeConn pure sq_ - agentSecureSndQueue c sq' + void $ agentSecureSndQueue c cData sq' enqueueConfirmation c cData sq' ownConnInfo Nothing where upgradeConn = do (sq, _) <- lift $ newSndQueue userId connId qInfo' withStore c $ \db -> upgradeRcvConnToDuplex db connId sq -secureConfirmQueueAsync :: AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> SubscriptionMode -> AM () +secureConfirmQueueAsync :: AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> SubscriptionMode -> AM SndQueueSecured secureConfirmQueueAsync c cData sq srv connInfo e2eEncryption_ subMode = do - agentSecureSndQueue c sq + sqSecured <- agentSecureSndQueue c cData sq storeConfirmation c cData sq e2eEncryption_ =<< mkAgentConfirmation c cData sq srv connInfo subMode lift $ submitPendingMsg c cData sq + pure sqSecured -secureConfirmQueue :: AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> SubscriptionMode -> AM () +secureConfirmQueue :: AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> SubscriptionMode -> AM SndQueueSecured secureConfirmQueue c cData@ConnData {connId, connAgentVersion, pqSupport} sq srv connInfo e2eEncryption_ subMode = do - agentSecureSndQueue c sq + sqSecured <- agentSecureSndQueue c cData sq msg <- mkConfirmation =<< mkAgentConfirmation c cData sq srv connInfo subMode void $ sendConfirmation c sq msg withStore' c $ \db -> setSndQueueStatus db sq Confirmed + pure sqSecured where mkConfirmation :: AgentMessage -> AM MsgBody mkConfirmation aMessage = do @@ -2762,11 +2825,17 @@ secureConfirmQueue c cData@ConnData {connId, connAgentVersion, pqSupport} sq srv (encConnInfo, _) <- agentRatchetEncrypt db cData agentMsgBody e2eEncConnInfoLength (Just pqEnc) currentE2EVersion pure . smpEncode $ AgentConfirmation {agentVersion = connAgentVersion, e2eEncryption_, encConnInfo} -agentSecureSndQueue :: AgentClient -> SndQueue -> AM () -agentSecureSndQueue c sq@SndQueue {sndSecure, status} = - when (sndSecure && status == New) $ do - secureSndQueue c sq - withStore' c $ \db -> setSndQueueStatus db sq Secured +agentSecureSndQueue :: AgentClient -> ConnData -> SndQueue -> AM SndQueueSecured +agentSecureSndQueue c ConnData {connAgentVersion} sq@SndQueue {sndSecure, status} + | sndSecure && status == New = do + secureSndQueue c sq + withStore' c $ \db -> setSndQueueStatus db sq Secured + pure initiatorRatchetOnConf + -- on repeat JOIN processing (e.g. previous attempt to create reply queue failed) + | sndSecure && status == Secured = pure initiatorRatchetOnConf + | otherwise = pure False + where + initiatorRatchetOnConf = connAgentVersion >= ratchetOnConfSMPAgentVersion mkAgentConfirmation :: AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> SubscriptionMode -> AM AgentMessage mkAgentConfirmation c cData sq srv connInfo subMode = do diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 01d97f9ac..23f0a98d1 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -89,9 +89,11 @@ module Simplex.Messaging.Agent.Client activeClientSession, agentClientStore, agentDRG, + ServerQueueInfo (..), AgentServersSummary (..), ServerSessions (..), SMPServerSubs (..), + getAgentSubsTotal, getAgentServersSummary, getAgentSubscriptions, slowNetworkConfig, @@ -116,7 +118,7 @@ module Simplex.Messaging.Agent.Client waitUntilActive, UserNetworkInfo (..), UserNetworkType (..), - getNetworkConfig', + getFastNetworkConfig, waitForUserNetwork, isNetworkOnline, isOnline, @@ -125,6 +127,7 @@ module Simplex.Messaging.Agent.Client beginAgentOperation, endAgentOperation, waitUntilForeground, + waitWhileSuspended, suspendSendingAndDatabase, suspendOperation, notifySuspended, @@ -142,6 +145,9 @@ module Simplex.Messaging.Agent.Client incSMPServerStat, incSMPServerStat', incXFTPServerStat, + incXFTPServerStat', + incXFTPServerSizeStat, + incNtfServerStat, AgentWorkersDetails (..), getAgentWorkersDetails, AgentWorkersSummary (..), @@ -159,7 +165,7 @@ where import Control.Applicative ((<|>)) import Control.Concurrent (ThreadId, forkIO) import Control.Concurrent.Async (Async, uninterruptibleCancel) -import Control.Concurrent.STM (retry, throwSTM) +import Control.Concurrent.STM (retry) import Control.Exception (AsyncException (..), BlockedIndefinitelyOnSTM (..)) import Control.Logger.Simple import Control.Monad @@ -171,11 +177,12 @@ import Crypto.Random (ChaChaDRG) import qualified Data.Aeson as J import qualified Data.Aeson.TH as J import Data.Bifunctor (bimap, first, second) -import Data.ByteString.Base64 +import qualified Data.ByteString.Base64 as B64 import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B -import Data.Either (partitionEithers) +import Data.Either (isRight, partitionEithers) import Data.Functor (($>)) +import Data.Int (Int64) import Data.List (deleteFirstsBy, foldl', partition, (\\)) import Data.List.NonEmpty (NonEmpty (..), (<|)) import qualified Data.List.NonEmpty as L @@ -232,6 +239,7 @@ import Simplex.Messaging.Protocol ProtoServerWithAuth (..), Protocol (..), ProtocolServer (..), + ProtocolType (..), ProtocolTypeI (..), QueueId, QueueIdsKeys (..), @@ -285,7 +293,7 @@ data AgentClient = AgentClient active :: TVar Bool, subQ :: TBQueue ATransmission, msgQ :: TBQueue (ServerTransmissionBatch SMPVersion ErrorType BrokerMsg), - smpServers :: TMap UserId (NonEmpty SMPServerWithAuth), + smpServers :: TMap UserId (UserServers 'PSMP), smpClients :: TMap SMPTransportSession SMPClientVar, -- smpProxiedRelays: -- SMPTransportSession defines connection from proxy to relay, @@ -293,14 +301,14 @@ data AgentClient = AgentClient smpProxiedRelays :: TMap SMPTransportSession SMPServerWithAuth, ntfServers :: TVar [NtfServer], ntfClients :: TMap NtfTransportSession NtfClientVar, - xftpServers :: TMap UserId (NonEmpty XFTPServerWithAuth), + xftpServers :: TMap UserId (UserServers 'PXFTP), xftpClients :: TMap XFTPTransportSession XFTPClientVar, useNetworkConfig :: TVar (NetworkConfig, NetworkConfig), -- (slow, fast) networks userNetworkInfo :: TVar UserNetworkInfo, userNetworkUpdated :: TVar (Maybe UTCTime), subscrConns :: TVar (Set ConnId), - activeSubs :: TRcvQueues, - pendingSubs :: TRcvQueues, + activeSubs :: TRcvQueues (SessionId, RcvQueue), + pendingSubs :: TRcvQueues RcvQueue, removedSubs :: TMap (UserId, SMPServer, SMP.RecipientId) SMPClientError, workerSeq :: TVar Int, smpDeliveryWorkers :: TMap SndQAddr (Worker, TMVar ()), @@ -325,6 +333,7 @@ data AgentClient = AgentClient agentEnv :: Env, smpServersStats :: TMap (UserId, SMPServer) AgentSMPServerStats, xftpServersStats :: TMap (UserId, XFTPServer) AgentXFTPServerStats, + ntfServersStats :: TMap (UserId, NtfServer) AgentNtfServerStats, srvStatsStartedAt :: TVar UTCTime } @@ -363,13 +372,15 @@ getAgentWorker' toW fromW name hasWork c key ws work = do restart <- atomically $ getWorker >>= maybe (pure False) (shouldRestart e_ (toW w) t maxRestarts) when restart runWork shouldRestart e_ Worker {workerId = wId, doWork, action, restarts} t maxRestarts w' - | wId == workerId (toW w') = - checkRestarts . updateRestartCount t =<< readTVar restarts + | wId == workerId (toW w') = do + rc <- readTVar restarts + isActive <- readTVar $ active c + checkRestarts isActive $ updateRestartCount t rc | otherwise = pure False -- there is a new worker in the map, no action where - checkRestarts rc - | restartCount rc < maxRestarts = do + checkRestarts isActive rc + | isActive && restartCount rc < maxRestarts = do writeTVar restarts rc hasWorkToDo' doWork void $ tryPutTMVar action Nothing @@ -377,7 +388,7 @@ getAgentWorker' toW fromW name hasWork c key ws work = do pure True | otherwise = do TM.delete key ws - notifyErr $ CRITICAL True + when isActive $ notifyErr $ CRITICAL True pure False where notifyErr err = do @@ -444,46 +455,47 @@ data UserNetworkType = UNNone | UNCellular | UNWifi | UNEthernet | UNOther deriving (Eq, Show) -- | Creates an SMP agent client instance that receives commands and sends responses via 'TBQueue's. -newAgentClient :: Int -> InitialAgentServers -> UTCTime -> Env -> STM AgentClient +newAgentClient :: Int -> InitialAgentServers -> UTCTime -> Env -> IO AgentClient newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg} currentTs agentEnv = do let cfg = config agentEnv qSize = tbqSize cfg - acThread <- newTVar Nothing - active <- newTVar True - subQ <- newTBQueue qSize - msgQ <- newTBQueue qSize - smpServers <- newTVar smp - smpClients <- TM.empty - smpProxiedRelays <- TM.empty - ntfServers <- newTVar ntf - ntfClients <- TM.empty - xftpServers <- newTVar xftp - xftpClients <- TM.empty - useNetworkConfig <- newTVar (slowNetworkConfig netCfg, netCfg) - userNetworkInfo <- newTVar $ UserNetworkInfo UNOther True - userNetworkUpdated <- newTVar Nothing - subscrConns <- newTVar S.empty + acThread <- newTVarIO Nothing + active <- newTVarIO True + subQ <- newTBQueueIO qSize + msgQ <- newTBQueueIO qSize + smpServers <- newTVarIO $ M.map mkUserServers smp + smpClients <- TM.emptyIO + smpProxiedRelays <- TM.emptyIO + ntfServers <- newTVarIO ntf + ntfClients <- TM.emptyIO + xftpServers <- newTVarIO $ M.map mkUserServers xftp + xftpClients <- TM.emptyIO + useNetworkConfig <- newTVarIO (slowNetworkConfig netCfg, netCfg) + userNetworkInfo <- newTVarIO $ UserNetworkInfo UNOther True + userNetworkUpdated <- newTVarIO Nothing + subscrConns <- newTVarIO S.empty activeSubs <- RQ.empty pendingSubs <- RQ.empty - removedSubs <- TM.empty - workerSeq <- newTVar 0 - smpDeliveryWorkers <- TM.empty - asyncCmdWorkers <- TM.empty - connCmdsQueued <- TM.empty - ntfNetworkOp <- newTVar $ AgentOpState False 0 - rcvNetworkOp <- newTVar $ AgentOpState False 0 - msgDeliveryOp <- newTVar $ AgentOpState False 0 - sndNetworkOp <- newTVar $ AgentOpState False 0 - databaseOp <- newTVar $ AgentOpState False 0 - agentState <- newTVar ASForeground - getMsgLocks <- TM.empty - connLocks <- TM.empty - invLocks <- TM.empty - deleteLock <- createLock - smpSubWorkers <- TM.empty - smpServersStats <- TM.empty - xftpServersStats <- TM.empty - srvStatsStartedAt <- newTVar currentTs + removedSubs <- TM.emptyIO + workerSeq <- newTVarIO 0 + smpDeliveryWorkers <- TM.emptyIO + asyncCmdWorkers <- TM.emptyIO + connCmdsQueued <- TM.emptyIO + ntfNetworkOp <- newTVarIO $ AgentOpState False 0 + rcvNetworkOp <- newTVarIO $ AgentOpState False 0 + msgDeliveryOp <- newTVarIO $ AgentOpState False 0 + sndNetworkOp <- newTVarIO $ AgentOpState False 0 + databaseOp <- newTVarIO $ AgentOpState False 0 + agentState <- newTVarIO ASForeground + getMsgLocks <- TM.emptyIO + connLocks <- TM.emptyIO + invLocks <- TM.emptyIO + deleteLock <- atomically createLock + smpSubWorkers <- TM.emptyIO + smpServersStats <- TM.emptyIO + xftpServersStats <- TM.emptyIO + ntfServersStats <- TM.emptyIO + srvStatsStartedAt <- newTVarIO currentTs return AgentClient { acThread, @@ -523,6 +535,7 @@ newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg} currentTs a agentEnv, smpServersStats, xftpServersStats, + ntfServersStats, srvStatsStartedAt } @@ -589,13 +602,13 @@ getSMPServerClient c@AgentClient {active, smpClients, workerSeq} tSess = do >>= either newClient (waitForProtocolClient c tSess smpClients) where newClient v = do - prs <- atomically TM.empty + prs <- liftIO TM.emptyIO smpConnectClient c tSess prs v -getSMPProxyClient :: AgentClient -> SMPTransportSession -> AM (SMPConnectedClient, Either AgentErrorType ProxiedRelay) -getSMPProxyClient c@AgentClient {active, smpClients, smpProxiedRelays, workerSeq} destSess@(userId, destSrv, qId) = do +getSMPProxyClient :: AgentClient -> Maybe SMPServerWithAuth -> SMPTransportSession -> AM (SMPConnectedClient, Either AgentErrorType ProxiedRelay) +getSMPProxyClient c@AgentClient {active, smpClients, smpProxiedRelays, workerSeq} proxySrv_ destSess@(userId, destSrv, qId) = do unlessM (readTVarIO active) $ throwE INACTIVE - proxySrv <- getNextServer c userId [destSrv] + proxySrv <- maybe (getNextServer c userId [destSrv]) pure proxySrv_ ts <- liftIO getCurrentTime atomically (getClientVar proxySrv ts) >>= \(tSess, auth, v) -> either (newProxyClient tSess auth ts) (waitForProxyClient tSess auth) v @@ -607,11 +620,10 @@ getSMPProxyClient c@AgentClient {active, smpClients, smpProxiedRelays, workerSeq (tSess,auth,) <$> getSessVar workerSeq tSess smpClients ts newProxyClient :: SMPTransportSession -> Maybe SMP.BasicAuth -> UTCTime -> SMPClientVar -> AM (SMPConnectedClient, Either AgentErrorType ProxiedRelay) newProxyClient tSess auth ts v = do - (prs, rv) <- atomically $ do - prs <- TM.empty - -- we do not need to check if it is a new proxied relay session, - -- as the client is just created and there are no sessions yet - (prs,) . either id id <$> getSessVar workerSeq destSrv prs ts + prs <- liftIO TM.emptyIO + -- we do not need to check if it is a new proxied relay session, + -- as the client is just created and there are no sessions yet + rv <- atomically $ either id id <$> getSessVar workerSeq destSrv prs ts clnt <- smpConnectClient c tSess prs v (clnt,) <$> newProxiedRelay clnt auth rv waitForProxyClient :: SMPTransportSession -> Maybe SMP.BasicAuth -> SMPClientVar -> AM (SMPConnectedClient, Either AgentErrorType ProxiedRelay) @@ -637,7 +649,7 @@ getSMPProxyClient c@AgentClient {active, smpClients, smpProxiedRelays, workerSeq pure $ Left e waitForProxiedRelay :: SMPTransportSession -> ProxiedRelayVar -> AM (Either AgentErrorType ProxiedRelay) waitForProxiedRelay (_, srv, _) rv = do - NetworkConfig {tcpConnectTimeout} <- atomically $ getNetworkConfig c + NetworkConfig {tcpConnectTimeout} <- getNetworkConfig c sess_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar $ sessionVar rv) pure $ case sess_ of Just (Right sess) -> Right sess @@ -667,11 +679,13 @@ smpClientDisconnected c@AgentClient {active, smpClients, smpProxiedRelays} tSess -- because we can have a race condition when a new current client could have already -- made subscriptions active, and the old client would be processing diconnection later. removeClientAndSubs :: IO ([RcvQueue], [ConnId]) - removeClientAndSubs = atomically $ ifM currentActiveClient removeSubs $ pure ([], []) + removeClientAndSubs = atomically $ do + removeSessVar v tSess smpClients + ifM (readTVar active) removeSubs (pure ([], [])) where - currentActiveClient = (&&) <$> removeSessVar' v tSess smpClients <*> readTVar active + sessId = sessionId $ thParams client removeSubs = do - (qs, cs) <- RQ.getDelSessQueues tSess $ activeSubs c + (qs, cs) <- RQ.getDelSessQueues tSess sessId $ activeSubs c RQ.batchAddQueues (pendingSubs c) qs -- this removes proxied relays that this client created sessions to destSrvs <- M.keys <$> readTVar prs @@ -696,7 +710,7 @@ resubscribeSMPSession c@AgentClient {smpSubWorkers, workerSeq} tSess = do where getWorkerVar ts = ifM - (null <$> getPending) + (not <$> RQ.hasSessQueues tSess (pendingSubs c)) (pure Nothing) -- prevent race with cleanup and adding pending queues in another call (Just <$> getSessVar workerSeq tSess smpSubWorkers ts) newSubWorker v = do @@ -704,13 +718,14 @@ resubscribeSMPSession c@AgentClient {smpSubWorkers, workerSeq} tSess = do atomically $ putTMVar (sessionVar v) a runSubWorker = do ri <- asks $ reconnectInterval . config - withRetryInterval ri $ \_ loop -> do - pending <- atomically getPending + withRetryForeground ri isForeground (isNetworkOnline c) $ \_ loop -> do + pending <- liftIO $ RQ.getSessQueues tSess $ pendingSubs c forM_ (L.nonEmpty pending) $ \qs -> do + liftIO $ waitUntilForeground c liftIO $ waitForUserNetwork c reconnectSMPClient c tSess qs loop - getPending = RQ.getSessQueues tSess $ pendingSubs c + isForeground = (ASForeground ==) <$> readTVar (agentState c) cleanup :: SessionVar (Async ()) -> STM () cleanup v = do -- Here we wait until TMVar is not empty to prevent worker cleanup happening before worker is added to TMVar. @@ -775,7 +790,7 @@ getXFTPServerClient c@AgentClient {active, xftpClients, workerSeq} tSess@(_, srv connectClient :: XFTPClientVar -> AM XFTPClient connectClient v = do cfg <- asks $ xftpCfg . config - xftpNetworkConfig <- atomically $ getNetworkConfig c + xftpNetworkConfig <- getNetworkConfig c liftError' (protocolClientError XFTP $ B.unpack $ strEncode srv) $ X.getXFTPClient tSess cfg {xftpNetworkConfig} $ clientDisconnected v @@ -794,7 +809,7 @@ waitForProtocolClient :: ClientVar msg -> AM (Client msg) waitForProtocolClient c tSess@(_, srv, _) clients v = do - NetworkConfig {tcpConnectTimeout} <- atomically $ getNetworkConfig c + NetworkConfig {tcpConnectTimeout} <- getNetworkConfig c client_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar $ sessionVar v) case client_ of Just (Right smpClient) -> pure smpClient @@ -845,26 +860,26 @@ hostEvent' event = event (AProtocolType $ protocolTypeI @(ProtoType msg)) . clie getClientConfig :: AgentClient -> (AgentConfig -> ProtocolClientConfig v) -> AM' (ProtocolClientConfig v) getClientConfig c cfgSel = do cfg <- asks $ cfgSel . config - networkConfig <- atomically $ getNetworkConfig c + networkConfig <- getNetworkConfig c pure cfg {networkConfig} -getNetworkConfig :: AgentClient -> STM NetworkConfig +getNetworkConfig :: MonadIO m => AgentClient -> m NetworkConfig getNetworkConfig c = do - (slowCfg, fastCfg) <- readTVar (useNetworkConfig c) - UserNetworkInfo {networkType} <- readTVar $ userNetworkInfo c + (slowCfg, fastCfg) <- readTVarIO $ useNetworkConfig c + UserNetworkInfo {networkType} <- readTVarIO $ userNetworkInfo c pure $ case networkType of UNCellular -> slowCfg UNNone -> slowCfg _ -> fastCfg -- returns fast network config -getNetworkConfig' :: AgentClient -> IO NetworkConfig -getNetworkConfig' = fmap snd . readTVarIO . useNetworkConfig -{-# INLINE getNetworkConfig' #-} +getFastNetworkConfig :: AgentClient -> IO NetworkConfig +getFastNetworkConfig = fmap snd . readTVarIO . useNetworkConfig +{-# INLINE getFastNetworkConfig #-} waitForUserNetwork :: AgentClient -> IO () waitForUserNetwork c = - unlessM (atomically $ isNetworkOnline c) $ do + unlessM (isOnline <$> readTVarIO (userNetworkInfo c)) $ do delay <- registerDelay $ userNetworkInterval $ config $ agentEnv c atomically $ unlessM (isNetworkOnline c) $ unlessM (readTVar delay) retry @@ -896,19 +911,18 @@ cancelWorker Worker {doWork, action} = do noWorkToDo doWork atomically (tryTakeTMVar action) >>= mapM_ (mapM_ uninterruptibleCancel) -waitUntilActive :: AgentClient -> STM () -waitUntilActive c = unlessM (readTVar $ active c) retry -{-# INLINE waitUntilActive #-} +waitUntilActive :: AgentClient -> IO () +waitUntilActive AgentClient {active} = unlessM (readTVarIO active) $ atomically $ unlessM (readTVar active) retry -throwWhenInactive :: AgentClient -> STM () -throwWhenInactive c = unlessM (readTVar $ active c) $ throwSTM ThreadKilled +throwWhenInactive :: AgentClient -> IO () +throwWhenInactive c = unlessM (readTVarIO $ active c) $ E.throwIO ThreadKilled {-# INLINE throwWhenInactive #-} -- this function is used to remove workers once delivery is complete, not when it is removed from the map -throwWhenNoDelivery :: AgentClient -> SndQueue -> STM () +throwWhenNoDelivery :: AgentClient -> SndQueue -> IO () throwWhenNoDelivery c sq = - unlessM (TM.member (qAddress sq) $ smpDeliveryWorkers c) $ - throwSTM ThreadKilled + unlessM (TM.memberIO (qAddress sq) $ smpDeliveryWorkers c) $ + E.throwIO ThreadKilled closeProtocolServerClients :: ProtocolServerClient v err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO () closeProtocolServerClients c clientsSel = @@ -934,7 +948,7 @@ closeClient c clientSel tSess = closeClient_ :: ProtocolServerClient v err msg => AgentClient -> ClientVar msg -> IO () closeClient_ c v = do - NetworkConfig {tcpConnectTimeout} <- atomically $ getNetworkConfig c + NetworkConfig {tcpConnectTimeout} <- getNetworkConfig c E.handle (\BlockedIndefinitelyOnSTM -> pure ()) $ tcpConnectTimeout `timeout` atomically (readTMVar $ sessionVar v) >>= \case Just (Right client) -> closeProtocolServerClient (protocolClient client) `catchAll_` pure () @@ -988,9 +1002,9 @@ withClient_ c tSess@(_, srv, _) action = do logServer "<--" c srv "" $ bshow e throwE e -withProxySession :: AgentClient -> SMPTransportSession -> SMP.SenderId -> ByteString -> ((SMPConnectedClient, ProxiedRelay) -> AM a) -> AM a -withProxySession c destSess@(_, destSrv, _) entId cmdStr action = do - (cl, sess_) <- getSMPProxyClient c destSess +withProxySession :: AgentClient -> Maybe SMPServerWithAuth -> SMPTransportSession -> SMP.SenderId -> ByteString -> ((SMPConnectedClient, ProxiedRelay) -> AM a) -> AM a +withProxySession c proxySrv_ destSess@(_, destSrv, _) entId cmdStr action = do + (cl, sess_) <- getSMPProxyClient c proxySrv_ destSess logServer ("--> " <> proxySrv cl <> " >") c destSrv entId cmdStr case sess_ of Right sess -> do @@ -1022,7 +1036,7 @@ withLogClient c tSess entId cmdStr action = withLogClient_ c tSess entId cmdStr withSMPClient :: SMPQueueRec q => AgentClient -> q -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> AM a withSMPClient c q cmdStr action = do - tSess <- liftIO $ mkSMPTransportSession c q + tSess <- mkSMPTransportSession c q withLogClient c tSess (queueId q) cmdStr $ action . connectedClient sendOrProxySMPMessage :: AgentClient -> UserId -> SMPServer -> ByteString -> Maybe SMP.SndPrivateAuthKey -> SMP.SenderId -> MsgFlags -> SMP.MsgBody -> AM (Maybe SMPServer) @@ -1047,8 +1061,8 @@ sendOrProxySMPCommand :: (SMPClient -> ExceptT SMPClientError IO ()) -> AM (Maybe SMPServer) sendOrProxySMPCommand c userId destSrv cmdStr senderId sendCmdViaProxy sendCmdDirectly = do - sess <- liftIO $ mkTransportSession c userId destSrv senderId - ifM (atomically shouldUseProxy) (sendViaProxy sess) (sendDirectly sess $> Nothing) + sess <- mkTransportSession c userId destSrv senderId + ifM shouldUseProxy (sendViaProxy Nothing sess) (sendDirectly sess $> Nothing) where shouldUseProxy = do cfg <- getNetworkConfig c @@ -1065,23 +1079,32 @@ sendOrProxySMPCommand c userId destSrv cmdStr senderId sendCmdViaProxy sendCmdDi SPFAllow -> True SPFAllowProtected -> ipAddressProtected cfg destSrv SPFProhibit -> False - unknownServer = maybe True (all ((destSrv /=) . protoServer)) <$> TM.lookup userId (userServers c) - sendViaProxy destSess@(_, _, qId) = do - r <- tryAgentError . withProxySession c destSess senderId ("PFWD " <> cmdStr) $ \(SMPConnectedClient smp _, proxySess) -> do + unknownServer = liftIO $ maybe True (notElem destSrv . knownSrvs) <$> TM.lookupIO userId (smpServers c) + sendViaProxy :: Maybe SMPServerWithAuth -> SMPTransportSession -> AM (Maybe SMPServer) + sendViaProxy proxySrv_ destSess@(_, _, qId) = do + r <- tryAgentError . withProxySession c proxySrv_ destSess senderId ("PFWD " <> cmdStr) $ \(SMPConnectedClient smp _, proxySess@ProxiedRelay {prBasicAuth}) -> do r' <- liftClient SMP (clientServer smp) $ sendCmdViaProxy smp proxySess + let proxySrv = protocolClientServer' smp case r' of - Right () -> pure . Just $ protocolClientServer' smp + Right () -> pure $ Just proxySrv Left proxyErr -> do case proxyErr of - (ProxyProtocolError (SMP.PROXY SMP.NO_SESSION)) -> atomically deleteRelaySession - _ -> pure () - throwE - PROXY - { proxyServer = protocolClientServer smp, - relayServer = B.unpack $ strEncode destSrv, - proxyErr - } + ProxyProtocolError (SMP.PROXY SMP.NO_SESSION) -> do + atomically deleteRelaySession + case proxySrv_ of + Just _ -> proxyError + -- sendViaProxy is called recursively here to re-create the session via the same server + -- to avoid failure in interactive calls that don't retry after the session disconnection. + Nothing -> sendViaProxy (Just $ ProtoServerWithAuth proxySrv prBasicAuth) destSess + _ -> proxyError where + proxyError = + throwE + PROXY + { proxyServer = protocolClientServer smp, + relayServer = B.unpack $ strEncode destSrv, + proxyErr + } -- checks that the current proxied relay session is the same one that was used to send the message and removes it deleteRelaySession = ( TM.lookup destSess (smpProxiedRelays c) @@ -1102,7 +1125,7 @@ sendOrProxySMPCommand c userId destSrv cmdStr senderId sendCmdViaProxy sendCmdDi forM_ r' $ \proxySrv -> atomically $ incSMPServerStat c userId proxySrv sentProxied pure r' Left e - | serverHostError e -> ifM (atomically directAllowed) (sendDirectly destSess $> Nothing) (throwE e) + | serverHostError e -> ifM directAllowed (sendDirectly destSess $> Nothing) (throwE e) | otherwise -> throwE e sendDirectly tSess = withLogClient_ c tSess senderId ("SEND " <> cmdStr) $ \(SMPConnectedClient smp _) -> do @@ -1128,7 +1151,7 @@ withXFTPClient :: (Client msg -> ExceptT (ProtocolClientError err) IO b) -> AM b withXFTPClient c (userId, srv, entityId) cmdStr action = do - tSess <- liftIO $ mkTransportSession c userId srv entityId + tSess <- mkTransportSession c userId srv entityId withLogClient c tSess entityId cmdStr action liftClient :: (Show err, Encoding err) => (HostName -> err -> AgentErrorType) -> HostName -> ExceptT (ProtocolClientError err) IO a -> AM a @@ -1200,7 +1223,7 @@ runXFTPServerTest :: AgentClient -> UserId -> XFTPServerWithAuth -> AM' (Maybe P runXFTPServerTest c userId (ProtoServerWithAuth srv auth) = do cfg <- asks $ xftpCfg . config g <- asks random - xftpNetworkConfig <- atomically $ getNetworkConfig c + xftpNetworkConfig <- getNetworkConfig c workDir <- getXFTPWorkPath filePath <- getTempFilePath workDir rcvPath <- getTempFilePath workDir @@ -1271,7 +1294,7 @@ getXFTPWorkPath = do workDir <- readTVarIO =<< asks (xftpWorkDir . xftpAgent) maybe getTemporaryDirectory pure workDir -mkTransportSession :: AgentClient -> UserId -> ProtoServer msg -> EntityId -> IO (TransportSession msg) +mkTransportSession :: MonadIO m => AgentClient -> UserId -> ProtoServer msg -> EntityId -> m (TransportSession msg) mkTransportSession c userId srv entityId = mkTSession userId srv entityId <$> getSessionMode c {-# INLINE mkTransportSession #-} @@ -1279,7 +1302,7 @@ mkTSession :: UserId -> ProtoServer msg -> EntityId -> TransportSessionMode -> T mkTSession userId srv entityId mode = (userId, srv, if mode == TSMEntity then Just entityId else Nothing) {-# INLINE mkTSession #-} -mkSMPTransportSession :: SMPQueueRec q => AgentClient -> q -> IO SMPTransportSession +mkSMPTransportSession :: (SMPQueueRec q, MonadIO m) => AgentClient -> q -> m SMPTransportSession mkSMPTransportSession c q = mkSMPTSession q <$> getSessionMode c {-# INLINE mkSMPTransportSession #-} @@ -1287,8 +1310,8 @@ mkSMPTSession :: SMPQueueRec q => q -> TransportSessionMode -> SMPTransportSessi mkSMPTSession q = mkTSession (qUserId q) (qServer q) (qConnId q) {-# INLINE mkSMPTSession #-} -getSessionMode :: AgentClient -> IO TransportSessionMode -getSessionMode = atomically . fmap sessionMode . getNetworkConfig +getSessionMode :: MonadIO m => AgentClient -> m TransportSessionMode +getSessionMode = fmap sessionMode . getNetworkConfig {-# INLINE getSessionMode #-} newRcvQueue :: AgentClient -> UserId -> ConnId -> SMPServerWithAuth -> VersionRangeSMPC -> SubscriptionMode -> SenderCanSecure -> AM (NewRcvQueue, SMPQueueUri, SMPTransportSession, SessionId) @@ -1299,7 +1322,7 @@ newRcvQueue c userId connId (ProtoServerWithAuth srv auth) vRange subMode sender (dhKey, privDhKey) <- atomically $ C.generateKeyPair g (e2eDhKey, e2ePrivKey) <- atomically $ C.generateKeyPair g logServer "-->" c srv "" "NEW" - tSess <- liftIO $ mkTransportSession c userId srv connId + tSess <- mkTransportSession c userId srv connId (sessId, QIK {rcvId, sndId, rcvPublicDhKey, sndSecure}) <- withClient c tSess $ \(SMPConnectedClient smp _) -> (sessionId $ thParams smp,) <$> createSMPQueue smp rKeys dhKey auth subMode senderCanSecure @@ -1328,14 +1351,17 @@ newRcvQueue c userId connId (ProtoServerWithAuth srv auth) vRange subMode sender qUri = SMPQueueUri vRange $ SMPQueueAddress srv sndId e2eDhKey sndSecure pure (rq, qUri, tSess, sessId) -processSubResult :: AgentClient -> RcvQueue -> Either SMPClientError () -> STM () -processSubResult c rq@RcvQueue {connId} = \case +processSubResult :: AgentClient -> SessionId -> RcvQueue -> Either SMPClientError () -> STM () +processSubResult c sessId rq@RcvQueue {userId, server, connId} = \case Left e -> - unless (temporaryClientError e) $ + unless (temporaryClientError e) $ do + incSMPServerStat c userId server connSubErrs failSubscription c rq e Right () -> - whenM (hasPendingSubscription c connId) $ - addSubscription c rq + ifM + (hasPendingSubscription c connId) + (incSMPServerStat c userId server connSubscribed >> addSubscription c sessId rq) + (incSMPServerStat c userId server connSubIgnored) temporaryAgentError :: AgentErrorType -> Bool temporaryAgentError = \case @@ -1382,19 +1408,19 @@ subscribeQueues c qs = do (errs <> rs,) <$> readTVarIO session where checkQueue rq = do - prohibited <- atomically $ hasGetLock c rq + prohibited <- liftIO $ hasGetLock c rq pure $ if prohibited then Left (rq, Left $ CMD PROHIBITED "subscribeQueues") else Right rq subscribeQueues_ :: Env -> TVar (Maybe SessionId) -> SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses SMPClientError ()) subscribeQueues_ env session smp qs' = do let (userId, srv, _) = transportSession' smp - atomically $ incSMPServerStat' c userId srv connSubAttempts (length qs') + atomically $ incSMPServerStat' c userId srv connSubAttempts $ length qs' rs <- sendBatch subscribeSMPQueues smp qs' active <- atomically $ ifM (activeClientSession c tSess sessId) (writeTVar session (Just sessId) >> processSubResults rs $> True) - (pure False) + (incSMPServerStat' c userId srv connSubIgnored (length rs) $> False) if active then when (hasTempErrors rs) resubscribe $> rs else do @@ -1405,7 +1431,7 @@ subscribeQueues c qs = do sessId = sessionId $ thParams smp hasTempErrors = any (either temporaryClientError (const False) . snd) processSubResults :: NonEmpty (RcvQueue, Either SMPClientError ()) -> STM () - processSubResults = mapM_ $ uncurry $ processSubResult c + processSubResults = mapM_ $ uncurry $ processSubResult c sessId resubscribe = resubscribeSMPSession c tSess `runReaderT` env activeClientSession :: AgentClient -> SMPTransportSession -> SessionId -> STM Bool @@ -1423,7 +1449,7 @@ sendTSessionBatches statCmd toRQ action c qs = where batchQueues :: AM' [(SMPTransportSession, NonEmpty q)] batchQueues = do - mode <- atomically $ sessionMode <$> getNetworkConfig c + mode <- getSessionMode c pure . M.assocs $ foldl' (batch mode) M.empty qs where batch mode m q = @@ -1444,10 +1470,10 @@ sendBatch smpCmdFunc smp qs = L.zip qs <$> smpCmdFunc smp (L.map queueCreds qs) where queueCreds RcvQueue {rcvPrivateKey, rcvId} = (rcvPrivateKey, rcvId) -addSubscription :: AgentClient -> RcvQueue -> STM () -addSubscription c rq@RcvQueue {connId} = do +addSubscription :: AgentClient -> SessionId -> RcvQueue -> STM () +addSubscription c sessId rq@RcvQueue {connId} = do modifyTVar' (subscrConns c) $ S.insert connId - RQ.addQueue rq $ activeSubs c + RQ.addQueue (sessId, rq) $ activeSubs c RQ.deleteQueue rq $ pendingSubs c failSubscription :: AgentClient -> RcvQueue -> SMPClientError -> STM () @@ -1466,7 +1492,7 @@ addNewQueueSubscription c rq tSess sessId = do atomically $ ifM (activeClientSession c tSess sessId) - (True <$ addSubscription c rq) + (True <$ addSubscription c sessId rq) (False <$ addPendingSubscription c rq) unless same $ resubscribeSMPSession c tSess @@ -1484,8 +1510,8 @@ removeSubscription c connId = do RQ.deleteConn connId $ activeSubs c RQ.deleteConn connId $ pendingSubs c -getSubscriptions :: AgentClient -> STM (Set ConnId) -getSubscriptions = readTVar . subscrConns +getSubscriptions :: AgentClient -> IO (Set ConnId) +getSubscriptions = readTVarIO . subscrConns {-# INLINE getSubscriptions #-} logServer :: MonadIO m => ByteString -> AgentClient -> ProtocolServer s -> QueueId -> ByteString -> m () @@ -1499,7 +1525,7 @@ showServer ProtocolServer {host, port} = {-# INLINE showServer #-} logSecret :: ByteString -> ByteString -logSecret bs = encode $ B.take 3 bs +logSecret bs = B64.encode $ B.take 3 bs {-# INLINE logSecret #-} sendConfirmation :: AgentClient -> SndQueue -> ByteString -> AM (Maybe SMPServer) @@ -1584,9 +1610,9 @@ sendAck c rq@RcvQueue {rcvId, rcvPrivateKey} msgId = do ackSMPMessage smp rcvPrivateKey rcvId msgId atomically $ releaseGetLock c rq -hasGetLock :: AgentClient -> RcvQueue -> STM Bool +hasGetLock :: AgentClient -> RcvQueue -> IO Bool hasGetLock c RcvQueue {server, rcvId} = - TM.member (server, rcvId) $ getMsgLocks c + TM.memberIO (server, rcvId) $ getMsgLocks c releaseGetLock :: AgentClient -> RcvQueue -> STM () releaseGetLock c RcvQueue {server, rcvId} = @@ -1603,7 +1629,15 @@ deleteQueue c rq@RcvQueue {rcvId, rcvPrivateKey} = do deleteSMPQueue smp rcvPrivateKey rcvId deleteQueues :: AgentClient -> [RcvQueue] -> AM' [(RcvQueue, Either AgentErrorType ())] -deleteQueues = sendTSessionBatches "DEL" id $ sendBatch deleteSMPQueues +deleteQueues c = sendTSessionBatches "DEL" id deleteQueues_ c + where + deleteQueues_ smp rqs = do + let (userId, srv, _) = transportSession' smp + atomically $ incSMPServerStat' c userId srv connDelAttempts $ length rqs + rs <- sendBatch deleteSMPQueues smp rqs + let successes = foldl' (\n (_, r) -> if isRight r then n + 1 else n) 0 rs + atomically $ incSMPServerStat' c userId srv connDeleted successes + pure rs sendAgentMessage :: AgentClient -> SndQueue -> MsgFlags -> ByteString -> AM (Maybe SMPServer) sendAgentMessage c sq@SndQueue {userId, server, sndId, sndPrivateKey} msgFlags agentMsg = do @@ -1611,10 +1645,24 @@ sendAgentMessage c sq@SndQueue {userId, server, sndId, sndPrivateKey} msgFlags a msg <- agentCbEncrypt sq Nothing $ smpEncode clientMsg sendOrProxySMPMessage c userId server "" (Just sndPrivateKey) sndId msgFlags msg -getQueueInfo :: AgentClient -> RcvQueue -> AM QueueInfo -getQueueInfo c rq@RcvQueue {rcvId, rcvPrivateKey} = - withSMPClient c rq "QUE" $ \smp -> - getSMPQueueInfo smp rcvPrivateKey rcvId +data ServerQueueInfo = ServerQueueInfo + { server :: SMPServer, + rcvId :: Text, + sndId :: Text, + ntfId :: Maybe Text, + status :: Text, + info :: QueueInfo + } + deriving (Show) + +getQueueInfo :: AgentClient -> RcvQueue -> AM ServerQueueInfo +getQueueInfo c rq@RcvQueue {server, rcvId, rcvPrivateKey, sndId, status, clientNtfCreds} = + withSMPClient c rq "QUE" $ \smp -> do + info <- getSMPQueueInfo smp rcvPrivateKey rcvId + let ntfId = enc . (\ClientNtfCreds {notifierId} -> notifierId) <$> clientNtfCreds + pure ServerQueueInfo {server, rcvId = enc rcvId, sndId = enc sndId, ntfId, status = serializeQueueStatus status, info} + where + enc = decodeLatin1 . B64.encode agentNtfRegisterToken :: AgentClient -> NtfToken -> NtfPublicAuthKey -> C.PublicKeyX25519 -> AM (NtfTokenId, C.PublicKeyX25519) agentNtfRegisterToken c NtfToken {deviceToken, ntfServer, ntfPrivKey} ntfPubKey pubDhKey = @@ -1663,7 +1711,7 @@ agentXFTPNewChunk c SndFileChunk {userId, chunkSpec = XFTPChunkSpec {chunkSize}, (sndKey, replicaKey) <- atomically . C.generateAuthKeyPair C.SEd25519 =<< asks random let fileInfo = FileInfo {sndKey, size = chunkSize, digest = chunkDigest} logServer "-->" c srv "" "FNEW" - tSess <- liftIO $ mkTransportSession c userId srv chunkDigest + tSess <- mkTransportSession c userId srv chunkDigest (sndId, rIds) <- withClient c tSess $ \xftp -> X.createXFTPChunk xftp replicaKey fileInfo (L.map fst rKeys) auth logServer "<--" c srv "" $ B.unwords ["SIDS", logSecret sndId] pure NewSndChunkReplica {server = srv, replicaId = ChunkReplicaId sndId, replicaKey, rcvIdsKeys = L.toList $ xftpRcvIdsKeys rIds rKeys} @@ -1816,16 +1864,28 @@ beginAgentOperation c op = do -- unsafeIOToSTM $ putStrLn $ "beginOperation! " <> show op <> " " <> show (opsInProgress s + 1) writeTVar opVar $! s {opsInProgress = opsInProgress s + 1} -agentOperationBracket :: MonadUnliftIO m => AgentClient -> AgentOperation -> (AgentClient -> STM ()) -> m a -> m a +agentOperationBracket :: MonadUnliftIO m => AgentClient -> AgentOperation -> (AgentClient -> IO ()) -> m a -> m a agentOperationBracket c op check action = E.bracket - (atomically (check c) >> atomically (beginAgentOperation c op)) + (liftIO (check c) >> atomically (beginAgentOperation c op)) (\_ -> atomically $ endAgentOperation c op) (const action) -waitUntilForeground :: AgentClient -> STM () -waitUntilForeground c = unlessM ((ASForeground ==) <$> readTVar (agentState c)) retry -{-# INLINE waitUntilForeground #-} +waitUntilForeground :: AgentClient -> IO () +waitUntilForeground c = + unlessM (foreground readTVarIO) $ atomically $ unlessM (foreground readTVar) retry + where + foreground :: Monad m => (TVar AgentState -> m AgentState) -> m Bool + foreground rd = (ASForeground ==) <$> rd (agentState c) + +-- This function waits while agent is suspended, but will proceed while it is suspending, +-- to allow completing in-flight operations. +waitWhileSuspended :: AgentClient -> IO () +waitWhileSuspended c = + whenM (suspended readTVarIO) $ atomically $ whenM (suspended readTVar) retry + where + suspended :: Monad m => (TVar AgentState -> m AgentState) -> m Bool + suspended rd = (ASSuspended ==) <$> rd (agentState c) withStore' :: AgentClient -> (DB.Connection -> IO a) -> AM a withStore' c action = withStore c $ fmap Right . action @@ -1875,7 +1935,7 @@ storeError = \case SEDatabaseBusy e -> CRITICAL True $ B.unpack e e -> INTERNAL $ show e -userServers :: forall p. (ProtocolTypeI p, UserProtocol p) => AgentClient -> TMap UserId (NonEmpty (ProtoServerWithAuth p)) +userServers :: forall p. (ProtocolTypeI p, UserProtocol p) => AgentClient -> TMap UserId (UserServers p) userServers c = case protocolTypeI @p of SPSMP -> smpServers c SPXFTP -> xftpServers c @@ -1896,52 +1956,67 @@ getNextServer c userId usedSrvs = withUserServers c userId $ \srvs -> withUserServers :: forall p a. (ProtocolTypeI p, UserProtocol p) => AgentClient -> UserId -> (NonEmpty (ProtoServerWithAuth p) -> AM a) -> AM a withUserServers c userId action = - atomically (TM.lookup userId $ userServers c) >>= \case - Just srvs -> action srvs + liftIO (TM.lookupIO userId $ userServers c) >>= \case + Just srvs -> action $ enabledSrvs srvs _ -> throwE $ INTERNAL "unknown userId - no user servers" withNextSrv :: forall p a. (ProtocolTypeI p, UserProtocol p) => AgentClient -> UserId -> TVar [ProtocolServer p] -> [ProtocolServer p] -> (ProtoServerWithAuth p -> AM a) -> AM a withNextSrv c userId usedSrvs initUsed action = do used <- readTVarIO usedSrvs srvAuth@(ProtoServerWithAuth srv _) <- getNextServer c userId used - atomically $ do - srvs_ <- TM.lookup userId $ userServers c - let unused = maybe [] ((\\ used) . map protoServer . L.toList) srvs_ - used' = if null unused then initUsed else srv : used - writeTVar usedSrvs $! used' + srvs_ <- liftIO $ TM.lookupIO userId $ userServers c + let unused = maybe [] ((\\ used) . map protoServer . L.toList . enabledSrvs) srvs_ + used' = if null unused then initUsed else srv : used + atomically $ writeTVar usedSrvs $! used' action srvAuth incSMPServerStat :: AgentClient -> UserId -> SMPServer -> (AgentSMPServerStats -> TVar Int) -> STM () incSMPServerStat c userId srv sel = incSMPServerStat' c userId srv sel 1 incSMPServerStat' :: AgentClient -> UserId -> SMPServer -> (AgentSMPServerStats -> TVar Int) -> Int -> STM () -incSMPServerStat' AgentClient {smpServersStats} userId srv sel n = do - TM.lookup (userId, srv) smpServersStats >>= \case - Just v -> modifyTVar' (sel v) (+ n) - Nothing -> do - newStats <- newAgentSMPServerStats - modifyTVar' (sel newStats) (+ n) - TM.insert (userId, srv) newStats smpServersStats +incSMPServerStat' = incServerStat (\AgentClient {smpServersStats = s} -> s) newAgentSMPServerStats incXFTPServerStat :: AgentClient -> UserId -> XFTPServer -> (AgentXFTPServerStats -> TVar Int) -> STM () -incXFTPServerStat AgentClient {xftpServersStats} userId srv sel = do - TM.lookup (userId, srv) xftpServersStats >>= \case - Just v -> modifyTVar' (sel v) (+ 1) +incXFTPServerStat c userId srv sel = incXFTPServerStat_ c userId srv sel 1 +{-# INLINE incXFTPServerStat #-} + +incXFTPServerStat' :: AgentClient -> UserId -> XFTPServer -> (AgentXFTPServerStats -> TVar Int) -> Int -> STM () +incXFTPServerStat' = incXFTPServerStat_ +{-# INLINE incXFTPServerStat' #-} + +incXFTPServerSizeStat :: AgentClient -> UserId -> XFTPServer -> (AgentXFTPServerStats -> TVar Int64) -> Int64 -> STM () +incXFTPServerSizeStat = incXFTPServerStat_ +{-# INLINE incXFTPServerSizeStat #-} + +incXFTPServerStat_ :: Num n => AgentClient -> UserId -> XFTPServer -> (AgentXFTPServerStats -> TVar n) -> n -> STM () +incXFTPServerStat_ = incServerStat (\AgentClient {xftpServersStats = s} -> s) newAgentXFTPServerStats +{-# INLINE incXFTPServerStat_ #-} + +incNtfServerStat :: AgentClient -> UserId -> NtfServer -> (AgentNtfServerStats -> TVar Int) -> STM () +incNtfServerStat c userId srv sel = incServerStat (\AgentClient {ntfServersStats = s} -> s) newAgentNtfServerStats c userId srv sel 1 +{-# INLINE incNtfServerStat #-} + +incServerStat :: Num n => (AgentClient -> TMap (UserId, ProtocolServer p) s) -> STM s -> AgentClient -> UserId -> ProtocolServer p -> (s -> TVar n) -> n -> STM () +incServerStat statsSel mkNewStats c userId srv sel n = do + TM.lookup (userId, srv) (statsSel c) >>= \case + Just v -> modifyTVar' (sel v) (+ n) Nothing -> do - newStats <- newAgentXFTPServerStats - modifyTVar' (sel newStats) (+ 1) - TM.insert (userId, srv) newStats xftpServersStats + newStats <- mkNewStats + modifyTVar' (sel newStats) (+ n) + TM.insert (userId, srv) newStats (statsSel c) data AgentServersSummary = AgentServersSummary { smpServersStats :: Map (UserId, SMPServer) AgentSMPServerStatsData, xftpServersStats :: Map (UserId, XFTPServer) AgentXFTPServerStatsData, + ntfServersStats :: Map (UserId, NtfServer) AgentNtfServerStatsData, statsStartedAt :: UTCTime, smpServersSessions :: Map (UserId, SMPServer) ServerSessions, smpServersSubs :: Map (UserId, SMPServer) SMPServerSubs, xftpServersSessions :: Map (UserId, XFTPServer) ServerSessions, xftpRcvInProgress :: [XFTPServer], xftpSndInProgress :: [XFTPServer], - xftpDelInProgress :: [XFTPServer] + xftpDelInProgress :: [XFTPServer], + ntfServersSessions :: Map (UserId, NtfServer) ServerSessions } deriving (Show) @@ -1958,10 +2033,30 @@ data ServerSessions = ServerSessions } deriving (Show) +getAgentSubsTotal :: AgentClient -> [UserId] -> IO (SMPServerSubs, Bool) +getAgentSubsTotal c userIds = do + ssActive <- getSubsCount activeSubs + ssPending <- getSubsCount pendingSubs + sess <- hasSession . M.toList =<< readTVarIO (smpClients c) + pure (SMPServerSubs {ssActive, ssPending}, sess) + where + getSubsCount :: (AgentClient -> TRcvQueues q) -> IO Int + getSubsCount subs = M.foldrWithKey' addSub 0 <$> readTVarIO (getRcvQueues $ subs c) + addSub :: (UserId, SMPServer, SMP.RecipientId) -> q -> Int -> Int + addSub (userId, _, _) _ cnt = if userId `elem` userIds then cnt + 1 else cnt + hasSession :: [(SMPTransportSession, SMPClientVar)] -> IO Bool + hasSession = \case + [] -> pure False + (s : ss) -> ifM (isConnected s) (pure True) (hasSession ss) + isConnected ((userId, _, _), SessionVar {sessionVar}) + | userId `elem` userIds = atomically $ maybe False isRight <$> tryReadTMVar sessionVar + | otherwise = pure False + getAgentServersSummary :: AgentClient -> IO AgentServersSummary -getAgentServersSummary c@AgentClient {smpServersStats, xftpServersStats, srvStatsStartedAt, agentEnv} = do +getAgentServersSummary c@AgentClient {smpServersStats, xftpServersStats, ntfServersStats, srvStatsStartedAt, agentEnv} = do sss <- mapM getAgentSMPServerStats =<< readTVarIO smpServersStats xss <- mapM getAgentXFTPServerStats =<< readTVarIO xftpServersStats + nss <- mapM getAgentNtfServerStats =<< readTVarIO ntfServersStats statsStartedAt <- readTVarIO srvStatsStartedAt smpServersSessions <- countSessions =<< readTVarIO (smpClients c) smpServersSubs <- getServerSubs @@ -1969,17 +2064,20 @@ getAgentServersSummary c@AgentClient {smpServersStats, xftpServersStats, srvStat xftpRcvInProgress <- catMaybes <$> getXFTPWorkerSrvs xftpRcvWorkers xftpSndInProgress <- catMaybes <$> getXFTPWorkerSrvs xftpSndWorkers xftpDelInProgress <- getXFTPWorkerSrvs xftpDelWorkers + ntfServersSessions <- countSessions =<< readTVarIO (ntfClients c) pure AgentServersSummary { smpServersStats = sss, xftpServersStats = xss, + ntfServersStats = nss, statsStartedAt, smpServersSessions, smpServersSubs, xftpServersSessions, xftpRcvInProgress, xftpSndInProgress, - xftpDelInProgress + xftpDelInProgress, + ntfServersSessions } where getServerSubs = do @@ -2025,6 +2123,7 @@ getAgentSubscriptions c = do removedSubscriptions <- getRemovedSubs pure $ SubscriptionsInfo {activeSubscriptions, pendingSubscriptions, removedSubscriptions} where + getSubs :: (AgentClient -> TRcvQueues q) -> IO [SubInfo] getSubs sel = map (`subInfo` Nothing) . M.keys <$> readTVarIO (getRcvQueues $ sel c) getRemovedSubs = map (uncurry subInfo . second Just) . M.assocs <$> readTVarIO (removedSubs c) subInfo :: (UserId, SMPServer, SMP.RecipientId) -> Maybe SMPClientError -> SubInfo @@ -2223,3 +2322,5 @@ $(J.deriveJSON defaultJSON ''AgentQueuesInfo) $(J.deriveJSON (enumJSON $ dropPrefix "UN") ''UserNetworkType) $(J.deriveJSON defaultJSON ''UserNetworkInfo) + +$(J.deriveJSON defaultJSON ''ServerQueueInfo) diff --git a/src/Simplex/Messaging/Agent/Env/SQLite.hs b/src/Simplex/Messaging/Agent/Env/SQLite.hs index 2ae2ad5c0..f57cf91e9 100644 --- a/src/Simplex/Messaging/Agent/Env/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Env/SQLite.hs @@ -7,6 +7,7 @@ {-# LANGUAGE NumericUnderscores #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} {-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-} @@ -15,7 +16,12 @@ module Simplex.Messaging.Agent.Env.SQLite AM, AgentConfig (..), InitialAgentServers (..), + ServerCfg (..), + UserServers (..), NetworkConfig (..), + presetServerCfg, + enabledServerCfg, + mkUserServers, defaultAgentConfig, defaultReconnectInterval, tryAgentError, @@ -39,10 +45,14 @@ import Control.Monad.Except import Control.Monad.IO.Unlift import Control.Monad.Reader import Crypto.Random +import Data.Aeson (FromJSON (..), ToJSON (..)) +import qualified Data.Aeson.TH as JQ import Data.ByteArray (ScrubbedBytes) import Data.Int (Int64) import Data.List.NonEmpty (NonEmpty) -import Data.Map (Map) +import qualified Data.List.NonEmpty as L +import Data.Map.Strict (Map) +import Data.Maybe (fromMaybe) import Data.Time.Clock (NominalDiffTime, nominalDay) import Data.Time.Clock.System (SystemTime (..)) import Data.Word (Word16) @@ -59,7 +69,8 @@ import Simplex.Messaging.Crypto.Ratchet (VersionRangeE2E, supportedE2EEncryptVRa import Simplex.Messaging.Notifications.Client (defaultNTFClientConfig) import Simplex.Messaging.Notifications.Transport (NTFVersion) import Simplex.Messaging.Notifications.Types -import Simplex.Messaging.Protocol (NtfServer, VersionRangeSMPC, XFTPServer, XFTPServerWithAuth, supportedSMPClientVRange) +import Simplex.Messaging.Parsers (defaultJSON) +import Simplex.Messaging.Protocol (NtfServer, ProtoServerWithAuth, ProtocolServer, ProtocolType (..), ProtocolTypeI, VersionRangeSMPC, XFTPServer, supportedSMPClientVRange) import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (SMPVersion, TLS, Transport (..)) @@ -74,12 +85,38 @@ type AM' a = ReaderT Env IO a type AM a = ExceptT AgentErrorType (ReaderT Env IO) a data InitialAgentServers = InitialAgentServers - { smp :: Map UserId (NonEmpty SMPServerWithAuth), + { smp :: Map UserId (NonEmpty (ServerCfg 'PSMP)), ntf :: [NtfServer], - xftp :: Map UserId (NonEmpty XFTPServerWithAuth), + xftp :: Map UserId (NonEmpty (ServerCfg 'PXFTP)), netCfg :: NetworkConfig } +data ServerCfg p = ServerCfg + { server :: ProtoServerWithAuth p, + preset :: Bool, + tested :: Maybe Bool, + enabled :: Bool + } + deriving (Show) + +enabledServerCfg :: ProtoServerWithAuth p -> ServerCfg p +enabledServerCfg server = ServerCfg {server, preset = False, tested = Nothing, enabled = True} + +presetServerCfg :: Bool -> ProtoServerWithAuth p -> ServerCfg p +presetServerCfg enabled server = ServerCfg {server, preset = True, tested = Nothing, enabled} + +data UserServers p = UserServers + { enabledSrvs :: NonEmpty (ProtoServerWithAuth p), + knownSrvs :: NonEmpty (ProtocolServer p) + } + +-- This function sets all servers as enabled in case all passed servers are disabled. +mkUserServers :: NonEmpty (ServerCfg p) -> UserServers p +mkUserServers srvs = UserServers {enabledSrvs, knownSrvs} + where + enabledSrvs = L.map (\ServerCfg {server} -> server) $ fromMaybe srvs $ L.nonEmpty $ L.filter (\ServerCfg {enabled} -> enabled) srvs + knownSrvs = L.map (\ServerCfg {server = ProtoServerWithAuth srv _} -> srv) srvs + data AgentConfig = AgentConfig { tcpPort :: Maybe ServiceName, rcvAuthAlg :: C.AuthAlg, @@ -111,10 +148,7 @@ data AgentConfig = AgentConfig xftpMaxRecipientsPerRequest :: Int, deleteErrorCount :: Int, ntfCron :: Word16, - ntfWorkerDelay :: Int, - ntfSMPWorkerDelay :: Int, ntfSubCheckInterval :: NominalDiffTime, - ntfMaxMessages :: Int, caCertificateFile :: FilePath, privateKeyFile :: FilePath, certificateFile :: FilePath, @@ -128,7 +162,7 @@ defaultReconnectInterval = RetryInterval { initialInterval = 2_000000, increaseAfter = 10_000000, - maxInterval = 60_000000 + maxInterval = 180_000000 } defaultMessageRetryInterval :: RetryInterval2 @@ -138,7 +172,7 @@ defaultMessageRetryInterval = RetryInterval { initialInterval = 2_000000, increaseAfter = 10_000000, - maxInterval = 60_000000 + maxInterval = 120_000000 }, riSlow = RetryInterval @@ -183,10 +217,7 @@ defaultAgentConfig = xftpMaxRecipientsPerRequest = 200, deleteErrorCount = 10, ntfCron = 20, -- minutes - ntfWorkerDelay = 100000, -- microseconds - ntfSMPWorkerDelay = 500000, -- microseconds ntfSubCheckInterval = nominalDay, - ntfMaxMessages = 3, -- CA certificate private key is not needed for initialization -- ! we do not generate these caCertificateFile = "/etc/opt/simplex-agent/ca.crt", @@ -211,8 +242,8 @@ newSMPAgentEnv :: AgentConfig -> SQLiteStore -> IO Env newSMPAgentEnv config store = do random <- C.newRandom randomServer <- newTVarIO =<< liftIO newStdGen - ntfSupervisor <- atomically . newNtfSubSupervisor $ tbqSize config - xftpAgent <- atomically newXFTPAgent + ntfSupervisor <- newNtfSubSupervisor $ tbqSize config + xftpAgent <- newXFTPAgent multicastSubscribers <- newTMVarIO 0 pure Env {config, store, random, randomServer, ntfSupervisor, xftpAgent, multicastSubscribers} @@ -229,12 +260,12 @@ data NtfSupervisor = NtfSupervisor data NtfSupervisorCommand = NSCCreate | NSCDelete | NSCSmpDelete | NSCNtfWorker NtfServer | NSCNtfSMPWorker SMPServer deriving (Show) -newNtfSubSupervisor :: Natural -> STM NtfSupervisor +newNtfSubSupervisor :: Natural -> IO NtfSupervisor newNtfSubSupervisor qSize = do - ntfTkn <- newTVar Nothing - ntfSubQ <- newTBQueue qSize - ntfWorkers <- TM.empty - ntfSMPWorkers <- TM.empty + ntfTkn <- newTVarIO Nothing + ntfSubQ <- newTBQueueIO qSize + ntfWorkers <- TM.emptyIO + ntfSMPWorkers <- TM.emptyIO pure NtfSupervisor {ntfTkn, ntfSubQ, ntfWorkers, ntfSMPWorkers} data XFTPAgent = XFTPAgent @@ -245,12 +276,12 @@ data XFTPAgent = XFTPAgent xftpDelWorkers :: TMap XFTPServer Worker } -newXFTPAgent :: STM XFTPAgent +newXFTPAgent :: IO XFTPAgent newXFTPAgent = do - xftpWorkDir <- newTVar Nothing - xftpRcvWorkers <- TM.empty - xftpSndWorkers <- TM.empty - xftpDelWorkers <- TM.empty + xftpWorkDir <- newTVarIO Nothing + xftpRcvWorkers <- TM.emptyIO + xftpSndWorkers <- TM.emptyIO + xftpDelWorkers <- TM.emptyIO pure XFTPAgent {xftpWorkDir, xftpRcvWorkers, xftpSndWorkers, xftpDelWorkers} tryAgentError :: AM a -> AM (Either AgentErrorType a) @@ -294,3 +325,12 @@ updateRestartCount :: SystemTime -> RestartCount -> RestartCount updateRestartCount t (RestartCount minute count) = do let min' = systemSeconds t `div` 60 in RestartCount min' $ if minute == min' then count + 1 else 1 + +$(pure []) + +instance ProtocolTypeI p => ToJSON (ServerCfg p) where + toEncoding = $(JQ.mkToEncoding defaultJSON ''ServerCfg) + toJSON = $(JQ.mkToJSON defaultJSON ''ServerCfg) + +instance ProtocolTypeI p => FromJSON (ServerCfg p) where + parseJSON = $(JQ.mkParseJSON defaultJSON ''ServerCfg) diff --git a/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs b/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs index a239768b0..23a88ea70 100644 --- a/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs +++ b/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs @@ -20,8 +20,8 @@ where import Control.Logger.Simple (logError, logInfo) import Control.Monad -import Control.Monad.Except import Control.Monad.Reader +import Control.Monad.Trans.Except import Data.Bifunctor (first) import qualified Data.Map.Strict as M import Data.Text (Text) @@ -31,6 +31,7 @@ import Simplex.Messaging.Agent.Client import Simplex.Messaging.Agent.Env.SQLite import Simplex.Messaging.Agent.Protocol (AEvent (..), AEvt (..), AgentErrorType (..), BrokerErrorType (..), ConnId, NotificationsMode (..), SAEntity (..)) import Simplex.Messaging.Agent.RetryInterval +import Simplex.Messaging.Agent.Stats import Simplex.Messaging.Agent.Store import Simplex.Messaging.Agent.Store.SQLite import qualified Simplex.Messaging.Crypto as C @@ -40,7 +41,7 @@ import Simplex.Messaging.Protocol (NtfServer, SMPServer, sameSrvAddr) import Simplex.Messaging.Util (diffToMicroseconds, threadDelay', tshow, unlessM) import System.Random (randomR) import UnliftIO -import UnliftIO.Concurrent (forkIO, threadDelay) +import UnliftIO.Concurrent (forkIO) import qualified UnliftIO.Exception as E runNtfSupervisor :: AgentClient -> AM' () @@ -64,7 +65,7 @@ processNtfSub c (connId, cmd) = do logInfo $ "processNtfSub - connId = " <> tshow connId <> " - cmd = " <> tshow cmd case cmd of NSCCreate -> do - (a, RcvQueue {server = smpServer, clientNtfCreds}) <- withStore c $ \db -> runExceptT $ do + (a, RcvQueue {userId, server = smpServer, clientNtfCreds}) <- withStore c $ \db -> runExceptT $ do a <- liftIO $ getNtfSubscription db connId q <- ExceptT $ getPrimaryRcvQueue db connId pure (a, q) @@ -74,12 +75,12 @@ processNtfSub c (connId, cmd) = do withTokenServer $ \ntfServer -> do case clientNtfCreds of Just ClientNtfCreds {notifierId} -> do - let newSub = newNtfSubscription connId smpServer (Just notifierId) ntfServer NASKey - withStore c $ \db -> createNtfSubscription db newSub $ NtfSubNTFAction NSACreate + let newSub = newNtfSubscription userId connId smpServer (Just notifierId) ntfServer NASKey + withStore c $ \db -> createNtfSubscription db newSub $ NSANtf NSACreate lift . void $ getNtfNTFWorker True c ntfServer Nothing -> do - let newSub = newNtfSubscription connId smpServer Nothing ntfServer NASNew - withStore c $ \db -> createNtfSubscription db newSub $ NtfSubSMPAction NSASmpKey + let newSub = newNtfSubscription userId connId smpServer Nothing ntfServer NASNew + withStore c $ \db -> createNtfSubscription db newSub $ NSASMP NSASmpKey lift . void $ getNtfSMPWorker True c smpServer (Just (sub@NtfSubscription {ntfSubStatus, ntfServer = subNtfServer, smpServer = smpServer', ntfQueueId}, action_)) -> do case (clientNtfCreds, ntfQueueId) of @@ -99,24 +100,24 @@ processNtfSub c (connId, cmd) = do if ntfSubStatus == NASNew || ntfSubStatus == NASOff || ntfSubStatus == NASDeleted then resetSubscription else withTokenServer $ \ntfServer -> do - withStore' c $ \db -> supervisorUpdateNtfSub db sub {ntfServer} (NtfSubNTFAction NSACreate) + withStore' c $ \db -> supervisorUpdateNtfSub db sub {ntfServer} (NSANtf NSACreate) lift . void $ getNtfNTFWorker True c ntfServer | otherwise -> case action of - NtfSubNTFAction _ -> lift . void $ getNtfNTFWorker True c subNtfServer - NtfSubSMPAction _ -> lift . void $ getNtfSMPWorker True c smpServer + NSANtf _ -> lift . void $ getNtfNTFWorker True c subNtfServer + NSASMP _ -> lift . void $ getNtfSMPWorker True c smpServer rotate :: AM () rotate = do - withStore' c $ \db -> supervisorUpdateNtfSub db sub (NtfSubNTFAction NSARotate) + withStore' c $ \db -> supervisorUpdateNtfSub db sub (NSANtf NSARotate) lift . void $ getNtfNTFWorker True c subNtfServer resetSubscription :: AM () resetSubscription = withTokenServer $ \ntfServer -> do let sub' = sub {ntfQueueId = Nothing, ntfServer, ntfSubId = Nothing, ntfSubStatus = NASNew} - withStore' c $ \db -> supervisorUpdateNtfSub db sub' (NtfSubSMPAction NSASmpKey) + withStore' c $ \db -> supervisorUpdateNtfSub db sub' (NSASMP NSASmpKey) lift . void $ getNtfSMPWorker True c smpServer NSCDelete -> do sub_ <- withStore' c $ \db -> do - supervisorUpdateNtfAction db connId (NtfSubNTFAction NSADelete) + supervisorUpdateNtfAction db connId (NSANtf NSADelete) getNtfSubscription db connId logInfo $ "processNtfSub, NSCDelete - sub_ = " <> tshow sub_ case sub_ of @@ -126,7 +127,7 @@ processNtfSub c (connId, cmd) = do withStore' c (`getPrimaryRcvQueue` connId) >>= \case Right rq@RcvQueue {server = smpServer} -> do logInfo $ "processNtfSub, NSCSmpDelete - rq = " <> tshow rq - withStore' c $ \db -> supervisorUpdateNtfAction db connId (NtfSubSMPAction NSASmpDelete) + withStore' c $ \db -> supervisorUpdateNtfAction db connId (NSASMP NSASmpDelete) lift . void $ getNtfSMPWorker True c smpServer _ -> notifyInternalError c connId "NSCSmpDelete - no rcv queue" NSCNtfWorker ntfServer -> lift . void $ getNtfNTFWorker True c ntfServer @@ -146,12 +147,10 @@ withTokenServer :: (NtfServer -> AM ()) -> AM () withTokenServer action = lift getNtfToken >>= mapM_ (\NtfToken {ntfServer} -> action ntfServer) runNtfWorker :: AgentClient -> NtfServer -> Worker -> AM () -runNtfWorker c srv Worker {doWork} = do - delay <- asks $ ntfWorkerDelay . config +runNtfWorker c srv Worker {doWork} = forever $ do waitForWork doWork ExceptT $ agentOperationBracket c AONtfNetwork throwWhenInactive $ runExceptT runNtfOperation - threadDelay delay where runNtfOperation :: AM () runNtfOperation = @@ -160,70 +159,73 @@ runNtfWorker c srv Worker {doWork} = do logInfo $ "runNtfWorker, nextSub " <> tshow nextSub ri <- asks $ reconnectInterval . config withRetryInterval ri $ \_ loop -> do + liftIO $ waitWhileSuspended c liftIO $ waitForUserNetwork c processSub nextSub `catchAgentError` retryOnError c "NtfWorker" loop (workerInternalError c connId . show) processSub :: (NtfSubscription, NtfSubNTFAction, NtfActionTs) -> AM () - processSub (sub@NtfSubscription {connId, smpServer, ntfSubId}, action, actionTs) = do + processSub (sub@NtfSubscription {userId, connId, smpServer, ntfSubId}, action, actionTs) = do ts <- liftIO getCurrentTime unlessM (lift $ rescheduleAction doWork ts actionTs) $ case action of NSACreate -> lift getNtfToken >>= \case - Just tkn@NtfToken {ntfTokenId = Just tknId, ntfTknStatus = NTActive, ntfMode = NMInstant} -> do + Just tkn@NtfToken {ntfServer, ntfTokenId = Just tknId, ntfTknStatus = NTActive, ntfMode = NMInstant} -> do RcvQueue {clientNtfCreds} <- withStore c (`getPrimaryRcvQueue` connId) case clientNtfCreds of Just ClientNtfCreds {ntfPrivateKey, notifierId} -> do + atomically $ incNtfServerStat c userId ntfServer ntfCreateAttempts nSubId <- agentNtfCreateSubscription c tknId tkn (SMPQueueNtf smpServer notifierId) ntfPrivateKey + atomically $ incNtfServerStat c userId ntfServer ntfCreated -- possible improvement: smaller retry until Active, less frequently (daily?) once Active let actionTs' = addUTCTime 30 ts withStore' c $ \db -> - updateNtfSubscription db sub {ntfSubId = Just nSubId, ntfSubStatus = NASCreated NSNew} (NtfSubNTFAction NSACheck) actionTs' + updateNtfSubscription db sub {ntfSubId = Just nSubId, ntfSubStatus = NASCreated NSNew} (NSANtf NSACheck) actionTs' _ -> workerInternalError c connId "NSACreate - no notifier queue credentials" _ -> workerInternalError c connId "NSACreate - no active token" NSACheck -> lift getNtfToken >>= \case - Just tkn -> + Just tkn@NtfToken {ntfServer} -> case ntfSubId of - Just nSubId -> + Just nSubId -> do + atomically $ incNtfServerStat c userId ntfServer ntfCheckAttempts agentNtfCheckSubscription c nSubId tkn >>= \case NSAuth -> do - lift (getNtfServer c) >>= \case - Just ntfServer -> do - withStore' c $ \db -> - updateNtfSubscription db sub {ntfServer, ntfQueueId = Nothing, ntfSubId = Nothing, ntfSubStatus = NASNew} (NtfSubSMPAction NSASmpKey) ts - ns <- asks ntfSupervisor - atomically $ writeTBQueue (ntfSubQ ns) (connId, NSCNtfSMPWorker smpServer) - _ -> workerInternalError c connId "NSACheck - failed to reset subscription, notification server not configured" + withStore' c $ \db -> + updateNtfSubscription db sub {ntfServer, ntfQueueId = Nothing, ntfSubId = Nothing, ntfSubStatus = NASNew} (NSASMP NSASmpKey) ts + ns <- asks ntfSupervisor + atomically $ writeTBQueue (ntfSubQ ns) (connId, NSCNtfSMPWorker smpServer) status -> updateSubNextCheck ts status + atomically $ incNtfServerStat c userId ntfServer ntfChecked Nothing -> workerInternalError c connId "NSACheck - no subscription ID" _ -> workerInternalError c connId "NSACheck - no active token" - NSADelete -> case ntfSubId of - Just nSubId -> - (lift getNtfToken >>= mapM_ (agentNtfDeleteSubscription c nSubId)) - `agentFinally` continueDeletion - _ -> continueDeletion - where - continueDeletion = do - let sub' = sub {ntfSubId = Nothing, ntfSubStatus = NASOff} - withStore' c $ \db -> updateNtfSubscription db sub' (NtfSubSMPAction NSASmpDelete) ts - ns <- asks ntfSupervisor - atomically $ writeTBQueue (ntfSubQ ns) (connId, NSCNtfSMPWorker smpServer) - NSARotate -> case ntfSubId of - Just nSubId -> - (lift getNtfToken >>= mapM_ (agentNtfDeleteSubscription c nSubId)) - `agentFinally` deleteCreate - _ -> deleteCreate - where - deleteCreate = do - withStore' c $ \db -> deleteNtfSubscription db connId - ns <- asks ntfSupervisor - atomically $ writeTBQueue (ntfSubQ ns) (connId, NSCCreate) + NSADelete -> + deleteNtfSub $ do + let sub' = sub {ntfSubId = Nothing, ntfSubStatus = NASOff} + withStore' c $ \db -> updateNtfSubscription db sub' (NSASMP NSASmpDelete) ts + ns <- asks ntfSupervisor + atomically $ writeTBQueue (ntfSubQ ns) (connId, NSCNtfSMPWorker smpServer) + NSARotate -> + deleteNtfSub $ do + withStore' c $ \db -> deleteNtfSubscription db connId + ns <- asks ntfSupervisor + atomically $ writeTBQueue (ntfSubQ ns) (connId, NSCCreate) where + deleteNtfSub continue = case ntfSubId of + Just nSubId -> + lift getNtfToken >>= \case + Just tkn@NtfToken {ntfServer} -> do + atomically $ incNtfServerStat c userId ntfServer ntfDelAttempts + tryAgentError (agentNtfDeleteSubscription c nSubId tkn) >>= \case + Left e | temporaryOrHostError e -> throwE e + _ -> continue + atomically $ incNtfServerStat c userId ntfServer ntfDeleted + Nothing -> continue + _ -> continue updateSubNextCheck ts toStatus = do checkInterval <- asks $ ntfSubCheckInterval . config let nextCheckTs = addUTCTime checkInterval ts - updateSub (NASCreated toStatus) (NtfSubNTFAction NSACheck) nextCheckTs + updateSub (NASCreated toStatus) (NSANtf NSACheck) nextCheckTs updateSub toStatus toAction actionTs' = withStore' c $ \db -> updateNtfSubscription db sub {ntfSubStatus = toStatus} toAction actionTs' @@ -231,12 +233,10 @@ runNtfWorker c srv Worker {doWork} = do runNtfSMPWorker :: AgentClient -> SMPServer -> Worker -> AM () runNtfSMPWorker c srv Worker {doWork} = do env <- ask - delay <- asks $ ntfSMPWorkerDelay . config forever $ do waitForWork doWork ExceptT . liftIO . agentOperationBracket c AONtfNetwork throwWhenInactive $ runReaderT (runExceptT runNtfSMPOperation) env - threadDelay delay where runNtfSMPOperation = withWork c doWork (`getNextNtfSubSMPAction` srv) $ @@ -244,6 +244,7 @@ runNtfSMPWorker c srv Worker {doWork} = do logInfo $ "runNtfSMPWorker, nextSub " <> tshow nextSub ri <- asks $ reconnectInterval . config withRetryInterval ri $ \_ loop -> do + liftIO $ waitWhileSuspended c liftIO $ waitForUserNetwork c processSub nextSub `catchAgentError` retryOnError c "NtfSMPWorker" loop (workerInternalError c connId . show) @@ -264,11 +265,12 @@ runNtfSMPWorker c srv Worker {doWork} = do let rcvNtfDhSecret = C.dh' rcvNtfSrvPubDhKey rcvNtfPrivDhKey withStore' c $ \db -> do setRcvQueueNtfCreds db connId $ Just ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret} - updateNtfSubscription db sub {ntfQueueId = Just notifierId, ntfSubStatus = NASKey} (NtfSubNTFAction NSACreate) ts + updateNtfSubscription db sub {ntfQueueId = Just notifierId, ntfSubStatus = NASKey} (NSANtf NSACreate) ts ns <- asks ntfSupervisor atomically $ sendNtfSubCommand ns (connId, NSCNtfWorker ntfServer) _ -> workerInternalError c connId "NSASmpKey - no active token" NSASmpDelete -> do + -- TODO should we remove it after successful removal from the server? rq_ <- withStore' c $ \db -> do setRcvQueueNtfCreds db connId Nothing getPrimaryRcvQueue db connId @@ -295,7 +297,7 @@ retryOnError c name loop done e = do where retryLoop = do atomically $ endAgentOperation c AONtfNetwork - atomically $ throwWhenInactive c + liftIO $ throwWhenInactive c atomically $ beginAgentOperation c AONtfNetwork loop diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index b123fc1ec..ea1d51a7d 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -42,6 +42,7 @@ module Simplex.Messaging.Agent.Protocol deliveryRcptsSMPAgentVersion, pqdrSMPAgentVersion, sndAuthKeySMPAgentVersion, + ratchetOnConfSMPAgentVersion, currentSMPAgentVersion, supportedSMPAgentVRange, e2eEncConnInfoLength, @@ -49,6 +50,7 @@ module Simplex.Messaging.Agent.Protocol -- * SMP agent protocol types ConnInfo, + SndQueueSecured, ACommand (..), AEvent (..), AEvt (..), @@ -153,8 +155,8 @@ import Data.Int (Int64) import Data.Kind (Type) import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L -import Data.Map (Map) -import qualified Data.Map as M +import Data.Map.Strict (Map) +import qualified Data.Map.Strict as M import Data.Maybe (fromMaybe, isJust) import Data.Text (Text) import Data.Text.Encoding (decodeLatin1, encodeUtf8) @@ -257,11 +259,14 @@ pqdrSMPAgentVersion = VersionSMPA 5 sndAuthKeySMPAgentVersion :: VersionSMPA sndAuthKeySMPAgentVersion = VersionSMPA 6 +ratchetOnConfSMPAgentVersion :: VersionSMPA +ratchetOnConfSMPAgentVersion = VersionSMPA 7 + minSupportedSMPAgentVersion :: VersionSMPA minSupportedSMPAgentVersion = duplexHandshakeSMPAgentVersion currentSMPAgentVersion :: VersionSMPA -currentSMPAgentVersion = VersionSMPA 6 +currentSMPAgentVersion = VersionSMPA 7 supportedSMPAgentVRange :: VersionRangeSMPA supportedSMPAgentVRange = mkVersionRange minSupportedSMPAgentVersion currentSMPAgentVersion @@ -327,6 +332,8 @@ deriving instance Show AEvt type ConnInfo = ByteString +type SndQueueSecured = Bool + -- | Parameterized type for SMP agent events data AEvent (e :: AEntity) where INV :: AConnectionRequestUri -> AEvent AEConn @@ -354,6 +361,7 @@ data AEvent (e :: AEntity) where DEL_USER :: Int64 -> AEvent AENone STAT :: ConnectionStats -> AEvent AEConn OK :: AEvent AEConn + JOINED :: SndQueueSecured -> AEvent AEConn ERR :: AgentErrorType -> AEvent AEConn SUSPENDED :: AEvent AENone RFPROG :: Int64 -> Int64 -> AEvent AERcvFile @@ -422,6 +430,7 @@ data AEventTag (e :: AEntity) where DEL_USER_ :: AEventTag AENone STAT_ :: AEventTag AEConn OK_ :: AEventTag AEConn + JOINED_ :: AEventTag AEConn ERR_ :: AEventTag AEConn SUSPENDED_ :: AEventTag AENone -- XFTP commands and responses @@ -474,6 +483,7 @@ aEventTag = \case DEL_USER _ -> DEL_USER_ STAT _ -> STAT_ OK -> OK_ + JOINED _ -> JOINED_ ERR _ -> ERR_ SUSPENDED -> SUSPENDED_ RFPROG {} -> RFPROG_ diff --git a/src/Simplex/Messaging/Agent/RetryInterval.hs b/src/Simplex/Messaging/Agent/RetryInterval.hs index 00fe4039e..35fa7c5c6 100644 --- a/src/Simplex/Messaging/Agent/RetryInterval.hs +++ b/src/Simplex/Messaging/Agent/RetryInterval.hs @@ -9,6 +9,7 @@ module Simplex.Messaging.Agent.RetryInterval RI2State (..), withRetryInterval, withRetryIntervalCount, + withRetryForeground, withRetryLock2, updateRetryInterval2, nextRetryDelay, @@ -16,10 +17,11 @@ module Simplex.Messaging.Agent.RetryInterval where import Control.Concurrent (forkIO) +import Control.Concurrent.STM (retry) import Control.Monad (void) import Control.Monad.IO.Class (MonadIO, liftIO) import Data.Int (Int64) -import Simplex.Messaging.Util (threadDelay', whenM) +import Simplex.Messaging.Util (threadDelay', unlessM, whenM) import UnliftIO.STM data RetryInterval = RetryInterval @@ -63,6 +65,27 @@ withRetryIntervalCount ri action = callAction 0 0 $ initialInterval ri let elapsed' = elapsed + delay callAction (n + 1) elapsed' $ nextRetryDelay elapsed' delay ri +withRetryForeground :: forall m a. MonadIO m => RetryInterval -> STM Bool -> STM Bool -> (Int64 -> m a -> m a) -> m a +withRetryForeground ri isForeground isOnline action = callAction 0 $ initialInterval ri + where + callAction :: Int64 -> Int64 -> m a + callAction elapsed delay = action delay loop + where + loop = do + -- limit delay to max Int value (~36 minutes on for 32 bit architectures) + d <- registerDelay $ fromIntegral $ min delay (fromIntegral (maxBound :: Int)) + (wasForeground, wasOnline) <- atomically $ (,) <$> isForeground <*> isOnline + reset <- atomically $ do + foreground <- isForeground + online <- isOnline + let reset = (not wasForeground && foreground) || (not wasOnline && online) + unlessM ((reset ||) <$> readTVar d) retry + pure reset + let (elapsed', delay') + | reset = (0, initialInterval ri) + | otherwise = (elapsed + delay, nextRetryDelay elapsed' delay ri) + callAction elapsed' delay' + -- This function allows action to toggle between slow and fast retry intervals. withRetryLock2 :: forall m. MonadIO m => RetryInterval2 -> TMVar () -> (RI2State -> (RetryIntervalMode -> m ()) -> m ()) -> m () withRetryLock2 RetryInterval2 {riSlow, riFast} lock action = diff --git a/src/Simplex/Messaging/Agent/Stats.hs b/src/Simplex/Messaging/Agent/Stats.hs index c8f81a6aa..d4663bfb1 100644 --- a/src/Simplex/Messaging/Agent/Stats.hs +++ b/src/Simplex/Messaging/Agent/Stats.hs @@ -1,16 +1,20 @@ {-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE TemplateHaskell #-} module Simplex.Messaging.Agent.Stats where +import Data.Aeson (FromJSON (..), FromJSONKey, ToJSON (..)) import qualified Data.Aeson.TH as J -import Data.Map (Map) +import Data.Int (Int64) +import Data.Map.Strict (Map) +import qualified Data.Map.Strict as M import Database.SQLite.Simple.FromField (FromField (..)) import Database.SQLite.Simple.ToField (ToField (..)) import Simplex.Messaging.Agent.Protocol (UserId) import Simplex.Messaging.Parsers (defaultJSON, fromTextField_) -import Simplex.Messaging.Protocol (SMPServer, XFTPServer) +import Simplex.Messaging.Protocol (SMPServer, XFTPServer, NtfServer) import Simplex.Messaging.Util (decodeJSON, encodeJSON) import UnliftIO.STM @@ -29,13 +33,26 @@ data AgentSMPServerStats = AgentSMPServerStats recvDuplicates :: TVar Int, -- duplicate messages received recvCryptoErrs :: TVar Int, -- message decryption errors recvErrs :: TVar Int, -- receive errors - connCreated :: TVar Int, - connSecured :: TVar Int, - connCompleted :: TVar Int, - connDeleted :: TVar Int, + ackMsgs :: TVar Int, -- total messages acknowledged + ackAttempts :: TVar Int, -- acknowledgement attempts + ackNoMsgErrs :: TVar Int, -- NO_MSG ack errors + ackOtherErrs :: TVar Int, -- other permanent ack errors (temporary accounted for in attempts) + -- conn stats are accounted for rcv queue server + connCreated :: TVar Int, -- total connections created + connSecured :: TVar Int, -- connections secured + connCompleted :: TVar Int, -- connections completed + connDeleted :: TVar Int, -- total connections deleted + connDelAttempts :: TVar Int, -- total connection deletion attempts + connDelErrs :: TVar Int, -- permanent connection deletion errors (temporary accounted for in attempts) connSubscribed :: TVar Int, -- total successful subscription connSubAttempts :: TVar Int, -- subscription attempts - connSubErrs :: TVar Int -- permanent subscription errors (temporary accounted for in attempts) + connSubIgnored :: TVar Int, -- subscription results ignored (client switched to different session or it was not pending) + connSubErrs :: TVar Int, -- permanent subscription errors (temporary accounted for in attempts) + -- notifications stats + ntfKey :: TVar Int, + ntfKeyAttempts :: TVar Int, + ntfKeyDeleted :: TVar Int, + ntfKeyDeleteAttempts :: TVar Int } data AgentSMPServerStatsData = AgentSMPServerStatsData @@ -53,16 +70,30 @@ data AgentSMPServerStatsData = AgentSMPServerStatsData _recvDuplicates :: Int, _recvCryptoErrs :: Int, _recvErrs :: Int, + _ackMsgs :: Int, + _ackAttempts :: Int, + _ackNoMsgErrs :: Int, + _ackOtherErrs :: Int, _connCreated :: Int, _connSecured :: Int, _connCompleted :: Int, _connDeleted :: Int, + _connDelAttempts :: Int, + _connDelErrs :: Int, _connSubscribed :: Int, _connSubAttempts :: Int, - _connSubErrs :: Int + _connSubIgnored :: Int, + _connSubErrs :: Int, + _ntfKey :: OptionalInt, + _ntfKeyAttempts :: OptionalInt, + _ntfKeyDeleted :: OptionalInt, + _ntfKeyDeleteAttempts :: OptionalInt } deriving (Show) +newtype OptionalInt = OInt {toInt :: Int} + deriving (Num, Show, ToJSON) + newAgentSMPServerStats :: STM AgentSMPServerStats newAgentSMPServerStats = do sentDirect <- newTVar 0 @@ -79,13 +110,24 @@ newAgentSMPServerStats = do recvDuplicates <- newTVar 0 recvCryptoErrs <- newTVar 0 recvErrs <- newTVar 0 + ackMsgs <- newTVar 0 + ackAttempts <- newTVar 0 + ackNoMsgErrs <- newTVar 0 + ackOtherErrs <- newTVar 0 connCreated <- newTVar 0 connSecured <- newTVar 0 connCompleted <- newTVar 0 connDeleted <- newTVar 0 + connDelAttempts <- newTVar 0 + connDelErrs <- newTVar 0 connSubscribed <- newTVar 0 connSubAttempts <- newTVar 0 + connSubIgnored <- newTVar 0 connSubErrs <- newTVar 0 + ntfKey <- newTVar 0 + ntfKeyAttempts <- newTVar 0 + ntfKeyDeleted <- newTVar 0 + ntfKeyDeleteAttempts <- newTVar 0 pure AgentSMPServerStats { sentDirect, @@ -102,15 +144,63 @@ newAgentSMPServerStats = do recvDuplicates, recvCryptoErrs, recvErrs, + ackMsgs, + ackAttempts, + ackNoMsgErrs, + ackOtherErrs, connCreated, connSecured, connCompleted, connDeleted, + connDelAttempts, + connDelErrs, connSubscribed, connSubAttempts, - connSubErrs + connSubIgnored, + connSubErrs, + ntfKey, + ntfKeyAttempts, + ntfKeyDeleted, + ntfKeyDeleteAttempts } +newAgentSMPServerStatsData :: AgentSMPServerStatsData +newAgentSMPServerStatsData = + AgentSMPServerStatsData + { _sentDirect = 0, + _sentViaProxy = 0, + _sentProxied = 0, + _sentDirectAttempts = 0, + _sentViaProxyAttempts = 0, + _sentProxiedAttempts = 0, + _sentAuthErrs = 0, + _sentQuotaErrs = 0, + _sentExpiredErrs = 0, + _sentOtherErrs = 0, + _recvMsgs = 0, + _recvDuplicates = 0, + _recvCryptoErrs = 0, + _recvErrs = 0, + _ackMsgs = 0, + _ackAttempts = 0, + _ackNoMsgErrs = 0, + _ackOtherErrs = 0, + _connCreated = 0, + _connSecured = 0, + _connCompleted = 0, + _connDeleted = 0, + _connDelAttempts = 0, + _connDelErrs = 0, + _connSubscribed = 0, + _connSubAttempts = 0, + _connSubIgnored = 0, + _connSubErrs = 0, + _ntfKey = 0, + _ntfKeyAttempts = 0, + _ntfKeyDeleted = 0, + _ntfKeyDeleteAttempts = 0 + } + newAgentSMPServerStats' :: AgentSMPServerStatsData -> STM AgentSMPServerStats newAgentSMPServerStats' s = do sentDirect <- newTVar $ _sentDirect s @@ -127,13 +217,24 @@ newAgentSMPServerStats' s = do recvDuplicates <- newTVar $ _recvDuplicates s recvCryptoErrs <- newTVar $ _recvCryptoErrs s recvErrs <- newTVar $ _recvErrs s + ackMsgs <- newTVar $ _ackMsgs s + ackAttempts <- newTVar $ _ackAttempts s + ackNoMsgErrs <- newTVar $ _ackNoMsgErrs s + ackOtherErrs <- newTVar $ _ackOtherErrs s connCreated <- newTVar $ _connCreated s connSecured <- newTVar $ _connSecured s connCompleted <- newTVar $ _connCompleted s connDeleted <- newTVar $ _connDeleted s + connDelAttempts <- newTVar $ _connDelAttempts s + connDelErrs <- newTVar $ _connDelErrs s connSubscribed <- newTVar $ _connSubscribed s connSubAttempts <- newTVar $ _connSubAttempts s + connSubIgnored <- newTVar $ _connSubIgnored s connSubErrs <- newTVar $ _connSubErrs s + ntfKey <- newTVar $ toInt $ _ntfKey s + ntfKeyAttempts <- newTVar $ toInt $ _ntfKeyAttempts s + ntfKeyDeleted <- newTVar $ toInt $ _ntfKeyDeleted s + ntfKeyDeleteAttempts <- newTVar $ toInt $ _ntfKeyDeleteAttempts s pure AgentSMPServerStats { sentDirect, @@ -150,13 +251,24 @@ newAgentSMPServerStats' s = do recvDuplicates, recvCryptoErrs, recvErrs, + ackMsgs, + ackAttempts, + ackNoMsgErrs, + ackOtherErrs, connCreated, connSecured, connCompleted, connDeleted, + connDelAttempts, + connDelErrs, connSubscribed, connSubAttempts, - connSubErrs + connSubIgnored, + connSubErrs, + ntfKey, + ntfKeyAttempts, + ntfKeyDeleted, + ntfKeyDeleteAttempts } -- as this is used to periodically update stats in db, @@ -177,13 +289,24 @@ getAgentSMPServerStats s = do _recvDuplicates <- readTVarIO $ recvDuplicates s _recvCryptoErrs <- readTVarIO $ recvCryptoErrs s _recvErrs <- readTVarIO $ recvErrs s + _ackMsgs <- readTVarIO $ ackMsgs s + _ackAttempts <- readTVarIO $ ackAttempts s + _ackNoMsgErrs <- readTVarIO $ ackNoMsgErrs s + _ackOtherErrs <- readTVarIO $ ackOtherErrs s _connCreated <- readTVarIO $ connCreated s _connSecured <- readTVarIO $ connSecured s _connCompleted <- readTVarIO $ connCompleted s _connDeleted <- readTVarIO $ connDeleted s + _connDelAttempts <- readTVarIO $ connDelAttempts s + _connDelErrs <- readTVarIO $ connDelErrs s _connSubscribed <- readTVarIO $ connSubscribed s _connSubAttempts <- readTVarIO $ connSubAttempts s + _connSubIgnored <- readTVarIO $ connSubIgnored s _connSubErrs <- readTVarIO $ connSubErrs s + _ntfKey <- OInt <$> readTVarIO (ntfKey s) + _ntfKeyAttempts <- OInt <$> readTVarIO (ntfKeyAttempts s) + _ntfKeyDeleted <- OInt <$> readTVarIO (ntfKeyDeleted s) + _ntfKeyDeleteAttempts <- OInt <$> readTVarIO (ntfKeyDeleteAttempts s) pure AgentSMPServerStatsData { _sentDirect, @@ -200,20 +323,70 @@ getAgentSMPServerStats s = do _recvDuplicates, _recvCryptoErrs, _recvErrs, + _ackMsgs, + _ackAttempts, + _ackNoMsgErrs, + _ackOtherErrs, _connCreated, _connSecured, _connCompleted, _connDeleted, + _connDelAttempts, + _connDelErrs, _connSubscribed, _connSubAttempts, - _connSubErrs + _connSubIgnored, + _connSubErrs, + _ntfKey, + _ntfKeyAttempts, + _ntfKeyDeleted, + _ntfKeyDeleteAttempts } +addSMPStatsData :: AgentSMPServerStatsData -> AgentSMPServerStatsData -> AgentSMPServerStatsData +addSMPStatsData sd1 sd2 = + AgentSMPServerStatsData + { _sentDirect = _sentDirect sd1 + _sentDirect sd2, + _sentViaProxy = _sentViaProxy sd1 + _sentViaProxy sd2, + _sentProxied = _sentProxied sd1 + _sentProxied sd2, + _sentDirectAttempts = _sentDirectAttempts sd1 + _sentDirectAttempts sd2, + _sentViaProxyAttempts = _sentViaProxyAttempts sd1 + _sentViaProxyAttempts sd2, + _sentProxiedAttempts = _sentProxiedAttempts sd1 + _sentProxiedAttempts sd2, + _sentAuthErrs = _sentAuthErrs sd1 + _sentAuthErrs sd2, + _sentQuotaErrs = _sentQuotaErrs sd1 + _sentQuotaErrs sd2, + _sentExpiredErrs = _sentExpiredErrs sd1 + _sentExpiredErrs sd2, + _sentOtherErrs = _sentOtherErrs sd1 + _sentOtherErrs sd2, + _recvMsgs = _recvMsgs sd1 + _recvMsgs sd2, + _recvDuplicates = _recvDuplicates sd1 + _recvDuplicates sd2, + _recvCryptoErrs = _recvCryptoErrs sd1 + _recvCryptoErrs sd2, + _recvErrs = _recvErrs sd1 + _recvErrs sd2, + _ackMsgs = _ackMsgs sd1 + _ackMsgs sd2, + _ackAttempts = _ackAttempts sd1 + _ackAttempts sd2, + _ackNoMsgErrs = _ackNoMsgErrs sd1 + _ackNoMsgErrs sd2, + _ackOtherErrs = _ackOtherErrs sd1 + _ackOtherErrs sd2, + _connCreated = _connCreated sd1 + _connCreated sd2, + _connSecured = _connSecured sd1 + _connSecured sd2, + _connCompleted = _connCompleted sd1 + _connCompleted sd2, + _connDeleted = _connDeleted sd1 + _connDeleted sd2, + _connDelAttempts = _connDelAttempts sd1 + _connDelAttempts sd2, + _connDelErrs = _connDelErrs sd1 + _connDelErrs sd2, + _connSubscribed = _connSubscribed sd1 + _connSubscribed sd2, + _connSubAttempts = _connSubAttempts sd1 + _connSubAttempts sd2, + _connSubIgnored = _connSubIgnored sd1 + _connSubIgnored sd2, + _connSubErrs = _connSubErrs sd1 + _connSubErrs sd2, + _ntfKey = _ntfKey sd1 + _ntfKey sd2, + _ntfKeyAttempts = _ntfKeyAttempts sd1 + _ntfKeyAttempts sd2, + _ntfKeyDeleted = _ntfKeyDeleted sd1 + _ntfKeyDeleted sd2, + _ntfKeyDeleteAttempts = _ntfKeyDeleteAttempts sd1 + _ntfKeyDeleteAttempts sd2 + } + data AgentXFTPServerStats = AgentXFTPServerStats { uploads :: TVar Int, -- total replicas uploaded to server + uploadsSize :: TVar Int64, -- total size of uploaded replicas in KB uploadAttempts :: TVar Int, -- upload attempts uploadErrs :: TVar Int, -- upload errors downloads :: TVar Int, -- total replicas downloaded from server + downloadsSize :: TVar Int64, -- total size of downloaded replicas in KB downloadAttempts :: TVar Int, -- download attempts downloadAuthErrs :: TVar Int, -- download AUTH errors downloadErrs :: TVar Int, -- other download errors (excluding above) @@ -224,9 +397,11 @@ data AgentXFTPServerStats = AgentXFTPServerStats data AgentXFTPServerStatsData = AgentXFTPServerStatsData { _uploads :: Int, + _uploadsSize :: Int64, _uploadAttempts :: Int, _uploadErrs :: Int, _downloads :: Int, + _downloadsSize :: Int64, _downloadAttempts :: Int, _downloadAuthErrs :: Int, _downloadErrs :: Int, @@ -239,9 +414,11 @@ data AgentXFTPServerStatsData = AgentXFTPServerStatsData newAgentXFTPServerStats :: STM AgentXFTPServerStats newAgentXFTPServerStats = do uploads <- newTVar 0 + uploadsSize <- newTVar 0 uploadAttempts <- newTVar 0 uploadErrs <- newTVar 0 downloads <- newTVar 0 + downloadsSize <- newTVar 0 downloadAttempts <- newTVar 0 downloadAuthErrs <- newTVar 0 downloadErrs <- newTVar 0 @@ -251,9 +428,11 @@ newAgentXFTPServerStats = do pure AgentXFTPServerStats { uploads, + uploadsSize, uploadAttempts, uploadErrs, downloads, + downloadsSize, downloadAttempts, downloadAuthErrs, downloadErrs, @@ -262,12 +441,31 @@ newAgentXFTPServerStats = do deleteErrs } +newAgentXFTPServerStatsData :: AgentXFTPServerStatsData +newAgentXFTPServerStatsData = + AgentXFTPServerStatsData + { _uploads = 0, + _uploadsSize = 0, + _uploadAttempts = 0, + _uploadErrs = 0, + _downloads = 0, + _downloadsSize = 0, + _downloadAttempts = 0, + _downloadAuthErrs = 0, + _downloadErrs = 0, + _deletions = 0, + _deleteAttempts = 0, + _deleteErrs = 0 + } + newAgentXFTPServerStats' :: AgentXFTPServerStatsData -> STM AgentXFTPServerStats newAgentXFTPServerStats' s = do uploads <- newTVar $ _uploads s + uploadsSize <- newTVar $ _uploadsSize s uploadAttempts <- newTVar $ _uploadAttempts s uploadErrs <- newTVar $ _uploadErrs s downloads <- newTVar $ _downloads s + downloadsSize <- newTVar $ _downloadsSize s downloadAttempts <- newTVar $ _downloadAttempts s downloadAuthErrs <- newTVar $ _downloadAuthErrs s downloadErrs <- newTVar $ _downloadErrs s @@ -277,9 +475,11 @@ newAgentXFTPServerStats' s = do pure AgentXFTPServerStats { uploads, + uploadsSize, uploadAttempts, uploadErrs, downloads, + downloadsSize, downloadAttempts, downloadAuthErrs, downloadErrs, @@ -293,9 +493,11 @@ newAgentXFTPServerStats' s = do getAgentXFTPServerStats :: AgentXFTPServerStats -> IO AgentXFTPServerStatsData getAgentXFTPServerStats s = do _uploads <- readTVarIO $ uploads s + _uploadsSize <- readTVarIO $ uploadsSize s _uploadAttempts <- readTVarIO $ uploadAttempts s _uploadErrs <- readTVarIO $ uploadErrs s _downloads <- readTVarIO $ downloads s + _downloadsSize <- readTVarIO $ downloadsSize s _downloadAttempts <- readTVarIO $ downloadAttempts s _downloadAuthErrs <- readTVarIO $ downloadAuthErrs s _downloadErrs <- readTVarIO $ downloadErrs s @@ -305,9 +507,11 @@ getAgentXFTPServerStats s = do pure AgentXFTPServerStatsData { _uploads, + _uploadsSize, _uploadAttempts, _uploadErrs, _downloads, + _downloadsSize, _downloadAttempts, _downloadAuthErrs, _downloadErrs, @@ -316,18 +520,144 @@ getAgentXFTPServerStats s = do _deleteErrs } +addXFTPStatsData :: AgentXFTPServerStatsData -> AgentXFTPServerStatsData -> AgentXFTPServerStatsData +addXFTPStatsData sd1 sd2 = + AgentXFTPServerStatsData + { _uploads = _uploads sd1 + _uploads sd2, + _uploadsSize = _uploadsSize sd1 + _uploadsSize sd2, + _uploadAttempts = _uploadAttempts sd1 + _uploadAttempts sd2, + _uploadErrs = _uploadErrs sd1 + _uploadErrs sd2, + _downloads = _downloads sd1 + _downloads sd2, + _downloadsSize = _downloadsSize sd1 + _downloadsSize sd2, + _downloadAttempts = _downloadAttempts sd1 + _downloadAttempts sd2, + _downloadAuthErrs = _downloadAuthErrs sd1 + _downloadAuthErrs sd2, + _downloadErrs = _downloadErrs sd1 + _downloadErrs sd2, + _deletions = _deletions sd1 + _deletions sd2, + _deleteAttempts = _deleteAttempts sd1 + _deleteAttempts sd2, + _deleteErrs = _deleteErrs sd1 + _deleteErrs sd2 + } + +data AgentNtfServerStats = AgentNtfServerStats + { ntfCreated :: TVar Int, + ntfCreateAttempts :: TVar Int, + ntfChecked :: TVar Int, + ntfCheckAttempts :: TVar Int, + ntfDeleted :: TVar Int, + ntfDelAttempts :: TVar Int + } + +data AgentNtfServerStatsData = AgentNtfServerStatsData + { _ntfCreated :: Int, + _ntfCreateAttempts :: Int, + _ntfChecked :: Int, + _ntfCheckAttempts :: Int, + _ntfDeleted :: Int, + _ntfDelAttempts :: Int + } + deriving (Show) + +newAgentNtfServerStats :: STM AgentNtfServerStats +newAgentNtfServerStats = do + ntfCreated <- newTVar 0 + ntfCreateAttempts <- newTVar 0 + ntfChecked <- newTVar 0 + ntfCheckAttempts <- newTVar 0 + ntfDeleted <- newTVar 0 + ntfDelAttempts <- newTVar 0 + pure + AgentNtfServerStats + { ntfCreated, + ntfCreateAttempts, + ntfChecked, + ntfCheckAttempts, + ntfDeleted, + ntfDelAttempts + } + +newAgentNtfServerStatsData :: AgentNtfServerStatsData +newAgentNtfServerStatsData = + AgentNtfServerStatsData + { _ntfCreated = 0, + _ntfCreateAttempts = 0, + _ntfChecked = 0, + _ntfCheckAttempts = 0, + _ntfDeleted = 0, + _ntfDelAttempts = 0 + } + +newAgentNtfServerStats' :: AgentNtfServerStatsData -> STM AgentNtfServerStats +newAgentNtfServerStats' s = do + ntfCreated <- newTVar $ _ntfCreated s + ntfCreateAttempts <- newTVar $ _ntfCreateAttempts s + ntfChecked <- newTVar $ _ntfChecked s + ntfCheckAttempts <- newTVar $ _ntfCheckAttempts s + ntfDeleted <- newTVar $ _ntfDeleted s + ntfDelAttempts <- newTVar $ _ntfDelAttempts s + pure + AgentNtfServerStats + { ntfCreated, + ntfCreateAttempts, + ntfChecked, + ntfCheckAttempts, + ntfDeleted, + ntfDelAttempts + } + +getAgentNtfServerStats :: AgentNtfServerStats -> IO AgentNtfServerStatsData +getAgentNtfServerStats s = do + _ntfCreated <- readTVarIO $ ntfCreated s + _ntfCreateAttempts <- readTVarIO $ ntfCreateAttempts s + _ntfChecked <- readTVarIO $ ntfChecked s + _ntfCheckAttempts <- readTVarIO $ ntfCheckAttempts s + _ntfDeleted <- readTVarIO $ ntfDeleted s + _ntfDelAttempts <- readTVarIO $ ntfDelAttempts s + pure + AgentNtfServerStatsData + { _ntfCreated, + _ntfCreateAttempts, + _ntfChecked, + _ntfCheckAttempts, + _ntfDeleted, + _ntfDelAttempts + } + +addNtfStatsData :: AgentNtfServerStatsData -> AgentNtfServerStatsData -> AgentNtfServerStatsData +addNtfStatsData sd1 sd2 = + AgentNtfServerStatsData + { _ntfCreated = _ntfCreated sd1 + _ntfCreated sd2, + _ntfCreateAttempts = _ntfCreateAttempts sd1 + _ntfCreateAttempts sd2, + _ntfChecked = _ntfChecked sd1 + _ntfChecked sd2, + _ntfCheckAttempts = _ntfCheckAttempts sd1 + _ntfCheckAttempts sd2, + _ntfDeleted = _ntfDeleted sd1 + _ntfDeleted sd2, + _ntfDelAttempts = _ntfDelAttempts sd1 + _ntfDelAttempts sd2 + } + -- Type for gathering both smp and xftp stats across all users and servers, -- to then be persisted to db as a single json. data AgentPersistedServerStats = AgentPersistedServerStats { smpServersStats :: Map (UserId, SMPServer) AgentSMPServerStatsData, - xftpServersStats :: Map (UserId, XFTPServer) AgentXFTPServerStatsData + xftpServersStats :: Map (UserId, XFTPServer) AgentXFTPServerStatsData, + ntfServersStats :: OptionalMap (UserId, NtfServer) AgentNtfServerStatsData } deriving (Show) +instance FromJSON OptionalInt where + parseJSON v = OInt <$> parseJSON v + omittedField = Just (OInt 0) + +newtype OptionalMap k v = OptionalMap (Map k v) + deriving (Show, ToJSON) + +instance (FromJSONKey k, Ord k, FromJSON v) => FromJSON (OptionalMap k v) where + parseJSON v = OptionalMap <$> parseJSON v + omittedField = Just (OptionalMap M.empty) + $(J.deriveJSON defaultJSON ''AgentSMPServerStatsData) $(J.deriveJSON defaultJSON ''AgentXFTPServerStatsData) +$(J.deriveJSON defaultJSON ''AgentNtfServerStatsData) + $(J.deriveJSON defaultJSON ''AgentPersistedServerStats) instance ToField AgentPersistedServerStats where diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index baec2ef93..ae010a884 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -175,6 +175,12 @@ instance SMPQueue RcvQueue where queueId RcvQueue {rcvId} = rcvId {-# INLINE queueId #-} +instance SMPQueue NewRcvQueue where + qServer RcvQueue {server} = server + {-# INLINE qServer #-} + queueId RcvQueue {rcvId} = rcvId + {-# INLINE queueId #-} + instance SMPQueue SndQueue where qServer SndQueue {server} = server {-# INLINE qServer #-} diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index d4cd99b39..20f382d40 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -220,7 +220,7 @@ module Simplex.Messaging.Agent.Store.SQLite -- * utilities withConnection, withTransaction, - withTransactionCtx, + withTransactionPriority, firstRow, firstRow', maybeFirstRow, @@ -392,10 +392,10 @@ connectSQLiteStore dbFilePath key keepKey = do dbNew <- not <$> doesFileExist dbFilePath dbConn <- dbBusyLoop (connectDB dbFilePath key) dbConnection <- newMVar dbConn - atomically $ do - dbKey <- newTVar $! storeKey key keepKey - dbClosed <- newTVar False - pure SQLiteStore {dbFilePath, dbKey, dbConnection, dbNew, dbClosed} + dbKey <- newTVarIO $! storeKey key keepKey + dbClosed <- newTVarIO False + dbSem <- newTVarIO 0 + pure SQLiteStore {dbFilePath, dbKey, dbSem, dbConnection, dbNew, dbClosed} connectDB :: FilePath -> ScrubbedBytes -> IO DB.Connection connectDB path key = do @@ -1457,23 +1457,24 @@ getNtfSubscription db connId = DB.query db [sql| - SELECT s.host, s.port, COALESCE(nsb.smp_server_key_hash, s.key_hash), ns.ntf_host, ns.ntf_port, ns.ntf_key_hash, + SELECT c.user_id, s.host, s.port, COALESCE(nsb.smp_server_key_hash, s.key_hash), ns.ntf_host, ns.ntf_port, ns.ntf_key_hash, nsb.smp_ntf_id, nsb.ntf_sub_id, nsb.ntf_sub_status, nsb.ntf_sub_action, nsb.ntf_sub_smp_action, nsb.ntf_sub_action_ts FROM ntf_subscriptions nsb + JOIN connections c USING (conn_id) JOIN servers s ON s.host = nsb.smp_host AND s.port = nsb.smp_port JOIN ntf_servers ns USING (ntf_host, ntf_port) WHERE nsb.conn_id = ? |] (Only connId) where - ntfSubscription (smpHost, smpPort, smpKeyHash, ntfHost, ntfPort, ntfKeyHash, ntfQueueId, ntfSubId, ntfSubStatus, ntfAction_, smpAction_, actionTs_) = + ntfSubscription ((userId, smpHost, smpPort, smpKeyHash, ntfHost, ntfPort, ntfKeyHash ) :. (ntfQueueId, ntfSubId, ntfSubStatus, ntfAction_, smpAction_, actionTs_)) = let smpServer = SMPServer smpHost smpPort smpKeyHash ntfServer = NtfServer ntfHost ntfPort ntfKeyHash action = case (ntfAction_, smpAction_, actionTs_) of - (Just ntfAction, Nothing, Just actionTs) -> Just (NtfSubNTFAction ntfAction, actionTs) - (Nothing, Just smpAction, Just actionTs) -> Just (NtfSubSMPAction smpAction, actionTs) + (Just ntfAction, Nothing, Just actionTs) -> Just (NSANtf ntfAction, actionTs) + (Nothing, Just smpAction, Just actionTs) -> Just (NSASMP smpAction, actionTs) _ -> Nothing - in (NtfSubscription {connId, smpServer, ntfQueueId, ntfServer, ntfSubId, ntfSubStatus}, action) + in (NtfSubscription {userId, connId, smpServer, ntfQueueId, ntfServer, ntfSubId, ntfSubStatus}, action) createNtfSubscription :: DB.Connection -> NtfSubscription -> NtfSubAction -> IO (Either StoreError ()) createNtfSubscription db ntfSubscription action = runExceptT $ do @@ -1607,18 +1608,19 @@ getNextNtfSubNTFAction db ntfServer@(NtfServer ntfHost ntfPort _) = DB.query db [sql| - SELECT s.host, s.port, COALESCE(ns.smp_server_key_hash, s.key_hash), + SELECT c.user_id, s.host, s.port, COALESCE(ns.smp_server_key_hash, s.key_hash), ns.smp_ntf_id, ns.ntf_sub_id, ns.ntf_sub_status, ns.ntf_sub_action_ts, ns.ntf_sub_action FROM ntf_subscriptions ns + JOIN connections c USING (conn_id) JOIN servers s ON s.host = ns.smp_host AND s.port = ns.smp_port WHERE ns.conn_id = ? |] (Only connId) where err = SEInternal $ "ntf subscription " <> bshow connId <> " returned []" - ntfSubAction (smpHost, smpPort, smpKeyHash, ntfQueueId, ntfSubId, ntfSubStatus, actionTs, action) = + ntfSubAction (userId, smpHost, smpPort, smpKeyHash, ntfQueueId, ntfSubId, ntfSubStatus, actionTs, action) = let smpServer = SMPServer smpHost smpPort smpKeyHash - ntfSubscription = NtfSubscription {connId, smpServer, ntfQueueId, ntfServer, ntfSubId, ntfSubStatus} + ntfSubscription = NtfSubscription {userId, connId, smpServer, ntfQueueId, ntfServer, ntfSubId, ntfSubStatus} in (ntfSubscription, action, actionTs) markNtfSubActionNtfFailed_ :: DB.Connection -> ConnId -> IO () @@ -1650,18 +1652,19 @@ getNextNtfSubSMPAction db smpServer@(SMPServer smpHost smpPort _) = DB.query db [sql| - SELECT s.ntf_host, s.ntf_port, s.ntf_key_hash, + SELECT c.user_id, s.ntf_host, s.ntf_port, s.ntf_key_hash, ns.smp_ntf_id, ns.ntf_sub_id, ns.ntf_sub_status, ns.ntf_sub_action_ts, ns.ntf_sub_smp_action FROM ntf_subscriptions ns + JOIN connections c USING (conn_id) JOIN ntf_servers s USING (ntf_host, ntf_port) WHERE ns.conn_id = ? |] (Only connId) where err = SEInternal $ "ntf subscription " <> bshow connId <> " returned []" - ntfSubAction (ntfHost, ntfPort, ntfKeyHash, ntfQueueId, ntfSubId, ntfSubStatus, actionTs, action) = + ntfSubAction (userId, ntfHost, ntfPort, ntfKeyHash, ntfQueueId, ntfSubId, ntfSubStatus, actionTs, action) = let ntfServer = NtfServer ntfHost ntfPort ntfKeyHash - ntfSubscription = NtfSubscription {connId, smpServer, ntfQueueId, ntfServer, ntfSubId, ntfSubStatus} + ntfSubscription = NtfSubscription {userId, connId, smpServer, ntfQueueId, ntfServer, ntfSubId, ntfSubStatus} in (ntfSubscription, action, actionTs) markNtfSubActionSMPFailed_ :: DB.Connection -> ConnId -> IO () @@ -2272,8 +2275,8 @@ randomId :: TVar ChaChaDRG -> Int -> IO ByteString randomId gVar n = atomically $ U.encode <$> C.randomBytes n gVar ntfSubAndSMPAction :: NtfSubAction -> (Maybe NtfSubNTFAction, Maybe NtfSubSMPAction) -ntfSubAndSMPAction (NtfSubNTFAction action) = (Just action, Nothing) -ntfSubAndSMPAction (NtfSubSMPAction action) = (Nothing, Just action) +ntfSubAndSMPAction (NSANtf action) = (Just action, Nothing) +ntfSubAndSMPAction (NSASMP action) = (Nothing, Just action) createXFTPServer_ :: DB.Connection -> XFTPServer -> IO Int64 createXFTPServer_ db newSrv@ProtocolServer {host, port, keyHash} = @@ -3041,10 +3044,9 @@ getServersStats db = firstRow id SEServersStatsNotFound $ DB.query_ db "SELECT started_at, servers_stats FROM servers_stats WHERE servers_stats_id = 1" -resetServersStats :: DB.Connection -> IO () -resetServersStats db = do - currentTs <- getCurrentTime - DB.execute db "UPDATE servers_stats SET servers_stats = NULL, started_at = ?, updated_at = ? WHERE servers_stats_id = 1" (currentTs, currentTs) +resetServersStats :: DB.Connection -> UTCTime -> IO () +resetServersStats db startedAt = + DB.execute db "UPDATE servers_stats SET servers_stats = NULL, started_at = ?, updated_at = ? WHERE servers_stats_id = 1" (startedAt, startedAt) $(J.deriveJSON defaultJSON ''UpMigration) diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs index b9a9bd501..a7ad47f37 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -8,20 +9,20 @@ module Simplex.Messaging.Agent.Store.SQLite.Common withConnection', withTransaction, withTransaction', - withTransactionCtx, + withTransactionPriority, dbBusyLoop, storeKey, ) where import Control.Concurrent (threadDelay) +import Control.Concurrent.STM (retry) import Data.ByteArray (ScrubbedBytes) import qualified Data.ByteArray as BA -import Data.Time.Clock (diffUTCTime, getCurrentTime) import Database.SQLite.Simple (SQLError) import qualified Database.SQLite.Simple as SQL import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB -import Simplex.Messaging.Util (diffToMilliseconds) +import Simplex.Messaging.Util (ifM, unlessM) import qualified UnliftIO.Exception as E import UnliftIO.MVar import UnliftIO.STM @@ -32,35 +33,40 @@ storeKey key keepKey = if keepKey || BA.null key then Just key else Nothing data SQLiteStore = SQLiteStore { dbFilePath :: FilePath, dbKey :: TVar (Maybe ScrubbedBytes), + dbSem :: TVar Int, dbConnection :: MVar DB.Connection, dbClosed :: TVar Bool, dbNew :: Bool } +withConnectionPriority :: SQLiteStore -> Bool -> (DB.Connection -> IO a) -> IO a +withConnectionPriority SQLiteStore {dbSem, dbConnection} priority action + | priority = E.bracket_ signal release $ withMVar dbConnection action + | otherwise = lowPriority + where + lowPriority = wait >> withMVar dbConnection (\db -> ifM free (Just <$> action db) (pure Nothing)) >>= maybe lowPriority pure + signal = atomically $ modifyTVar' dbSem (+ 1) + release = atomically $ modifyTVar' dbSem $ \sem -> if sem > 0 then sem - 1 else 0 + wait = unlessM free $ atomically $ unlessM ((0 ==) <$> readTVar dbSem) retry + free = (0 ==) <$> readTVarIO dbSem + withConnection :: SQLiteStore -> (DB.Connection -> IO a) -> IO a -withConnection SQLiteStore {dbConnection} = withMVar dbConnection +withConnection st = withConnectionPriority st False withConnection' :: SQLiteStore -> (SQL.Connection -> IO a) -> IO a withConnection' st action = withConnection st $ action . DB.conn -withTransaction :: SQLiteStore -> (DB.Connection -> IO a) -> IO a -withTransaction = withTransactionCtx Nothing - withTransaction' :: SQLiteStore -> (SQL.Connection -> IO a) -> IO a withTransaction' st action = withTransaction st $ action . DB.conn -withTransactionCtx :: Maybe String -> SQLiteStore -> (DB.Connection -> IO a) -> IO a -withTransactionCtx ctx_ st action = withConnection st $ dbBusyLoop . transactionWithCtx +withTransaction :: SQLiteStore -> (DB.Connection -> IO a) -> IO a +withTransaction st = withTransactionPriority st False +{-# INLINE withTransaction #-} + +withTransactionPriority :: SQLiteStore -> Bool -> (DB.Connection -> IO a) -> IO a +withTransactionPriority st priority action = withConnectionPriority st priority $ dbBusyLoop . transaction where - transactionWithCtx db@DB.Connection {conn} = case ctx_ of - Nothing -> SQL.withImmediateTransaction conn $ action db - Just ctx -> do - t1 <- getCurrentTime - r <- SQL.withImmediateTransaction conn $ action db - t2 <- getCurrentTime - putStrLn $ "withTransactionCtx start :: " <> show t1 <> " :: " <> ctx - putStrLn $ "withTransactionCtx end :: " <> show t2 <> " :: " <> ctx <> " :: duration=" <> show (diffToMilliseconds $ diffUTCTime t2 t1) - pure r + transaction db@DB.Connection {conn} = SQL.withImmediateTransaction conn $ action db dbBusyLoop :: forall a. IO a -> IO a dbBusyLoop action = loop 500 3000000 diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/DB.hs b/src/Simplex/Messaging/Agent/Store/SQLite/DB.hs index 2ae4eb731..b356b3f87 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/DB.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/DB.hs @@ -64,7 +64,7 @@ timeIt slow sql a = do open :: String -> IO Connection open f = do conn <- SQL.open f - slow <- atomically $ TM.empty + slow <- TM.emptyIO pure Connection {conn, slow} close :: Connection -> IO () diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs index 2279d7ea5..131561f4d 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs @@ -73,6 +73,7 @@ import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240223_connections_wai import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240225_ratchet_kem import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240417_rcv_files_approved_relays import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240624_snd_secure +import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240702_servers_stats import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (dropPrefix, sumTypeJSON) import Simplex.Messaging.Transport.Client (TransportHost) @@ -114,7 +115,8 @@ schemaMigrations = ("m20240223_connections_wait_delivery", m20240223_connections_wait_delivery, Just down_m20240223_connections_wait_delivery), ("m20240225_ratchet_kem", m20240225_ratchet_kem, Just down_m20240225_ratchet_kem), ("m20240417_rcv_files_approved_relays", m20240417_rcv_files_approved_relays, Just down_m20240417_rcv_files_approved_relays), - ("m20240624_snd_secure", m20240624_snd_secure, Just down_m20240624_snd_secure) + ("m20240624_snd_secure", m20240624_snd_secure, Just down_m20240624_snd_secure), + ("m20240702_servers_stats", m20240702_servers_stats, Just down_m20240702_servers_stats) ] -- | The list of migrations in ascending order by date diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20240518_servers_stats.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20240702_servers_stats.hs similarity index 79% rename from src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20240518_servers_stats.hs rename to src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20240702_servers_stats.hs index fe017e233..5e283d8b1 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20240518_servers_stats.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20240702_servers_stats.hs @@ -1,6 +1,6 @@ {-# LANGUAGE QuasiQuotes #-} -module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240518_servers_stats where +module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240702_servers_stats where import Database.SQLite.Simple (Query) import Database.SQLite.Simple.QQ (sql) @@ -8,8 +8,8 @@ import Database.SQLite.Simple.QQ (sql) -- servers_stats_id: dummy id, there should always only be one record with servers_stats_id = 1 -- servers_stats: overall accumulated stats, past and session, reset to null on stats reset -- started_at: starting point of tracking stats, reset on stats reset -m20240518_servers_stats :: Query -m20240518_servers_stats = +m20240702_servers_stats :: Query +m20240702_servers_stats = [sql| CREATE TABLE servers_stats( servers_stats_id INTEGER PRIMARY KEY, @@ -22,8 +22,8 @@ CREATE TABLE servers_stats( INSERT INTO servers_stats (servers_stats_id) VALUES (1); |] -down_m20240518_servers_stats :: Query -down_m20240518_servers_stats = +down_m20240702_servers_stats :: Query +down_m20240702_servers_stats = [sql| DROP TABLE servers_stats; |] diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql index b9d2d945f..80af08989 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql @@ -396,6 +396,13 @@ CREATE TABLE processed_ratchet_key_hashes( created_at TEXT NOT NULL DEFAULT(datetime('now')), updated_at TEXT NOT NULL DEFAULT(datetime('now')) ); +CREATE TABLE servers_stats( + servers_stats_id INTEGER PRIMARY KEY, + servers_stats TEXT, + started_at TEXT NOT NULL DEFAULT(datetime('now')), + created_at TEXT NOT NULL DEFAULT(datetime('now')), + updated_at TEXT NOT NULL DEFAULT(datetime('now')) +); CREATE UNIQUE INDEX idx_rcv_queues_ntf ON rcv_queues(host, port, ntf_id); CREATE UNIQUE INDEX idx_rcv_queue_id ON rcv_queues(conn_id, rcv_queue_id); CREATE UNIQUE INDEX idx_snd_queue_id ON snd_queues(conn_id, snd_queue_id); diff --git a/src/Simplex/Messaging/Agent/TRcvQueues.hs b/src/Simplex/Messaging/Agent/TRcvQueues.hs index 9ffe325b2..3b02f64ae 100644 --- a/src/Simplex/Messaging/Agent/TRcvQueues.hs +++ b/src/Simplex/Messaging/Agent/TRcvQueues.hs @@ -1,7 +1,9 @@ +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE LambdaCase #-} module Simplex.Messaging.Agent.TRcvQueues ( TRcvQueues (getRcvQueues, getConnections), + Queue (..), empty, clear, deleteConn, @@ -9,9 +11,9 @@ module Simplex.Messaging.Agent.TRcvQueues addQueue, batchAddQueues, deleteQueue, + hasSessQueues, getSessQueues, getDelSessQueues, - qKey, ) where @@ -25,46 +27,51 @@ import Simplex.Messaging.Agent.Store (RcvQueue, StoredRcvQueue (..)) import Simplex.Messaging.Protocol (RecipientId, SMPServer) import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM +import Simplex.Messaging.Transport + +class Queue q where + connId' :: q -> ConnId + qKey :: q -> (UserId, SMPServer, ConnId) -- the fields in this record have the same data with swapped keys for lookup efficiency, -- and all methods must maintain this invariant. -data TRcvQueues = TRcvQueues - { getRcvQueues :: TMap (UserId, SMPServer, RecipientId) RcvQueue, +data TRcvQueues q = TRcvQueues + { getRcvQueues :: TMap (UserId, SMPServer, RecipientId) q, getConnections :: TMap ConnId (NonEmpty (UserId, SMPServer, RecipientId)) } -empty :: STM TRcvQueues -empty = TRcvQueues <$> TM.empty <*> TM.empty +empty :: IO (TRcvQueues q) +empty = TRcvQueues <$> TM.emptyIO <*> TM.emptyIO -clear :: TRcvQueues -> STM () +clear :: TRcvQueues q -> STM () clear (TRcvQueues qs cs) = TM.clear qs >> TM.clear cs -deleteConn :: ConnId -> TRcvQueues -> STM () +deleteConn :: ConnId -> TRcvQueues q -> STM () deleteConn cId (TRcvQueues qs cs) = TM.lookupDelete cId cs >>= \case Just ks -> modifyTVar' qs $ \qs' -> foldl' (flip M.delete) qs' ks Nothing -> pure () -hasConn :: ConnId -> TRcvQueues -> STM Bool +hasConn :: ConnId -> TRcvQueues q -> STM Bool hasConn cId (TRcvQueues _ cs) = TM.member cId cs -addQueue :: RcvQueue -> TRcvQueues -> STM () +addQueue :: Queue q => q -> TRcvQueues q -> STM () addQueue rq (TRcvQueues qs cs) = do TM.insert k rq qs - TM.alter addQ (connId rq) cs + TM.alter addQ (connId' rq) cs where addQ = Just . maybe (k :| []) (k <|) k = qKey rq -- Save time by aggregating modifyTVar -batchAddQueues :: Foldable t => TRcvQueues -> t RcvQueue -> STM () +batchAddQueues :: (Foldable t, Queue q) => TRcvQueues q -> t q -> STM () batchAddQueues (TRcvQueues qs cs) rqs = do modifyTVar' qs $ \now -> foldl' (\rqs' rq -> M.insert (qKey rq) rq rqs') now rqs - modifyTVar' cs $ \now -> foldl' (\cs' rq -> M.alter (addQ $ qKey rq) (connId rq) cs') now rqs + modifyTVar' cs $ \now -> foldl' (\cs' rq -> M.alter (addQ $ qKey rq) (connId' rq) cs') now rqs where addQ k = Just . maybe (k :| []) (k <|) -deleteQueue :: RcvQueue -> TRcvQueues -> STM () +deleteQueue :: RcvQueue -> TRcvQueues RcvQueue -> STM () deleteQueue rq (TRcvQueues qs cs) = do TM.delete k qs TM.update delQ (connId rq) cs @@ -72,21 +79,25 @@ deleteQueue rq (TRcvQueues qs cs) = do delQ = L.nonEmpty . L.filter (/= k) k = qKey rq -getSessQueues :: (UserId, SMPServer, Maybe ConnId) -> TRcvQueues -> STM [RcvQueue] -getSessQueues tSess (TRcvQueues qs _) = M.foldl' addQ [] <$> readTVar qs +hasSessQueues :: (UserId, SMPServer, Maybe ConnId) -> TRcvQueues RcvQueue -> STM Bool +hasSessQueues tSess (TRcvQueues qs _) = any (`isSession` tSess) <$> readTVar qs + +getSessQueues :: (UserId, SMPServer, Maybe ConnId) -> TRcvQueues RcvQueue -> IO [RcvQueue] +getSessQueues tSess (TRcvQueues qs _) = M.foldl' addQ [] <$> readTVarIO qs where addQ qs' rq = if rq `isSession` tSess then rq : qs' else qs' -getDelSessQueues :: (UserId, SMPServer, Maybe ConnId) -> TRcvQueues -> STM ([RcvQueue], [ConnId]) -getDelSessQueues tSess (TRcvQueues qs cs) = do +getDelSessQueues :: (UserId, SMPServer, Maybe ConnId) -> SessionId -> TRcvQueues (SessionId, RcvQueue) -> STM ([RcvQueue], [ConnId]) +getDelSessQueues tSess sessId' (TRcvQueues qs cs) = do (removedQs, qs'') <- (\qs' -> M.foldl' delQ ([], qs') qs') <$> readTVar qs writeTVar qs $! qs'' removedConns <- stateTVar cs $ \cs' -> foldl' delConn ([], cs') removedQs pure (removedQs, removedConns) where - delQ acc@(removed, qs') rq - | rq `isSession` tSess = (rq : removed, M.delete (qKey rq) qs') + delQ acc@(removed, qs') (sessId, rq) + | rq `isSession` tSess && sessId == sessId' = (rq : removed, M.delete (qKey rq) qs') | otherwise = acc + delConn :: ([ConnId], M.Map ConnId (NonEmpty (UserId, SMPServer, ConnId))) -> RcvQueue -> ([ConnId], M.Map ConnId (NonEmpty (UserId, SMPServer, ConnId))) delConn (removed, cs') rq = M.alterF f cId cs' where cId = connId rq @@ -100,5 +111,10 @@ isSession :: RcvQueue -> (UserId, SMPServer, Maybe ConnId) -> Bool isSession rq (uId, srv, connId_) = userId rq == uId && server rq == srv && maybe True (connId rq ==) connId_ -qKey :: RcvQueue -> (UserId, SMPServer, ConnId) -qKey rq = (userId rq, server rq, connId rq) +instance Queue RcvQueue where + connId' = connId + qKey rq = (userId rq, server rq, connId rq) + +instance Queue (SessionId, RcvQueue) where + connId' = connId . snd + qKey = qKey . snd diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index e20b00039..b4567c62e 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -130,7 +130,7 @@ import Numeric.Natural import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Parsers (defaultJSON, dropPrefix, enumJSON) +import Simplex.Messaging.Parsers (defaultJSON, dropPrefix, enumJSON, sumTypeJSON) import Simplex.Messaging.Protocol import Simplex.Messaging.Server.QueueStore.QueueInfo import Simplex.Messaging.TMap (TMap) @@ -170,17 +170,17 @@ data PClient v err msg = PClient msgQ :: Maybe (TBQueue (ServerTransmissionBatch v err msg)) } -smpClientStub :: TVar ChaChaDRG -> ByteString -> VersionSMP -> Maybe (THandleAuth 'TClient) -> STM SMPClient +smpClientStub :: TVar ChaChaDRG -> ByteString -> VersionSMP -> Maybe (THandleAuth 'TClient) -> IO SMPClient smpClientStub g sessionId thVersion thAuth = do let ts = UTCTime (read "2024-03-31") 0 - connected <- newTVar False - clientCorrId <- C.newRandomDRG g - sentCommands <- TM.empty - sendPings <- newTVar False - lastReceived <- newTVar ts - timeoutErrorCount <- newTVar 0 - sndQ <- newTBQueue 100 - rcvQ <- newTBQueue 100 + connected <- newTVarIO False + clientCorrId <- atomically $ C.newRandomDRG g + sentCommands <- TM.emptyIO + sendPings <- newTVarIO False + lastReceived <- newTVarIO ts + timeoutErrorCount <- newTVarIO 0 + sndQ <- newTBQueueIO 100 + rcvQ <- newTBQueueIO 100 return ProtocolClient { action = Nothing, @@ -240,10 +240,20 @@ data SocksMode = -- | always use SOCKS proxy when enabled SMAlways | -- | use SOCKS proxy only for .onion hosts when no public host is available - -- This mode is used in SMP proxy to minimize SOCKS proxy usage. + -- This mode is used in SMP proxy and in notifications server to minimize SOCKS proxy usage. SMOnion deriving (Eq, Show) +instance StrEncoding SocksMode where + strEncode = \case + SMAlways -> "always" + SMOnion -> "onion" + strP = + A.takeTill (== ' ') >>= \case + "always" -> pure SMAlways + "onion" -> pure SMOnion + _ -> fail "Invalid Socks mode" + -- | network configuration for the client data NetworkConfig = NetworkConfig { -- | use SOCKS5 proxy @@ -442,21 +452,21 @@ getProtocolClient :: forall v err msg. Protocol v err msg => TVar ChaChaDRG -> T getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize, networkConfig, clientALPN, serverVRange, agreeSecret} msgQ disconnected = do case chooseTransportHost networkConfig (host srv) of Right useHost -> - (getCurrentTime >>= atomically . mkProtocolClient useHost >>= runClient useTransport useHost) + (getCurrentTime >>= mkProtocolClient useHost >>= runClient useTransport useHost) `catch` \(e :: IOException) -> pure . Left $ PCEIOError e Left e -> pure $ Left e where NetworkConfig {tcpConnectTimeout, tcpTimeout, smpPingInterval} = networkConfig - mkProtocolClient :: TransportHost -> UTCTime -> STM (PClient v err msg) + mkProtocolClient :: TransportHost -> UTCTime -> IO (PClient v err msg) mkProtocolClient transportHost ts = do - connected <- newTVar False - sendPings <- newTVar False - lastReceived <- newTVar ts - timeoutErrorCount <- newTVar 0 - clientCorrId <- C.newRandomDRG g - sentCommands <- TM.empty - sndQ <- newTBQueue qSize - rcvQ <- newTBQueue qSize + connected <- newTVarIO False + sendPings <- newTVarIO False + lastReceived <- newTVarIO ts + timeoutErrorCount <- newTVarIO 0 + clientCorrId <- atomically $ C.newRandomDRG g + sentCommands <- TM.emptyIO + sndQ <- newTBQueueIO qSize + rcvQ <- newTBQueueIO qSize return PClient { connected, @@ -555,7 +565,7 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize processMsg ProtocolClient {client_ = PClient {sentCommands}} (_, _, (corrId, entId, respOrErr)) | B.null $ bs corrId = sendMsg $ STEvent clientResp | otherwise = - atomically (TM.lookup corrId sentCommands) >>= \case + TM.lookupIO corrId sentCommands >>= \case Nothing -> sendMsg $ STUnexpectedError unexpected Just Request {entityId, command, pending, responseVar} -> do wasPending <- @@ -823,7 +833,7 @@ connectSMPProxiedRelay c@ProtocolClient {client_ = PClient {tcpConnectTimeout, t PKEY sId vr (chain, key) -> case supportedClientSMPRelayVRange `compatibleVersion` vr of Nothing -> throwE $ transportErr TEVersion - Just (Compatible v) -> liftEitherWith (const $ transportErr $ TEHandshake IDENTITY) $ ProxiedRelay sId v <$> validateRelay chain key + Just (Compatible v) -> liftEitherWith (const $ transportErr $ TEHandshake IDENTITY) $ ProxiedRelay sId v proxyAuth <$> validateRelay chain key r -> throwE $ unexpectedResponse r | otherwise = throwE $ PCETransportError TEVersion where @@ -842,16 +852,17 @@ connectSMPProxiedRelay c@ProtocolClient {client_ = PClient {tcpConnectTimeout, t data ProxiedRelay = ProxiedRelay { prSessionId :: SessionId, prVersion :: VersionSMP, + prBasicAuth :: Maybe BasicAuth, -- auth is included here to allow reconnecting via the same proxy after NO_SESSION error prServerKey :: C.PublicKeyX25519 } data ProxyClientError = -- | protocol error response from proxy - ProxyProtocolError ErrorType + ProxyProtocolError {protocolErr :: ErrorType} | -- | unexpexted response - ProxyUnexpectedResponse String + ProxyUnexpectedResponse {responseStr :: String} | -- | error between proxy and server - ProxyResponseError ErrorType + ProxyResponseError {responseErr :: ErrorType} deriving (Eq, Show, Exception) instance StrEncoding ProxyClientError where @@ -902,7 +913,7 @@ proxySMPCommand :: SenderId -> Command 'Sender -> ExceptT SMPClientError IO (Either ProxyClientError ()) -proxySMPCommand c@ProtocolClient {thParams = proxyThParams, client_ = PClient {clientCorrId = g, tcpTimeout}} (ProxiedRelay sessionId v serverKey) spKey sId command = do +proxySMPCommand c@ProtocolClient {thParams = proxyThParams, client_ = PClient {clientCorrId = g, tcpTimeout}} (ProxiedRelay sessionId v _ serverKey) spKey sId command = do -- prepare params let serverThAuth = (\ta -> ta {serverPeerPubKey = serverKey}) <$> thAuth proxyThParams serverThParams = smpTHParamsSetVersion v proxyThParams {sessionId, thAuth = serverThAuth} @@ -1078,13 +1089,13 @@ mkTransmission_ ProtocolClient {thParams, client_ = PClient {clientCorrId, sentC nonce@(C.CbNonce corrId) <- maybe (atomically $ C.randomCbNonce clientCorrId) pure nonce_ let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (CorrId corrId, entityId, command) auth = authTransmission (thAuth thParams) pKey_ nonce tForAuth - r <- atomically $ mkRequest (CorrId corrId) + r <- mkRequest (CorrId corrId) pure ((,tToSend) <$> auth, r) where - mkRequest :: CorrId -> STM (Request err msg) + mkRequest :: CorrId -> IO (Request err msg) mkRequest corrId = do - pending <- newTVar True - responseVar <- newEmptyTMVar + pending <- newTVarIO True + responseVar <- newEmptyTMVarIO let r = Request { corrId, @@ -1093,7 +1104,7 @@ mkTransmission_ ProtocolClient {thParams, client_ = PClient {clientCorrId, sentC pending, responseVar } - TM.insert corrId r sentCommands + atomically $ TM.insert corrId r sentCommands pure r authTransmission :: Maybe (THandleAuth 'TClient) -> Maybe C.APrivateAuthKey -> C.CbNonce -> ByteString -> Either TransportError (Maybe TransmissionAuth) @@ -1139,6 +1150,6 @@ $(J.deriveJSON (enumJSON $ dropPrefix "SPF") ''SMPProxyFallback) $(J.deriveJSON defaultJSON ''NetworkConfig) -$(J.deriveJSON (enumJSON $ dropPrefix "Proxy") ''ProxyClientError) +$(J.deriveJSON (sumTypeJSON $ dropPrefix "Proxy") ''ProxyClientError) $(J.deriveJSON defaultJSON ''TBQueueInfo) diff --git a/src/Simplex/Messaging/Client/Agent.hs b/src/Simplex/Messaging/Client/Agent.hs index e7c22eec2..8073f1d48 100644 --- a/src/Simplex/Messaging/Client/Agent.hs +++ b/src/Simplex/Messaging/Client/Agent.hs @@ -20,17 +20,15 @@ import Control.Monad.Except import Control.Monad.IO.Unlift import Control.Monad.Trans.Except import Crypto.Random (ChaChaDRG) -import Data.Bifunctor (bimap, first) +import Data.Bifunctor (first) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B -import Data.Either (partitionEithers) -import Data.List (partition) import Data.List.NonEmpty (NonEmpty) import qualified Data.List.NonEmpty as L import Data.Map.Strict (Map) import qualified Data.Map.Strict as M -import Data.Maybe (listToMaybe) import Data.Set (Set) +import qualified Data.Set as S import Data.Text.Encoding import Data.Time.Clock (NominalDiffTime, UTCTime, addUTCTime, getCurrentTime) import Data.Tuple (swap) @@ -55,8 +53,8 @@ type SMPClientVar = SessionVar (Either (SMPClientError, Maybe UTCTime) (OwnServe data SMPClientAgentEvent = CAConnected SMPServer | CADisconnected SMPServer (Set SMPSub) - | CAResubscribed SMPServer (NonEmpty SMPSub) - | CASubError SMPServer (NonEmpty (SMPSub, SMPClientError)) + | CASubscribed SMPServer SMPSubParty (NonEmpty QueueId) + | CASubError SMPServer SMPSubParty (NonEmpty (QueueId, SMPClientError)) data SMPSubParty = SPRecipient | SPNotifier deriving (Eq, Ord, Show) @@ -86,9 +84,9 @@ defaultSMPClientAgentConfig = maxInterval = 10 * second }, persistErrorInterval = 30, -- seconds - msgQSize = 256, - agentQSize = 256, - agentSubsBatchSize = 900, + msgQSize = 1024, + agentQSize = 1024, + agentSubsBatchSize = 1360, ownServerDomains = [] } where @@ -102,7 +100,7 @@ data SMPClientAgent = SMPClientAgent randomDrg :: TVar ChaChaDRG, smpClients :: TMap SMPServer SMPClientVar, smpSessions :: TMap SessionId (OwnServer, SMPClient), - srvSubs :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey), + srvSubs :: TMap SMPServer (TMap SMPSub (SessionId, C.APrivateAuthKey)), pendingSrvSubs :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey), smpSubWorkers :: TMap SMPServer (SessionVar (Async ())), workerSeq :: TVar Int @@ -110,17 +108,17 @@ data SMPClientAgent = SMPClientAgent type OwnServer = Bool -newSMPClientAgent :: SMPClientAgentConfig -> TVar ChaChaDRG -> STM SMPClientAgent +newSMPClientAgent :: SMPClientAgentConfig -> TVar ChaChaDRG -> IO SMPClientAgent newSMPClientAgent agentCfg@SMPClientAgentConfig {msgQSize, agentQSize} randomDrg = do - active <- newTVar True - msgQ <- newTBQueue msgQSize - agentQ <- newTBQueue agentQSize - smpClients <- TM.empty - smpSessions <- TM.empty - srvSubs <- TM.empty - pendingSrvSubs <- TM.empty - smpSubWorkers <- TM.empty - workerSeq <- newTVar 0 + active <- newTVarIO True + msgQ <- newTBQueueIO msgQSize + agentQ <- newTBQueueIO agentQSize + smpClients <- TM.emptyIO + smpSessions <- TM.emptyIO + srvSubs <- TM.emptyIO + pendingSrvSubs <- TM.emptyIO + smpSubWorkers <- TM.emptyIO + workerSeq <- newTVarIO 0 pure SMPClientAgent { agentCfg, @@ -192,7 +190,7 @@ getSMPServerClient'' ca@SMPClientAgent {agentCfg, smpClients, smpSessions, worke isOwnServer :: SMPClientAgent -> SMPServer -> OwnServer isOwnServer SMPClientAgent {agentCfg} ProtocolServer {host} = let srv = strEncode $ L.head host - in any (\s -> s == srv || (B.cons '.' s) `B.isSuffixOf` srv) (ownServerDomains agentCfg) + in any (\s -> s == srv || B.cons '.' s `B.isSuffixOf` srv) (ownServerDomains agentCfg) -- | Run an SMP client for SMPClientVar connectClient :: SMPClientAgent -> SMPServer -> SMPClientVar -> IO (Either SMPClientError SMPClient) @@ -206,20 +204,17 @@ connectClient ca@SMPClientAgent {agentCfg, smpClients, smpSessions, msgQ, random removeClientAndSubs :: SMPClient -> IO (Maybe (Map SMPSub C.APrivateAuthKey)) removeClientAndSubs smp = atomically $ do + TM.delete sessId smpSessions removeSessVar v srv smpClients - TM.delete (sessionId $ thParams smp) smpSessions - TM.lookupDelete srv (srvSubs ca) >>= mapM updateSubs + TM.lookup srv (srvSubs ca) >>= mapM updateSubs where + sessId = sessionId $ thParams smp updateSubs sVar = do - ss <- readTVar sVar - addPendingSubs sVar ss - pure ss - - addPendingSubs sVar ss = do - let ps = pendingSrvSubs ca - TM.lookup srv ps >>= \case - Just ss' -> TM.union ss ss' - _ -> TM.insert srv sVar ps + -- removing subscriptions that have matching sessionId to disconnected client + -- and keep the other ones (they can be made by the new client) + pending <- M.map snd <$> stateTVar sVar (M.partition ((sessId ==) . fst)) + addSubs_ (pendingSrvSubs ca) srv pending + pure pending serverDown :: Map SMPSub C.APrivateAuthKey -> IO () serverDown ss = unless (M.null ss) $ do @@ -234,7 +229,7 @@ reconnectClient ca@SMPClientAgent {active, agentCfg, smpSubWorkers, workerSeq} s where getWorkerVar ts = ifM - (null <$> getPending) + (noPending) (pure Nothing) -- prevent race with cleanup and adding pending queues in another call (Just <$> getSessVar workerSeq srv smpSubWorkers ts) newSubWorker :: SessionVar (Async ()) -> IO () @@ -243,12 +238,13 @@ reconnectClient ca@SMPClientAgent {active, agentCfg, smpSubWorkers, workerSeq} s atomically $ putTMVar (sessionVar v) a runSubWorker = withRetryInterval (reconnectInterval agentCfg) $ \_ loop -> do - pending <- atomically getPending - forM_ pending $ \cs -> whenM (readTVarIO active) $ do - void $ tcpConnectTimeout `timeout` runExceptT (reconnectSMPClient ca srv cs) + pending <- liftIO getPending + unless (null pending) $ whenM (readTVarIO active) $ do + void $ tcpConnectTimeout `timeout` runExceptT (reconnectSMPClient ca srv pending) loop ProtocolClientConfig {networkConfig = NetworkConfig {tcpConnectTimeout}} = smpCfg agentCfg - getPending = mapM readTVar =<< TM.lookup srv (pendingSrvSubs ca) + noPending = maybe (pure True) (fmap M.null . readTVar) =<< TM.lookup srv (pendingSrvSubs ca) + getPending = maybe (pure M.empty) readTVarIO =<< TM.lookupIO srv (pendingSrvSubs ca) cleanup :: SessionVar (Async ()) -> STM () cleanup v = do -- Here we wait until TMVar is not empty to prevent worker cleanup happening before worker is added to TMVar. @@ -258,32 +254,22 @@ reconnectClient ca@SMPClientAgent {active, agentCfg, smpSubWorkers, workerSeq} s reconnectSMPClient :: SMPClientAgent -> SMPServer -> Map SMPSub C.APrivateAuthKey -> ExceptT SMPClientError IO () reconnectSMPClient ca@SMPClientAgent {agentCfg} srv cs = - withSMP ca srv $ \smp -> do - subs' <- filterM (fmap not . atomically . hasSub (srvSubs ca) srv . fst) $ M.assocs cs - let (nSubs, rSubs) = partition (isNotifier . fst . fst) subs' + withSMP ca srv $ \smp -> liftIO $ do + currSubs <- maybe (pure M.empty) readTVarIO =<< TM.lookupIO srv (srvSubs ca) + let (nSubs, rSubs) = foldr (groupSub currSubs) ([], []) $ M.assocs cs subscribe_ smp SPNotifier nSubs subscribe_ smp SPRecipient rSubs where - isNotifier = \case - SPNotifier -> True - SPRecipient -> False - subscribe_ :: SMPClient -> SMPSubParty -> [(SMPSub, C.APrivateAuthKey)] -> ExceptT SMPClientError IO () - subscribe_ smp party = mapM_ subscribeBatch . toChunks (agentSubsBatchSize agentCfg) + groupSub :: Map SMPSub (SessionId, C.APrivateAuthKey) -> (SMPSub, C.APrivateAuthKey) -> ([(QueueId, C.APrivateAuthKey)], [(QueueId, C.APrivateAuthKey)]) -> ([(QueueId, C.APrivateAuthKey)], [(QueueId, C.APrivateAuthKey)]) + groupSub currSubs (s@(party, qId), k) acc@(nSubs, rSubs) + | M.member s currSubs = acc + | otherwise = case party of + SPNotifier -> (s' : nSubs, rSubs) + SPRecipient -> (nSubs, s' : rSubs) where - subscribeBatch subs' = do - let subs'' :: (NonEmpty (QueueId, C.APrivateAuthKey)) = L.map (first snd) subs' - rs <- liftIO $ smpSubscribeQueues party ca smp srv subs'' - let rs' :: (NonEmpty ((SMPSub, C.APrivateAuthKey), Either SMPClientError ())) = - L.zipWith (first . const) subs' rs - rs'' :: [Either (SMPSub, SMPClientError) (SMPSub, C.APrivateAuthKey)] = - map (\(sub, r) -> bimap (fst sub,) (const sub) r) $ L.toList rs' - (errs, oks) = partitionEithers rs'' - (tempErrs, finalErrs) = partition (temporaryClientError . snd) errs - mapM_ (atomically . addSubscription ca srv) oks - mapM_ (notify ca . CAResubscribed srv) $ L.nonEmpty $ map fst oks - mapM_ (atomically . removePendingSubscription ca srv . fst) finalErrs - mapM_ (notify ca . CASubError srv) $ L.nonEmpty finalErrs - mapM_ (throwE . snd) $ listToMaybe tempErrs + s' = (qId, k) + subscribe_ :: SMPClient -> SMPSubParty -> [(QueueId, C.APrivateAuthKey)] -> IO () + subscribe_ smp party = mapM_ (smpSubscribeQueues party ca smp srv) . toChunks (agentSubsBatchSize agentCfg) notify :: MonadIO m => SMPClientAgent -> SMPClientAgentEvent -> m () notify ca evt = atomically $ writeTBQueue (agentQ ca) evt @@ -297,14 +283,15 @@ getConnectedSMPServerClient SMPClientAgent {smpClients} srv = $>>= \case (_, Right r) -> pure $ Just $ Right r (v, Left (e, ts_)) -> - pure ts_ $>>= \ts -> -- proxy will create a new connection if ts_ is Nothing + pure ts_ $>>= \ts -> + -- proxy will create a new connection if ts_ is Nothing ifM ((ts <) <$> liftIO getCurrentTime) -- error persistence interval period expired? (Nothing <$ atomically (removeSessVar v srv smpClients)) -- proxy will create a new connection (pure $ Just $ Left e) -- not expired, returning error -lookupSMPServerClient :: SMPClientAgent -> SessionId -> STM (Maybe (OwnServer, SMPClient)) -lookupSMPServerClient SMPClientAgent {smpSessions} sessId = TM.lookup sessId smpSessions +lookupSMPServerClient :: SMPClientAgent -> SessionId -> IO (Maybe (OwnServer, SMPClient)) +lookupSMPServerClient SMPClientAgent {smpSessions} sessId = TM.lookupIO sessId smpSessions closeSMPClientAgent :: SMPClientAgent -> IO () closeSMPClientAgent c = do @@ -334,86 +321,100 @@ withSMP ca srv action = (getSMPServerClient' ca srv >>= action) `catchE` logSMPE liftIO $ putStrLn $ "SMP error (" <> show srv <> "): " <> show e throwE e -subscribeQueue :: SMPClientAgent -> SMPServer -> (SMPSub, C.APrivateAuthKey) -> ExceptT SMPClientError IO () -subscribeQueue ca srv sub = do - atomically $ addPendingSubscription ca srv sub - withSMP ca srv $ \smp -> subscribe_ smp `catchE` handleErr - where - subscribe_ smp = do - smpSubscribe smp sub - atomically $ addSubscription ca srv sub - - handleErr e = do - atomically . when (e /= PCENetworkError && e /= PCEResponseTimeout) $ - removePendingSubscription ca srv (fst sub) - throwE e - -subscribeQueuesSMP :: SMPClientAgent -> SMPServer -> NonEmpty (RecipientId, RcvPrivateAuthKey) -> IO (NonEmpty (RecipientId, Either SMPClientError ())) +subscribeQueuesSMP :: SMPClientAgent -> SMPServer -> NonEmpty (RecipientId, RcvPrivateAuthKey) -> IO () subscribeQueuesSMP = subscribeQueues_ SPRecipient -subscribeQueuesNtfs :: SMPClientAgent -> SMPServer -> NonEmpty (NotifierId, NtfPrivateAuthKey) -> IO (NonEmpty (NotifierId, Either SMPClientError ())) +subscribeQueuesNtfs :: SMPClientAgent -> SMPServer -> NonEmpty (NotifierId, NtfPrivateAuthKey) -> IO () subscribeQueuesNtfs = subscribeQueues_ SPNotifier -subscribeQueues_ :: SMPSubParty -> SMPClientAgent -> SMPServer -> NonEmpty (QueueId, C.APrivateAuthKey) -> IO (NonEmpty (QueueId, Either SMPClientError ())) +subscribeQueues_ :: SMPSubParty -> SMPClientAgent -> SMPServer -> NonEmpty (QueueId, C.APrivateAuthKey) -> IO () subscribeQueues_ party ca srv subs = do - atomically $ forM_ subs $ addPendingSubscription ca srv . first (party,) + atomically $ addPendingSubs ca srv party $ L.toList subs runExceptT (getSMPServerClient' ca srv) >>= \case - Left e -> pure $ L.map ((,Left e) . fst) subs Right smp -> smpSubscribeQueues party ca smp srv subs + Left _ -> pure () -- no call to reconnectClient - failing getSMPServerClient' does that -smpSubscribeQueues :: SMPSubParty -> SMPClientAgent -> SMPClient -> SMPServer -> NonEmpty (QueueId, C.APrivateAuthKey) -> IO (NonEmpty (QueueId, Either SMPClientError ())) +smpSubscribeQueues :: SMPSubParty -> SMPClientAgent -> SMPClient -> SMPServer -> NonEmpty (QueueId, C.APrivateAuthKey) -> IO () smpSubscribeQueues party ca smp srv subs = do - rs <- L.zip subs <$> subscribe smp (L.map swap subs) - atomically $ forM rs $ \(sub, r) -> - (fst sub,) <$> case r of - Right () -> do - addSubscription ca srv $ first (party,) sub - pure $ Right () - Left e -> do - when (e /= PCENetworkError && e /= PCEResponseTimeout) $ - removePendingSubscription ca srv (party, fst sub) - pure $ Left e + rs <- subscribe smp $ L.map swap subs + rs' <- + atomically $ + ifM + (activeClientSession ca smp srv) + (Just <$> processSubscriptions rs) + (pure Nothing) + case rs' of + Just (tempErrs, finalErrs, oks, _) -> do + notify_ CASubscribed $ map fst oks + notify_ CASubError finalErrs + when tempErrs $ reconnectClient ca srv + Nothing -> reconnectClient ca srv where + processSubscriptions :: NonEmpty (Either SMPClientError ()) -> STM (Bool, [(QueueId, SMPClientError)], [(QueueId, (SessionId, C.APrivateAuthKey))], [QueueId]) + processSubscriptions rs = do + pending <- maybe (pure M.empty) readTVar =<< TM.lookup srv (pendingSrvSubs ca) + let acc@(_, _, oks, notPending) = foldr (groupSub pending) (False, [], [], []) (L.zip subs rs) + unless (null oks) $ addSubscriptions ca srv party oks + unless (null notPending) $ removePendingSubs ca srv party notPending + pure acc + sessId = sessionId $ thParams smp + groupSub :: Map SMPSub C.APrivateAuthKey -> ((QueueId, C.APrivateAuthKey), Either SMPClientError ()) -> (Bool, [(QueueId, SMPClientError)], [(QueueId, (SessionId, C.APrivateAuthKey))], [QueueId]) -> (Bool, [(QueueId, SMPClientError)], [(QueueId, (SessionId, C.APrivateAuthKey))], [QueueId]) + groupSub pending ((qId, pk), r) acc@(!tempErrs, finalErrs, oks, notPending) = case r of + Right () + | M.member (party, qId) pending -> (tempErrs, finalErrs, (qId, (sessId, pk)) : oks, qId : notPending) + | otherwise -> acc + Left e + | temporaryClientError e -> (True, finalErrs, oks, notPending) + | otherwise -> (tempErrs, (qId, e) : finalErrs, oks, qId : notPending) subscribe = case party of SPRecipient -> subscribeSMPQueues SPNotifier -> subscribeSMPQueuesNtfs + notify_ :: (SMPServer -> SMPSubParty -> NonEmpty a -> SMPClientAgentEvent) -> [a] -> IO () + notify_ evt qs = mapM_ (notify ca . evt srv party) $ L.nonEmpty qs + +activeClientSession :: SMPClientAgent -> SMPClient -> SMPServer -> STM Bool +activeClientSession ca smp srv = sameSess <$> tryReadSessVar srv (smpClients ca) + where + sessId = sessionId . thParams + sameSess = \case + Just (Right (_, smp')) -> sessId smp == sessId smp' + _ -> False showServer :: SMPServer -> ByteString showServer ProtocolServer {host, port} = strEncode host <> B.pack (if null port then "" else ':' : port) -smpSubscribe :: SMPClient -> (SMPSub, C.APrivateAuthKey) -> ExceptT SMPClientError IO () -smpSubscribe smp ((party, queueId), privKey) = subscribe_ smp privKey queueId +addSubscriptions :: SMPClientAgent -> SMPServer -> SMPSubParty -> [(QueueId, (SessionId, C.APrivateAuthKey))] -> STM () +addSubscriptions = addSubsList_ . srvSubs +{-# INLINE addSubscriptions #-} + +addPendingSubs :: SMPClientAgent -> SMPServer -> SMPSubParty -> [(QueueId, C.APrivateAuthKey)] -> STM () +addPendingSubs = addSubsList_ . pendingSrvSubs +{-# INLINE addPendingSubs #-} + +addSubsList_ :: TMap SMPServer (TMap SMPSub s) -> SMPServer -> SMPSubParty -> [(QueueId, s)] -> STM () +addSubsList_ subs srv party ss = addSubs_ subs srv ss' where - subscribe_ = case party of - SPRecipient -> subscribeSMPQueue - SPNotifier -> subscribeSMPQueueNotifications + ss' = M.fromList $ map (first (party,)) ss -addSubscription :: SMPClientAgent -> SMPServer -> (SMPSub, C.APrivateAuthKey) -> STM () -addSubscription ca srv sub = do - addSub_ (srvSubs ca) srv sub - removePendingSubscription ca srv $ fst sub - -addPendingSubscription :: SMPClientAgent -> SMPServer -> (SMPSub, C.APrivateAuthKey) -> STM () -addPendingSubscription = addSub_ . pendingSrvSubs - -addSub_ :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey) -> SMPServer -> (SMPSub, C.APrivateAuthKey) -> STM () -addSub_ subs srv (s, key) = +addSubs_ :: TMap SMPServer (TMap SMPSub s) -> SMPServer -> Map SMPSub s -> STM () +addSubs_ subs srv ss = TM.lookup srv subs >>= \case - Just m -> TM.insert s key m - _ -> TM.singleton s key >>= \v -> TM.insert srv v subs + Just m -> TM.union ss m + _ -> newTVar ss >>= \v -> TM.insert srv v subs removeSubscription :: SMPClientAgent -> SMPServer -> SMPSub -> STM () removeSubscription = removeSub_ . srvSubs +{-# INLINE removeSubscription #-} -removePendingSubscription :: SMPClientAgent -> SMPServer -> SMPSub -> STM () -removePendingSubscription = removeSub_ . pendingSrvSubs - -removeSub_ :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey) -> SMPServer -> SMPSub -> STM () +removeSub_ :: TMap SMPServer (TMap SMPSub s) -> SMPServer -> SMPSub -> STM () removeSub_ subs srv s = TM.lookup srv subs >>= mapM_ (TM.delete s) -getSubKey :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey) -> SMPServer -> SMPSub -> STM (Maybe C.APrivateAuthKey) -getSubKey subs srv s = TM.lookup srv subs $>>= TM.lookup s +removePendingSubs :: SMPClientAgent -> SMPServer -> SMPSubParty -> [QueueId] -> STM () +removePendingSubs = removeSubs_ . pendingSrvSubs +{-# INLINE removePendingSubs #-} -hasSub :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey) -> SMPServer -> SMPSub -> STM Bool -hasSub subs srv s = maybe (pure False) (TM.member s) =<< TM.lookup srv subs +removeSubs_ :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey) -> SMPServer -> SMPSubParty -> [QueueId] -> STM () +removeSubs_ subs srv party qs = TM.lookup srv subs >>= mapM_ (`modifyTVar'` (`M.withoutKeys` ss)) + where + ss = S.fromList $ map (party,) qs diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index 5d3b4d806..1192148ac 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -176,10 +176,10 @@ ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAge getSMPSubscriber :: SMPServer -> M SMPSubscriber getSMPSubscriber smpServer = - atomically (TM.lookup smpServer smpSubscribers) >>= maybe createSMPSubscriber pure + liftIO (TM.lookupIO smpServer smpSubscribers) >>= maybe createSMPSubscriber pure where createSMPSubscriber = do - sub@SMPSubscriber {subThreadId} <- atomically newSMPSubscriber + sub@SMPSubscriber {subThreadId} <- liftIO newSMPSubscriber atomically $ TM.insert smpServer sub smpSubscribers tId <- mkWeakThreadId =<< forkIO (runSMPSubscriber sub) atomically . writeTVar subThreadId $ Just tId @@ -188,33 +188,16 @@ ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAge runSMPSubscriber :: SMPSubscriber -> M () runSMPSubscriber SMPSubscriber {newSubQ = subscriberSubQ} = forever $ do - subs <- atomically (peekTQueue subscriberSubQ) + subs <- atomically $ readTQueue subscriberSubQ let subs' = L.map (\(NtfSub sub) -> sub) subs srv = server $ L.head subs logSubStatus srv "subscribing" $ length subs mapM_ (\NtfSubData {smpQueue} -> updateSubStatus smpQueue NSPending) subs' - rs <- liftIO $ subscribeQueues srv subs' - (subs'', oks, errs) <- foldM process ([], 0, []) rs - atomically $ do - void $ readTQueue subscriberSubQ - mapM_ (writeTQueue subscriberSubQ . L.map NtfSub) $ L.nonEmpty subs'' - logSubStatus srv "retrying" $ length subs'' - logSubStatus srv "subscribed" oks - logSubErrors srv errs - where - process :: ([NtfSubData], Int, [NtfSubStatus]) -> (NtfSubData, Either SMPClientError ()) -> M ([NtfSubData], Int, [NtfSubStatus]) - process (subs, oks, errs) (sub@NtfSubData {smpQueue}, r) = case r of - Right _ -> updateSubStatus smpQueue NSActive $> (subs, oks + 1, errs) - Left e -> update <$> handleSubError smpQueue e - where - update = \case - Just err -> (subs, oks, err : errs) -- permanent error, log and don't retry subscription - Nothing -> (sub : subs, oks, errs) -- temporary error, retry subscription + liftIO $ subscribeQueues srv subs' -- \| Subscribe to queues. The list of results can have a different order. - subscribeQueues :: SMPServer -> NonEmpty NtfSubData -> IO (NonEmpty (NtfSubData, Either SMPClientError ())) - subscribeQueues srv subs = - L.zipWith (\s r -> (s, snd r)) subs <$> subscribeQueuesNtfs ca srv (L.map sub subs) + subscribeQueues :: SMPServer -> NonEmpty NtfSubData -> IO () + subscribeQueues srv subs = subscribeQueuesNtfs ca srv (L.map sub subs) where sub NtfSubData {smpQueue = SMPQueueNtf {notifierId}, notifierKey} = (notifierId, notifierKey) @@ -239,7 +222,7 @@ ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAge incNtfStat ntfReceived Right SMP.END -> updateSubStatus smpQueue NSEnd Right (SMP.ERR e) -> logError $ "SMP server error: " <> tshow e - Right _ -> logError $ "SMP server unexpected response" + Right _ -> logError "SMP server unexpected response" Left e -> logError $ "SMP client error: " <> tshow e receiveAgent = @@ -252,11 +235,11 @@ ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAge forM_ subs $ \(_, ntfId) -> do let smpQueue = SMPQueueNtf srv ntfId updateSubStatus smpQueue NSInactive - CAResubscribed srv subs -> do - forM_ subs $ \(_, ntfId) -> updateSubStatus (SMPQueueNtf srv ntfId) NSActive - logSubStatus srv "resubscribed" $ length subs - CASubError srv errs -> - forM errs (\((_, ntfId), err) -> handleSubError (SMPQueueNtf srv ntfId) err) + CASubscribed srv _ subs -> do + forM_ subs $ \ntfId -> updateSubStatus (SMPQueueNtf srv ntfId) NSActive + logSubStatus srv "subscribed" $ length subs + CASubError srv _ errs -> + forM errs (\(ntfId, err) -> handleSubError (SMPQueueNtf srv ntfId) err) >>= logSubErrors srv . catMaybes . L.toList logSubStatus srv event n = @@ -350,7 +333,7 @@ runNtfClientTransport :: Transport c => THandleNTF c 'TServer -> M () runNtfClientTransport th@THandle {params} = do qSize <- asks $ clientQSize . config ts <- liftIO getSystemTime - c <- atomically $ newNtfServerClient qSize params ts + c <- liftIO $ newNtfServerClient qSize params ts s <- asks subscriber ps <- asks pushServer expCfg <- asks $ inactiveClientExpiration . config @@ -524,7 +507,7 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {newSubQ, smpAgent = ca} NtfPu | otherwise -> do logDebug "TCRN" atomically $ writeTVar tknCronInterval int - atomically (TM.lookup tknId intervalNotifiers) >>= \case + liftIO (TM.lookupIO tknId intervalNotifiers) >>= \case Nothing -> runIntervalNotifier int Just IntervalNotifier {interval, action} -> unless (interval == int) $ do @@ -602,7 +585,7 @@ incNtfStat statSel = do saveServerStats :: M () saveServerStats = asks (serverStatsBackupFile . config) - >>= mapM_ (\f -> asks serverStats >>= atomically . getNtfServerStatsData >>= liftIO . saveStats f) + >>= mapM_ (\f -> asks serverStats >>= liftIO . getNtfServerStatsData >>= liftIO . saveStats f) where saveStats f stats = do logInfo $ "saving server stats to file " <> T.pack f diff --git a/src/Simplex/Messaging/Notifications/Server/Env.hs b/src/Simplex/Messaging/Notifications/Server/Env.hs index 5ebd5230e..dc0cb0a73 100644 --- a/src/Simplex/Messaging/Notifications/Server/Env.hs +++ b/src/Simplex/Messaging/Notifications/Server/Env.hs @@ -10,7 +10,6 @@ module Simplex.Messaging.Notifications.Server.Env where import Control.Concurrent (ThreadId) import Control.Concurrent.Async (Async) import Control.Logger.Simple -import Control.Monad.IO.Unlift import Crypto.Random import Data.Int (Int64) import Data.List.NonEmpty (NonEmpty) @@ -85,16 +84,16 @@ data NtfEnv = NtfEnv newNtfServerEnv :: NtfServerConfig -> IO NtfEnv newNtfServerEnv config@NtfServerConfig {subQSize, pushQSize, smpAgentCfg, apnsConfig, storeLogFile, caCertificateFile, certificateFile, privateKeyFile, transportConfig} = do - random <- liftIO C.newRandom - store <- atomically newNtfStore + random <- C.newRandom + store <- newNtfStore logInfo "restoring subscriptions..." - storeLog <- liftIO $ mapM (`readWriteNtfStore` store) storeLogFile + storeLog <- mapM (`readWriteNtfStore` store) storeLogFile logInfo "restored subscriptions" - subscriber <- atomically $ newNtfSubscriber subQSize smpAgentCfg random - pushServer <- atomically $ newNtfPushServer pushQSize apnsConfig - tlsServerParams <- liftIO $ loadTLSServerParams caCertificateFile certificateFile privateKeyFile (alpn transportConfig) - Fingerprint fp <- liftIO $ loadFingerprint caCertificateFile - serverStats <- atomically . newNtfServerStats =<< liftIO getCurrentTime + subscriber <- newNtfSubscriber subQSize smpAgentCfg random + pushServer <- newNtfPushServer pushQSize apnsConfig + tlsServerParams <- loadTLSServerParams caCertificateFile certificateFile privateKeyFile (alpn transportConfig) + Fingerprint fp <- loadFingerprint caCertificateFile + serverStats <- newNtfServerStats =<< getCurrentTime pure NtfEnv {config, subscriber, pushServer, store, storeLog, random, tlsServerParams, serverIdentity = C.KeyHash fp, serverStats} data NtfSubscriber = NtfSubscriber @@ -103,10 +102,10 @@ data NtfSubscriber = NtfSubscriber smpAgent :: SMPClientAgent } -newNtfSubscriber :: Natural -> SMPClientAgentConfig -> TVar ChaChaDRG -> STM NtfSubscriber +newNtfSubscriber :: Natural -> SMPClientAgentConfig -> TVar ChaChaDRG -> IO NtfSubscriber newNtfSubscriber qSize smpAgentCfg random = do - smpSubscribers <- TM.empty - newSubQ <- newTBQueue qSize + smpSubscribers <- TM.emptyIO + newSubQ <- newTBQueueIO qSize smpAgent <- newSMPClientAgent smpAgentCfg random pure NtfSubscriber {smpSubscribers, newSubQ, smpAgent} @@ -115,10 +114,10 @@ data SMPSubscriber = SMPSubscriber subThreadId :: TVar (Maybe (Weak ThreadId)) } -newSMPSubscriber :: STM SMPSubscriber +newSMPSubscriber :: IO SMPSubscriber newSMPSubscriber = do - newSubQ <- newTQueue - subThreadId <- newTVar Nothing + newSubQ <- newTQueueIO + subThreadId <- newTVarIO Nothing pure SMPSubscriber {newSubQ, subThreadId} data NtfPushServer = NtfPushServer @@ -134,11 +133,11 @@ data IntervalNotifier = IntervalNotifier interval :: Word16 } -newNtfPushServer :: Natural -> APNSPushClientConfig -> STM NtfPushServer +newNtfPushServer :: Natural -> APNSPushClientConfig -> IO NtfPushServer newNtfPushServer qSize apnsConfig = do - pushQ <- newTBQueue qSize - pushClients <- TM.empty - intervalNotifiers <- TM.empty + pushQ <- newTBQueueIO qSize + pushClients <- TM.emptyIO + intervalNotifiers <- TM.emptyIO pure NtfPushServer {pushQ, pushClients, intervalNotifiers, apnsConfig} newPushClient :: NtfPushServer -> PushProvider -> IO PushProviderClient @@ -151,7 +150,7 @@ newPushClient NtfPushServer {apnsConfig, pushClients} pp = do getPushClient :: NtfPushServer -> PushProvider -> IO PushProviderClient getPushClient s@NtfPushServer {pushClients} pp = - atomically (TM.lookup pp pushClients) >>= maybe (newPushClient s pp) pure + TM.lookupIO pp pushClients >>= maybe (newPushClient s pp) pure data NtfRequest = NtfReqNew CorrId ANewNtfEntity @@ -167,11 +166,11 @@ data NtfServerClient = NtfServerClient sndActiveAt :: TVar SystemTime } -newNtfServerClient :: Natural -> THandleParams NTFVersion 'TServer -> SystemTime -> STM NtfServerClient +newNtfServerClient :: Natural -> THandleParams NTFVersion 'TServer -> SystemTime -> IO NtfServerClient newNtfServerClient qSize ntfThParams ts = do - rcvQ <- newTBQueue qSize - sndQ <- newTBQueue qSize - connected <- newTVar True - rcvActiveAt <- newTVar ts - sndActiveAt <- newTVar ts + rcvQ <- newTBQueueIO qSize + sndQ <- newTBQueueIO qSize + connected <- newTVarIO True + rcvActiveAt <- newTVarIO ts + sndActiveAt <- newTVarIO ts return NtfServerClient {rcvQ, sndQ, ntfThParams, connected, rcvActiveAt, sndActiveAt} diff --git a/src/Simplex/Messaging/Notifications/Server/Main.hs b/src/Simplex/Messaging/Notifications/Server/Main.hs index 351fe6d72..dadeb82fc 100644 --- a/src/Simplex/Messaging/Notifications/Server/Main.hs +++ b/src/Simplex/Messaging/Notifications/Server/Main.hs @@ -7,6 +7,7 @@ module Simplex.Messaging.Notifications.Server.Main where +import Control.Monad ((<$!>)) import Data.Functor (($>)) import Data.Ini (lookupValue, readIniFile) import Data.Maybe (fromMaybe) @@ -14,6 +15,7 @@ import qualified Data.Text as T import qualified Data.Text.IO as T import Network.Socket (HostName) import Options.Applicative +import Simplex.Messaging.Client (NetworkConfig (..), ProtocolClientConfig (..), SocksMode (..), defaultNetworkConfig) import Simplex.Messaging.Client.Agent (SMPClientAgentConfig (..), defaultSMPClientAgentConfig) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Notifications.Server (runNtfServer) @@ -87,6 +89,14 @@ ntfServerCLI cfgPath logPath = <> ("port: " <> T.pack defaultServerPort <> "\n") <> "log_tls_errors: off\n" <> "websockets: off\n\n\ + \[SUBSCRIBER]\n\ + \# Network configuration for notification server client.\n\ + \# SOCKS proxy port for subscribing to SMP servers.\n\ + \# You may need a separate instance of SOCKS proxy for incoming single-hop requests.\n\ + \# socks_proxy: localhost:9050\n\n\ + \# `socks_mode` can be 'onion' for SOCKS proxy to be used for .onion destination hosts only (default)\n\ + \# or 'always' to be used for all destination hosts (can be used if it is an .onion server).\n\ + \# socks_mode: onion\n\n\ \[INACTIVE_CLIENTS]\n\ \# TTL and interval to check inactive clients\n\ \disconnect: off\n" @@ -115,7 +125,18 @@ ntfServerCLI cfgPath logPath = clientQSize = 64, subQSize = 512, pushQSize = 1048, - smpAgentCfg = defaultSMPClientAgentConfig {persistErrorInterval = 0}, + smpAgentCfg = + defaultSMPClientAgentConfig + { smpCfg = + (smpCfg defaultSMPClientAgentConfig) + { networkConfig = + defaultNetworkConfig + { socksProxy = either error id <$!> strDecodeIni "SUBSCRIBER" "socks_proxy" ini, + socksMode = maybe SMOnion (either error id) $! strDecodeIni "SUBSCRIBER" "socks_mode" ini + } + }, + persistErrorInterval = 0 -- seconds + }, apnsConfig = defaultAPNSPushClientConfig, subsBatchSize = 900, inactiveClientExpiration = diff --git a/src/Simplex/Messaging/Notifications/Server/Stats.hs b/src/Simplex/Messaging/Notifications/Server/Stats.hs index 7debc1ac9..b73e6098f 100644 --- a/src/Simplex/Messaging/Notifications/Server/Stats.hs +++ b/src/Simplex/Messaging/Notifications/Server/Stats.hs @@ -40,30 +40,30 @@ data NtfServerStatsData = NtfServerStatsData _activeSubs :: PeriodStatsData NotifierId } -newNtfServerStats :: UTCTime -> STM NtfServerStats +newNtfServerStats :: UTCTime -> IO NtfServerStats newNtfServerStats ts = do - fromTime <- newTVar ts - tknCreated <- newTVar 0 - tknVerified <- newTVar 0 - tknDeleted <- newTVar 0 - subCreated <- newTVar 0 - subDeleted <- newTVar 0 - ntfReceived <- newTVar 0 - ntfDelivered <- newTVar 0 + fromTime <- newTVarIO ts + tknCreated <- newTVarIO 0 + tknVerified <- newTVarIO 0 + tknDeleted <- newTVarIO 0 + subCreated <- newTVarIO 0 + subDeleted <- newTVarIO 0 + ntfReceived <- newTVarIO 0 + ntfDelivered <- newTVarIO 0 activeTokens <- newPeriodStats activeSubs <- newPeriodStats pure NtfServerStats {fromTime, tknCreated, tknVerified, tknDeleted, subCreated, subDeleted, ntfReceived, ntfDelivered, activeTokens, activeSubs} -getNtfServerStatsData :: NtfServerStats -> STM NtfServerStatsData +getNtfServerStatsData :: NtfServerStats -> IO NtfServerStatsData getNtfServerStatsData s@NtfServerStats {fromTime} = do - _fromTime <- readTVar fromTime - _tknCreated <- readTVar $ tknCreated s - _tknVerified <- readTVar $ tknVerified s - _tknDeleted <- readTVar $ tknDeleted s - _subCreated <- readTVar $ subCreated s - _subDeleted <- readTVar $ subDeleted s - _ntfReceived <- readTVar $ ntfReceived s - _ntfDelivered <- readTVar $ ntfDelivered s + _fromTime <- readTVarIO fromTime + _tknCreated <- readTVarIO $ tknCreated s + _tknVerified <- readTVarIO $ tknVerified s + _tknDeleted <- readTVarIO $ tknDeleted s + _subCreated <- readTVarIO $ subCreated s + _subDeleted <- readTVarIO $ subDeleted s + _ntfReceived <- readTVarIO $ ntfReceived s + _ntfDelivered <- readTVarIO $ ntfDelivered s _activeTokens <- getPeriodStatsData $ activeTokens s _activeSubs <- getPeriodStatsData $ activeSubs s pure NtfServerStatsData {_fromTime, _tknCreated, _tknVerified, _tknDeleted, _subCreated, _subDeleted, _ntfReceived, _ntfDelivered, _activeTokens, _activeSubs} diff --git a/src/Simplex/Messaging/Notifications/Server/Store.hs b/src/Simplex/Messaging/Notifications/Server/Store.hs index 83dc1a4c2..b4d91dc88 100644 --- a/src/Simplex/Messaging/Notifications/Server/Store.hs +++ b/src/Simplex/Messaging/Notifications/Server/Store.hs @@ -33,13 +33,13 @@ data NtfStore = NtfStore subscriptionLookup :: TMap SMPQueueNtf NtfSubscriptionId } -newNtfStore :: STM NtfStore +newNtfStore :: IO NtfStore newNtfStore = do - tokens <- TM.empty - tokenRegistrations <- TM.empty - subscriptions <- TM.empty - tokenSubscriptions <- TM.empty - subscriptionLookup <- TM.empty + tokens <- TM.emptyIO + tokenRegistrations <- TM.emptyIO + subscriptions <- TM.emptyIO + tokenSubscriptions <- TM.emptyIO + subscriptionLookup <- TM.emptyIO pure NtfStore {tokens, tokenRegistrations, subscriptions, tokenSubscriptions, subscriptionLookup} data NtfTknData = NtfTknData @@ -77,6 +77,9 @@ data NtfEntityRec (e :: NtfEntity) where getNtfToken :: NtfStore -> NtfTokenId -> STM (Maybe NtfTknData) getNtfToken st tknId = TM.lookup tknId (tokens st) +getNtfTokenIO :: NtfStore -> NtfTokenId -> IO (Maybe NtfTknData) +getNtfTokenIO st tknId = TM.lookupIO tknId (tokens st) + addNtfToken :: NtfStore -> NtfTokenId -> NtfTknData -> STM () addNtfToken st tknId tkn@NtfTknData {token, tknVerifyKey} = do TM.insert tknId tkn $ tokens st diff --git a/src/Simplex/Messaging/Notifications/Types.hs b/src/Simplex/Messaging/Notifications/Types.hs index 4465f8767..8fcedab53 100644 --- a/src/Simplex/Messaging/Notifications/Types.hs +++ b/src/Simplex/Messaging/Notifications/Types.hs @@ -11,7 +11,7 @@ import Data.Text.Encoding (decodeLatin1, encodeUtf8) import Data.Time (UTCTime) import Database.SQLite.Simple.FromField (FromField (..)) import Database.SQLite.Simple.ToField (ToField (..)) -import Simplex.Messaging.Agent.Protocol (ConnId, NotificationsMode (..)) +import Simplex.Messaging.Agent.Protocol (ConnId, NotificationsMode (..), UserId) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Notifications.Protocol @@ -48,6 +48,7 @@ data NtfToken = NtfToken ntfServer :: NtfServer, ntfTokenId :: Maybe NtfTokenId, -- TODO combine keys to key pair as the types should match + -- | key used by the ntf server to verify transmissions ntfPubKey :: C.APublicAuthKey, -- | key used by the ntf client to sign transmissions @@ -79,17 +80,17 @@ newNtfToken deviceToken ntfServer (ntfPubKey, ntfPrivKey) ntfDhKeys ntfMode = ntfMode } -data NtfSubAction = NtfSubNTFAction NtfSubNTFAction | NtfSubSMPAction NtfSubSMPAction +data NtfSubAction = NSANtf NtfSubNTFAction | NSASMP NtfSubSMPAction deriving (Show) isDeleteNtfSubAction :: NtfSubAction -> Bool isDeleteNtfSubAction = \case - NtfSubNTFAction a -> case a of + NSANtf a -> case a of NSACreate -> False NSACheck -> False NSADelete -> True NSARotate -> True - NtfSubSMPAction a -> case a of + NSASMP a -> case a of NSASmpKey -> False NSASmpDelete -> True @@ -177,7 +178,8 @@ instance FromField NtfAgentSubStatus where fromField = fromTextField_ $ either ( instance ToField NtfAgentSubStatus where toField = toField . decodeLatin1 . smpEncode data NtfSubscription = NtfSubscription - { connId :: ConnId, + { userId :: UserId, + connId :: ConnId, smpServer :: SMPServer, ntfQueueId :: Maybe NotifierId, ntfServer :: NtfServer, @@ -186,10 +188,11 @@ data NtfSubscription = NtfSubscription } deriving (Show) -newNtfSubscription :: ConnId -> SMPServer -> Maybe NotifierId -> NtfServer -> NtfAgentSubStatus -> NtfSubscription -newNtfSubscription connId smpServer ntfQueueId ntfServer ntfSubStatus = +newNtfSubscription :: UserId -> ConnId -> SMPServer -> Maybe NotifierId -> NtfServer -> NtfAgentSubStatus -> NtfSubscription +newNtfSubscription userId connId smpServer ntfQueueId ntfServer ntfSubStatus = NtfSubscription - { connId, + { userId, + connId, smpServer, ntfQueueId, ntfServer, diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index e96e8b582..c5d067475 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -6,7 +6,6 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE LambdaCase #-} -{-# LANGUAGE MultiWayIf #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE NumericUnderscores #-} {-# LANGUAGE OverloadedLists #-} @@ -38,6 +37,7 @@ module Simplex.Messaging.Server ) where +import Control.Concurrent.STM.TQueue (flushTQueue) import Control.Logger.Simple import Control.Monad import Control.Monad.Except @@ -48,6 +48,7 @@ import Crypto.Random import Control.Monad.STM (retry) import Data.Bifunctor (first) import Data.ByteString.Base64 (encode) +import qualified Data.ByteString.Builder as BLD import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import qualified Data.ByteString.Lazy.Char8 as LB @@ -55,11 +56,13 @@ import Data.Either (fromRight, partitionEithers) import Data.Functor (($>)) import Data.Int (Int64) import qualified Data.IntMap.Strict as IM +import qualified Data.IntSet as IS import Data.List (intercalate, mapAccumR) import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L import qualified Data.Map.Strict as M import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing) +import qualified Data.Set as S import qualified Data.Text as T import Data.Text.Encoding (decodeLatin1) import Data.Time.Clock (UTCTime (..), diffTimeToPicoseconds, getCurrentTime) @@ -69,6 +72,7 @@ import Data.Type.Equality import GHC.Stats (getRTSStats) import GHC.TypeLits (KnownNat) import Network.Socket (ServiceName, Socket, socketToHandle) +import Numeric.Natural (Natural) import Simplex.Messaging.Agent.Lock import Simplex.Messaging.Client (ProtocolClient (thParams), ProtocolClientError (..), SMPClient, SMPClientError, forwardSMPTransmission, smpProxyError, temporaryClientError) import Simplex.Messaging.Client.Agent (OwnServer, SMPClientAgent (..), SMPClientAgentEvent (..), closeSMPClientAgent, getSMPServerClient'', isOwnServer, lookupSMPServerClient, getConnectedSMPServerClient) @@ -158,28 +162,33 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do forall s. Server -> String -> - (Server -> TQueue (QueueId, Client)) -> + (Server -> TQueue (QueueId, Client, Subscribed)) -> (Server -> TMap QueueId Client) -> (Client -> TMap QueueId s) -> (s -> IO ()) -> M () serverThread s label subQ subs clientSubs unsub = do labelMyThread label + cls <- asks clients forever $ - atomically updateSubscribers + atomically (updateSubscribers cls) $>>= endPreviousSubscriptions >>= liftIO . mapM_ unsub where - updateSubscribers :: STM (Maybe (QueueId, Client)) - updateSubscribers = do - (qId, clnt) <- readTQueue $ subQ s - let clientToBeNotified c' = - if sameClientId clnt c' - then pure Nothing - else do + updateSubscribers :: TVar (IM.IntMap Client) -> STM (Maybe (QueueId, Client)) + updateSubscribers cls = do + (qId, clnt, subscribed) <- readTQueue $ subQ s + current <- IM.member (clientId clnt) <$> readTVar cls + let updateSub + | not subscribed = TM.lookupDelete + | not current = TM.lookup -- do not insert client if it is already disconnected, but send END to any other client + | otherwise = (`TM.lookupInsert` clnt) -- insert subscribed and current client + clientToBeNotified c' + | sameClientId clnt c' = pure Nothing + | otherwise = do yes <- readTVar $ connected c' pure $ if yes then Just (qId, c') else Nothing - TM.lookupInsert qId clnt (subs s) $>>= clientToBeNotified + updateSub qId (subs s) $>>= clientToBeNotified endPreviousSubscriptions :: (QueueId, Client) -> M (Maybe s) endPreviousSubscriptions (qId, c) = do forkClient c (label <> ".endPreviousSubscriptions") $ @@ -193,8 +202,8 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do CAConnected srv -> logInfo $ "SMP server connected " <> showServer' srv CADisconnected srv [] -> logInfo $ "SMP server disconnected " <> showServer' srv CADisconnected srv subs -> logError $ "SMP server disconnected " <> showServer' srv <> " / subscriptions: " <> tshow (length subs) - CAResubscribed srv subs -> logError $ "SMP server resubscribed " <> showServer' srv <> " / subscriptions: " <> tshow (length subs) - CASubError srv errs -> logError $ "SMP server subscription errors " <> showServer' srv <> " / errors: " <> tshow (length errs) + CASubscribed srv _ subs -> logError $ "SMP server subscribed " <> showServer' srv <> " / subscriptions: " <> tshow (length subs) + CASubError srv _ errs -> logError $ "SMP server subscription errors " <> showServer' srv <> " / errors: " <> tshow (length errs) where showServer' = decodeLatin1 . strEncode . host @@ -229,7 +238,9 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do initialDelay <- (startAt -) . fromIntegral . (`div` 1000000_000000) . diffTimeToPicoseconds . utctDayTime <$> liftIO getCurrentTime liftIO $ putStrLn $ "server stats log enabled: " <> statsFilePath liftIO $ threadDelay' $ 1000000 * (initialDelay + if initialDelay < 0 then 86400 else 0) - ss@ServerStats {fromTime, qCreated, qSecured, qDeletedAll, qDeletedNew, qDeletedSecured, qSub, qSubAuth, qSubDuplicate, qSubProhibited, msgSent, msgSentAuth, msgSentQuota, msgSentLarge, msgRecv, msgExpired, activeQueues, msgSentNtf, msgRecvNtf, activeQueuesNtf, qCount, msgCount, pRelays, pRelaysOwn, pMsgFwds, pMsgFwdsOwn, pMsgFwdsRecv} <- asks serverStats + ss@ServerStats {fromTime, qCreated, qSecured, qDeletedAll, qDeletedNew, qDeletedSecured, qSub, qSubNoMsg, qSubAuth, qSubDuplicate, qSubProhibited, ntfCreated, ntfDeleted, ntfSub, ntfSubAuth, ntfSubDuplicate, msgSent, msgSentAuth, msgSentQuota, msgSentLarge, msgRecv, msgRecvGet, msgGet, msgGetNoMsg, msgGetAuth, msgGetDuplicate, msgGetProhibited, msgExpired, activeQueues, subscribedQueues, msgSentNtf, msgRecvNtf, activeQueuesNtf, qCount, msgCount, pRelays, pRelaysOwn, pMsgFwds, pMsgFwdsOwn, pMsgFwdsRecv} + <- asks serverStats + QueueStore {queues, notifiers} <- asks queueStore let interval = 1000000 * logInterval forever $ do withFile statsFilePath AppendMode $ \h -> liftIO $ do @@ -242,16 +253,29 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do qDeletedNew' <- atomically $ swapTVar qDeletedNew 0 qDeletedSecured' <- atomically $ swapTVar qDeletedSecured 0 qSub' <- atomically $ swapTVar qSub 0 + qSubNoMsg' <- atomically $ swapTVar qSubNoMsg 0 qSubAuth' <- atomically $ swapTVar qSubAuth 0 qSubDuplicate' <- atomically $ swapTVar qSubDuplicate 0 qSubProhibited' <- atomically $ swapTVar qSubProhibited 0 + ntfCreated' <- atomically $ swapTVar ntfCreated 0 + ntfDeleted' <- atomically $ swapTVar ntfDeleted 0 + ntfSub' <- atomically $ swapTVar ntfSub 0 + ntfSubAuth' <- atomically $ swapTVar ntfSubAuth 0 + ntfSubDuplicate' <- atomically $ swapTVar ntfSubDuplicate 0 msgSent' <- atomically $ swapTVar msgSent 0 msgSentAuth' <- atomically $ swapTVar msgSentAuth 0 msgSentQuota' <- atomically $ swapTVar msgSentQuota 0 msgSentLarge' <- atomically $ swapTVar msgSentLarge 0 msgRecv' <- atomically $ swapTVar msgRecv 0 + msgRecvGet' <- atomically $ swapTVar msgRecvGet 0 + msgGet' <- atomically $ swapTVar msgGet 0 + msgGetNoMsg' <- atomically $ swapTVar msgGetNoMsg 0 + msgGetAuth' <- atomically $ swapTVar msgGetAuth 0 + msgGetDuplicate' <- atomically $ swapTVar msgGetDuplicate 0 + msgGetProhibited' <- atomically $ swapTVar msgGetProhibited 0 msgExpired' <- atomically $ swapTVar msgExpired 0 ps <- atomically $ periodStatCounts activeQueues ts + psSub <- atomically $ periodStatCounts subscribedQueues ts msgSentNtf' <- atomically $ swapTVar msgSentNtf 0 msgRecvNtf' <- atomically $ swapTVar msgRecvNtf 0 psNtf <- atomically $ periodStatCounts activeQueuesNtf ts @@ -264,6 +288,8 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do pMsgFwdsOwn' <- atomically $ getResetProxyStatsData pMsgFwdsOwn pMsgFwdsRecv' <- atomically $ swapTVar pMsgFwdsRecv 0 qCount' <- readTVarIO qCount + qCount'' <- M.size <$> readTVarIO queues + ntfCount' <- M.size <$> readTVarIO notifiers msgCount' <- readTVarIO msgCount hPutStrLn h $ intercalate @@ -302,7 +328,24 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do show msgSentLarge', show msgNtfs', show msgNtfNoSub', - show msgNtfLost' + show msgNtfLost', + show qSubNoMsg', + show msgRecvGet', + show msgGet', + show msgGetNoMsg', + show msgGetAuth', + show msgGetDuplicate', + show msgGetProhibited', + dayCount psSub, + weekCount psSub, + monthCount psSub, + show qCount'', + show ntfCreated', + show ntfDeleted', + show ntfSub', + show ntfSubAuth', + show ntfSubDuplicate', + show ntfCount' ] ) liftIO $ threadDelay' interval @@ -377,22 +420,35 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do let age = systemSeconds now - systemSeconds createdAt subscriptions' <- bshow . M.size <$> readTVarIO subscriptions hPutStrLn h . B.unpack $ B.intercalate "," [bshow cid, encode sessionId, connected', strEncode createdAt, rcvActiveAt', sndActiveAt', bshow age, subscriptions'] - CPStats -> withAdminRole $ do + CPStats -> withUserRole $ do ss <- unliftIO u $ asks serverStats - let putStat :: Show a => ByteString -> (ServerStats -> TVar a) -> IO () - putStat label var = readTVarIO (var ss) >>= \v -> B.hPutStr h $ label <> ": " <> bshow v <> "\n" - putProxyStat :: ByteString -> (ServerStats -> ProxyStats) -> IO () + let getStat :: (ServerStats -> TVar a) -> IO a + getStat var = readTVarIO (var ss) + putStat :: Show a => String -> (ServerStats -> TVar a) -> IO () + putStat label var = getStat var >>= \v -> hPutStrLn h $ label <> ": " <> show v + putProxyStat :: String -> (ServerStats -> ProxyStats) -> IO () putProxyStat label var = do - ProxyStatsData {_pRequests, _pSuccesses, _pErrorsConnect, _pErrorsCompat, _pErrorsOther} <- atomically $ getProxyStatsData $ var ss - B.hPutStr h $ label <> ": requests=" <> bshow _pRequests <> ", successes=" <> bshow _pSuccesses <> ", errorsConnect=" <> bshow _pErrorsConnect <> ", errorsCompat=" <> bshow _pErrorsCompat <> ", errorsOther=" <> bshow _pErrorsOther <> "\n" + ProxyStatsData {_pRequests, _pSuccesses, _pErrorsConnect, _pErrorsCompat, _pErrorsOther} <- getProxyStatsData $ var ss + hPutStrLn h $ label <> ": requests=" <> show _pRequests <> ", successes=" <> show _pSuccesses <> ", errorsConnect=" <> show _pErrorsConnect <> ", errorsCompat=" <> show _pErrorsCompat <> ", errorsOther=" <> show _pErrorsOther putStat "fromTime" fromTime putStat "qCreated" qCreated putStat "qSecured" qSecured putStat "qDeletedAll" qDeletedAll putStat "qDeletedNew" qDeletedNew putStat "qDeletedSecured" qDeletedSecured + getStat (day . activeQueues) >>= \v -> hPutStrLn h $ "daily active queues: " <> show (S.size v) + getStat (day . subscribedQueues) >>= \v -> hPutStrLn h $ "daily subscribed queues: " <> show (S.size v) + putStat "qSub" qSub + putStat "qSubNoMsg" qSubNoMsg + subs <- (,,) <$> getStat qSubAuth <*> getStat qSubDuplicate <*> getStat qSubProhibited + hPutStrLn h $ "other SUB events (auth, duplicate, prohibited): " <> show subs putStat "msgSent" msgSent putStat "msgRecv" msgRecv + putStat "msgRecvGet" msgRecvGet + putStat "msgGet" msgGet + putStat "msgGetNoMsg" msgGet + gets <- (,,) <$> getStat msgGetAuth <*> getStat msgGetDuplicate <*> getStat msgGetProhibited + hPutStrLn h $ "other GET events (auth, duplicate, prohibited): " <> show gets putStat "msgSentNtf" msgSentNtf putStat "msgRecvNtf" msgRecvNtf putStat "qCount" qCount @@ -414,9 +470,9 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do #else hPutStrLn h "Not available on GHC 8.10" #endif - CPSockets -> withAdminRole $ do + CPSockets -> withUserRole $ do (accepted', closed', active') <- unliftIO u $ asks sockets - (accepted, closed, active) <- atomically $ (,,) <$> readTVar accepted' <*> readTVar closed' <*> readTVar active' + (accepted, closed, active) <- (,,) <$> readTVarIO accepted' <*> readTVarIO closed' <*> readTVarIO active' hPutStrLn h "Sockets: " hPutStrLn h $ "accepted: " <> show accepted hPutStrLn h $ "closed: " <> show closed @@ -436,6 +492,92 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do #else hPutStrLn h "Not available on GHC 8.10" #endif + CPServerInfo -> readTVarIO role >>= \case + CPRNone -> do + logError "Unauthorized control port command" + hPutStrLn h "AUTH" + r -> do +#if MIN_VERSION_base(4,18,0) + threads <- liftIO listThreads + hPutStrLn h $ "Threads: " <> show (length threads) +#else + hPutStrLn h "Threads: not available on GHC 8.10" +#endif + Env {clients, server = Server {subscribers, notifiers}} <- unliftIO u ask + activeClients <- readTVarIO clients + hPutStrLn h $ "Clients: " <> show (IM.size activeClients) + when (r == CPRAdmin) $ do + clQs <- clientTBQueueLengths activeClients + hPutStrLn h $ "Client queues (rcvQ, sndQ, msgQ): " <> show clQs + (smpSubCnt, smpSubCntByGroup, smpClCnt, smpClQs) <- countClientSubs subscriptions (Just countSMPSubs) activeClients + hPutStrLn h $ "SMP subscriptions (via clients): " <> show smpSubCnt + hPutStrLn h $ "SMP subscriptions (by group: NoSub, SubPending, SubThread, ProhibitSub): " <> show smpSubCntByGroup + hPutStrLn h $ "SMP subscribed clients (via clients): " <> show smpClCnt + hPutStrLn h $ "SMP subscribed clients queues (via clients, rcvQ, sndQ, msgQ): " <> show smpClQs + (ntfSubCnt, _, ntfClCnt, ntfClQs) <- countClientSubs ntfSubscriptions Nothing activeClients + hPutStrLn h $ "Ntf subscriptions (via clients): " <> show ntfSubCnt + hPutStrLn h $ "Ntf subscribed clients (via clients): " <> show ntfClCnt + hPutStrLn h $ "Ntf subscribed clients queues (via clients, rcvQ, sndQ, msgQ): " <> show ntfClQs + putActiveClientsInfo "SMP" subscribers + putActiveClientsInfo "Ntf" notifiers + where + putActiveClientsInfo :: String -> TMap QueueId Client -> IO () + putActiveClientsInfo protoName clients = do + activeSubs <- readTVarIO clients + hPutStrLn h $ protoName <> " subscriptions: " <> show (M.size activeSubs) + clCnt <- if r == CPRAdmin then putClientQueues activeSubs else pure $ countSubClients activeSubs + hPutStrLn h $ protoName <> " subscribed clients: " <> show clCnt + where + putClientQueues :: M.Map QueueId Client -> IO Int + putClientQueues subs = do + let cls = differentClients subs + clQs <- clientTBQueueLengths cls + hPutStrLn h $ protoName <> " subscribed clients queues (rcvQ, sndQ, msgQ): " <> show clQs + pure $ length cls + differentClients :: M.Map QueueId Client -> [Client] + differentClients = fst . M.foldl' addClient ([], IS.empty) + where + addClient acc@(cls, clSet) cl@Client {clientId} + | IS.member clientId clSet = acc + | otherwise = (cl : cls, IS.insert clientId clSet) + countSubClients :: M.Map QueueId Client -> Int + countSubClients = IS.size . M.foldr' (IS.insert . clientId) IS.empty + countClientSubs :: (Client -> TMap QueueId a) -> Maybe (M.Map QueueId a -> IO (Int, Int, Int, Int)) -> IM.IntMap Client -> IO (Int, (Int, Int, Int, Int), Int, (Natural, Natural, Natural)) + countClientSubs subSel countSubs_ = foldM addSubs (0, (0, 0, 0, 0), 0, (0, 0, 0)) + where + addSubs :: (Int, (Int, Int, Int, Int), Int, (Natural, Natural, Natural)) -> Client -> IO (Int, (Int, Int, Int, Int), Int, (Natural, Natural, Natural)) + addSubs (!subCnt, cnts@(!c1, !c2, !c3, !c4), !clCnt, !qs) cl = do + subs <- readTVarIO $ subSel cl + cnts' <- case countSubs_ of + Nothing -> pure cnts + Just countSubs -> do + (c1', c2', c3', c4') <- countSubs subs + pure (c1 + c1', c2 + c2', c3 + c3', c4 + c4') + let cnt = M.size subs + clCnt' = if cnt == 0 then clCnt else clCnt + 1 + qs' <- if cnt == 0 then pure qs else addQueueLengths qs cl + pure (subCnt + cnt, cnts', clCnt', qs') + clientTBQueueLengths :: Foldable t => t Client -> IO (Natural, Natural, Natural) + clientTBQueueLengths = foldM addQueueLengths (0, 0, 0) + addQueueLengths (!rl, !sl, !ml) cl = do + (rl', sl', ml') <- queueLengths cl + pure (rl + rl', sl + sl', ml + ml') + queueLengths Client {rcvQ, sndQ, msgQ} = do + rl <- atomically $ lengthTBQueue rcvQ + sl <- atomically $ lengthTBQueue sndQ + ml <- atomically $ lengthTBQueue msgQ + pure (rl, sl, ml) + countSMPSubs :: M.Map QueueId Sub -> IO (Int, Int, Int, Int) + countSMPSubs = foldM countSubs (0, 0, 0, 0) + where + countSubs (c1, c2, c3, c4) Sub {subThread} = case subThread of + ServerSub t -> do + st <- readTVarIO t + pure $ case st of + NoSub -> (c1 + 1, c2, c3, c4) + SubPending -> (c1, c2 + 1, c3, c4) + SubThread _ -> (c1, c2, c3 + 1, c4) + ProhibitSub -> pure (c1, c2, c3, c4 + 1) CPDelete queueId' -> withUserRole $ unliftIO u $ do st <- asks queueStore ms <- asks msgStore @@ -455,7 +597,7 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do hPutStrLn h "saving server state..." unliftIO u $ saveServer True hPutStrLn h "server state saved!" - CPHelp -> hPutStrLn h "commands: stats, stats-rts, clients, sockets, socket-threads, threads, delete, save, help, quit" + CPHelp -> hPutStrLn h "commands: stats, stats-rts, clients, sockets, socket-threads, threads, server-info, delete, save, help, quit" CPQuit -> pure () CPSkip -> pure () where @@ -477,10 +619,8 @@ runClientTransport h@THandle {params = thParams@THandleParams {thVersion, sessio ts <- liftIO getSystemTime active <- asks clients nextClientId <- asks clientSeq - c <- atomically $ do - new@Client {clientId} <- newClient nextClientId q thVersion sessionId ts - modifyTVar' active $ IM.insert clientId new - pure new + c@Client {clientId} <- liftIO $ newClient nextClientId q thVersion sessionId ts + atomically $ modifyTVar' active $ IM.insert clientId c s <- asks server expCfg <- asks $ inactiveClientExpiration . config th <- newMVar h -- put TH under a fair lock to interleave messages and command responses @@ -490,22 +630,26 @@ runClientTransport h@THandle {params = thParams@THandleParams {thVersion, sessio where disconnectThread_ c (Just expCfg) = [liftIO $ disconnectTransport h (rcvActiveAt c) (sndActiveAt c) expCfg (noSubscriptions c)] disconnectThread_ _ _ = [] - noSubscriptions c = atomically $ (&&) <$> TM.null (subscriptions c) <*> TM.null (ntfSubscriptions c) + noSubscriptions c = atomically $ (&&) <$> TM.null (ntfSubscriptions c) <*> (not . hasSubs <$> readTVar (subscriptions c)) + hasSubs = any $ (\case ServerSub _ -> True; ProhibitSub -> False) . subThread clientDisconnected :: Client -> M () -clientDisconnected c@Client {clientId, subscriptions, connected, sessionId, endThreads} = do +clientDisconnected c@Client {clientId, subscriptions, ntfSubscriptions, connected, sessionId, endThreads} = do labelMyThread . B.unpack $ "client $" <> encode sessionId <> " disc" - subs <- atomically $ do + (subs, ntfSubs) <- atomically $ do writeTVar connected False - swapTVar subscriptions M.empty + (,) <$> swapTVar subscriptions M.empty <*> swapTVar ntfSubscriptions M.empty liftIO $ mapM_ cancelSub subs - srvSubs <- asks $ subscribers . server - atomically $ modifyTVar' srvSubs $ \cs -> - M.foldrWithKey (\sub _ -> M.update deleteCurrentClient sub) cs subs + Server {subscribers, notifiers} <- asks server + updateSubscribers subs subscribers + updateSubscribers ntfSubs notifiers asks clients >>= atomically . (`modifyTVar'` IM.delete clientId) tIds <- atomically $ swapTVar endThreads IM.empty liftIO $ mapM_ (mapM_ killThread <=< deRefWeak) tIds where + updateSubscribers subs srvSubs = do + atomically $ modifyTVar' srvSubs $ \cs -> + M.foldrWithKey (\sub _ -> M.update deleteCurrentClient sub) cs subs deleteCurrentClient :: Client -> Maybe Client deleteCurrentClient c' | sameClientId c c' = Nothing @@ -514,11 +658,13 @@ clientDisconnected c@Client {clientId, subscriptions, connected, sessionId, endT sameClientId :: Client -> Client -> Bool sameClientId Client {clientId} Client {clientId = cId'} = clientId == cId' -cancelSub :: TVar Sub -> IO () -cancelSub sub = - readTVarIO sub >>= \case - Sub {subThread = SubThread t} -> liftIO $ deRefWeak t >>= mapM_ killThread - _ -> return () +cancelSub :: Sub -> IO () +cancelSub s = case subThread s of + ServerSub st -> + readTVarIO st >>= \case + SubThread t -> liftIO $ deRefWeak t >>= mapM_ killThread + _ -> pure () + ProhibitSub -> pure () receive :: Transport c => THandleSMP c 'TServer -> Client -> M () receive h@THandle {params = THandleParams {thAuth}} Client {rcvQ, sndQ, rcvActiveAt, sessionId} = do @@ -541,8 +687,10 @@ receive h@THandle {params = THandleParams {thAuth}} Client {rcvQ, sndQ, rcvActiv VRVerified qr -> pure $ Right (qr, (corrId, entId, cmd)) VRFailed -> do case cmd of - Cmd _ SEND {} -> atomically $ modifyTVar' (msgSentAuth stats) (+ 1) - Cmd _ SUB -> atomically $ modifyTVar' (qSubAuth stats) (+ 1) + Cmd _ SEND {} -> incStat $ msgSentAuth stats + Cmd _ SUB -> incStat $ qSubAuth stats + Cmd _ NSUB -> incStat $ ntfSubAuth stats + Cmd _ GET -> incStat $ msgGetAuth stats _ -> pure () pure $ Left (corrId, entId, ERR AUTH) write q = mapM_ (atomically . writeTBQueue q) . L.nonEmpty @@ -683,7 +831,7 @@ forkClient Client {endThreads, endThreadSeq} label action = do mkWeakThreadId t >>= atomically . modifyTVar' endThreads . IM.insert tId client :: THandleParams SMPVersion 'TServer -> Client -> Server -> M () -client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId, procThreads} Server {subscribedQ, ntfSubscribedQ, notifiers} = do +client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId, procThreads} Server {subscribedQ, ntfSubscribedQ, subscribers, notifiers} = do labelMyThread . B.unpack $ "client $" <> encode sessionId <> " commands" forever $ atomically (readTBQueue rcvQ) @@ -737,7 +885,7 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi ProxyAgent {smpAgent = a} <- asks proxyAgent ServerStats {pMsgFwds, pMsgFwdsOwn} <- asks serverStats let inc = mkIncProxyStats pMsgFwds pMsgFwdsOwn - atomically (lookupSMPServerClient a sessId) >>= \case + liftIO (lookupSMPServerClient a sessId) >>= \case Just (own, smp) -> do inc own pRequests if v >= sendingProxySMPVersion @@ -770,13 +918,13 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi transportErr = PROXY . BROKER . TRANSPORT mkIncProxyStats :: MonadIO m => ProxyStats -> ProxyStats -> OwnServer -> (ProxyStats -> TVar Int) -> m () mkIncProxyStats ps psOwn own sel = do - atomically $ modifyTVar' (sel ps) (+ 1) - when own $ atomically $ modifyTVar' (sel psOwn) (+ 1) + incStat $ sel ps + when own $ incStat $ sel psOwn processCommand :: (Maybe QueueRec, Transmission Cmd) -> M (Maybe (Transmission BrokerMsg)) - processCommand (qr_, (corrId, queueId, cmd)) = case cmd of - Cmd SProxiedClient command -> processProxiedCmd (corrId, queueId, command) + processCommand (qr_, (corrId, entId, cmd)) = case cmd of + Cmd SProxiedClient command -> processProxiedCmd (corrId, entId, command) Cmd SSender command -> Just <$> case command of - SKEY sKey -> (corrId,queueId,) <$> case qr_ of + SKEY sKey -> (corrId,entId,) <$> case qr_ of Just QueueRec {sndSecure, recipientId} | sndSecure -> secureQueue_ "SKEY" recipientId sKey | otherwise -> pure $ ERR AUTH @@ -792,15 +940,15 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi ifM allowNew (createQueue st rKey dhKey subMode sndSecure) - (pure (corrId, queueId, ERR AUTH)) + (pure (corrId, entId, ERR AUTH)) where allowNew = do ServerConfig {allowNewQueues, newQueueBasicAuth} <- asks config pure $ allowNewQueues && maybe True ((== auth) . Just) newQueueBasicAuth - SUB -> withQueue (`subscribeQueue` queueId) + SUB -> withQueue (`subscribeQueue` entId) GET -> withQueue getMessage ACK msgId -> withQueue (`acknowledgeMsg` msgId) - KEY sKey -> (corrId,queueId,) <$> case qr_ of + KEY sKey -> (corrId,entId,) <$> case qr_ of Just QueueRec {recipientId} -> secureQueue_ "KEY" recipientId sKey Nothing -> pure $ ERR INTERNAL NKEY nKey dhKey -> addQueueNotifier_ st nKey dhKey @@ -825,7 +973,7 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi status = QueueActive, sndSecure } - (corrId,queueId,) <$> addQueueRetry 3 qik qRec + (corrId,entId,) <$> addQueueRetry 3 qik qRec where addQueueRetry :: Int -> ((RecipientId, SenderId) -> QueueIdsKeys) -> ((RecipientId, SenderId) -> QueueRec) -> M BrokerMsg @@ -840,8 +988,8 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi Right _ -> do withLog (`logCreateById` rId) stats <- asks serverStats - atomically $ modifyTVar' (qCreated stats) (+ 1) - atomically $ modifyTVar' (qCount stats) (+ 1) + incStat $ qCreated stats + incStat $ qCount stats case subMode of SMOnlyCreate -> pure () SMSubscribe -> void $ subscribeQueue qr rId @@ -863,154 +1011,178 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi withLog $ \s -> logSecureQueue s rId sKey st <- asks queueStore stats <- asks serverStats - atomically $ modifyTVar' (qSecured stats) (+ 1) + incStat $ qSecured stats atomically $ either ERR (const OK) <$> secureQueue st rId sKey addQueueNotifier_ :: QueueStore -> NtfPublicAuthKey -> RcvNtfPublicDhKey -> M (Transmission BrokerMsg) addQueueNotifier_ st notifierKey dhKey = time "NKEY" $ do (rcvPublicDhKey, privDhKey) <- atomically . C.generateKeyPair =<< asks random let rcvNtfDhSecret = C.dh' dhKey privDhKey - (corrId,queueId,) <$> addNotifierRetry 3 rcvPublicDhKey rcvNtfDhSecret + (corrId,entId,) <$> addNotifierRetry 3 rcvPublicDhKey rcvNtfDhSecret where addNotifierRetry :: Int -> RcvNtfPublicDhKey -> RcvNtfDhSecret -> M BrokerMsg addNotifierRetry 0 _ _ = pure $ ERR INTERNAL addNotifierRetry n rcvPublicDhKey rcvNtfDhSecret = do notifierId <- randomId =<< asks (queueIdBytes . config) let ntfCreds = NtfCreds {notifierId, notifierKey, rcvNtfDhSecret} - atomically (addQueueNotifier st queueId ntfCreds) >>= \case + atomically (addQueueNotifier st entId ntfCreds) >>= \case Left DUPLICATE_ -> addNotifierRetry (n - 1) rcvPublicDhKey rcvNtfDhSecret Left e -> pure $ ERR e Right _ -> do - withLog $ \s -> logAddNotifier s queueId ntfCreds + withLog $ \s -> logAddNotifier s entId ntfCreds + incStat . ntfCreated =<< asks serverStats pure $ NID notifierId rcvPublicDhKey deleteQueueNotifier_ :: QueueStore -> M (Transmission BrokerMsg) deleteQueueNotifier_ st = do - withLog (`logDeleteNotifier` queueId) - okResp <$> atomically (deleteQueueNotifier st queueId) + withLog (`logDeleteNotifier` entId) + atomically (deleteQueueNotifier st entId) >>= \case + Right () -> do + -- Possibly, the same should be done if the queue is suspended, but currently we do not use it + atomically $ writeTQueue ntfSubscribedQ (entId, clnt, False) + incStat . ntfDeleted =<< asks serverStats + pure ok + Left e -> pure $ err e suspendQueue_ :: QueueStore -> M (Transmission BrokerMsg) suspendQueue_ st = do - withLog (`logSuspendQueue` queueId) - okResp <$> atomically (suspendQueue st queueId) + withLog (`logSuspendQueue` entId) + okResp <$> atomically (suspendQueue st entId) subscribeQueue :: QueueRec -> RecipientId -> M (Transmission BrokerMsg) subscribeQueue qr rId = do - stats <- asks serverStats atomically (TM.lookup rId subscriptions) >>= \case - Nothing -> do - atomically $ modifyTVar' (qSub stats) (+ 1) - newSub >>= deliver - Just sub -> - readTVarIO sub >>= \case - Sub {subThread = ProhibitSub} -> do + Nothing -> newSub >>= deliver True + Just s@Sub {subThread} -> do + stats <- asks serverStats + case subThread of + ProhibitSub -> do -- cannot use SUB in the same connection where GET was used - atomically $ modifyTVar' (qSubProhibited stats) (+ 1) + incStat $ qSubProhibited stats pure (corrId, rId, ERR $ CMD PROHIBITED) - s -> do - atomically $ modifyTVar' (qSubDuplicate stats) (+ 1) - atomically (tryTakeTMVar $ delivered s) >> deliver sub + _ -> do + incStat $ qSubDuplicate stats + atomically (tryTakeTMVar $ delivered s) >> deliver False s where - newSub :: M (TVar Sub) + newSub :: M Sub newSub = time "SUB newSub" . atomically $ do - writeTQueue subscribedQ (rId, clnt) - sub <- newTVar =<< newSubscription NoSub + writeTQueue subscribedQ (rId, clnt, True) + sub <- newSubscription NoSub TM.insert rId sub subscriptions pure sub - deliver :: TVar Sub -> M (Transmission BrokerMsg) - deliver sub = do + deliver :: Bool -> Sub -> M (Transmission BrokerMsg) + deliver inc sub = do q <- getStoreMsgQueue "SUB" rId msg_ <- atomically $ tryPeekMsg q - deliverMessage "SUB" qr rId sub q msg_ + when inc $ do + stats <- asks serverStats + incStat $ (if isJust msg_ then qSub else qSubNoMsg) stats + atomically $ updatePeriodStats (subscribedQueues stats) rId + deliverMessage "SUB" qr rId sub msg_ getMessage :: QueueRec -> M (Transmission BrokerMsg) getMessage qr = time "GET" $ do - atomically (TM.lookup queueId subscriptions) >>= \case + atomically (TM.lookup entId subscriptions) >>= \case Nothing -> - atomically newSub >>= getMessage_ - Just sub -> - readTVarIO sub >>= \case - s@Sub {subThread = ProhibitSub} -> + atomically newSub >>= (`getMessage_` Nothing) + Just s@Sub {subThread} -> + case subThread of + ProhibitSub -> atomically (tryTakeTMVar $ delivered s) - >> getMessage_ s + >>= getMessage_ s -- cannot use GET in the same connection where there is an active subscription - _ -> pure (corrId, queueId, ERR $ CMD PROHIBITED) + _ -> do + stats <- asks serverStats + incStat $ msgGetProhibited stats + pure (corrId, entId, ERR $ CMD PROHIBITED) where newSub :: STM Sub newSub = do - s <- newSubscription ProhibitSub - sub <- newTVar s - TM.insert queueId sub subscriptions + s <- newProhibitedSub + TM.insert entId s subscriptions pure s - getMessage_ :: Sub -> M (Transmission BrokerMsg) - getMessage_ s = do - q <- getStoreMsgQueue "GET" queueId - atomically $ - tryPeekMsg q >>= \case - Just msg -> - let encMsg = encryptMsg qr msg - in setDelivered s msg $> (corrId, queueId, MSG encMsg) - _ -> pure (corrId, queueId, OK) + getMessage_ :: Sub -> Maybe MsgId -> M (Transmission BrokerMsg) + getMessage_ s delivered_ = do + q <- getStoreMsgQueue "GET" entId + stats <- asks serverStats + (statCnt, r) <- + atomically $ + tryPeekMsg q >>= \case + Just msg -> + let encMsg = encryptMsg qr msg + cnt = if isJust delivered_ then msgGetDuplicate else msgGet + in setDelivered s msg $> (cnt, (corrId, entId, MSG encMsg)) + _ -> pure (msgGetNoMsg, (corrId, entId, OK)) + incStat $ statCnt stats + pure r withQueue :: (QueueRec -> M (Transmission BrokerMsg)) -> M (Transmission BrokerMsg) withQueue action = maybe (pure $ err AUTH) action qr_ subscribeNotifications :: M (Transmission BrokerMsg) - subscribeNotifications = time "NSUB" . atomically $ do - unlessM (TM.member queueId ntfSubscriptions) $ do - writeTQueue ntfSubscribedQ (queueId, clnt) - TM.insert queueId () ntfSubscriptions + subscribeNotifications = do + statCount <- + time "NSUB" . atomically $ do + ifM + (TM.member entId ntfSubscriptions) + (pure ntfSubDuplicate) + (newSub $> ntfSub) + incStat . statCount =<< asks serverStats pure ok + where + newSub = do + writeTQueue ntfSubscribedQ (entId, clnt, True) + TM.insert entId () ntfSubscriptions acknowledgeMsg :: QueueRec -> MsgId -> M (Transmission BrokerMsg) acknowledgeMsg qr msgId = time "ACK" $ do - atomically (TM.lookup queueId subscriptions) >>= \case + liftIO (TM.lookupIO entId subscriptions) >>= \case Nothing -> pure $ err NO_MSG Just sub -> atomically (getDelivered sub) >>= \case - Just s -> do - q <- getStoreMsgQueue "ACK" queueId - case s of - Sub {subThread = ProhibitSub} -> do + Just st -> do + q <- getStoreMsgQueue "ACK" entId + case st of + ProhibitSub -> do deletedMsg_ <- atomically $ tryDelMsg q msgId - mapM_ updateStats deletedMsg_ + mapM_ (updateStats True) deletedMsg_ pure ok _ -> do (deletedMsg_, msg_) <- atomically $ tryDelPeekMsg q msgId - mapM_ updateStats deletedMsg_ - deliverMessage "ACK" qr queueId sub q msg_ + mapM_ (updateStats False) deletedMsg_ + deliverMessage "ACK" qr entId sub msg_ _ -> pure $ err NO_MSG where - getDelivered :: TVar Sub -> STM (Maybe Sub) - getDelivered sub = do - s@Sub {delivered} <- readTVar sub + getDelivered :: Sub -> STM (Maybe ServerSub) + getDelivered Sub {delivered, subThread} = do tryTakeTMVar delivered $>>= \msgId' -> if msgId == msgId' || B.null msgId - then pure $ Just s + then pure $ Just subThread else putTMVar delivered msgId' $> Nothing - updateStats :: Message -> M () - updateStats = \case + updateStats :: Bool -> Message -> M () + updateStats isGet = \case MessageQuota {} -> pure () Message {msgFlags} -> do stats <- asks serverStats - atomically $ modifyTVar' (msgRecv stats) (+ 1) + incStat $ msgRecv stats + when isGet $ incStat $ msgRecvGet stats atomically $ modifyTVar' (msgCount stats) (subtract 1) - atomically $ updatePeriodStats (activeQueues stats) queueId + atomically $ updatePeriodStats (activeQueues stats) entId when (notification msgFlags) $ do - atomically $ modifyTVar' (msgRecvNtf stats) (+ 1) - atomically $ updatePeriodStats (activeQueuesNtf stats) queueId + incStat $ msgRecvNtf stats + atomically $ updatePeriodStats (activeQueuesNtf stats) entId sendMessage :: QueueRec -> MsgFlags -> MsgBody -> M (Transmission BrokerMsg) sendMessage qr msgFlags msgBody | B.length msgBody > maxMessageLength thVersion = do stats <- asks serverStats - atomically $ modifyTVar' (msgSentLarge stats) (+ 1) + incStat $ msgSentLarge stats pure $ err LARGE_MSG | otherwise = do stats <- asks serverStats case status qr of QueueOff -> do - atomically $ modifyTVar' (msgSentAuth stats) (+ 1) + incStat $ msgSentAuth stats pure $ err AUTH QueueActive -> case C.maxLenBS msgBody of @@ -1022,23 +1194,24 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi atomically . writeMsg q =<< mkMessage body case msg_ of Nothing -> do - atomically $ modifyTVar' (msgSentQuota stats) (+ 1) + incStat $ msgSentQuota stats pure $ err QUOTA - Just msg -> time "SEND ok" $ do + Just (msg, wasEmpty) -> time "SEND ok" $ do + when wasEmpty $ tryDeliverMessage msg when (notification msgFlags) $ do forM_ (notifier qr) $ \ntf -> do asks random >>= atomically . trySendNotification ntf msg >>= \case Nothing -> do - atomically $ modifyTVar' (msgNtfNoSub stats) (+ 1) + incStat $ msgNtfNoSub stats logWarn "No notification subscription" Just False -> do - atomically $ modifyTVar' (msgNtfLost stats) (+ 1) + incStat $ msgNtfLost stats logWarn "Dropped message notification" - Just True -> atomically $ modifyTVar' (msgNtfs stats) (+ 1) - atomically $ modifyTVar' (msgSentNtf stats) (+ 1) + Just True -> incStat $ msgNtfs stats + incStat $ msgSentNtf stats atomically $ updatePeriodStats (activeQueuesNtf stats) (recipientId qr) - atomically $ modifyTVar' (msgSent stats) (+ 1) - atomically $ modifyTVar' (msgCount stats) (+ 1) + incStat $ msgSent stats + incStat $ msgCount stats atomically $ updatePeriodStats (activeQueues stats) (recipientId qr) pure ok where @@ -1058,6 +1231,54 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi stats <- asks serverStats atomically $ modifyTVar' (msgExpired stats) (+ deleted) + -- The condition for delivery of the message is: + -- - the queue was empty when the message was sent, + -- - there is subscribed recipient, + -- - no message was "delivered" that was not acknowledged. + -- If the send queue of the subscribed client is not full the message is put there in the same transaction. + -- If the queue is not full, then the thread is created where these checks are made: + -- - it is the same subscribed client (in case it was reconnected it would receive message via SUB command) + -- - nothing was delivered to this subscription (to avoid race conditions with the recipient). + tryDeliverMessage :: Message -> M () + tryDeliverMessage msg = atomically deliverToSub >>= mapM_ forkDeliver + where + rId = recipientId qr + deliverToSub = + TM.lookup rId subscribers + $>>= \rc@Client {subscriptions = subs, sndQ = q} -> TM.lookup rId subs + $>>= \s@Sub {subThread, delivered} -> case subThread of + ProhibitSub -> pure Nothing + ServerSub st -> readTVar st >>= \case + NoSub -> + tryTakeTMVar delivered >>= \case + Just _ -> pure Nothing -- if a message was already delivered, should not deliver more + Nothing -> + ifM + (isFullTBQueue q) + (writeTVar st SubPending $> Just (rc, s, st)) + (deliver q s $> Nothing) + _ -> pure Nothing + deliver q s = do + let encMsg = encryptMsg qr msg + writeTBQueue q [(CorrId "", rId, MSG encMsg)] + void $ setDelivered s msg + forkDeliver (rc@Client {sndQ = q}, s@Sub {delivered}, st) = do + t <- mkWeakThreadId =<< forkIO deliverThread + atomically . modifyTVar' st $ \case + -- this case is needed because deliverThread can exit before it + SubPending -> SubThread t + st' -> st' + where + deliverThread = do + labelMyThread $ B.unpack ("client $" <> encode sessionId) <> " deliver/SEND" + time "deliver" . atomically $ + whenM (maybe False (sameClientId rc) <$> TM.lookup rId subscribers) $ do + tryTakeTMVar delivered >>= \case + Just _ -> pure () -- if a message was already delivered, should not deliver more + Nothing -> do + deliver q s + writeTVar st NoSub + trySendNotification :: NtfCreds -> Message -> TVar ChaChaDRG -> STM (Maybe Bool) trySendNotification NtfCreds {notifierId, rcvNtfDhSecret} msg ntfNonceDrg = mapM (writeNtf notifierId msg rcvNtfDhSecret ntfNonceDrg) =<< TM.lookup notifierId notifiers @@ -1114,7 +1335,7 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi let fr = FwdResponse {fwdCorrId, fwdResponse = r2} r3 = EncFwdResponse $ C.cbEncryptNoPad sessSecret (C.reverseNonce proxyNonce) (smpEncode fr) stats <- asks serverStats - atomically $ modifyTVar' (pMsgFwdsRecv stats) (+ 1) + incStat $ pMsgFwdsRecv stats pure $ RRES r3 where rejectOrVerify :: Maybe (THandleAuth 'TServer) -> SignedTransmission ErrorType Cmd -> M (Either (Transmission BrokerMsg) (Maybe QueueRec, Transmission Cmd)) @@ -1132,38 +1353,20 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi verified = \case VRVerified qr -> Right (qr, (corrId', entId', cmd')) VRFailed -> Left (corrId', entId', ERR AUTH) - deliverMessage :: T.Text -> QueueRec -> RecipientId -> TVar Sub -> MsgQueue -> Maybe Message -> M (Transmission BrokerMsg) - deliverMessage name qr rId sub q msg_ = time (name <> " deliver") $ do - readTVarIO sub >>= \case - s@Sub {subThread = NoSub} -> - case msg_ of - Just msg -> - let encMsg = encryptMsg qr msg - in atomically (setDelivered s msg) $> (corrId, rId, MSG encMsg) - _ -> forkSub $> resp - _ -> pure resp + deliverMessage :: T.Text -> QueueRec -> RecipientId -> Sub -> Maybe Message -> M (Transmission BrokerMsg) + deliverMessage name qr rId s@Sub {subThread} msg_ = time (name <> " deliver") . atomically $ + case subThread of + ProhibitSub -> pure resp + _ -> case msg_ of + Just msg -> + let encMsg = encryptMsg qr msg + in setDelivered s msg $> (corrId, rId, MSG encMsg) + _ -> pure resp where resp = (corrId, rId, OK) - forkSub :: M () - forkSub = do - atomically . modifyTVar' sub $ \s -> s {subThread = SubPending} - t <- mkWeakThreadId =<< forkIO subscriber - atomically . modifyTVar' sub $ \case - s@Sub {subThread = SubPending} -> s {subThread = SubThread t} - s -> s - where - subscriber = do - labelMyThread $ B.unpack ("client $" <> encode sessionId) <> " subscriber/" <> T.unpack name - msg <- atomically $ peekMsg q - time "subscriber" . atomically $ do - let encMsg = encryptMsg qr msg - writeTBQueue sndQ [(CorrId "", rId, MSG encMsg)] - s <- readTVar sub - void $ setDelivered s msg - writeTVar sub $! s {subThread = NoSub} time :: T.Text -> M a -> M a - time name = timed name queueId + time name = timed name entId encryptMsg :: QueueRec -> Message -> RcvMessage encryptMsg qr msg = encrypt . encodeRcvMsgBody $ case msg of @@ -1186,37 +1389,44 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi delQueueAndMsgs :: QueueStore -> M (Transmission BrokerMsg) delQueueAndMsgs st = do - withLog (`logDeleteQueue` queueId) + withLog (`logDeleteQueue` entId) ms <- asks msgStore - atomically (deleteQueue st queueId $>>= \q -> delMsgQueue ms queueId $> Right q) >>= \case - Right q -> updateDeletedStats q $> ok + atomically (deleteQueue st entId $>>= \q -> delMsgQueue ms entId $> Right q) >>= \case + Right q -> do + -- Possibly, the same should be done if the queue is suspended, but currently we do not use it + atomically $ writeTQueue subscribedQ (entId, clnt, False) + atomically $ writeTQueue ntfSubscribedQ (entId, clnt, False) + updateDeletedStats q + pure ok Left e -> pure $ err e getQueueInfo :: QueueRec -> M (Transmission BrokerMsg) getQueueInfo QueueRec {senderKey, notifier} = do - q@MsgQueue {size} <- getStoreMsgQueue "getQueueInfo" queueId + q@MsgQueue {size} <- getStoreMsgQueue "getQueueInfo" entId info <- atomically $ do - qiSub <- TM.lookup queueId subscriptions >>= mapM mkQSub + qiSub <- TM.lookup entId subscriptions >>= mapM mkQSub qiSize <- readTVar size qiMsg <- toMsgInfo <$$> tryPeekMsg q pure QueueInfo {qiSnd = isJust senderKey, qiNtf = isJust notifier, qiSub, qiSize, qiMsg} - pure (corrId, queueId, INFO info) + pure (corrId, entId, INFO info) where - mkQSub sub = do - Sub {subThread, delivered} <- readTVar sub - let qSubThread = case subThread of + mkQSub Sub {subThread, delivered} = do + qSubThread <- case subThread of + ServerSub t -> do + st <- readTVar t + pure $ case st of NoSub -> QNoSub SubPending -> QSubPending SubThread _ -> QSubThread - ProhibitSub -> QProhibitSub + ProhibitSub -> pure QProhibitSub qDelivered <- decodeLatin1 . encode <$$> tryReadTMVar delivered pure QSub {qSubThread, qDelivered} ok :: Transmission BrokerMsg - ok = (corrId, queueId, OK) + ok = (corrId, entId, OK) err :: ErrorType -> Transmission BrokerMsg - err e = (corrId, queueId, ERR e) + err e = (corrId, entId, ERR e) okResp :: Either ErrorType () -> Transmission BrokerMsg okResp = either err $ const ok @@ -1225,9 +1435,13 @@ updateDeletedStats :: QueueRec -> M () updateDeletedStats q = do stats <- asks serverStats let delSel = if isNothing (senderKey q) then qDeletedNew else qDeletedSecured - atomically $ modifyTVar' (delSel stats) (+ 1) - atomically $ modifyTVar' (qDeletedAll stats) (+ 1) - atomically $ modifyTVar' (qCount stats) (subtract 1) + incStat $ delSel stats + incStat $ qDeletedAll stats + incStat $ qCount stats + +incStat :: MonadIO m => TVar Int -> m () +incStat v = atomically $ modifyTVar' v (+ 1) +{-# INLINE incStat #-} withLog :: (StoreLog 'WriteMode -> IO a) -> M () withLog action = do @@ -1256,13 +1470,16 @@ saveServerMessages keepMsgs = asks (storeMsgsFile . config) >>= mapM_ saveMessag logInfo $ "saving messages to file " <> T.pack f ms <- asks msgStore liftIO . withFile f WriteMode $ \h -> - readTVarIO ms >>= mapM_ (saveQueueMsgs ms h) . M.keys + readTVarIO ms >>= mapM_ (saveQueueMsgs h) . M.assocs logInfo "messages saved" where - getMessages = if keepMsgs then snapshotMsgQueue else flushMsgQueue - saveQueueMsgs ms h rId = - atomically (getMessages ms rId) - >>= mapM_ (B.hPutStrLn h . strEncode . MLRv3 rId) + saveQueueMsgs h (rId, q) = BLD.hPutBuilder h . encodeMessages rId =<< atomically (getMessages $ msgQueue q) + getMessages = if keepMsgs then snapshotTQueue else flushTQueue + snapshotTQueue q = do + msgs <- flushTQueue q + mapM_ (writeTQueue q) msgs + pure msgs + encodeMessages rId = mconcat . map (\msg -> BLD.byteString (strEncode $ MLRv3 rId msg) <> BLD.char8 '\n') restoreServerMessages :: M Int restoreServerMessages = @@ -1305,7 +1522,7 @@ restoreServerMessages = saveServerStats :: M () saveServerStats = asks (serverStatsBackupFile . config) - >>= mapM_ (\f -> asks serverStats >>= atomically . getServerStatsData >>= liftIO . saveStats f) + >>= mapM_ (\f -> asks serverStats >>= liftIO . getServerStatsData >>= liftIO . saveStats f) where saveStats f stats = do logInfo $ "saving server stats to file " <> T.pack f diff --git a/src/Simplex/Messaging/Server/Control.hs b/src/Simplex/Messaging/Server/Control.hs index 9463fa777..b4c74e4ac 100644 --- a/src/Simplex/Messaging/Server/Control.hs +++ b/src/Simplex/Messaging/Server/Control.hs @@ -9,6 +9,7 @@ import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol (BasicAuth) data CPClientRole = CPRNone | CPRUser | CPRAdmin + deriving (Eq) data ControlProtocol = CPAuth BasicAuth @@ -20,6 +21,7 @@ data ControlProtocol | CPThreads | CPSockets | CPSocketThreads + | CPServerInfo | CPDelete ByteString | CPSave | CPHelp @@ -37,6 +39,7 @@ instance StrEncoding ControlProtocol where CPThreads -> "threads" CPSockets -> "sockets" CPSocketThreads -> "socket-threads" + CPServerInfo -> "server-info" CPDelete bs -> "delete " <> strEncode bs CPSave -> "save" CPHelp -> "help" @@ -53,6 +56,7 @@ instance StrEncoding ControlProtocol where "threads" -> pure CPThreads "sockets" -> pure CPSockets "socket-threads" -> pure CPSocketThreads + "server-info" -> pure CPServerInfo "delete" -> CPDelete <$> (A.space *> strP) "save" -> pure CPSave "help" -> pure CPHelp diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index 4217ea9b9..84e664607 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -1,13 +1,15 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StrictData #-} module Simplex.Messaging.Server.Env.STM where import Control.Concurrent (ThreadId) -import Control.Monad.IO.Unlift +import Control.Logger.Simple +import Control.Monad import Crypto.Random import Data.ByteString.Char8 (ByteString) import Data.Int (Int64) @@ -17,6 +19,7 @@ import Data.List.NonEmpty (NonEmpty) import Data.Map.Strict (Map) import qualified Data.Map.Strict as M import Data.Maybe (isJust, isNothing) +import qualified Data.Text as T import Data.Time.Clock (getCurrentTime) import Data.Time.Clock.System (SystemTime) import Data.X509.Validation (Fingerprint (..)) @@ -47,7 +50,6 @@ data ServerConfig = ServerConfig { transports :: [(ServiceName, ATransport)], smpHandshakeTimeout :: Int, tbqSize :: Natural, - -- serverTbqSize :: Natural, msgQueueQuota :: Int, queueIdBytes :: Int, msgIdBytes :: Int, @@ -105,7 +107,7 @@ defaultMessageExpiration = defaultInactiveClientExpiration :: ExpirationConfig defaultInactiveClientExpiration = ExpirationConfig - { ttl = 43200, -- seconds, 12 hours + { ttl = 21600, -- seconds, 6 hours checkInterval = 3600 -- seconds, 1 hours } @@ -129,10 +131,12 @@ data Env = Env proxyAgent :: ProxyAgent -- senders served on this proxy } +type Subscribed = Bool + data Server = Server - { subscribedQ :: TQueue (RecipientId, Client), + { subscribedQ :: TQueue (RecipientId, Client, Subscribed), subscribers :: TMap RecipientId Client, - ntfSubscribedQ :: TQueue (NotifierId, Client), + ntfSubscribedQ :: TQueue (NotifierId, Client, Subscribed), notifiers :: TMap NotifierId Client, savingLock :: Lock } @@ -145,7 +149,7 @@ type ClientId = Int data Client = Client { clientId :: ClientId, - subscriptions :: TMap RecipientId (TVar Sub), + subscriptions :: TMap RecipientId Sub, ntfSubscriptions :: TMap NotifierId (), rcvQ :: TBQueue (NonEmpty (Maybe QueueRec, Transmission Cmd)), sndQ :: TBQueue (NonEmpty (Transmission BrokerMsg)), @@ -161,67 +165,77 @@ data Client = Client sndActiveAt :: TVar SystemTime } -data SubscriptionThread = NoSub | SubPending | SubThread (Weak ThreadId) | ProhibitSub +data ServerSub = ServerSub (TVar SubscriptionThread) | ProhibitSub + +data SubscriptionThread = NoSub | SubPending | SubThread (Weak ThreadId) data Sub = Sub - { subThread :: SubscriptionThread, + { subThread :: ServerSub, -- Nothing value indicates that sub delivered :: TMVar MsgId } -newServer :: STM Server +newServer :: IO Server newServer = do - subscribedQ <- newTQueue - subscribers <- TM.empty - ntfSubscribedQ <- newTQueue - notifiers <- TM.empty - savingLock <- createLock + subscribedQ <- newTQueueIO + subscribers <- TM.emptyIO + ntfSubscribedQ <- newTQueueIO + notifiers <- TM.emptyIO + savingLock <- atomically createLock return Server {subscribedQ, subscribers, ntfSubscribedQ, notifiers, savingLock} -newClient :: TVar ClientId -> Natural -> VersionSMP -> ByteString -> SystemTime -> STM Client +newClient :: TVar ClientId -> Natural -> VersionSMP -> ByteString -> SystemTime -> IO Client newClient nextClientId qSize thVersion sessionId createdAt = do - clientId <- stateTVar nextClientId $ \next -> (next, next + 1) - subscriptions <- TM.empty - ntfSubscriptions <- TM.empty - rcvQ <- newTBQueue qSize - sndQ <- newTBQueue qSize - msgQ <- newTBQueue qSize - procThreads <- newTVar 0 - endThreads <- newTVar IM.empty - endThreadSeq <- newTVar 0 - connected <- newTVar True - rcvActiveAt <- newTVar createdAt - sndActiveAt <- newTVar createdAt + clientId <- atomically $ stateTVar nextClientId $ \next -> (next, next + 1) + subscriptions <- TM.emptyIO + ntfSubscriptions <- TM.emptyIO + rcvQ <- newTBQueueIO qSize + sndQ <- newTBQueueIO qSize + msgQ <- newTBQueueIO qSize + procThreads <- newTVarIO 0 + endThreads <- newTVarIO IM.empty + endThreadSeq <- newTVarIO 0 + connected <- newTVarIO True + rcvActiveAt <- newTVarIO createdAt + sndActiveAt <- newTVarIO createdAt return Client {clientId, subscriptions, ntfSubscriptions, rcvQ, sndQ, msgQ, procThreads, endThreads, endThreadSeq, thVersion, sessionId, connected, createdAt, rcvActiveAt, sndActiveAt} newSubscription :: SubscriptionThread -> STM Sub -newSubscription subThread = do +newSubscription st = do delivered <- newEmptyTMVar + subThread <- ServerSub <$> newTVar st return Sub {subThread, delivered} +newProhibitedSub :: STM Sub +newProhibitedSub = do + delivered <- newEmptyTMVar + return Sub {subThread = ProhibitSub, delivered} + newEnv :: ServerConfig -> IO Env newEnv config@ServerConfig {caCertificateFile, certificateFile, privateKeyFile, storeLogFile, smpAgentCfg, transportConfig, information, messageExpiration} = do - server <- atomically newServer - queueStore <- atomically newQueueStore - msgStore <- atomically newMsgStore - random <- liftIO C.newRandom - storeLog <- restoreQueues queueStore `mapM` storeLogFile + server <- newServer + queueStore <- newQueueStore + msgStore <- newMsgStore + random <- C.newRandom + storeLog <- + forM storeLogFile $ \f -> do + logInfo $ "restoring queues from file " <> T.pack f + restoreQueues queueStore f tlsServerParams <- loadTLSServerParams caCertificateFile certificateFile privateKeyFile (alpn transportConfig) Fingerprint fp <- loadFingerprint caCertificateFile let serverIdentity = KeyHash fp - serverStats <- atomically . newServerStats =<< getCurrentTime - sockets <- atomically newSocketState + serverStats <- newServerStats =<< getCurrentTime + sockets <- newSocketState clientSeq <- newTVarIO 0 clients <- newTVarIO mempty - proxyAgent <- atomically $ newSMPProxyAgent smpAgentCfg random + proxyAgent <- newSMPProxyAgent smpAgentCfg random pure Env {config, serverInfo, server, serverIdentity, queueStore, msgStore, random, storeLog, tlsServerParams, serverStats, sockets, clientSeq, clients, proxyAgent} where restoreQueues :: QueueStore -> FilePath -> IO (StoreLog 'WriteMode) restoreQueues QueueStore {queues, senders, notifiers} f = do (qs, s) <- readWriteStoreLog f - atomically $ do - writeTVar queues =<< mapM newTVar qs - writeTVar senders $! M.foldr' addSender M.empty qs - writeTVar notifiers $! M.foldr' addNotifier M.empty qs + atomically . writeTVar queues =<< mapM newTVarIO qs + atomically $ writeTVar senders $! M.foldr' addSender M.empty qs + atomically $ writeTVar notifiers $! M.foldr' addNotifier M.empty qs pure s addSender :: QueueRec -> Map SenderId RecipientId -> Map SenderId RecipientId addSender q = M.insert (senderId q) (recipientId q) @@ -247,7 +261,7 @@ newEnv config@ServerConfig {caCertificateFile, certificateFile, privateKeyFile, | isJust (storeMsgsFile config) = SPMMessages | otherwise = SPMQueues -newSMPProxyAgent :: SMPClientAgentConfig -> TVar ChaChaDRG -> STM ProxyAgent +newSMPProxyAgent :: SMPClientAgentConfig -> TVar ChaChaDRG -> IO ProxyAgent newSMPProxyAgent smpAgentCfg random = do smpAgent <- newSMPClientAgent smpAgentCfg random pure ProxyAgent {smpAgent} diff --git a/src/Simplex/Messaging/Server/Main.hs b/src/Simplex/Messaging/Server/Main.hs index 7af57ba25..784d0504a 100644 --- a/src/Simplex/Messaging/Server/Main.hs +++ b/src/Simplex/Messaging/Server/Main.hs @@ -255,7 +255,6 @@ smpServerCLI_ generateSite serveStaticFiles cfgPath logPath = { transports = iniTransports ini, smpHandshakeTimeout = 120000000, tbqSize = 64, - -- serverTbqSize = 1024, msgQueueQuota = 128, queueIdBytes = 24, msgIdBytes = 24, -- must be at least 24 bytes, it is used as 192-bit nonce for XSalsa20 @@ -306,7 +305,7 @@ smpServerCLI_ generateSite serveStaticFiles cfgPath logPath = networkConfig = defaultNetworkConfig { socksProxy = either error id <$!> strDecodeIni "PROXY" "socks_proxy" ini, - socksMode = either (const SMOnion) textToSocksMode $ lookupValue "PROXY" "socks_mode" ini, + socksMode = maybe SMOnion (either error id) $! strDecodeIni "PROXY" "socks_mode" ini, hostMode = either (const HMPublic) textToHostMode $ lookupValue "PROXY" "host_mode" ini, requiredHostMode = fromMaybe False $ iniOnOff "PROXY" "required_host_mode" ini } @@ -318,11 +317,6 @@ smpServerCLI_ generateSite serveStaticFiles cfgPath logPath = serverClientConcurrency = readIniDefault defaultProxyClientConcurrency "PROXY" "client_concurrency" ini, information = serverPublicInfo ini } - textToSocksMode :: Text -> SocksMode - textToSocksMode = \case - "always" -> SMAlways - "onion" -> SMOnion - s -> error . T.unpack $ "Invalid socks_mode: " <> s textToHostMode :: Text -> HostMode textToHostMode = \case "public" -> HMPublic diff --git a/src/Simplex/Messaging/Server/MsgStore/STM.hs b/src/Simplex/Messaging/Server/MsgStore/STM.hs index e315c4fe5..e0a5c8b45 100644 --- a/src/Simplex/Messaging/Server/MsgStore/STM.hs +++ b/src/Simplex/Messaging/Server/MsgStore/STM.hs @@ -14,8 +14,6 @@ module Simplex.Messaging.Server.MsgStore.STM getMsgQueue, delMsgQueue, delMsgQueueSize, - flushMsgQueue, - snapshotMsgQueue, writeMsg, tryPeekMsg, peekMsg, @@ -25,7 +23,6 @@ module Simplex.Messaging.Server.MsgStore.STM ) where -import Control.Concurrent.STM.TQueue (flushTQueue) import qualified Data.ByteString.Char8 as B import Data.Functor (($>)) import Data.Int (Int64) @@ -44,8 +41,8 @@ data MsgQueue = MsgQueue type STMMsgStore = TMap RecipientId MsgQueue -newMsgStore :: STM STMMsgStore -newMsgStore = TM.empty +newMsgStore :: IO STMMsgStore +newMsgStore = TM.emptyIO getMsgQueue :: STMMsgStore -> RecipientId -> Int -> STM MsgQueue getMsgQueue st rId quota = maybe newQ pure =<< TM.lookup rId st @@ -64,18 +61,7 @@ delMsgQueue st rId = TM.delete rId st delMsgQueueSize :: STMMsgStore -> RecipientId -> STM Int delMsgQueueSize st rId = TM.lookupDelete rId st >>= maybe (pure 0) (\MsgQueue {size} -> readTVar size) -flushMsgQueue :: STMMsgStore -> RecipientId -> STM [Message] -flushMsgQueue st rId = TM.lookupDelete rId st >>= maybe (pure []) (flushTQueue . msgQueue) - -snapshotMsgQueue :: STMMsgStore -> RecipientId -> STM [Message] -snapshotMsgQueue st rId = TM.lookup rId st >>= maybe (pure []) (snapshotTQueue . msgQueue) - where - snapshotTQueue q = do - msgs <- flushTQueue q - mapM_ (writeTQueue q) msgs - pure msgs - -writeMsg :: MsgQueue -> Message -> STM (Maybe Message) +writeMsg :: MsgQueue -> Message -> STM (Maybe (Message, Bool)) writeMsg MsgQueue {msgQueue = q, quota, canWrite, size} !msg = do canWrt <- readTVar canWrite empty <- isEmptyTQueue q @@ -85,7 +71,7 @@ writeMsg MsgQueue {msgQueue = q, quota, canWrite, size} !msg = do writeTVar canWrite $! canWrt' modifyTVar' size (+ 1) if canWrt' - then writeTQueue q msg $> Just msg + then writeTQueue q msg $> Just (msg, empty) else (writeTQueue q $! msgQuota) $> Nothing else pure Nothing where diff --git a/src/Simplex/Messaging/Server/QueueStore/STM.hs b/src/Simplex/Messaging/Server/QueueStore/STM.hs index d6cdaf10a..50907cf9a 100644 --- a/src/Simplex/Messaging/Server/QueueStore/STM.hs +++ b/src/Simplex/Messaging/Server/QueueStore/STM.hs @@ -38,11 +38,11 @@ data QueueStore = QueueStore notifiers :: TMap NotifierId RecipientId } -newQueueStore :: STM QueueStore +newQueueStore :: IO QueueStore newQueueStore = do - queues <- TM.empty - senders <- TM.empty - notifiers <- TM.empty + queues <- TM.emptyIO + senders <- TM.emptyIO + notifiers <- TM.emptyIO pure QueueStore {queues, senders, notifiers} addQueue :: QueueStore -> QueueRec -> STM (Either ErrorType ()) diff --git a/src/Simplex/Messaging/Server/Stats.hs b/src/Simplex/Messaging/Server/Stats.hs index f2716c9c3..f5b430bb6 100644 --- a/src/Simplex/Messaging/Server/Stats.hs +++ b/src/Simplex/Messaging/Server/Stats.hs @@ -27,16 +27,29 @@ data ServerStats = ServerStats qDeletedNew :: TVar Int, qDeletedSecured :: TVar Int, qSub :: TVar Int, + qSubNoMsg :: TVar Int, qSubAuth :: TVar Int, qSubDuplicate :: TVar Int, qSubProhibited :: TVar Int, + ntfCreated :: TVar Int, + ntfDeleted :: TVar Int, + ntfSub :: TVar Int, + ntfSubAuth :: TVar Int, + ntfSubDuplicate :: TVar Int, msgSent :: TVar Int, msgSentAuth :: TVar Int, msgSentQuota :: TVar Int, msgSentLarge :: TVar Int, msgRecv :: TVar Int, + msgRecvGet :: TVar Int, + msgGet :: TVar Int, + msgGetNoMsg :: TVar Int, + msgGetAuth :: TVar Int, + msgGetDuplicate :: TVar Int, + msgGetProhibited :: TVar Int, msgExpired :: TVar Int, activeQueues :: PeriodStats RecipientId, + subscribedQueues :: PeriodStats RecipientId, msgSentNtf :: TVar Int, -- sent messages with NTF flag msgRecvNtf :: TVar Int, -- received messages with NTF flag activeQueuesNtf :: PeriodStats RecipientId, @@ -60,16 +73,29 @@ data ServerStatsData = ServerStatsData _qDeletedNew :: Int, _qDeletedSecured :: Int, _qSub :: Int, + _qSubNoMsg :: Int, _qSubAuth :: Int, _qSubDuplicate :: Int, _qSubProhibited :: Int, + _ntfCreated :: Int, + _ntfDeleted :: Int, + _ntfSub :: Int, + _ntfSubAuth :: Int, + _ntfSubDuplicate :: Int, _msgSent :: Int, _msgSentAuth :: Int, _msgSentQuota :: Int, _msgSentLarge :: Int, _msgRecv :: Int, + _msgRecvGet :: Int, + _msgGet :: Int, + _msgGetNoMsg :: Int, + _msgGetAuth :: Int, + _msgGetDuplicate :: Int, + _msgGetProhibited :: Int, _msgExpired :: Int, _activeQueues :: PeriodStatsData RecipientId, + _subscribedQueues :: PeriodStatsData RecipientId, _msgSentNtf :: Int, _msgRecvNtf :: Int, _activeQueuesNtf :: PeriodStatsData RecipientId, @@ -86,38 +112,51 @@ data ServerStatsData = ServerStatsData } deriving (Show) -newServerStats :: UTCTime -> STM ServerStats +newServerStats :: UTCTime -> IO ServerStats newServerStats ts = do - fromTime <- newTVar ts - qCreated <- newTVar 0 - qSecured <- newTVar 0 - qDeletedAll <- newTVar 0 - qDeletedNew <- newTVar 0 - qDeletedSecured <- newTVar 0 - qSub <- newTVar 0 - qSubAuth <- newTVar 0 - qSubDuplicate <- newTVar 0 - qSubProhibited <- newTVar 0 - msgSent <- newTVar 0 - msgSentAuth <- newTVar 0 - msgSentQuota <- newTVar 0 - msgSentLarge <- newTVar 0 - msgRecv <- newTVar 0 - msgExpired <- newTVar 0 + fromTime <- newTVarIO ts + qCreated <- newTVarIO 0 + qSecured <- newTVarIO 0 + qDeletedAll <- newTVarIO 0 + qDeletedNew <- newTVarIO 0 + qDeletedSecured <- newTVarIO 0 + qSub <- newTVarIO 0 + qSubNoMsg <- newTVarIO 0 + qSubAuth <- newTVarIO 0 + qSubDuplicate <- newTVarIO 0 + qSubProhibited <- newTVarIO 0 + ntfCreated <- newTVarIO 0 + ntfDeleted <- newTVarIO 0 + ntfSub <- newTVarIO 0 + ntfSubAuth <- newTVarIO 0 + ntfSubDuplicate <- newTVarIO 0 + msgSent <- newTVarIO 0 + msgSentAuth <- newTVarIO 0 + msgSentQuota <- newTVarIO 0 + msgSentLarge <- newTVarIO 0 + msgRecv <- newTVarIO 0 + msgRecvGet <- newTVarIO 0 + msgGet <- newTVarIO 0 + msgGetNoMsg <- newTVarIO 0 + msgGetAuth <- newTVarIO 0 + msgGetDuplicate <- newTVarIO 0 + msgGetProhibited <- newTVarIO 0 + msgExpired <- newTVarIO 0 activeQueues <- newPeriodStats - msgSentNtf <- newTVar 0 - msgRecvNtf <- newTVar 0 + subscribedQueues <- newPeriodStats + msgSentNtf <- newTVarIO 0 + msgRecvNtf <- newTVarIO 0 activeQueuesNtf <- newPeriodStats - msgNtfs <- newTVar 0 - msgNtfNoSub <- newTVar 0 - msgNtfLost <- newTVar 0 + msgNtfs <- newTVarIO 0 + msgNtfNoSub <- newTVarIO 0 + msgNtfLost <- newTVarIO 0 pRelays <- newProxyStats pRelaysOwn <- newProxyStats pMsgFwds <- newProxyStats pMsgFwdsOwn <- newProxyStats - pMsgFwdsRecv <- newTVar 0 - qCount <- newTVar 0 - msgCount <- newTVar 0 + pMsgFwdsRecv <- newTVarIO 0 + qCount <- newTVarIO 0 + msgCount <- newTVarIO 0 pure ServerStats { fromTime, @@ -127,16 +166,29 @@ newServerStats ts = do qDeletedNew, qDeletedSecured, qSub, + qSubNoMsg, qSubAuth, qSubDuplicate, qSubProhibited, + ntfCreated, + ntfDeleted, + ntfSub, + ntfSubAuth, + ntfSubDuplicate, msgSent, msgSentAuth, msgSentQuota, msgSentLarge, msgRecv, + msgRecvGet, + msgGet, + msgGetNoMsg, + msgGetAuth, + msgGetDuplicate, + msgGetProhibited, msgExpired, activeQueues, + subscribedQueues, msgSentNtf, msgRecvNtf, activeQueuesNtf, @@ -152,38 +204,51 @@ newServerStats ts = do msgCount } -getServerStatsData :: ServerStats -> STM ServerStatsData +getServerStatsData :: ServerStats -> IO ServerStatsData getServerStatsData s = do - _fromTime <- readTVar $ fromTime s - _qCreated <- readTVar $ qCreated s - _qSecured <- readTVar $ qSecured s - _qDeletedAll <- readTVar $ qDeletedAll s - _qDeletedNew <- readTVar $ qDeletedNew s - _qDeletedSecured <- readTVar $ qDeletedSecured s - _qSub <- readTVar $ qSub s - _qSubAuth <- readTVar $ qSubAuth s - _qSubDuplicate <- readTVar $ qSubDuplicate s - _qSubProhibited <- readTVar $ qSubProhibited s - _msgSent <- readTVar $ msgSent s - _msgSentAuth <- readTVar $ msgSentAuth s - _msgSentQuota <- readTVar $ msgSentQuota s - _msgSentLarge <- readTVar $ msgSentLarge s - _msgRecv <- readTVar $ msgRecv s - _msgExpired <- readTVar $ msgExpired s + _fromTime <- readTVarIO $ fromTime s + _qCreated <- readTVarIO $ qCreated s + _qSecured <- readTVarIO $ qSecured s + _qDeletedAll <- readTVarIO $ qDeletedAll s + _qDeletedNew <- readTVarIO $ qDeletedNew s + _qDeletedSecured <- readTVarIO $ qDeletedSecured s + _qSub <- readTVarIO $ qSub s + _qSubNoMsg <- readTVarIO $ qSubNoMsg s + _qSubAuth <- readTVarIO $ qSubAuth s + _qSubDuplicate <- readTVarIO $ qSubDuplicate s + _qSubProhibited <- readTVarIO $ qSubProhibited s + _ntfCreated <- readTVarIO $ ntfCreated s + _ntfDeleted <- readTVarIO $ ntfDeleted s + _ntfSub <- readTVarIO $ ntfSub s + _ntfSubAuth <- readTVarIO $ ntfSubAuth s + _ntfSubDuplicate <- readTVarIO $ ntfSubDuplicate s + _msgSent <- readTVarIO $ msgSent s + _msgSentAuth <- readTVarIO $ msgSentAuth s + _msgSentQuota <- readTVarIO $ msgSentQuota s + _msgSentLarge <- readTVarIO $ msgSentLarge s + _msgRecv <- readTVarIO $ msgRecv s + _msgRecvGet <- readTVarIO $ msgRecvGet s + _msgGet <- readTVarIO $ msgGet s + _msgGetNoMsg <- readTVarIO $ msgGetNoMsg s + _msgGetAuth <- readTVarIO $ msgGetAuth s + _msgGetDuplicate <- readTVarIO $ msgGetDuplicate s + _msgGetProhibited <- readTVarIO $ msgGetProhibited s + _msgExpired <- readTVarIO $ msgExpired s _activeQueues <- getPeriodStatsData $ activeQueues s - _msgSentNtf <- readTVar $ msgSentNtf s - _msgRecvNtf <- readTVar $ msgRecvNtf s + _subscribedQueues <- getPeriodStatsData $ subscribedQueues s + _msgSentNtf <- readTVarIO $ msgSentNtf s + _msgRecvNtf <- readTVarIO $ msgRecvNtf s _activeQueuesNtf <- getPeriodStatsData $ activeQueuesNtf s - _msgNtfs <- readTVar $ msgNtfs s - _msgNtfNoSub <- readTVar $ msgNtfNoSub s - _msgNtfLost <- readTVar $ msgNtfLost s + _msgNtfs <- readTVarIO $ msgNtfs s + _msgNtfNoSub <- readTVarIO $ msgNtfNoSub s + _msgNtfLost <- readTVarIO $ msgNtfLost s _pRelays <- getProxyStatsData $ pRelays s _pRelaysOwn <- getProxyStatsData $ pRelaysOwn s _pMsgFwds <- getProxyStatsData $ pMsgFwds s _pMsgFwdsOwn <- getProxyStatsData $ pMsgFwdsOwn s - _pMsgFwdsRecv <- readTVar $ pMsgFwdsRecv s - _qCount <- readTVar $ qCount s - _msgCount <- readTVar $ msgCount s + _pMsgFwdsRecv <- readTVarIO $ pMsgFwdsRecv s + _qCount <- readTVarIO $ qCount s + _msgCount <- readTVarIO $ msgCount s pure ServerStatsData { _fromTime, @@ -193,16 +258,29 @@ getServerStatsData s = do _qDeletedNew, _qDeletedSecured, _qSub, + _qSubNoMsg, _qSubAuth, _qSubDuplicate, _qSubProhibited, + _ntfCreated, + _ntfDeleted, + _ntfSub, + _ntfSubAuth, + _ntfSubDuplicate, _msgSent, _msgSentAuth, _msgSentQuota, _msgSentLarge, _msgRecv, + _msgRecvGet, + _msgGet, + _msgGetNoMsg, + _msgGetAuth, + _msgGetDuplicate, + _msgGetProhibited, _msgExpired, _activeQueues, + _subscribedQueues, _msgSentNtf, _msgRecvNtf, _activeQueuesNtf, @@ -227,16 +305,29 @@ setServerStats s d = do writeTVar (qDeletedNew s) $! _qDeletedNew d writeTVar (qDeletedSecured s) $! _qDeletedSecured d writeTVar (qSub s) $! _qSub d + writeTVar (qSubNoMsg s) $! _qSubNoMsg d writeTVar (qSubAuth s) $! _qSubAuth d writeTVar (qSubDuplicate s) $! _qSubDuplicate d writeTVar (qSubProhibited s) $! _qSubProhibited d + writeTVar (ntfCreated s) $! _ntfCreated d + writeTVar (ntfDeleted s) $! _ntfDeleted d + writeTVar (ntfSub s) $! _ntfSub d + writeTVar (ntfSubAuth s) $! _ntfSubAuth d + writeTVar (ntfSubDuplicate s) $! _ntfSubDuplicate d writeTVar (msgSent s) $! _msgSent d writeTVar (msgSentAuth s) $! _msgSentAuth d writeTVar (msgSentQuota s) $! _msgSentQuota d writeTVar (msgSentLarge s) $! _msgSentLarge d writeTVar (msgRecv s) $! _msgRecv d + writeTVar (msgRecvGet s) $! _msgRecvGet d + writeTVar (msgGet s) $! _msgGet d + writeTVar (msgGetNoMsg s) $! _msgGetNoMsg d + writeTVar (msgGetAuth s) $! _msgGetAuth d + writeTVar (msgGetDuplicate s) $! _msgGetDuplicate d + writeTVar (msgGetProhibited s) $! _msgGetProhibited d writeTVar (msgExpired s) $! _msgExpired d setPeriodStats (activeQueues s) (_activeQueues d) + setPeriodStats (subscribedQueues s) (_subscribedQueues d) writeTVar (msgSentNtf s) $! _msgSentNtf d writeTVar (msgRecvNtf s) $! _msgRecvNtf d setPeriodStats (activeQueuesNtf s) (_activeQueuesNtf d) @@ -262,14 +353,26 @@ instance StrEncoding ServerStatsData where "qDeletedSecured=" <> strEncode (_qDeletedSecured d), "qCount=" <> strEncode (_qCount d), "qSub=" <> strEncode (_qSub d), + "qSubNoMsg=" <> strEncode (_qSubNoMsg d), "qSubAuth=" <> strEncode (_qSubAuth d), "qSubDuplicate=" <> strEncode (_qSubDuplicate d), "qSubProhibited=" <> strEncode (_qSubProhibited d), + "ntfCreated=" <> strEncode (_ntfCreated d), + "ntfDeleted=" <> strEncode (_ntfDeleted d), + "ntfSub=" <> strEncode (_ntfSub d), + "ntfSubAuth=" <> strEncode (_ntfSubAuth d), + "ntfSubDuplicate=" <> strEncode (_ntfSubDuplicate d), "msgSent=" <> strEncode (_msgSent d), "msgSentAuth=" <> strEncode (_msgSentAuth d), "msgSentQuota=" <> strEncode (_msgSentQuota d), "msgSentLarge=" <> strEncode (_msgSentLarge d), "msgRecv=" <> strEncode (_msgRecv d), + "msgRecvGet=" <> strEncode (_msgRecvGet d), + "msgGet=" <> strEncode (_msgGet d), + "msgGetNoMsg=" <> strEncode (_msgGetNoMsg d), + "msgGetAuth=" <> strEncode (_msgGetAuth d), + "msgGetDuplicate=" <> strEncode (_msgGetDuplicate d), + "msgGetProhibited=" <> strEncode (_msgGetProhibited d), "msgExpired=" <> strEncode (_msgExpired d), "msgSentNtf=" <> strEncode (_msgSentNtf d), "msgRecvNtf=" <> strEncode (_msgRecvNtf d), @@ -278,6 +381,8 @@ instance StrEncoding ServerStatsData where "msgNtfLost=" <> strEncode (_msgNtfLost d), "activeQueues:", strEncode (_activeQueues d), + "subscribedQueues:", + strEncode (_subscribedQueues d), "activeQueuesNtf:", strEncode (_activeQueuesNtf d), "pRelays:", @@ -299,14 +404,26 @@ instance StrEncoding ServerStatsData where <|> ((,,) <$> ("qDeletedAll=" *> strP <* A.endOfLine) <*> ("qDeletedNew=" *> strP <* A.endOfLine) <*> ("qDeletedSecured=" *> strP <* A.endOfLine)) _qCount <- opt "qCount=" _qSub <- opt "qSub=" + _qSubNoMsg <- opt "qSubNoMsg=" _qSubAuth <- opt "qSubAuth=" _qSubDuplicate <- opt "qSubDuplicate=" _qSubProhibited <- opt "qSubProhibited=" + _ntfCreated <- opt "ntfCreated=" + _ntfDeleted <- opt "ntfDeleted=" + _ntfSub <- opt "ntfSub=" + _ntfSubAuth <- opt "ntfSubAuth=" + _ntfSubDuplicate <- opt "ntfSubDuplicate=" _msgSent <- "msgSent=" *> strP <* A.endOfLine _msgSentAuth <- opt "msgSentAuth=" _msgSentQuota <- opt "msgSentQuota=" _msgSentLarge <- opt "msgSentLarge=" _msgRecv <- "msgRecv=" *> strP <* A.endOfLine + _msgRecvGet <- opt "msgRecvGet=" + _msgGet <- opt "msgGet=" + _msgGetNoMsg <- opt "msgGetNoMsg=" + _msgGetAuth <- opt "msgGetAuth=" + _msgGetDuplicate <- opt "msgGetDuplicate=" + _msgGetProhibited <- opt "msgGetProhibited=" _msgExpired <- opt "msgExpired=" _msgSentNtf <- opt "msgSentNtf=" _msgRecvNtf <- opt "msgRecvNtf=" @@ -321,6 +438,10 @@ instance StrEncoding ServerStatsData where _week <- "weekMsgQueues=" *> strP <* A.endOfLine _month <- "monthMsgQueues=" *> strP <* optional A.endOfLine pure PeriodStatsData {_day, _week, _month} + _subscribedQueues <- + optional ("subscribedQueues:" <* A.endOfLine) >>= \case + Just _ -> strP <* optional A.endOfLine + _ -> pure newPeriodStatsData _activeQueuesNtf <- optional ("activeQueuesNtf:" <* A.endOfLine) >>= \case Just _ -> strP <* optional A.endOfLine @@ -339,14 +460,26 @@ instance StrEncoding ServerStatsData where _qDeletedNew, _qDeletedSecured, _qSub, + _qSubNoMsg, _qSubAuth, _qSubDuplicate, _qSubProhibited, + _ntfCreated, + _ntfDeleted, + _ntfSub, + _ntfSubAuth, + _ntfSubDuplicate, _msgSent, _msgSentAuth, _msgSentQuota, _msgSentLarge, _msgRecv, + _msgRecvGet, + _msgGet, + _msgGetNoMsg, + _msgGetAuth, + _msgGetDuplicate, + _msgGetProhibited, _msgExpired, _msgSentNtf, _msgRecvNtf, @@ -354,6 +487,7 @@ instance StrEncoding ServerStatsData where _msgNtfNoSub, _msgNtfLost, _activeQueues, + _subscribedQueues, _activeQueuesNtf, _pRelays, _pRelaysOwn, @@ -376,11 +510,11 @@ data PeriodStats a = PeriodStats month :: TVar (Set a) } -newPeriodStats :: STM (PeriodStats a) +newPeriodStats :: IO (PeriodStats a) newPeriodStats = do - day <- newTVar S.empty - week <- newTVar S.empty - month <- newTVar S.empty + day <- newTVarIO S.empty + week <- newTVarIO S.empty + month <- newTVarIO S.empty pure PeriodStats {day, week, month} data PeriodStatsData a = PeriodStatsData @@ -393,11 +527,11 @@ data PeriodStatsData a = PeriodStatsData newPeriodStatsData :: PeriodStatsData a newPeriodStatsData = PeriodStatsData {_day = S.empty, _week = S.empty, _month = S.empty} -getPeriodStatsData :: PeriodStats a -> STM (PeriodStatsData a) +getPeriodStatsData :: PeriodStats a -> IO (PeriodStatsData a) getPeriodStatsData s = do - _day <- readTVar $ day s - _week <- readTVar $ week s - _month <- readTVar $ month s + _day <- readTVarIO $ day s + _week <- readTVarIO $ week s + _month <- readTVarIO $ month s pure PeriodStatsData {_day, _week, _month} setPeriodStats :: PeriodStats a -> PeriodStatsData a -> STM () @@ -451,13 +585,13 @@ data ProxyStats = ProxyStats pErrorsOther :: TVar Int } -newProxyStats :: STM ProxyStats +newProxyStats :: IO ProxyStats newProxyStats = do - pRequests <- newTVar 0 - pSuccesses <- newTVar 0 - pErrorsConnect <- newTVar 0 - pErrorsCompat <- newTVar 0 - pErrorsOther <- newTVar 0 + pRequests <- newTVarIO 0 + pSuccesses <- newTVarIO 0 + pErrorsConnect <- newTVarIO 0 + pErrorsCompat <- newTVarIO 0 + pErrorsOther <- newTVarIO 0 pure ProxyStats {pRequests, pSuccesses, pErrorsConnect, pErrorsCompat, pErrorsOther} data ProxyStatsData = ProxyStatsData @@ -472,13 +606,13 @@ data ProxyStatsData = ProxyStatsData newProxyStatsData :: ProxyStatsData newProxyStatsData = ProxyStatsData {_pRequests = 0, _pSuccesses = 0, _pErrorsConnect = 0, _pErrorsCompat = 0, _pErrorsOther = 0} -getProxyStatsData :: ProxyStats -> STM ProxyStatsData +getProxyStatsData :: ProxyStats -> IO ProxyStatsData getProxyStatsData s = do - _pRequests <- readTVar $ pRequests s - _pSuccesses <- readTVar $ pSuccesses s - _pErrorsConnect <- readTVar $ pErrorsConnect s - _pErrorsCompat <- readTVar $ pErrorsCompat s - _pErrorsOther <- readTVar $ pErrorsOther s + _pRequests <- readTVarIO $ pRequests s + _pSuccesses <- readTVarIO $ pSuccesses s + _pErrorsConnect <- readTVarIO $ pErrorsConnect s + _pErrorsCompat <- readTVarIO $ pErrorsCompat s + _pErrorsOther <- readTVarIO $ pErrorsOther s pure ProxyStatsData {_pRequests, _pSuccesses, _pErrorsConnect, _pErrorsCompat, _pErrorsOther} getResetProxyStatsData :: ProxyStats -> STM ProxyStatsData diff --git a/src/Simplex/Messaging/Session.hs b/src/Simplex/Messaging/Session.hs index 3ce5a35c8..45c182046 100644 --- a/src/Simplex/Messaging/Session.hs +++ b/src/Simplex/Messaging/Session.hs @@ -5,9 +5,6 @@ module Simplex.Messaging.Session where import Control.Concurrent.STM -import Control.Monad -import Data.Composition ((.:.)) -import Data.Functor (($>)) import Data.Time (UTCTime) import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM @@ -31,14 +28,10 @@ getSessVar sessSeq sessKey vs sessionVarTs = maybe (Left <$> newSessionVar) (pur pure v removeSessVar :: Ord k => SessionVar a -> k -> TMap k (SessionVar a) -> STM () -removeSessVar = void .:. removeSessVar' -{-# INLINE removeSessVar #-} - -removeSessVar' :: Ord k => SessionVar a -> k -> TMap k (SessionVar a) -> STM Bool -removeSessVar' v sessKey vs = +removeSessVar v sessKey vs = TM.lookup sessKey vs >>= \case - Just v' | sessionVarId v == sessionVarId v' -> TM.delete sessKey vs $> True - _ -> pure False + Just v' | sessionVarId v == sessionVarId v' -> TM.delete sessKey vs + _ -> pure () tryReadSessVar :: Ord k => k -> TMap k (SessionVar a) -> STM (Maybe a) tryReadSessVar sessKey vs = TM.lookup sessKey vs $>>= (tryReadTMVar . sessionVar) diff --git a/src/Simplex/Messaging/TMap.hs b/src/Simplex/Messaging/TMap.hs index 2f6e0cf8a..1bc9bcb60 100644 --- a/src/Simplex/Messaging/TMap.hs +++ b/src/Simplex/Messaging/TMap.hs @@ -1,11 +1,13 @@ module Simplex.Messaging.TMap ( TMap, - empty, + emptyIO, singleton, clear, Simplex.Messaging.TMap.null, Simplex.Messaging.TMap.lookup, + lookupIO, member, + memberIO, insert, delete, lookupInsert, @@ -24,9 +26,9 @@ import qualified Data.Map.Strict as M type TMap k a = TVar (Map k a) -empty :: STM (TMap k a) -empty = newTVar M.empty -{-# INLINE empty #-} +emptyIO :: IO (TMap k a) +emptyIO = newTVarIO M.empty +{-# INLINE emptyIO #-} singleton :: k -> a -> STM (TMap k a) singleton k v = newTVar $ M.singleton k v @@ -44,10 +46,18 @@ lookup :: Ord k => k -> TMap k a -> STM (Maybe a) lookup k m = M.lookup k <$> readTVar m {-# INLINE lookup #-} +lookupIO :: Ord k => k -> TMap k a -> IO (Maybe a) +lookupIO k m = M.lookup k <$> readTVarIO m +{-# INLINE lookupIO #-} + member :: Ord k => k -> TMap k a -> STM Bool member k m = M.member k <$> readTVar m {-# INLINE member #-} +memberIO :: Ord k => k -> TMap k a -> IO Bool +memberIO k m = M.member k <$> readTVarIO m +{-# INLINE memberIO #-} + insert :: Ord k => k -> a -> TMap k a -> STM () insert k v m = modifyTVar' m $ M.insert k v {-# INLINE insert #-} diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index d7f81f563..58843b7f5 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -285,7 +285,7 @@ getTLS :: TransportPeer -> TransportConfig -> X.CertificateChain -> T.Context -> getTLS tlsPeer cfg tlsServerCerts cxt = withTlsUnique tlsPeer cxt newTLS where newTLS tlsUniq = do - tlsBuffer <- atomically newTBuffer + tlsBuffer <- newTBuffer tlsALPN <- T.getNegotiatedProtocol cxt pure TLS {tlsContext = cxt, tlsALPN, tlsTransportConfig = cfg, tlsServerCerts, tlsPeer, tlsUniq, tlsBuffer} diff --git a/src/Simplex/Messaging/Transport/Buffer.hs b/src/Simplex/Messaging/Transport/Buffer.hs index 6de9326f8..a612afafc 100644 --- a/src/Simplex/Messaging/Transport/Buffer.hs +++ b/src/Simplex/Messaging/Transport/Buffer.hs @@ -17,10 +17,10 @@ data TBuffer = TBuffer getLock :: TMVar () } -newTBuffer :: STM TBuffer +newTBuffer :: IO TBuffer newTBuffer = do - buffer <- newTVar "" - getLock <- newTMVar () + buffer <- newTVarIO "" + getLock <- newTMVarIO () pure TBuffer {buffer, getLock} withBufferLock :: TBuffer -> IO a -> IO a diff --git a/src/Simplex/Messaging/Transport/HTTP2.hs b/src/Simplex/Messaging/Transport/HTTP2.hs index 9c6cd7abc..3b741e6ce 100644 --- a/src/Simplex/Messaging/Transport/HTTP2.hs +++ b/src/Simplex/Messaging/Transport/HTTP2.hs @@ -75,7 +75,7 @@ instance HTTP2BodyChunk HS.Request where getHTTP2Body :: HTTP2BodyChunk a => a -> Int -> IO HTTP2Body getHTTP2Body r n = do - bodyBuffer <- atomically newTBuffer + bodyBuffer <- newTBuffer let getPart n' = getBuffered bodyBuffer n' Nothing $ getBodyChunk r bodyHead <- getPart n let bodySize = fromMaybe 0 $ getBodySize r diff --git a/src/Simplex/Messaging/Transport/HTTP2/Client.hs b/src/Simplex/Messaging/Transport/HTTP2/Client.hs index 71757ca6d..d8d3d495d 100644 --- a/src/Simplex/Messaging/Transport/HTTP2/Client.hs +++ b/src/Simplex/Messaging/Transport/HTTP2/Client.hs @@ -104,13 +104,13 @@ attachHTTP2Client config host port disconnected bufferSize tls = getVerifiedHTTP getVerifiedHTTP2ClientWith :: HTTP2ClientConfig -> TransportHost -> ServiceName -> IO () -> ((TLS -> H.Client HTTP2Response) -> IO HTTP2Response) -> IO (Either HTTP2ClientError HTTP2Client) getVerifiedHTTP2ClientWith config host port disconnected setup = - (atomically mkHTTPS2Client >>= runClient) + (mkHTTPS2Client >>= runClient) `E.catch` \(e :: IOException) -> pure . Left $ HCIOError e where - mkHTTPS2Client :: STM HClient + mkHTTPS2Client :: IO HClient mkHTTPS2Client = do - connected <- newTVar False - reqQ <- newTBQueue $ qSize config + connected <- newTVarIO False + reqQ <- newTBQueueIO $ qSize config pure HClient {connected, disconnected, host, port, config, reqQ} runClient :: HClient -> IO (Either HTTP2ClientError HTTP2Client) diff --git a/src/Simplex/Messaging/Transport/Server.hs b/src/Simplex/Messaging/Transport/Server.hs index ffde39991..0b4da7833 100644 --- a/src/Simplex/Messaging/Transport/Server.hs +++ b/src/Simplex/Messaging/Transport/Server.hs @@ -76,7 +76,7 @@ serverTransportConfig TransportServerConfig {logTLSErrors} = -- All accepted connections are passed to the passed function. runTransportServer :: forall c. Transport c => TMVar Bool -> ServiceName -> T.ServerParams -> TransportServerConfig -> (c -> IO ()) -> IO () runTransportServer started port params cfg server = do - ss <- atomically newSocketState + ss <- newSocketState runTransportServerState ss started port params cfg server runTransportServerState :: forall c . Transport c => SocketState -> TMVar Bool -> ServiceName -> T.ServerParams -> TransportServerConfig -> (c -> IO ()) -> IO () @@ -85,7 +85,7 @@ runTransportServerState ss started port = runTransportServerSocketState ss start -- | Run a transport server with provided connection setup and handler. runTransportServerSocket :: Transport a => TMVar Bool -> IO Socket -> String -> T.ServerParams -> TransportServerConfig -> (a -> IO ()) -> IO () runTransportServerSocket started getSocket threadLabel serverParams cfg server = do - ss <- atomically newSocketState + ss <- newSocketState runTransportServerSocketState ss started getSocket threadLabel serverParams cfg server -- | Run a transport server with provided connection setup and handler. @@ -109,7 +109,7 @@ tlsServerCredentials serverParams = case T.sharedCredentials $ T.serverShared se -- | Run TCP server without TLS runTCPServer :: TMVar Bool -> ServiceName -> (Socket -> IO ()) -> IO () runTCPServer started port server = do - ss <- atomically newSocketState + ss <- newSocketState runTCPServerSocket ss started (startTCPServer started port) server -- | Wrap socket provider in a TCP server bracket. @@ -148,8 +148,8 @@ safeAccept sock = type SocketState = (TVar Int, TVar Int, TVar (IntMap (Weak ThreadId))) -newSocketState :: STM SocketState -newSocketState = (,,) <$> newTVar 0 <*> newTVar 0 <*> newTVar mempty +newSocketState :: IO SocketState +newSocketState = (,,) <$> newTVarIO 0 <*> newTVarIO 0 <*> newTVarIO mempty closeServer :: TMVar Bool -> TVar (IntMap (Weak ThreadId)) -> Socket -> IO () closeServer started clients sock = do diff --git a/tests/AgentTests/ConnectionRequestTests.hs b/tests/AgentTests/ConnectionRequestTests.hs index 8684c787c..5d0a2c00a 100644 --- a/tests/AgentTests/ConnectionRequestTests.hs +++ b/tests/AgentTests/ConnectionRequestTests.hs @@ -225,23 +225,23 @@ connectionRequestTests = queueV1NoPort #== ("smp://1234-w==@smp.simplex.im/3456-w==#/?v=1-1&dh=" <> url testDhKeyStr <> "&srv=jjbyvoemxysm7qxap7m5d5m35jzv5qq6gnlv7s4rsn7tdwwmuqciwpid.onion") queueV1NoPort #== ("smp://1234-w==@smp.simplex.im,jjbyvoemxysm7qxap7m5d5m35jzv5qq6gnlv7s4rsn7tdwwmuqciwpid.onion/3456-w==#" <> testDhKeyStr) it "should serialize and parse connection invitations and contact addresses" $ do - connectionRequest #==# ("simplex:/invitation#/?v=2-6&smp=" <> url queueStr <> "&e2e=" <> testE2ERatchetParamsStrUri) - connectionRequest #== ("https://simplex.chat/invitation#/?v=2-6&smp=" <> url queueStr <> "&e2e=" <> testE2ERatchetParamsStrUri) - connectionRequestSK #==# ("simplex:/invitation#/?v=2-6&smp=" <> url queueStrSK <> "&e2e=" <> testE2ERatchetParamsStrUri) - connectionRequest1 #==# ("simplex:/invitation#/?v=2-6&smp=" <> url queue1Str <> "&e2e=" <> testE2ERatchetParamsStrUri) - connectionRequest2queues #==# ("simplex:/invitation#/?v=2-6&smp=" <> url (queueStr <> ";" <> queueStr) <> "&e2e=" <> testE2ERatchetParamsStrUri) - connectionRequestNew #==# ("simplex:/invitation#/?v=2-6&smp=" <> url queueNewStr <> "&e2e=" <> testE2ERatchetParamsStrUri) - connectionRequestNew1 #==# ("simplex:/invitation#/?v=2-6&smp=" <> url queueNew1Str <> "&e2e=" <> testE2ERatchetParamsStrUri) - connectionRequest2queuesNew #==# ("simplex:/invitation#/?v=2-6&smp=" <> url (queueNewStr <> ";" <> queueNewStr) <> "&e2e=" <> testE2ERatchetParamsStrUri) + connectionRequest #==# ("simplex:/invitation#/?v=2-7&smp=" <> url queueStr <> "&e2e=" <> testE2ERatchetParamsStrUri) + connectionRequest #== ("https://simplex.chat/invitation#/?v=2-7&smp=" <> url queueStr <> "&e2e=" <> testE2ERatchetParamsStrUri) + connectionRequestSK #==# ("simplex:/invitation#/?v=2-7&smp=" <> url queueStrSK <> "&e2e=" <> testE2ERatchetParamsStrUri) + connectionRequest1 #==# ("simplex:/invitation#/?v=2-7&smp=" <> url queue1Str <> "&e2e=" <> testE2ERatchetParamsStrUri) + connectionRequest2queues #==# ("simplex:/invitation#/?v=2-7&smp=" <> url (queueStr <> ";" <> queueStr) <> "&e2e=" <> testE2ERatchetParamsStrUri) + connectionRequestNew #==# ("simplex:/invitation#/?v=2-7&smp=" <> url queueNewStr <> "&e2e=" <> testE2ERatchetParamsStrUri) + connectionRequestNew1 #==# ("simplex:/invitation#/?v=2-7&smp=" <> url queueNew1Str <> "&e2e=" <> testE2ERatchetParamsStrUri) + connectionRequest2queuesNew #==# ("simplex:/invitation#/?v=2-7&smp=" <> url (queueNewStr <> ";" <> queueNewStr) <> "&e2e=" <> testE2ERatchetParamsStrUri) connectionRequestV1 #== ("https://simplex.chat/invitation#/?v=1&smp=" <> url queueStr <> "&e2e=" <> testE2ERatchetParamsStrUri) - connectionRequestClientDataEmpty #==# ("simplex:/invitation#/?v=2-6&smp=" <> url queueStr <> "&e2e=" <> testE2ERatchetParamsStrUri <> "&data=" <> url "{}") - contactAddress #==# ("simplex:/contact#/?v=2-6&smp=" <> url queueStr) - contactAddress #== ("https://simplex.chat/contact#/?v=2-6&smp=" <> url queueStr) - contactAddress2queues #==# ("simplex:/contact#/?v=2-6&smp=" <> url (queueStr <> ";" <> queueStr)) - contactAddressNew #==# ("simplex:/contact#/?v=2-6&smp=" <> url queueNewStr) - contactAddress2queuesNew #==# ("simplex:/contact#/?v=2-6&smp=" <> url (queueNewStr <> ";" <> queueNewStr)) + connectionRequestClientDataEmpty #==# ("simplex:/invitation#/?v=2-7&smp=" <> url queueStr <> "&e2e=" <> testE2ERatchetParamsStrUri <> "&data=" <> url "{}") + contactAddress #==# ("simplex:/contact#/?v=2-7&smp=" <> url queueStr) + contactAddress #== ("https://simplex.chat/contact#/?v=2-7&smp=" <> url queueStr) + contactAddress2queues #==# ("simplex:/contact#/?v=2-7&smp=" <> url (queueStr <> ";" <> queueStr)) + contactAddressNew #==# ("simplex:/contact#/?v=2-7&smp=" <> url queueNewStr) + contactAddress2queuesNew #==# ("simplex:/contact#/?v=2-7&smp=" <> url (queueNewStr <> ";" <> queueNewStr)) contactAddressV2 #==# ("simplex:/contact#/?v=2&smp=" <> url queueStr) contactAddressV2 #== ("https://simplex.chat/contact#/?v=1&smp=" <> url queueStr) -- adjusted to v2 contactAddressV2 #== ("https://simplex.chat/contact#/?v=1-2&smp=" <> url queueStr) -- adjusted to v2 contactAddressV2 #== ("https://simplex.chat/contact#/?v=2-2&smp=" <> url queueStr) - contactAddressClientData #==# ("simplex:/contact#/?v=2-6&smp=" <> url queueStr <> "&data=" <> url "{\"type\":\"group_link\", \"group_link_id\":\"abc\"}") + contactAddressClientData #==# ("simplex:/contact#/?v=2-7&smp=" <> url queueStr <> "&data=" <> url "{\"type\":\"group_link\", \"group_link_id\":\"abc\"}") diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 19f4977fc..4d61d8463 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -38,6 +38,7 @@ module AgentTests.FunctionalAPITests rfGet, sfGet, nGet, + getInAnyOrder, (##>), (=##>), pattern CON, @@ -78,7 +79,7 @@ import SMPAgentClient import SMPClient (cfg, prevRange, prevVersion, testPort, testPort2, testStoreLogFile2, withSmpServer, withSmpServerConfigOn, withSmpServerOn, withSmpServerProxy, withSmpServerStoreLogOn, withSmpServerStoreMsgLogOn) import Simplex.Messaging.Agent hiding (createConnection, joinConnection, sendMessage) import qualified Simplex.Messaging.Agent as A -import Simplex.Messaging.Agent.Client (ProtocolTestFailure (..), ProtocolTestStep (..), UserNetworkInfo (..), UserNetworkType (..), waitForUserNetwork) +import Simplex.Messaging.Agent.Client (ProtocolTestFailure (..), ProtocolTestStep (..), ServerQueueInfo (..), UserNetworkInfo (..), UserNetworkType (..), waitForUserNetwork) import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..), InitialAgentServers (..), createAgentStore) import Simplex.Messaging.Agent.Protocol hiding (CON, CONF, INFO, REQ, SENT) import qualified Simplex.Messaging.Agent.Protocol as A @@ -244,7 +245,7 @@ inAnyOrder g rs = withFrozenCallStack $ do createConnection :: AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SubscriptionMode -> AE (ConnId, ConnectionRequestUri c) createConnection c userId enableNtfs cMode clientData = A.createConnection c userId enableNtfs cMode clientData (IKNoPQ PQSupportOn) -joinConnection :: AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> AE ConnId +joinConnection :: AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> AE (ConnId, SndQueueSecured) joinConnection c userId enableNtfs cReq connInfo = A.joinConnection c userId Nothing enableNtfs cReq connInfo PQSupportOn sendMessage :: AgentClient -> ConnId -> SMP.MsgFlags -> MsgBody -> AE AgentMsgId @@ -257,20 +258,25 @@ functionalAPITests :: ATransport -> Spec functionalAPITests t = do describe "Establishing duplex connection" $ do testMatrix2 t runAgentClientTest - xit "should connect when server with multiple identities is stored" $ + it "should connect when server with multiple identities is stored" $ withSmpServer t testServerMultipleIdentities - xit "should connect with two peers" $ + it "should connect with two peers" $ withSmpServer t testAgentClient3 - xit "should establish connection without PQ encryption and enable it" $ + it "should establish connection without PQ encryption and enable it" $ withSmpServer t testEnablePQEncryption + describe "Duplex connection - delivery stress test" $ do + describe "one way (50)" $ testMatrix2Stress t $ runAgentClientStressTestOneWay 50 + xdescribe "one way (1000)" $ testMatrix2Stress t $ runAgentClientStressTestOneWay 1000 + describe "two way concurrently (50)" $ testMatrix2Stress t $ runAgentClientStressTestConc 25 + xdescribe "two way concurrently (1000)" $ testMatrix2Stress t $ runAgentClientStressTestConc 500 describe "Establishing duplex connection, different PQ settings" $ do - testPQMatrix2 t $ runAgentClientTestPQ True + testPQMatrix2 t $ runAgentClientTestPQ False True describe "Establishing duplex connection v2, different Ratchet versions" $ testRatchetMatrix2 t runAgentClientTest describe "Establish duplex connection via contact address" $ testMatrix2 t runAgentClientContactTest describe "Establish duplex connection via contact address, different PQ settings" $ do - testPQMatrix2NoInv t $ runAgentClientContactTestPQ True PQSupportOn + testPQMatrix2NoInv t $ runAgentClientContactTestPQ False True PQSupportOn describe "Establish duplex connection via contact address v2, different Ratchet versions" $ testRatchetMatrix2 t runAgentClientContactTest describe "Establish duplex connection via contact address, different PQ settings" $ do @@ -290,43 +296,43 @@ functionalAPITests t = do testAllowConnectionClientRestart t describe "Message delivery" $ do describe "update connection agent version on received messages" $ do - xit "should increase if compatible, shouldn't decrease" $ + it "should increase if compatible, shouldn't decrease" $ testIncreaseConnAgentVersion t - xit "should increase to max compatible version" $ + it "should increase to max compatible version" $ testIncreaseConnAgentVersionMaxCompatible t - xit "should increase when connection was negotiated on different versions" $ + it "should increase when connection was negotiated on different versions" $ testIncreaseConnAgentVersionStartDifferentVersion t -- TODO PQ tests for upgrading connection to PQ encryption - xit "should deliver message after client restart" $ + it "should deliver message after client restart" $ testDeliverClientRestart t - xit "should deliver messages to the user once, even if repeat delivery is made by the server (no ACK)" $ + it "should deliver messages to the user once, even if repeat delivery is made by the server (no ACK)" $ testDuplicateMessage t - xit "should report error via msg integrity on skipped messages" $ + it "should report error via msg integrity on skipped messages" $ testSkippedMessages t - xit "should connect to the server when server goes up if it initially was down" $ + it "should connect to the server when server goes up if it initially was down" $ testDeliveryAfterSubscriptionError t - xit "should deliver messages if one of connections has quota exceeded" $ + it "should deliver messages if one of connections has quota exceeded" $ testMsgDeliveryQuotaExceeded t describe "message expiration" $ do - xit "should expire one message" $ testExpireMessage t - xit "should expire multiple messages" $ testExpireManyMessages t - xit "should expire one message if quota is exceeded" $ testExpireMessageQuota t - xit "should expire multiple messages if quota is exceeded" $ testExpireManyMessagesQuota t + it "should expire one message" $ testExpireMessage t + it "should expire multiple messages" $ testExpireManyMessages t + it "should expire one message if quota is exceeded" $ testExpireMessageQuota t + it "should expire multiple messages if quota is exceeded" $ testExpireManyMessagesQuota t describe "Ratchet synchronization" $ do - xit "should report ratchet de-synchronization, synchronize ratchets" $ + it "should report ratchet de-synchronization, synchronize ratchets" $ testRatchetSync t - xit "should synchronize ratchets after server being offline" $ + it "should synchronize ratchets after server being offline" $ testRatchetSyncServerOffline t - xit "should synchronize ratchets after client restart" $ + it "should synchronize ratchets after client restart" $ testRatchetSyncClientRestart t - xit "should synchronize ratchets after suspend/foreground" $ + it "should synchronize ratchets after suspend/foreground" $ testRatchetSyncSuspendForeground t - xit "should synchronize ratchets when clients start synchronization simultaneously" $ + it "should synchronize ratchets when clients start synchronization simultaneously" $ testRatchetSyncSimultaneous t describe "Subscription mode OnlyCreate" $ do it "messages delivered only when polled (v8 - slow handshake)" $ withSmpServer t testOnlyCreatePullSlowHandshake - xit "messages delivered only when polled" $ + it "messages delivered only when polled" $ withSmpServer t testOnlyCreatePull describe "Inactive client disconnection" $ do it "should disconnect clients without subs if they were inactive longer than TTL" $ @@ -336,14 +342,14 @@ functionalAPITests t = do it "should NOT disconnect active clients" $ testActiveClientNotDisconnected t describe "Suspending agent" $ do - xit "should update client when agent is suspended" $ + it "should update client when agent is suspended" $ withSmpServer t testSuspendingAgent - xit "should complete sending messages when agent is suspended" $ + it "should complete sending messages when agent is suspended" $ testSuspendingAgentCompleteSending t - xit "should suspend agent on timeout, even if pending messages not sent" $ + it "should suspend agent on timeout, even if pending messages not sent" $ testSuspendingAgentTimeout t describe "Batching SMP commands" $ do - xit "should subscribe to multiple (200) subscriptions with batching" $ + it "should subscribe to multiple (200) subscriptions with batching" $ testBatchedSubscriptions 200 10 t skip "faster version of the previous test (200 subscriptions gets very slow with test coverage)" $ it "should subscribe to multiple (6) subscriptions with batching" $ @@ -351,6 +357,9 @@ functionalAPITests t = do it "should subscribe to multiple connections with pending messages" $ withSmpServer t $ testBatchedPendingMessages 10 5 + describe "Batch send messages" $ do + it "should send multiple messages to the same connection" $ withSmpServer t testSendMessagesB + it "should send messages to the 2 connections" $ withSmpServer t testSendMessagesB2 describe "Async agent commands" $ do describe "connect using async agent commands" $ testBasicMatrix2 t testAsyncCommands @@ -362,7 +371,7 @@ functionalAPITests t = do testDeleteConnectionAsync t it "join connection when reply queue creation fails (v8 - slow handshake)" $ testJoinConnectionAsyncReplyErrorV8 t - xit "join connection when reply queue creation fails" $ + it "join connection when reply queue creation fails" $ testJoinConnectionAsyncReplyError t describe "delete connection waiting for delivery" $ do it "should delete connection immediately if there are no pending messages" $ @@ -376,58 +385,59 @@ functionalAPITests t = do it "should delete connection by timeout, message in progress can be delivered" $ testWaitDeliveryTimeout2 t describe "Users" $ do - xit "should create and delete user with connections" $ + it "should create and delete user with connections" $ withSmpServer t testUsers - xit "should create and delete user without connections" $ + it "should create and delete user without connections" $ withSmpServer t testDeleteUserQuietly - xit "should create and delete user with connections when server connection fails" $ + it "should create and delete user with connections when server connection fails" $ testUsersNoServer t - xit "should connect two users and switch session mode" $ + it "should connect two users and switch session mode" $ withSmpServer t testTwoUsers describe "Connection switch" $ do - xdescribe "should switch delivery to the new queue" $ + describe "should switch delivery to the new queue" $ testServerMatrix2 t testSwitchConnection - xdescribe "should switch to new queue asynchronously" $ + describe "should switch to new queue asynchronously" $ testServerMatrix2 t testSwitchAsync describe "should delete connection during switch" $ testServerMatrix2 t testSwitchDelete - xdescribe "should abort switch in Started phase" $ + describe "should abort switch in Started phase" $ testServerMatrix2 t testAbortSwitchStarted - xdescribe "should abort switch in Started phase, reinitiate immediately" $ + describe "should abort switch in Started phase, reinitiate immediately" $ testServerMatrix2 t testAbortSwitchStartedReinitiate - xdescribe "should prohibit to abort switch in Secured phase" $ + describe "should prohibit to abort switch in Secured phase" $ testServerMatrix2 t testCannotAbortSwitchSecured - xdescribe "should switch two connections simultaneously" $ + describe "should switch two connections simultaneously" $ testServerMatrix2 t testSwitch2Connections - xdescribe "should switch two connections simultaneously, abort one" $ + describe "should switch two connections simultaneously, abort one" $ testServerMatrix2 t testSwitch2ConnectionsAbort1 describe "SMP basic auth" $ do let v4 = prevVersion basicAuthSMPVersion - forM_ (nub [prevVersion authCmdsSMPVersion, authCmdsSMPVersion, prevVersion currentServerSMPRelayVersion]) $ \v -> do + forM_ (nub [prevVersion authCmdsSMPVersion, authCmdsSMPVersion, currentServerSMPRelayVersion]) $ \v -> do let baseId = if v >= sndAuthKeySMPVersion then 1 else 3 + sqSecured = if v >= sndAuthKeySMPVersion then True else False describe ("v" <> show v <> ": with server auth") $ do -- allow NEW | server auth, v | clnt1 auth, v | clnt2 auth, v | 2 - success, 1 - JOIN fail, 0 - NEW fail - it "success " $ testBasicAuth t True (Just "abcd", v) (Just "abcd", v) (Just "abcd", v) baseId `shouldReturn` 2 - it "disabled " $ testBasicAuth t False (Just "abcd", v) (Just "abcd", v) (Just "abcd", v) baseId `shouldReturn` 0 - it "NEW fail, no auth " $ testBasicAuth t True (Just "abcd", v) (Nothing, v) (Just "abcd", v) baseId `shouldReturn` 0 - it "NEW fail, bad auth " $ testBasicAuth t True (Just "abcd", v) (Just "wrong", v) (Just "abcd", v) baseId `shouldReturn` 0 - it "NEW fail, version " $ testBasicAuth t True (Just "abcd", v) (Just "abcd", v4) (Just "abcd", v) baseId `shouldReturn` 0 - it "JOIN fail, no auth " $ testBasicAuth t True (Just "abcd", v) (Just "abcd", v) (Nothing, v) baseId `shouldReturn` 1 - it "JOIN fail, bad auth " $ testBasicAuth t True (Just "abcd", v) (Just "abcd", v) (Just "wrong", v) baseId `shouldReturn` 1 - it "JOIN fail, version " $ testBasicAuth t True (Just "abcd", v) (Just "abcd", v) (Just "abcd", v4) baseId `shouldReturn` 1 + it "success " $ testBasicAuth t True (Just "abcd", v) (Just "abcd", v) (Just "abcd", v) sqSecured baseId `shouldReturn` 2 + it "disabled " $ testBasicAuth t False (Just "abcd", v) (Just "abcd", v) (Just "abcd", v) sqSecured baseId `shouldReturn` 0 + it "NEW fail, no auth " $ testBasicAuth t True (Just "abcd", v) (Nothing, v) (Just "abcd", v) sqSecured baseId `shouldReturn` 0 + it "NEW fail, bad auth " $ testBasicAuth t True (Just "abcd", v) (Just "wrong", v) (Just "abcd", v) sqSecured baseId `shouldReturn` 0 + it "NEW fail, version " $ testBasicAuth t True (Just "abcd", v) (Just "abcd", v4) (Just "abcd", v) sqSecured baseId `shouldReturn` 0 + it "JOIN fail, no auth " $ testBasicAuth t True (Just "abcd", v) (Just "abcd", v) (Nothing, v) sqSecured baseId `shouldReturn` 1 + it "JOIN fail, bad auth " $ testBasicAuth t True (Just "abcd", v) (Just "abcd", v) (Just "wrong", v) sqSecured baseId `shouldReturn` 1 + it "JOIN fail, version " $ testBasicAuth t True (Just "abcd", v) (Just "abcd", v) (Just "abcd", v4) sqSecured baseId `shouldReturn` 1 describe ("v" <> show v <> ": no server auth") $ do - it "success " $ testBasicAuth t True (Nothing, v) (Nothing, v) (Nothing, v) baseId `shouldReturn` 2 - it "srv disabled" $ testBasicAuth t False (Nothing, v) (Nothing, v) (Nothing, v) baseId `shouldReturn` 0 - it "version srv " $ testBasicAuth t True (Nothing, v4) (Nothing, v) (Nothing, v) 3 `shouldReturn` 2 - it "version fst " $ testBasicAuth t True (Nothing, v) (Nothing, v4) (Nothing, v) baseId `shouldReturn` 2 - it "version snd " $ testBasicAuth t True (Nothing, v) (Nothing, v) (Nothing, v4) 3 `shouldReturn` 2 - it "version both" $ testBasicAuth t True (Nothing, v) (Nothing, v4) (Nothing, v4) 3 `shouldReturn` 2 - it "version all " $ testBasicAuth t True (Nothing, v4) (Nothing, v4) (Nothing, v4) 3 `shouldReturn` 2 - it "auth fst " $ testBasicAuth t True (Nothing, v) (Just "abcd", v) (Nothing, v) baseId `shouldReturn` 2 - it "auth fst 2 " $ testBasicAuth t True (Nothing, v4) (Just "abcd", v) (Nothing, v) 3 `shouldReturn` 2 - it "auth snd " $ testBasicAuth t True (Nothing, v) (Nothing, v) (Just "abcd", v) baseId `shouldReturn` 2 - it "auth both " $ testBasicAuth t True (Nothing, v) (Just "abcd", v) (Just "abcd", v) baseId `shouldReturn` 2 - it "auth, disabled" $ testBasicAuth t False (Nothing, v) (Just "abcd", v) (Just "abcd", v) baseId `shouldReturn` 0 + it "success " $ testBasicAuth t True (Nothing, v) (Nothing, v) (Nothing, v) sqSecured baseId `shouldReturn` 2 + it "srv disabled" $ testBasicAuth t False (Nothing, v) (Nothing, v) (Nothing, v) sqSecured baseId `shouldReturn` 0 + it "version srv " $ testBasicAuth t True (Nothing, v4) (Nothing, v) (Nothing, v) False 3 `shouldReturn` 2 + it "version fst " $ testBasicAuth t True (Nothing, v) (Nothing, v4) (Nothing, v) False baseId `shouldReturn` 2 + it "version snd " $ testBasicAuth t True (Nothing, v) (Nothing, v) (Nothing, v4) sqSecured 3 `shouldReturn` 2 + it "version both" $ testBasicAuth t True (Nothing, v) (Nothing, v4) (Nothing, v4) False 3 `shouldReturn` 2 + it "version all " $ testBasicAuth t True (Nothing, v4) (Nothing, v4) (Nothing, v4) False 3 `shouldReturn` 2 + it "auth fst " $ testBasicAuth t True (Nothing, v) (Just "abcd", v) (Nothing, v) sqSecured baseId `shouldReturn` 2 + it "auth fst 2 " $ testBasicAuth t True (Nothing, v4) (Just "abcd", v) (Nothing, v) False 3 `shouldReturn` 2 + it "auth snd " $ testBasicAuth t True (Nothing, v) (Nothing, v) (Just "abcd", v) sqSecured baseId `shouldReturn` 2 + it "auth both " $ testBasicAuth t True (Nothing, v) (Just "abcd", v) (Just "abcd", v) sqSecured baseId `shouldReturn` 2 + it "auth, disabled" $ testBasicAuth t False (Nothing, v) (Just "abcd", v) (Just "abcd", v) sqSecured baseId `shouldReturn` 0 describe "SMP server test via agent API" $ do it "should pass without basic auth" $ testSMPServerConnectionTest t Nothing (noAuthSrv testSMPServer2) `shouldReturn` Nothing let srv1 = testSMPServer2 {keyHash = "1234"} @@ -444,19 +454,19 @@ functionalAPITests t = do it "should return the same data for both peers" $ withSmpServer t testRatchetAdHash describe "Delivery receipts" $ do - xit "should send and receive delivery receipt" $ withSmpServer t testDeliveryReceipts - xit "should send delivery receipt only in connection v3+" $ testDeliveryReceiptsVersion t + it "should send and receive delivery receipt" $ withSmpServer t testDeliveryReceipts + it "should send delivery receipt only in connection v3+" $ testDeliveryReceiptsVersion t it "send delivery receipts concurrently with messages" $ testDeliveryReceiptsConcurrent t describe "user network info" $ do it "should wait for user network" testWaitForUserNetwork it "should not reset online to offline if happens too quickly" testDoNotResetOnlineToOffline it "should resume multiple threads" testResumeMultipleThreads describe "SMP queue info" $ do - xit "server should respond with queue and subscription information" $ + it "server should respond with queue and subscription information" $ withSmpServer t testServerQueueInfo -testBasicAuth :: ATransport -> Bool -> (Maybe BasicAuth, VersionSMP) -> (Maybe BasicAuth, VersionSMP) -> (Maybe BasicAuth, VersionSMP) -> AgentMsgId -> IO Int -testBasicAuth t allowNewQueues srv@(srvAuth, srvVersion) clnt1 clnt2 baseId = do +testBasicAuth :: ATransport -> Bool -> (Maybe BasicAuth, VersionSMP) -> (Maybe BasicAuth, VersionSMP) -> (Maybe BasicAuth, VersionSMP) -> SndQueueSecured -> AgentMsgId -> IO Int +testBasicAuth t allowNewQueues srv@(srvAuth, srvVersion) clnt1 clnt2 sqSecured baseId = do let testCfg = cfg {allowNewQueues, newQueueBasicAuth = srvAuth, smpServerVRange = V.mkVersionRange batchCmdsSMPVersion srvVersion} canCreate1 = canCreateQueue allowNewQueues srv clnt1 canCreate2 = canCreateQueue allowNewQueues srv clnt2 @@ -464,7 +474,7 @@ testBasicAuth t allowNewQueues srv@(srvAuth, srvVersion) clnt1 clnt2 baseId = do | canCreate1 && canCreate2 = 2 | canCreate1 = 1 | otherwise = 0 - created <- withSmpServerConfigOn t testCfg testPort $ \_ -> testCreateQueueAuth srvVersion clnt1 clnt2 baseId + created <- withSmpServerConfigOn t testCfg testPort $ \_ -> testCreateQueueAuth srvVersion clnt1 clnt2 sqSecured baseId created `shouldBe` expected pure created @@ -473,30 +483,43 @@ canCreateQueue allowNew (srvAuth, srvVersion) (clntAuth, clntVersion) = let v = basicAuthSMPVersion in allowNew && (isNothing srvAuth || (srvVersion >= v && clntVersion >= v && srvAuth == clntAuth)) -testMatrix2 :: HasCallStack => ATransport -> (PQSupport -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec +testMatrix2 :: HasCallStack => ATransport -> (PQSupport -> SndQueueSecured -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec testMatrix2 t runTest = do - xit "current, via proxy" $ withSmpServerProxy t $ runTestCfgServers2 agentCfg agentCfg (initAgentServersProxy SPMAlways SPFProhibit) 1 $ runTest PQSupportOn True - it "v8, via proxy" $ withSmpServerProxy t $ runTestCfgServers2 agentProxyCfgV8 agentProxyCfgV8 (initAgentServersProxy SPMAlways SPFProhibit) 3 $ runTest PQSupportOn True - xit "current" $ withSmpServer t $ runTestCfg2 agentCfg agentCfg 1 $ runTest PQSupportOn False - it "prev" $ withSmpServer t $ runTestCfg2 agentCfgVPrev agentCfgVPrev 3 $ runTest PQSupportOff False - it "prev to current" $ withSmpServer t $ runTestCfg2 agentCfgVPrev agentCfg 3 $ runTest PQSupportOff False - it "current to prev" $ withSmpServer t $ runTestCfg2 agentCfg agentCfgVPrev 3 $ runTest PQSupportOff False + it "current, via proxy" $ withSmpServerProxy t $ runTestCfgServers2 agentCfg agentCfg (initAgentServersProxy SPMAlways SPFProhibit) 1 $ runTest PQSupportOn True True + it "v8, via proxy" $ withSmpServerProxy t $ runTestCfgServers2 agentProxyCfgV8 agentProxyCfgV8 (initAgentServersProxy SPMAlways SPFProhibit) 3 $ runTest PQSupportOn False True + it "current" $ withSmpServer t $ runTestCfg2 agentCfg agentCfg 1 $ runTest PQSupportOn True False + it "prev" $ withSmpServer t $ runTestCfg2 agentCfgVPrev agentCfgVPrev 3 $ runTest PQSupportOff False False + it "prev to current" $ withSmpServer t $ runTestCfg2 agentCfgVPrev agentCfg 3 $ runTest PQSupportOff False False + it "current to prev" $ withSmpServer t $ runTestCfg2 agentCfg agentCfgVPrev 3 $ runTest PQSupportOff False False -testBasicMatrix2 :: HasCallStack => ATransport -> (AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec +testMatrix2Stress :: HasCallStack => ATransport -> (PQSupport -> SndQueueSecured -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec +testMatrix2Stress t runTest = do + it "current, via proxy" $ withSmpServerProxy t $ runTestCfgServers2 aCfg aCfg (initAgentServersProxy SPMAlways SPFProhibit) 1 $ runTest PQSupportOn True True + it "v8, via proxy" $ withSmpServerProxy t $ runTestCfgServers2 aProxyCfgV8 aProxyCfgV8 (initAgentServersProxy SPMAlways SPFProhibit) 3 $ runTest PQSupportOn False True + it "current" $ withSmpServer t $ runTestCfg2 aCfg aCfg 1 $ runTest PQSupportOn True False + it "prev" $ withSmpServer t $ runTestCfg2 aCfgVPrev aCfgVPrev 3 $ runTest PQSupportOff False False + it "prev to current" $ withSmpServer t $ runTestCfg2 aCfgVPrev aCfg 3 $ runTest PQSupportOff False False + it "current to prev" $ withSmpServer t $ runTestCfg2 aCfg aCfgVPrev 3 $ runTest PQSupportOff False False + where + aCfg = agentCfg {messageRetryInterval = fastMessageRetryInterval} + aProxyCfgV8 = agentProxyCfgV8 {messageRetryInterval = fastMessageRetryInterval} + aCfgVPrev = agentCfgVPrev {messageRetryInterval = fastMessageRetryInterval} + +testBasicMatrix2 :: HasCallStack => ATransport -> (SndQueueSecured -> AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec testBasicMatrix2 t runTest = do - xit "current" $ withSmpServer t $ runTestCfg2 agentCfg agentCfg 1 $ runTest - it "prev" $ withSmpServer t $ runTestCfg2 agentCfgVPrevPQ agentCfgVPrevPQ 3 $ runTest - it "prev to current" $ withSmpServer t $ runTestCfg2 agentCfgVPrevPQ agentCfg 3 $ runTest - it "current to prev" $ withSmpServer t $ runTestCfg2 agentCfg agentCfgVPrevPQ 3 $ runTest + it "current" $ withSmpServer t $ runTestCfg2 agentCfg agentCfg 1 $ runTest True + it "prev" $ withSmpServer t $ runTestCfg2 agentCfgVPrevPQ agentCfgVPrevPQ 3 $ runTest False + it "prev to current" $ withSmpServer t $ runTestCfg2 agentCfgVPrevPQ agentCfg 3 $ runTest False + it "current to prev" $ withSmpServer t $ runTestCfg2 agentCfg agentCfgVPrevPQ 3 $ runTest False -testRatchetMatrix2 :: HasCallStack => ATransport -> (PQSupport -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec +testRatchetMatrix2 :: HasCallStack => ATransport -> (PQSupport -> SndQueueSecured -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec testRatchetMatrix2 t runTest = do - xit "current, via proxy" $ withSmpServerProxy t $ runTestCfgServers2 agentCfg agentCfg (initAgentServersProxy SPMAlways SPFProhibit) 1 $ runTest PQSupportOn True - it "v8, via proxy" $ withSmpServerProxy t $ runTestCfgServers2 agentProxyCfgV8 agentProxyCfgV8 (initAgentServersProxy SPMAlways SPFProhibit) 3 $ runTest PQSupportOn True - xit "ratchet current" $ withSmpServer t $ runTestCfg2 agentCfg agentCfg 1 $ runTest PQSupportOn False - xit "ratchet prev" $ withSmpServer t $ runTestCfg2 agentCfgRatchetVPrev agentCfgRatchetVPrev 1 $ runTest PQSupportOff False - xit "ratchets prev to current" $ withSmpServer t $ runTestCfg2 agentCfgRatchetVPrev agentCfg 1 $ runTest PQSupportOff False - xit "ratchets current to prev" $ withSmpServer t $ runTestCfg2 agentCfg agentCfgRatchetVPrev 1 $ runTest PQSupportOff False + it "current, via proxy" $ withSmpServerProxy t $ runTestCfgServers2 agentCfg agentCfg (initAgentServersProxy SPMAlways SPFProhibit) 1 $ runTest PQSupportOn True True + it "v8, via proxy" $ withSmpServerProxy t $ runTestCfgServers2 agentProxyCfgV8 agentProxyCfgV8 (initAgentServersProxy SPMAlways SPFProhibit) 3 $ runTest PQSupportOn False True + it "ratchet current" $ withSmpServer t $ runTestCfg2 agentCfg agentCfg 1 $ runTest PQSupportOn True False + it "ratchet prev" $ withSmpServer t $ runTestCfg2 agentCfgRatchetVPrev agentCfgRatchetVPrev 1 $ runTest PQSupportOff True False + it "ratchets prev to current" $ withSmpServer t $ runTestCfg2 agentCfgRatchetVPrev agentCfg 1 $ runTest PQSupportOff True False + it "ratchets current to prev" $ withSmpServer t $ runTestCfg2 agentCfg agentCfgRatchetVPrev 1 $ runTest PQSupportOff True False testServerMatrix2 :: HasCallStack => ATransport -> (InitialAgentServers -> IO ()) -> Spec testServerMatrix2 t runTest = do @@ -571,15 +594,16 @@ withAgentClients3 runTest = withAgent 3 agentCfg initAgentServers testDB3 $ \c -> runTest a b c -runAgentClientTest :: HasCallStack => PQSupport -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO () -runAgentClientTest pqSupport viaProxy alice bob baseId = - runAgentClientTestPQ viaProxy (alice, IKNoPQ pqSupport) (bob, pqSupport) baseId +runAgentClientTest :: HasCallStack => PQSupport -> SndQueueSecured -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO () +runAgentClientTest pqSupport sqSecured viaProxy alice bob baseId = + runAgentClientTestPQ sqSecured viaProxy (alice, IKNoPQ pqSupport) (bob, pqSupport) baseId -runAgentClientTestPQ :: HasCallStack => Bool -> (AgentClient, InitialKeys) -> (AgentClient, PQSupport) -> AgentMsgId -> IO () -runAgentClientTestPQ viaProxy (alice, aPQ) (bob, bPQ) baseId = +runAgentClientTestPQ :: HasCallStack => SndQueueSecured -> Bool -> (AgentClient, InitialKeys) -> (AgentClient, PQSupport) -> AgentMsgId -> IO () +runAgentClientTestPQ sqSecured viaProxy (alice, aPQ) (bob, bPQ) baseId = runRight_ $ do (bobId, qInfo) <- A.createConnection alice 1 True SCMInvitation Nothing aPQ SMSubscribe - aliceId <- A.joinConnection bob 1 Nothing True qInfo "bob's connInfo" bPQ SMSubscribe + (aliceId, sqSecured') <- A.joinConnection bob 1 Nothing True qInfo "bob's connInfo" bPQ SMSubscribe + liftIO $ sqSecured' `shouldBe` sqSecured ("", _, A.CONF confId pqSup' _ "bob's connInfo") <- get alice liftIO $ pqSup' `shouldBe` CR.connPQEncryption aPQ allowConnection alice bobId confId "alice's connInfo" @@ -616,11 +640,76 @@ runAgentClientTestPQ viaProxy (alice, aPQ) (bob, bPQ) baseId = pqConnectionMode :: InitialKeys -> PQSupport -> Bool pqConnectionMode pqMode1 pqMode2 = supportPQ (CR.connPQEncryption pqMode1) && supportPQ pqMode2 +runAgentClientStressTestOneWay :: HasCallStack => Int64 -> PQSupport -> SndQueueSecured -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO () +runAgentClientStressTestOneWay n pqSupport sqSecured viaProxy alice bob baseId = runRight_ $ do + let pqEnc = PQEncryption $ supportPQ pqSupport + (aliceId, bobId) <- makeConnection_ pqSupport sqSecured alice bob + let proxySrv = if viaProxy then Just testSMPServer else Nothing + message i = "message " <> bshow i + concurrently_ + ( forM_ ([1 .. n] :: [Int64]) $ \i -> do + mId <- msgId <$> A.sendMessage alice bobId pqEnc SMP.noMsgFlags (message i) + liftIO $ do + mId >= i `shouldBe` True + let getEvent = + get alice >>= \case + ("", c, A.SENT mId' srv) -> c == bobId && mId' >= baseId + i && srv == proxySrv `shouldBe` True + ("", c, QCONT) -> do + c == bobId `shouldBe` True + getEvent + r -> expectationFailure $ "wrong message: " <> show r + getEvent + ) + ( forM_ ([1 .. n] :: [Int64]) $ \i -> do + get bob >>= \case + ("", c, Msg' mId pq msg) -> do + liftIO $ c == aliceId && mId >= baseId + i && pq == pqEnc && msg == message i `shouldBe` True + ackMessage bob aliceId mId Nothing + r -> liftIO $ expectationFailure $ "wrong message: " <> show r + ) + liftIO $ noMessagesIngoreQCONT alice "nothing else should be delivered to alice" + liftIO $ noMessagesIngoreQCONT bob "nothing else should be delivered to bob" + where + msgId = subtract baseId . fst + +runAgentClientStressTestConc :: HasCallStack => Int64 -> PQSupport -> SndQueueSecured -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO () +runAgentClientStressTestConc n pqSupport sqSecured viaProxy alice bob baseId = runRight_ $ do + let pqEnc = PQEncryption $ supportPQ pqSupport + (aliceId, bobId) <- makeConnection_ pqSupport sqSecured alice bob + let proxySrv = if viaProxy then Just testSMPServer else Nothing + message i = "message " <> bshow i + loop a bId mIdVar i = do + when (i <= n) $ do + mId <- msgId <$> A.sendMessage a bId pqEnc SMP.noMsgFlags (message i) + liftIO $ mId >= i `shouldBe` True + let getEvent = do + get a >>= \case + ("", c, A.SENT _ srv) -> liftIO $ c == bId && srv == proxySrv `shouldBe` True + ("", c, QCONT) -> do + liftIO $ c == bId `shouldBe` True + getEvent + ("", c, Msg' mId pq msg) -> do + -- tests that mId increases + liftIO $ (mId >) <$> atomically (swapTVar mIdVar mId) `shouldReturn` True + liftIO $ c == bId && pq == pqEnc && ("message " `B.isPrefixOf` msg) `shouldBe` True + ackMessage a bId mId Nothing + r -> liftIO $ expectationFailure $ "wrong message: " <> show r + getEvent + amId <- newTVarIO 0 + bmId <- newTVarIO 0 + concurrently_ + (forM_ ([1 .. n * 2] :: [Int64]) $ loop alice bobId amId) + (forM_ ([1 .. n * 2] :: [Int64]) $ loop bob aliceId bmId) + liftIO $ noMessagesIngoreQCONT alice "nothing else should be delivered to alice" + liftIO $ noMessagesIngoreQCONT bob "nothing else should be delivered to bob" + where + msgId = subtract baseId . fst + testEnablePQEncryption :: HasCallStack => IO () testEnablePQEncryption = withAgentClients2 $ \ca cb -> runRight_ $ do g <- liftIO C.newRandom - (aId, bId) <- makeConnection_ PQSupportOff ca cb + (aId, bId) <- makeConnection_ PQSupportOff True ca cb let a = (ca, aId) b = (cb, bId) (a, 2, "msg 1") \#>\ b @@ -706,20 +795,23 @@ testAgentClient3 = get c =##> \case ("", connId, Msg "c5") -> connId == aIdForC; _ -> False ackMessage c aIdForC 3 Nothing -runAgentClientContactTest :: HasCallStack => PQSupport -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO () -runAgentClientContactTest pqSupport viaProxy alice bob baseId = - runAgentClientContactTestPQ viaProxy pqSupport (alice, IKNoPQ pqSupport) (bob, pqSupport) baseId +runAgentClientContactTest :: HasCallStack => PQSupport -> SndQueueSecured -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO () +runAgentClientContactTest pqSupport sqSecured viaProxy alice bob baseId = + runAgentClientContactTestPQ sqSecured viaProxy pqSupport (alice, IKNoPQ pqSupport) (bob, pqSupport) baseId -runAgentClientContactTestPQ :: HasCallStack => Bool -> PQSupport -> (AgentClient, InitialKeys) -> (AgentClient, PQSupport) -> AgentMsgId -> IO () -runAgentClientContactTestPQ viaProxy reqPQSupport (alice, aPQ) (bob, bPQ) baseId = +runAgentClientContactTestPQ :: HasCallStack => SndQueueSecured -> Bool -> PQSupport -> (AgentClient, InitialKeys) -> (AgentClient, PQSupport) -> AgentMsgId -> IO () +runAgentClientContactTestPQ sqSecured viaProxy reqPQSupport (alice, aPQ) (bob, bPQ) baseId = runRight_ $ do (_, qInfo) <- A.createConnection alice 1 True SCMContact Nothing aPQ SMSubscribe aliceId <- A.prepareConnectionToJoin bob 1 True qInfo bPQ - aliceId' <- A.joinConnection bob 1 (Just aliceId) True qInfo "bob's connInfo" bPQ SMSubscribe - liftIO $ aliceId' `shouldBe` aliceId + (aliceId', sqSecuredJoin) <- A.joinConnection bob 1 (Just aliceId) True qInfo "bob's connInfo" bPQ SMSubscribe + liftIO $ do + aliceId' `shouldBe` aliceId + sqSecuredJoin `shouldBe` False -- joining via contact address connection ("", _, A.REQ invId pqSup' _ "bob's connInfo") <- get alice liftIO $ pqSup' `shouldBe` reqPQSupport - bobId <- acceptContact alice True invId "alice's connInfo" (CR.connPQEncryption aPQ) SMSubscribe + (bobId, sqSecured') <- acceptContact alice True invId "alice's connInfo" (CR.connPQEncryption aPQ) SMSubscribe + liftIO $ sqSecured' `shouldBe` sqSecured ("", _, A.CONF confId pqSup'' _ "alice's connInfo") <- get bob liftIO $ pqSup'' `shouldBe` bPQ allowConnection bob aliceId confId "bob's connInfo" @@ -764,11 +856,14 @@ runAgentClientContactTestPQ3 viaProxy (alice, aPQ) (bob, bPQ) (tom, tPQ) baseId msgId = subtract baseId . fst connectViaContact b pq qInfo = do aId <- A.prepareConnectionToJoin b 1 True qInfo pq - aId' <- A.joinConnection b 1 (Just aId) True qInfo "bob's connInfo" pq SMSubscribe - liftIO $ aId' `shouldBe` aId + (aId', sqSecuredJoin) <- A.joinConnection b 1 (Just aId) True qInfo "bob's connInfo" pq SMSubscribe + liftIO $ do + aId' `shouldBe` aId + sqSecuredJoin `shouldBe` False -- joining via contact address connection ("", _, A.REQ invId pqSup' _ "bob's connInfo") <- get alice liftIO $ pqSup' `shouldBe` PQSupportOn - bId <- acceptContact alice True invId "alice's connInfo" (CR.connPQEncryption aPQ) SMSubscribe + (bId, sqSecuredAccept) <- acceptContact alice True invId "alice's connInfo" (CR.connPQEncryption aPQ) SMSubscribe + liftIO $ sqSecuredAccept `shouldBe` False -- agent cfg is v8 ("", _, A.CONF confId pqSup'' _ "alice's connInfo") <- get b liftIO $ pqSup'' `shouldBe` pq allowConnection b aId confId "bob's connInfo" @@ -789,10 +884,17 @@ runAgentClientContactTestPQ3 viaProxy (alice, aPQ) (bob, bPQ) (tom, tPQ) baseId ackMessage a bId (baseId + 2) Nothing noMessages :: HasCallStack => AgentClient -> String -> Expectation -noMessages c err = tryGet `shouldReturn` () +noMessages = noMessages_ False + +noMessagesIngoreQCONT :: AgentClient -> String -> Expectation +noMessagesIngoreQCONT = noMessages_ True + +noMessages_ :: Bool -> HasCallStack => AgentClient -> String -> Expectation +noMessages_ ingoreQCONT c err = tryGet `shouldReturn` () where tryGet = 10000 `timeout` get c >>= \case + Just (_, _, QCONT) | ingoreQCONT -> noMessages_ ingoreQCONT c err Just msg -> error $ err <> ": " <> show msg _ -> return () @@ -801,8 +903,10 @@ testRejectContactRequest = withAgentClients2 $ \alice bob -> runRight_ $ do (addrConnId, qInfo) <- A.createConnection alice 1 True SCMContact Nothing IKPQOn SMSubscribe aliceId <- A.prepareConnectionToJoin bob 1 True qInfo PQSupportOn - aliceId' <- A.joinConnection bob 1 (Just aliceId) True qInfo "bob's connInfo" PQSupportOn SMSubscribe - liftIO $ aliceId' `shouldBe` aliceId + (aliceId', sqSecured) <- A.joinConnection bob 1 (Just aliceId) True qInfo "bob's connInfo" PQSupportOn SMSubscribe + liftIO $ do + aliceId' `shouldBe` aliceId + sqSecured `shouldBe` False -- joining via contact address connection ("", _, A.REQ invId PQSupportOn _ "bob's connInfo") <- get alice liftIO $ runExceptT (rejectContact alice "abcd" invId) `shouldReturn` Left (CONN NOT_FOUND) rejectContact alice addrConnId invId @@ -814,15 +918,34 @@ testAsyncInitiatingOffline = alice <- liftIO $ getSMPAgentClient' 1 agentCfg initAgentServers testDB (bobId, cReq) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe liftIO $ disposeAgentClient alice - aliceId <- joinConnection bob 1 True cReq "bob's connInfo" SMSubscribe + + (aliceId, sqSecured) <- joinConnection bob 1 True cReq "bob's connInfo" SMSubscribe + liftIO $ sqSecured `shouldBe` True + + -- send messages + msgId1 <- A.sendMessage bob aliceId PQEncOn SMP.noMsgFlags "can send 1" + liftIO $ msgId1 `shouldBe` (2, PQEncOff) + get bob ##> ("", aliceId, SENT 2) + msgId2 <- A.sendMessage bob aliceId PQEncOn SMP.noMsgFlags "can send 2" + liftIO $ msgId2 `shouldBe` (3, PQEncOff) + get bob ##> ("", aliceId, SENT 3) + alice' <- liftIO $ getSMPAgentClient' 3 agentCfg initAgentServers testDB subscribeConnection alice' bobId ("", _, CONF confId _ "bob's connInfo") <- get alice' + -- receive messages + get alice' =##> \case ("", c, Msg' mId pq "can send 1") -> c == bobId && mId == 1 && pq == PQEncOff; _ -> False + ackMessage alice' bobId 1 Nothing + get alice' =##> \case ("", c, Msg' mId pq "can send 2") -> c == bobId && mId == 2 && pq == PQEncOff; _ -> False + ackMessage alice' bobId 2 Nothing + -- for alice msg id 3 is sent confirmation, then they're matched with bob at msg id 4 + + -- allow connection allowConnection alice' bobId confId "alice's connInfo" get alice' ##> ("", bobId, CON) get bob ##> ("", aliceId, INFO "alice's connInfo") get bob ##> ("", aliceId, CON) - exchangeGreetings alice' bobId bob aliceId + exchangeGreetingsMsgId 4 alice' bobId bob aliceId liftIO $ disposeAgentClient alice' testAsyncJoiningOfflineBeforeActivation :: HasCallStack => IO () @@ -830,7 +953,8 @@ testAsyncJoiningOfflineBeforeActivation = withAgent 1 agentCfg initAgentServers testDB $ \alice -> runRight_ $ do bob <- liftIO $ getSMPAgentClient' 2 agentCfg initAgentServers testDB2 (bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe - aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe + (aliceId, sqSecured) <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe + liftIO $ sqSecured `shouldBe` True liftIO $ disposeAgentClient bob ("", _, CONF confId _ "bob's connInfo") <- get alice allowConnection alice bobId confId "alice's connInfo" @@ -849,7 +973,8 @@ testAsyncBothOffline = do runRight_ $ do (bobId, cReq) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe liftIO $ disposeAgentClient alice - aliceId <- joinConnection bob 1 True cReq "bob's connInfo" SMSubscribe + (aliceId, sqSecured) <- joinConnection bob 1 True cReq "bob's connInfo" SMSubscribe + liftIO $ sqSecured `shouldBe` True liftIO $ disposeAgentClient bob alice' <- liftIO $ getSMPAgentClient' 3 agentCfg initAgentServers testDB subscribeConnection alice' bobId @@ -880,7 +1005,8 @@ testAsyncServerOffline t = withAgentClients2 $ \alice bob -> do liftIO $ do srv1 `shouldBe` testSMPServer conns1 `shouldBe` [bobId] - aliceId <- joinConnection bob 1 True cReq "bob's connInfo" SMSubscribe + (aliceId, sqSecured) <- joinConnection bob 1 True cReq "bob's connInfo" SMSubscribe + liftIO $ sqSecured `shouldBe` True ("", _, CONF confId _ "bob's connInfo") <- get alice allowConnection alice bobId confId "alice's connInfo" get alice ##> ("", bobId, CON) @@ -890,7 +1016,7 @@ testAsyncServerOffline t = withAgentClients2 $ \alice bob -> do testAllowConnectionClientRestart :: HasCallStack => ATransport -> IO () testAllowConnectionClientRestart t = do - let initAgentServersSrv2 = initAgentServers {smp = userServers [noAuthSrv testSMPServer2]} + let initAgentServersSrv2 = initAgentServers {smp = userServers [testSMPServer2]} alice <- getSMPAgentClient' 1 agentCfg initAgentServers testDB bob <- getSMPAgentClient' 2 agentCfg initAgentServersSrv2 testDB2 withSmpServerStoreLogOn t testPort $ \_ -> do @@ -898,7 +1024,8 @@ testAllowConnectionClientRestart t = do withSmpServerConfigOn t cfg {storeLogFile = Just testStoreLogFile2} testPort2 $ \_ -> do runRight $ do (bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe - aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe + (aliceId, sqSecured) <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe + liftIO $ sqSecured `shouldBe` True ("", _, CONF confId _ "bob's connInfo") <- get alice pure (aliceId, bobId, confId) @@ -911,6 +1038,7 @@ testAllowConnectionClientRestart t = do threadDelay 100000 -- give time to enqueue confirmation (enqueueConfirmation) disposeAgentClient alice + threadDelay 250000 alice2 <- getSMPAgentClient' 3 agentCfg initAgentServers testDB @@ -933,7 +1061,7 @@ testIncreaseConnAgentVersion t = do bob <- getSMPAgentClient' 2 agentCfg {smpAgentVRange = mkVersionRange 1 2} initAgentServers testDB2 withSmpServerStoreMsgLogOn t testPort $ \_ -> do (aliceId, bobId) <- runRight $ do - (aliceId, bobId) <- makeConnection_ PQSupportOff alice bob + (aliceId, bobId) <- makeConnection_ PQSupportOff False alice bob exchangeGreetingsMsgId_ PQEncOff 2 alice bobId bob aliceId checkVersion alice bobId 2 checkVersion bob aliceId 2 @@ -998,7 +1126,7 @@ testIncreaseConnAgentVersionMaxCompatible t = do bob <- getSMPAgentClient' 2 agentCfg {smpAgentVRange = mkVersionRange 1 2} initAgentServers testDB2 withSmpServerStoreMsgLogOn t testPort $ \_ -> do (aliceId, bobId) <- runRight $ do - (aliceId, bobId) <- makeConnection_ PQSupportOff alice bob + (aliceId, bobId) <- makeConnection_ PQSupportOff False alice bob exchangeGreetingsMsgId_ PQEncOff 2 alice bobId bob aliceId checkVersion alice bobId 2 checkVersion bob aliceId 2 @@ -1028,7 +1156,7 @@ testIncreaseConnAgentVersionStartDifferentVersion t = do bob <- getSMPAgentClient' 2 agentCfg {smpAgentVRange = mkVersionRange 1 3} initAgentServers testDB2 withSmpServerStoreMsgLogOn t testPort $ \_ -> do (aliceId, bobId) <- runRight $ do - (aliceId, bobId) <- makeConnection_ PQSupportOff alice bob + (aliceId, bobId) <- makeConnection_ PQSupportOff False alice bob exchangeGreetingsMsgId_ PQEncOff 2 alice bobId bob aliceId checkVersion alice bobId 2 checkVersion bob aliceId 2 @@ -1037,6 +1165,7 @@ testIncreaseConnAgentVersionStartDifferentVersion t = do -- version increases to max compatible disposeAgentClient alice + threadDelay 250000 alice2 <- getSMPAgentClient' 3 agentCfg {smpAgentVRange = mkVersionRange 1 3} initAgentServers testDB runRight_ $ do @@ -1192,7 +1321,6 @@ testDeliveryAfterSubscriptionError t = do withAgentClients2 $ \a b -> do Left (BROKER _ NETWORK) <- runExceptT $ subscribeConnection a bId Left (BROKER _ NETWORK) <- runExceptT $ subscribeConnection b aId - pure () withSmpServerStoreLogOn t testPort $ \_ -> runRight $ do withUP a bId $ \case ("", c, SENT 2) -> c == bId; _ -> False withUP b aId $ \case ("", c, Msg "hello") -> c == aId; _ -> False @@ -1228,13 +1356,13 @@ testMsgDeliveryQuotaExceeded t = testExpireMessage :: HasCallStack => ATransport -> IO () testExpireMessage t = - withAgent 1 agentCfg {messageTimeout = 1, messageRetryInterval = fastMessageRetryInterval} initAgentServers testDB $ \a -> + withAgent 1 agentCfg {messageTimeout = 1.5, messageRetryInterval = fastMessageRetryInterval} initAgentServers testDB $ \a -> withAgent 2 agentCfg initAgentServers testDB2 $ \b -> do (aId, bId) <- withSmpServerStoreLogOn t testPort $ \_ -> runRight $ makeConnection a b nGet a =##> \case ("", "", DOWN _ [c]) -> c == bId; _ -> False nGet b =##> \case ("", "", DOWN _ [c]) -> c == aId; _ -> False 2 <- runRight $ sendMessage a bId SMP.noMsgFlags "1" - threadDelay 1000000 + threadDelay 1500000 3 <- runRight $ sendMessage a bId SMP.noMsgFlags "2" -- this won't expire get a =##> \case ("", c, MERR 2 (BROKER _ e)) -> bId == c && (e == TIMEOUT || e == NETWORK); _ -> False withSmpServerStoreLogOn t testPort $ \_ -> runRight_ $ do @@ -1244,7 +1372,7 @@ testExpireMessage t = testExpireManyMessages :: HasCallStack => ATransport -> IO () testExpireManyMessages t = - withAgent 1 agentCfg {messageTimeout = 1, messageRetryInterval = fastMessageRetryInterval} initAgentServers testDB $ \a -> + withAgent 1 agentCfg {messageTimeout = 2, messageRetryInterval = fastMessageRetryInterval} initAgentServers testDB $ \a -> withAgent 2 agentCfg initAgentServers testDB2 $ \b -> do (aId, bId) <- withSmpServerStoreLogOn t testPort $ \_ -> runRight $ makeConnection a b runRight_ $ do @@ -1253,15 +1381,14 @@ testExpireManyMessages t = 2 <- sendMessage a bId SMP.noMsgFlags "1" 3 <- sendMessage a bId SMP.noMsgFlags "2" 4 <- sendMessage a bId SMP.noMsgFlags "3" - liftIO $ threadDelay 1000000 + liftIO $ threadDelay 2000000 5 <- sendMessage a bId SMP.noMsgFlags "4" -- this won't expire get a =##> \case ("", c, MERR 2 (BROKER _ e)) -> bId == c && (e == TIMEOUT || e == NETWORK); _ -> False - -- get a =##> \case ("", c, MERRS [5, 6] (BROKER _ e)) -> bId == c && (e == TIMEOUT || e == NETWORK); _ -> False let expected c e = bId == c && (e == TIMEOUT || e == NETWORK) get a >>= \case ("", c, MERR 3 (BROKER _ e)) -> do liftIO $ expected c e `shouldBe` True - get a =##> \case ("", c', MERR 4 (BROKER _ e')) -> expected c' e'; ("", c', MERRS [6] (BROKER _ e')) -> expected c' e'; _ -> False + get a =##> \case ("", c', MERR 4 (BROKER _ e')) -> expected c' e'; ("", c', MERRS [4] (BROKER _ e')) -> expected c' e'; _ -> False ("", c, MERRS [3] (BROKER _ e)) -> do liftIO $ expected c e `shouldBe` True get a =##> \case ("", c', MERR 4 (BROKER _ e')) -> expected c' e'; _ -> False @@ -1273,7 +1400,7 @@ testExpireManyMessages t = withUP b aId $ \case ("", _, MsgErr 2 (MsgSkipped 2 4) "4") -> True; _ -> False ackMessage b aId 2 Nothing -withUP :: AgentClient -> ConnId -> (AEntityTransmission 'AEConn -> Bool) -> ExceptT AgentErrorType IO () +withUP :: HasCallStack => AgentClient -> ConnId -> (AEntityTransmission 'AEConn -> Bool) -> ExceptT AgentErrorType IO () withUP a bId p = liftIO $ getInAnyOrder @@ -1310,7 +1437,7 @@ testExpireMessageQuota t = withSmpServerConfigOn t cfg {msgQueueQuota = 1} testP testExpireManyMessagesQuota :: ATransport -> IO () testExpireManyMessagesQuota t = withSmpServerConfigOn t cfg {msgQueueQuota = 1} testPort $ \_ -> do - a <- getSMPAgentClient' 1 agentCfg {quotaExceededTimeout = 1, messageRetryInterval = fastMessageRetryInterval} initAgentServers testDB + a <- getSMPAgentClient' 1 agentCfg {quotaExceededTimeout = 2, messageRetryInterval = fastMessageRetryInterval} initAgentServers testDB b <- getSMPAgentClient' 2 agentCfg initAgentServers testDB2 (aId, bId) <- runRight $ do (aId, bId) <- makeConnection a b @@ -1320,7 +1447,7 @@ testExpireManyMessagesQuota t = withSmpServerConfigOn t cfg {msgQueueQuota = 1} 3 <- sendMessage a bId SMP.noMsgFlags "2" 4 <- sendMessage a bId SMP.noMsgFlags "3" 5 <- sendMessage a bId SMP.noMsgFlags "4" - liftIO $ threadDelay 1000000 + liftIO $ threadDelay 2000000 6 <- sendMessage a bId SMP.noMsgFlags "5" -- this won't expire get a =##> \case ("", c, MERR 3 (SMP _ QUOTA)) -> bId == c; _ -> False get a >>= \case @@ -1530,7 +1657,8 @@ testRatchetSyncSimultaneous t = do testOnlyCreatePullSlowHandshake :: IO () testOnlyCreatePullSlowHandshake = withAgentClientsCfg2 agentProxyCfgV8 agentProxyCfgV8 $ \alice bob -> runRight_ $ do (bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing SMOnlyCreate - aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" SMOnlyCreate + (aliceId, sqSecured) <- joinConnection bob 1 True qInfo "bob's connInfo" SMOnlyCreate + liftIO $ sqSecured `shouldBe` False Just ("", _, CONF confId _ "bob's connInfo") <- getMsg alice bobId $ timeout 5_000000 $ get alice allowConnection alice bobId confId "alice's connInfo" liftIO $ threadDelay 1_000000 @@ -1564,7 +1692,8 @@ getMsg c cId action = do testOnlyCreatePull :: IO () testOnlyCreatePull = withAgentClients2 $ \alice bob -> runRight_ $ do (bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing SMOnlyCreate - aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" SMOnlyCreate + (aliceId, sqSecured) <- joinConnection bob 1 True qInfo "bob's connInfo" SMOnlyCreate + liftIO $ sqSecured `shouldBe` True Just ("", _, CONF confId _ "bob's connInfo") <- getMsg alice bobId $ timeout 5_000000 $ get alice allowConnection alice bobId confId "alice's connInfo" liftIO $ threadDelay 1_000000 @@ -1586,20 +1715,22 @@ testOnlyCreatePull = withAgentClients2 $ \alice bob -> runRight_ $ do ackMessage alice bobId 3 Nothing makeConnection :: AgentClient -> AgentClient -> ExceptT AgentErrorType IO (ConnId, ConnId) -makeConnection = makeConnection_ PQSupportOn +makeConnection = makeConnection_ PQSupportOn True -makeConnection_ :: PQSupport -> AgentClient -> AgentClient -> ExceptT AgentErrorType IO (ConnId, ConnId) -makeConnection_ pqEnc alice bob = makeConnectionForUsers_ pqEnc alice 1 bob 1 +makeConnection_ :: PQSupport -> SndQueueSecured -> AgentClient -> AgentClient -> ExceptT AgentErrorType IO (ConnId, ConnId) +makeConnection_ pqEnc sqSecured alice bob = makeConnectionForUsers_ pqEnc sqSecured alice 1 bob 1 makeConnectionForUsers :: HasCallStack => AgentClient -> UserId -> AgentClient -> UserId -> ExceptT AgentErrorType IO (ConnId, ConnId) -makeConnectionForUsers = makeConnectionForUsers_ PQSupportOn +makeConnectionForUsers = makeConnectionForUsers_ PQSupportOn True -makeConnectionForUsers_ :: HasCallStack => PQSupport -> AgentClient -> UserId -> AgentClient -> UserId -> ExceptT AgentErrorType IO (ConnId, ConnId) -makeConnectionForUsers_ pqSupport alice aliceUserId bob bobUserId = do +makeConnectionForUsers_ :: HasCallStack => PQSupport -> SndQueueSecured -> AgentClient -> UserId -> AgentClient -> UserId -> ExceptT AgentErrorType IO (ConnId, ConnId) +makeConnectionForUsers_ pqSupport sqSecured alice aliceUserId bob bobUserId = do (bobId, qInfo) <- A.createConnection alice aliceUserId True SCMInvitation Nothing (CR.IKNoPQ pqSupport) SMSubscribe aliceId <- A.prepareConnectionToJoin bob bobUserId True qInfo pqSupport - aliceId' <- A.joinConnection bob bobUserId (Just aliceId) True qInfo "bob's connInfo" pqSupport SMSubscribe - liftIO $ aliceId' `shouldBe` aliceId + (aliceId', sqSecured') <- A.joinConnection bob bobUserId (Just aliceId) True qInfo "bob's connInfo" pqSupport SMSubscribe + liftIO $ do + aliceId' `shouldBe` aliceId + sqSecured' `shouldBe` sqSecured ("", _, A.CONF confId pqSup' _ "bob's connInfo") <- get alice liftIO $ pqSup' `shouldBe` pqSupport allowConnection alice bobId confId "alice's connInfo" @@ -1682,7 +1813,6 @@ testSuspendingAgentCompleteSending t = withAgentClients2 $ \a b -> do get b =##> \case ("", c, Msg "hello") -> c == aId; _ -> False ackMessage b aId 2 Nothing pure (aId, bId) - runRight_ $ do ("", "", DOWN {}) <- nGet a ("", "", DOWN {}) <- nGet b @@ -1690,15 +1820,17 @@ testSuspendingAgentCompleteSending t = withAgentClients2 $ \a b -> do 4 <- sendMessage b aId SMP.noMsgFlags "how are you?" liftIO $ threadDelay 100000 liftIO $ suspendAgent b 5000000 - withSmpServerStoreLogOn t testPort $ \_ -> runRight_ @AgentErrorType $ do - pGet b =##> \case ("", c, AEvt SAEConn (SENT 3)) -> c == aId; ("", "", AEvt _ UP {}) -> True; _ -> False - pGet b =##> \case ("", c, AEvt SAEConn (SENT 3)) -> c == aId; ("", "", AEvt _ UP {}) -> True; _ -> False - pGet b =##> \case ("", c, AEvt SAEConn (SENT 4)) -> c == aId; ("", "", AEvt _ UP {}) -> True; _ -> False - ("", "", SUSPENDED) <- nGet b - - pGet a =##> \case ("", c, AEvt _ (Msg "hello too")) -> c == bId; ("", "", AEvt _ UP {}) -> True; _ -> False - pGet a =##> \case ("", c, AEvt _ (Msg "hello too")) -> c == bId; ("", "", AEvt _ UP {}) -> True; _ -> False + -- there will be no UP event for b, because re-subscriptions are suspended until the agent is in foreground + get b =##> \case ("", c, SENT 3) -> c == aId; _ -> False + get b =##> \case ("", c, SENT 4) -> c == aId; _ -> False + nGet b ##> ("", "", SUSPENDED) + liftIO $ + getInAnyOrder + a + [ \case ("", c, AEvt _ (Msg "hello too")) -> c == bId; _ -> False, + \case ("", "", AEvt _ UP {}) -> True; _ -> False + ] ackMessage a bId 3 Nothing get a =##> \case ("", c, Msg "how are you?") -> c == bId; _ -> False ackMessage a bId 4 Nothing @@ -1726,7 +1858,7 @@ testBatchedSubscriptions :: Int -> Int -> ATransport -> IO () testBatchedSubscriptions nCreate nDel t = withAgentClientsCfgServers2 agentCfg agentCfg initAgentServers2 $ \a b -> do conns <- runServers $ do - conns <- replicateM nCreate $ makeConnection_ PQSupportOff a b + conns <- replicateM nCreate $ makeConnection_ PQSupportOff True a b forM_ conns $ \(aId, bId) -> exchangeGreetings_ PQEncOff a bId b aId let (aIds', bIds') = unzip $ take nDel conns delete a bIds' @@ -1804,15 +1936,59 @@ testBatchedPendingMessages nCreate nMsgs = withA = withAgent 1 agentCfg initAgentServers testDB withB = withAgent 2 agentCfg initAgentServers testDB2 -testAsyncCommands :: AgentClient -> AgentClient -> AgentMsgId -> IO () -testAsyncCommands alice bob baseId = +testSendMessagesB :: IO () +testSendMessagesB = withAgentClients2 $ \a b -> runRight_ $ do + (aId, bId) <- makeConnection a b + let msg cId body = Right (cId, PQEncOn, SMP.noMsgFlags, body) + [SentB 2, SentB 3, SentB 4] <- sendMessagesB a ([msg bId "msg 1", msg "" "msg 2", msg "" "msg 3"] :: [Either AgentErrorType MsgReq]) + get a ##> ("", bId, SENT 2) + get a ##> ("", bId, SENT 3) + get a ##> ("", bId, SENT 4) + receiveMsg b aId 2 "msg 1" + receiveMsg b aId 3 "msg 2" + receiveMsg b aId 4 "msg 3" + +testSendMessagesB2 :: IO () +testSendMessagesB2 = withAgentClients3 $ \a b c -> runRight_ $ do + (abId, bId) <- makeConnection a b + (acId, cId) <- makeConnection a c + let msg connId body = Right (connId, PQEncOn, SMP.noMsgFlags, body) + [SentB 2, SentB 3, SentB 4, SentB 2, SentB 3] <- + sendMessagesB a ([msg bId "msg 1", msg "" "msg 2", msg "" "msg 3", msg cId "msg 4", msg "" "msg 5"] :: [Either AgentErrorType MsgReq]) + liftIO $ + getInAnyOrder + a + [ \case ("", cId', AEvt SAEConn (SENT 2)) -> cId' == bId; _ -> False, + \case ("", cId', AEvt SAEConn (SENT 3)) -> cId' == bId; _ -> False, + \case ("", cId', AEvt SAEConn (SENT 4)) -> cId' == bId; _ -> False, + \case ("", cId', AEvt SAEConn (SENT 2)) -> cId' == cId; _ -> False, + \case ("", cId', AEvt SAEConn (SENT 3)) -> cId' == cId; _ -> False + ] + receiveMsg b abId 2 "msg 1" + receiveMsg b abId 3 "msg 2" + receiveMsg b abId 4 "msg 3" + receiveMsg c acId 2 "msg 4" + receiveMsg c acId 3 "msg 5" + +pattern SentB :: AgentMsgId -> Either AgentErrorType (AgentMsgId, PQEncryption) +pattern SentB msgId <- Right (msgId, PQEncOn) + +receiveMsg :: AgentClient -> ConnId -> AgentMsgId -> MsgBody -> ExceptT AgentErrorType IO () +receiveMsg c cId msgId msg = do + get c =##> \case ("", cId', Msg' mId' PQEncOn msg') -> cId' == cId && mId' == msgId && msg' == msg; _ -> False + ackMessage c cId msgId Nothing + +testAsyncCommands :: SndQueueSecured -> AgentClient -> AgentClient -> AgentMsgId -> IO () +testAsyncCommands sqSecured alice bob baseId = runRight_ $ do bobId <- createConnectionAsync alice 1 "1" True SCMInvitation (IKNoPQ PQSupportOn) SMSubscribe ("1", bobId', INV (ACR _ qInfo)) <- get alice liftIO $ bobId' `shouldBe` bobId aliceId <- joinConnectionAsync bob 1 "2" True qInfo "bob's connInfo" PQSupportOn SMSubscribe - ("2", aliceId', OK) <- get bob - liftIO $ aliceId' `shouldBe` aliceId + ("2", aliceId', JOINED sqSecured') <- get bob + liftIO $ do + aliceId' `shouldBe` aliceId + sqSecured' `shouldBe` sqSecured ("", _, CONF confId _ "bob's connInfo") <- get alice allowConnectionAsync alice "3" bobId confId "alice's connInfo" get alice =##> \case ("3", _, OK) -> True; _ -> False @@ -1865,14 +2041,15 @@ testAsyncCommandsRestore t = do get alice' =##> \case ("1", _, INV _) -> True; _ -> False pure () -testAcceptContactAsync :: AgentClient -> AgentClient -> AgentMsgId -> IO () -testAcceptContactAsync alice bob baseId = +testAcceptContactAsync :: SndQueueSecured -> AgentClient -> AgentClient -> AgentMsgId -> IO () +testAcceptContactAsync sqSecured alice bob baseId = runRight_ $ do (_, qInfo) <- createConnection alice 1 True SCMContact Nothing SMSubscribe - aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe + (aliceId, sqSecuredJoin) <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe + liftIO $ sqSecuredJoin `shouldBe` False -- joining via contact address connection ("", _, REQ invId _ "bob's connInfo") <- get alice bobId <- acceptContactAsync alice "1" True invId "alice's connInfo" PQSupportOn SMSubscribe - get alice =##> \case ("1", c, OK) -> c == bobId; _ -> False + get alice =##> \case ("1", c, JOINED sqSecured') -> c == bobId && sqSecured' == sqSecured; _ -> False ("", _, CONF confId _ "alice's connInfo") <- get bob allowConnection bob aliceId confId "bob's connInfo" get alice ##> ("", bobId, INFO "bob's connInfo") @@ -1947,7 +2124,7 @@ testWaitDeliveryNoPending t = withAgentClients2 $ \alice bob -> liftIO $ noMessages alice "nothing else should be delivered to alice" liftIO $ noMessages bob "nothing else should be delivered to bob" where - baseId = 3 + baseId = 1 msgId = subtract baseId testWaitDelivery :: ATransport -> IO () @@ -2001,7 +2178,7 @@ testWaitDelivery t = liftIO $ noMessages alice "nothing else should be delivered to alice" liftIO $ noMessages bob "nothing else should be delivered to bob" where - baseId = 3 + baseId = 1 msgId = subtract baseId testWaitDeliveryAUTHErr :: ATransport -> IO () @@ -2044,7 +2221,7 @@ testWaitDeliveryAUTHErr t = liftIO $ noMessages alice "nothing else should be delivered to alice" liftIO $ noMessages bob "nothing else should be delivered to bob" where - baseId = 3 + baseId = 1 msgId = subtract baseId testWaitDeliveryTimeout :: ATransport -> IO () @@ -2084,7 +2261,7 @@ testWaitDeliveryTimeout t = liftIO $ noMessages alice "nothing else should be delivered to alice" liftIO $ noMessages bob "nothing else should be delivered to bob" where - baseId = 3 + baseId = 1 msgId = subtract baseId testWaitDeliveryTimeout2 :: ATransport -> IO () @@ -2130,12 +2307,12 @@ testWaitDeliveryTimeout2 t = liftIO $ noMessages alice "nothing else should be delivered to alice" liftIO $ noMessages bob "nothing else should be delivered to bob" where - baseId = 3 + baseId = 1 msgId = subtract baseId testJoinConnectionAsyncReplyErrorV8 :: HasCallStack => ATransport -> IO () testJoinConnectionAsyncReplyErrorV8 t = do - let initAgentServersSrv2 = initAgentServers {smp = userServers [noAuthSrv testSMPServer2]} + let initAgentServersSrv2 = initAgentServers {smp = userServers [testSMPServer2]} withAgent 1 agentCfgVPrevPQ initAgentServers testDB $ \a -> withAgent 2 agentCfgVPrevPQ initAgentServersSrv2 testDB2 $ \b -> do (aId, bId) <- withSmpServerStoreLogOn t testPort $ \_ -> runRight $ do @@ -2148,7 +2325,7 @@ testJoinConnectionAsyncReplyErrorV8 t = do pure (aId, bId) nGet a =##> \case ("", "", DOWN _ [c]) -> c == bId; _ -> False withSmpServerOn t testPort2 $ do - get b =##> \case ("2", c, OK) -> c == aId; _ -> False + get b =##> \case ("2", c, JOINED sqSecured) -> c == aId && not sqSecured; _ -> False confId <- withSmpServerStoreLogOn t testPort $ \_ -> do pGet a >>= \case ("", "", AEvt _ (UP _ [_])) -> do @@ -2174,7 +2351,7 @@ testJoinConnectionAsyncReplyErrorV8 t = do testJoinConnectionAsyncReplyError :: HasCallStack => ATransport -> IO () testJoinConnectionAsyncReplyError t = do - let initAgentServersSrv2 = initAgentServers {smp = userServers [noAuthSrv testSMPServer2]} + let initAgentServersSrv2 = initAgentServers {smp = userServers [testSMPServer2]} withAgent 1 agentCfg initAgentServers testDB $ \a -> withAgent 2 agentCfg initAgentServersSrv2 testDB2 $ \b -> do (aId, bId) <- withSmpServerStoreLogOn t testPort $ \_ -> runRight $ do @@ -2189,7 +2366,7 @@ testJoinConnectionAsyncReplyError t = do withSmpServerOn t testPort2 $ do confId <- withSmpServerStoreLogOn t testPort $ \_ -> do -- both servers need to be online for connection to progress because of SKEY - get b =##> \case ("2", c, OK) -> c == aId; _ -> False + get b =##> \case ("2", c, JOINED sqSecured) -> c == aId && sqSecured; _ -> False pGet a >>= \case ("", "", AEvt _ (UP _ [_])) -> do ("", _, CONF confId _ "bob's connInfo") <- get a @@ -2217,7 +2394,7 @@ testUsers = withAgentClients2 $ \a b -> runRight_ $ do (aId, bId) <- makeConnection a b exchangeGreetings a bId b aId - auId <- createUser a [noAuthSrv testSMPServer] [noAuthSrv testXFTPServer] + auId <- createUser a [noAuthSrvCfg testSMPServer] [noAuthSrvCfg testXFTPServer] (aId', bId') <- makeConnectionForUsers a auId b 1 exchangeGreetings a bId' b aId' deleteUser a auId True @@ -2232,7 +2409,7 @@ testDeleteUserQuietly = withAgentClients2 $ \a b -> runRight_ $ do (aId, bId) <- makeConnection a b exchangeGreetings a bId b aId - auId <- createUser a [noAuthSrv testSMPServer] [noAuthSrv testXFTPServer] + auId <- createUser a [noAuthSrvCfg testSMPServer] [noAuthSrvCfg testXFTPServer] (aId', bId') <- makeConnectionForUsers a auId b 1 exchangeGreetings a bId' b aId' deleteUser a auId False @@ -2244,7 +2421,7 @@ testUsersNoServer t = withAgentClientsCfg2 aCfg agentCfg $ \a b -> do (aId, bId, auId, _aId', bId') <- withSmpServerStoreLogOn t testPort $ \_ -> runRight $ do (aId, bId) <- makeConnection a b exchangeGreetings a bId b aId - auId <- createUser a [noAuthSrv testSMPServer] [noAuthSrv testXFTPServer] + auId <- createUser a [noAuthSrvCfg testSMPServer] [noAuthSrvCfg testXFTPServer] (aId', bId') <- makeConnectionForUsers a auId b 1 exchangeGreetings a bId' b aId' pure (aId, bId, auId, aId', bId') @@ -2643,8 +2820,8 @@ testSwitch2ConnectionsAbort1 servers = do withB :: (AgentClient -> IO a) -> IO a withB = withAgent 2 agentCfg servers testDB2 -testCreateQueueAuth :: HasCallStack => VersionSMP -> (Maybe BasicAuth, VersionSMP) -> (Maybe BasicAuth, VersionSMP) -> AgentMsgId -> IO Int -testCreateQueueAuth srvVersion clnt1 clnt2 baseId = do +testCreateQueueAuth :: HasCallStack => VersionSMP -> (Maybe BasicAuth, VersionSMP) -> (Maybe BasicAuth, VersionSMP) -> SndQueueSecured -> AgentMsgId -> IO Int +testCreateQueueAuth srvVersion clnt1 clnt2 sqSecured baseId = do a <- getClient 1 clnt1 testDB b <- getClient 2 clnt2 testDB2 r <- runRight $ do @@ -2655,7 +2832,8 @@ testCreateQueueAuth srvVersion clnt1 clnt2 baseId = do tryError (joinConnection b 1 True qInfo "bob's connInfo" SMSubscribe) >>= \case Left (SMP _ AUTH) -> pure 1 Left e -> throwError e - Right aId -> do + Right (aId, sqSecured') -> do + liftIO $ sqSecured' `shouldBe` sqSecured ("", _, CONF confId _ "bob's connInfo") <- get a allowConnection a bId confId "alice's connInfo" get a ##> ("", bId, CON) @@ -2668,7 +2846,7 @@ testCreateQueueAuth srvVersion clnt1 clnt2 baseId = do pure r where getClient clientId (clntAuth, clntVersion) db = - let servers = initAgentServers {smp = userServers [ProtoServerWithAuth testSMPServer clntAuth]} + let servers = initAgentServers {smp = userServers' [ProtoServerWithAuth testSMPServer clntAuth]} alpn_ = if clntVersion >= authCmdsSMPVersion then Just supportedSMPHandshakes else Nothing smpCfg = defaultClientConfig alpn_ $ V.mkVersionRange (prevVersion basicAuthSMPVersion) clntVersion sndAuthAlg = if srvVersion >= authCmdsSMPVersion && clntVersion >= authCmdsSMPVersion then C.AuthAlg C.SX25519 else C.AuthAlg C.SEd25519 @@ -2715,7 +2893,7 @@ testDeliveryReceiptsVersion t = do b <- getSMPAgentClient' 2 agentCfg {smpAgentVRange = mkVersionRange 1 3} initAgentServers testDB2 withSmpServerStoreMsgLogOn t testPort $ \_ -> do (aId, bId) <- runRight $ do - (aId, bId) <- makeConnection_ PQSupportOff a b + (aId, bId) <- makeConnection_ PQSupportOff False a b checkVersion a bId 3 checkVersion b aId 3 (2, _) <- A.sendMessage a bId PQEncOff SMP.noMsgFlags "hello" @@ -2739,8 +2917,8 @@ testDeliveryReceiptsVersion t = do subscribeConnection a' bId subscribeConnection b' aId exchangeGreetingsMsgId_ PQEncOff 4 a' bId b' aId - checkVersion a' bId 6 - checkVersion b' aId 6 + checkVersion a' bId 7 + checkVersion b' aId 7 (6, PQEncOff) <- A.sendMessage a' bId PQEncOn SMP.noMsgFlags "hello" get a' ##> ("", bId, SENT 6) get b' =##> \case ("", c, Msg' 6 PQEncOff "hello") -> c == aId; _ -> False @@ -2840,7 +3018,7 @@ testTwoUsers = withAgentClients2 $ \a b -> do ("", "", UP _ _) <- nGet a a `hasClients` 1 - aUserId2 <- createUser a [noAuthSrv testSMPServer] [noAuthSrv testXFTPServer] + aUserId2 <- createUser a [noAuthSrvCfg testSMPServer] [noAuthSrvCfg testXFTPServer] (aId2, bId2) <- makeConnectionForUsers a aUserId2 b 1 exchangeGreetings a bId2 b aId2 (aId2', bId2') <- makeConnectionForUsers a aUserId2 b 1 @@ -2889,7 +3067,8 @@ testServerMultipleIdentities :: HasCallStack => IO () testServerMultipleIdentities = withAgentClients2 $ \alice bob -> runRight_ $ do (bobId, cReq) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe - aliceId <- joinConnection bob 1 True cReq "bob's connInfo" SMSubscribe + (aliceId, sqSecured) <- joinConnection bob 1 True cReq "bob's connInfo" SMSubscribe + liftIO $ sqSecured `shouldBe` True ("", _, CONF confId _ "bob's connInfo") <- get alice allowConnection alice bobId confId "alice's connInfo" get alice ##> ("", bobId, CON) @@ -2900,6 +3079,7 @@ testServerMultipleIdentities = bob' <- liftIO $ do Left (BROKER _ NETWORK) <- runExceptT $ joinConnection bob 1 True secondIdentityCReq "bob's connInfo" SMSubscribe disposeAgentClient bob + threadDelay 250000 getSMPAgentClient' 3 agentCfg initAgentServers testDB2 subscribeConnection bob' aliceId exchangeGreetingsMsgId 4 alice bobId bob' aliceId @@ -2987,7 +3167,8 @@ testServerQueueInfo = do (bobId, cReq) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe liftIO $ threadDelay 200000 checkEmptyQ alice bobId False - aliceId <- joinConnection bob 1 True cReq "bob's connInfo" SMSubscribe + (aliceId, sqSecured) <- joinConnection bob 1 True cReq "bob's connInfo" SMSubscribe + liftIO $ sqSecured `shouldBe` True ("", _, CONF confId _ "bob's connInfo") <- get alice liftIO $ threadDelay 200000 checkEmptyQ alice bobId True -- secured by sender @@ -3049,14 +3230,14 @@ testServerQueueInfo = do pure () where checkEmptyQ c cId qiSnd' = do - r <- checkQ c cId qiSnd' (Just QSubThread) 0 Nothing + r <- checkQ c cId qiSnd' (Just QNoSub) 0 Nothing liftIO $ r `shouldBe` Nothing checkMsgQ c cId qiSize' = do r <- checkQ c cId True (Just QNoSub) qiSize' (Just MTMessage) liftIO $ isJust r `shouldBe` True pure r checkQ c cId qiSnd' qiSubThread_ qiSize' msgType_ = do - QueueInfo {qiSnd, qiNtf, qiSub, qiSize, qiMsg} <- getConnectionQueueInfo c cId + ServerQueueInfo {info = QueueInfo {qiSnd, qiNtf, qiSub, qiSize, qiMsg}} <- getConnectionQueueInfo c cId liftIO $ do qiSnd `shouldBe` qiSnd' qiNtf `shouldBe` False @@ -3087,7 +3268,7 @@ exchangeGreetings :: HasCallStack => AgentClient -> ConnId -> AgentClient -> Con exchangeGreetings = exchangeGreetings_ PQEncOn exchangeGreetings_ :: HasCallStack => PQEncryption -> AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO () -exchangeGreetings_ pqEnc = exchangeGreetingsMsgId_ pqEnc 4 +exchangeGreetings_ pqEnc = exchangeGreetingsMsgId_ pqEnc 2 exchangeGreetingsMsgId :: HasCallStack => Int64 -> AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO () exchangeGreetingsMsgId = exchangeGreetingsMsgId_ PQEncOn diff --git a/tests/AgentTests/NotificationTests.hs b/tests/AgentTests/NotificationTests.hs index 603ffd3c0..cc79faeca 100644 --- a/tests/AgentTests/NotificationTests.hs +++ b/tests/AgentTests/NotificationTests.hs @@ -5,6 +5,7 @@ {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE TypeApplications #-} {-# OPTIONS_GHC -fno-warn-ambiguous-fields #-} @@ -49,6 +50,7 @@ import qualified Data.ByteString.Base64.URL as U import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Text.Encoding (encodeUtf8) +import Database.SQLite.Simple.QQ (sql) import NtfClient import SMPAgentClient (agentCfg, initAgentServers, initAgentServers2, testDB, testDB2, testNtfServer, testNtfServer2) import SMPClient (cfg, cfgVPrev, testPort, testPort2, testStoreLogFile2, withSmpServer, withSmpServerConfigOn, withSmpServerStoreLogOn) @@ -56,13 +58,14 @@ import Simplex.Messaging.Agent hiding (createConnection, joinConnection, sendMes import Simplex.Messaging.Agent.Client (ProtocolTestFailure (..), ProtocolTestStep (..), withStore') import Simplex.Messaging.Agent.Env.SQLite (AgentConfig, Env (..), InitialAgentServers) import Simplex.Messaging.Agent.Protocol hiding (CON, CONF, INFO, SENT) -import Simplex.Messaging.Agent.Store.SQLite (closeSQLiteStore, getSavedNtfToken, reopenSQLiteStore) +import Simplex.Messaging.Agent.Store.SQLite (closeSQLiteStore, getSavedNtfToken, reopenSQLiteStore, withTransaction) +import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding.String import Simplex.Messaging.Notifications.Protocol import Simplex.Messaging.Notifications.Server.Env (NtfServerConfig (..)) import Simplex.Messaging.Notifications.Server.Push.APNS -import Simplex.Messaging.Notifications.Types (NtfToken (..)) +import Simplex.Messaging.Notifications.Types (NtfTknAction (..), NtfToken (..)) import Simplex.Messaging.Protocol (ErrorType (AUTH), MsgFlags (MsgFlags), NtfServer, ProtocolServer (..), SMPMsgMeta (..), SubscriptionMode (..)) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Server.Env.STM (ServerConfig (..)) @@ -88,9 +91,21 @@ notificationTests t = do it "should allow the second registration with different credentials and delete the first after verification" $ withAPNSMockServer $ \apns -> withNtfServer t $ testNtfTokenSecondRegistration apns - it "should re-register token when notification server is restarted" $ + it "should verify token after notification server is restarted" $ withAPNSMockServer $ \apns -> testNtfTokenServerRestart t apns + it "should re-verify token after notification server is restarted" $ + withAPNSMockServer $ \apns -> + testNtfTokenServerRestartReverify t apns + it "should re-verify token after notification server is restarted when first request timed-out" $ + withAPNSMockServer $ \apns -> + testNtfTokenServerRestartReverifyTimeout t apns + it "should re-register token when notification server is restarted" $ + withAPNSMockServer $ \apns -> + testNtfTokenServerRestartReregister t apns + it "should re-register token when notification server is restarted when first request timed-out" $ + withAPNSMockServer $ \apns -> + testNtfTokenServerRestartReregisterTimeout t apns it "should work with multiple configured servers" $ withAPNSMockServer $ \apns -> testNtfTokenMultipleServers t apns @@ -105,7 +120,7 @@ notificationTests t = do describe "Managing notification subscriptions" $ do describe "should create notification subscription for existing connection" $ testNtfMatrix t testNotificationSubscriptionExistingConnection - xdescribe "should create notification subscription for new connection" $ + describe "should create notification subscription for new connection" $ testNtfMatrix t testNotificationSubscriptionNewConnection it "should change notifications mode" $ withSmpServer t $ @@ -116,19 +131,19 @@ notificationTests t = do withAPNSMockServer $ \apns -> withNtfServer t $ testChangeToken apns describe "Notifications server store log" $ - xit "should save and restore tokens and subscriptions" $ + it "should save and restore tokens and subscriptions" $ withSmpServer t $ withAPNSMockServer $ \apns -> testNotificationsStoreLog t apns describe "Notifications after SMP server restart" $ - xit "should resume subscriptions after SMP server is restarted" $ + it "should resume subscriptions after SMP server is restarted" $ withAPNSMockServer $ \apns -> withNtfServer t $ testNotificationsSMPRestart t apns describe "Notifications after SMP server restart" $ it "should resume batched subscriptions after SMP server is restarted" $ withAPNSMockServer $ \apns -> withNtfServer t $ testNotificationsSMPRestartBatch 100 t apns - xdescribe "should switch notifications to the new queue" $ + describe "should switch notifications to the new queue" $ testServerMatrix2 t $ \servers -> withAPNSMockServer $ \apns -> withNtfServer t $ testSwitchNotifications servers apns @@ -146,7 +161,7 @@ notificationTests t = do testNtfMatrix :: HasCallStack => ATransport -> (APNSMockServer -> AgentMsgId -> AgentClient -> AgentClient -> IO ()) -> Spec testNtfMatrix t runTest = do describe "next and current" $ do - xit "curr servers; curr clients" $ runNtfTestCfg t 1 cfg ntfServerCfg agentCfg agentCfg runTest + it "curr servers; curr clients" $ runNtfTestCfg t 1 cfg ntfServerCfg agentCfg agentCfg runTest it "curr servers; prev clients" $ runNtfTestCfg t 3 cfg ntfServerCfg agentCfgVPrevPQ agentCfgVPrevPQ runTest it "prev servers; prev clients" $ runNtfTestCfg t 3 cfgVPrev ntfServerCfgVPrev agentCfgVPrevPQ agentCfgVPrevPQ runTest it "prev servers; curr clients" $ runNtfTestCfg t 3 cfgVPrev ntfServerCfgVPrev agentCfg agentCfg runTest @@ -251,7 +266,7 @@ testNtfTokenServerRestart :: ATransport -> APNSMockServer -> IO () testNtfTokenServerRestart t APNSMockServer {apnsQ} = do let tkn = DeviceToken PPApnsTest "abcd" ntfData <- withAgent 1 agentCfg initAgentServers testDB $ \a -> - withNtfServer t . runRight $ do + withNtfServerStoreLog t $ \_ -> runRight $ do NTRegistered <- registerNtfToken a tkn NMPeriodic APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <- atomically $ readTBQueue apnsQ @@ -262,16 +277,131 @@ testNtfTokenServerRestart t APNSMockServer {apnsQ} = do withAgent 2 agentCfg initAgentServers testDB $ \a' -> -- server stopped before token is verified, so now the attempt to verify it will return AUTH error but re-register token, -- so that repeat verification happens without restarting the clients, when notification arrives - withNtfServer t . runRight_ $ do + withNtfServerStoreLog t $ \_ -> runRight_ $ do verification <- ntfData .-> "verification" nonce <- C.cbNonce <$> ntfData .-> "nonce" - Left (NTF _ AUTH) <- tryE $ verifyNtfToken a' tkn nonce verification - APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData'}, sendApnsResponse = sendApnsResponse'} <- + verifyNtfToken a' tkn nonce verification + NTActive <- checkNtfToken a' tkn + pure () + +testNtfTokenServerRestartReverify :: ATransport -> APNSMockServer -> IO () +testNtfTokenServerRestartReverify t APNSMockServer {apnsQ} = do + let tkn = DeviceToken PPApnsTest "abcd" + withAgent 1 agentCfg initAgentServers testDB $ \a -> do + ntfData <- withNtfServerStoreLog t $ \_ -> runRight $ do + NTRegistered <- registerNtfToken a tkn NMPeriodic + APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <- atomically $ readTBQueue apnsQ - verification' <- ntfData' .-> "verification" - nonce' <- C.cbNonce <$> ntfData' .-> "nonce" - liftIO $ sendApnsResponse' APNSRespOk - verifyNtfToken a' tkn nonce' verification' + liftIO $ sendApnsResponse APNSRespOk + pure ntfData + runRight_ $ do + verification <- ntfData .-> "verification" + nonce <- C.cbNonce <$> ntfData .-> "nonce" + Left (BROKER _ NETWORK) <- tryE $ verifyNtfToken a tkn nonce verification + pure () + threadDelay 1000000 + withAgent 2 agentCfg initAgentServers testDB $ \a' -> + -- server stopped before token is verified, so now the attempt to verify it will return AUTH error but re-register token, + -- so that repeat verification happens without restarting the clients, when notification arrives + withNtfServerStoreLog t $ \_ -> runRight_ $ do + NTActive <- registerNtfToken a' tkn NMPeriodic + NTActive <- checkNtfToken a' tkn + pure () + +testNtfTokenServerRestartReverifyTimeout :: ATransport -> APNSMockServer -> IO () +testNtfTokenServerRestartReverifyTimeout t APNSMockServer {apnsQ} = do + let tkn = DeviceToken PPApnsTest "abcd" + withAgent 1 agentCfg initAgentServers testDB $ \a@AgentClient {agentEnv = Env {store}} -> do + (nonce, verification) <- withNtfServerStoreLog t $ \_ -> runRight $ do + NTRegistered <- registerNtfToken a tkn NMPeriodic + APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <- + atomically $ readTBQueue apnsQ + liftIO $ sendApnsResponse APNSRespOk + verification <- ntfData .-> "verification" + nonce <- C.cbNonce <$> ntfData .-> "nonce" + verifyNtfToken a tkn nonce verification + pure (nonce, verification) + -- this emulates the situation when server verified token but the client did not receive the response + Just NtfToken {ntfTknStatus = NTActive, ntfTknAction = Just NTACheck, ntfDhSecret = Just dhSecret} <- withTransaction store getSavedNtfToken + Right code <- pure $ NtfRegCode <$> C.cbDecrypt dhSecret nonce verification + withTransaction store $ \db -> + DB.execute + db + [sql| + UPDATE ntf_tokens + SET tkn_status = ?, tkn_action = ? + WHERE provider = ? AND device_token = ? + |] + (NTConfirmed, Just (NTAVerify code), PPApnsTest, "abcd" :: ByteString) + Just NtfToken {ntfTknStatus = NTConfirmed, ntfTknAction = Just (NTAVerify _)} <- withTransaction store getSavedNtfToken + pure () + threadDelay 1000000 + withAgent 2 agentCfg initAgentServers testDB $ \a' -> + -- server stopped before token is verified, so now the attempt to verify it will return AUTH error but re-register token, + -- so that repeat verification happens without restarting the clients, when notification arrives + withNtfServerStoreLog t $ \_ -> runRight_ $ do + NTActive <- registerNtfToken a' tkn NMPeriodic + NTActive <- checkNtfToken a' tkn + pure () + +testNtfTokenServerRestartReregister :: ATransport -> APNSMockServer -> IO () +testNtfTokenServerRestartReregister t APNSMockServer {apnsQ} = do + let tkn = DeviceToken PPApnsTest "abcd" + withAgent 1 agentCfg initAgentServers testDB $ \a -> + withNtfServerStoreLog t $ \_ -> runRight $ do + NTRegistered <- registerNtfToken a tkn NMPeriodic + APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just _}, sendApnsResponse} <- + atomically $ readTBQueue apnsQ + liftIO $ sendApnsResponse APNSRespOk + -- the new agent is created as otherwise when running the tests in CI the old agent was keeping the connection to the server + threadDelay 1000000 + withAgent 2 agentCfg initAgentServers testDB $ \a' -> + -- server stopped before token is verified, and client might have lost verification notification. + -- so that repeat registration happens when client is restarted. + withNtfServerStoreLog t $ \_ -> runRight_ $ do + NTRegistered <- registerNtfToken a' tkn NMPeriodic + APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <- + atomically $ readTBQueue apnsQ + liftIO $ sendApnsResponse APNSRespOk + verification <- ntfData .-> "verification" + nonce <- C.cbNonce <$> ntfData .-> "nonce" + verifyNtfToken a' tkn nonce verification + NTActive <- checkNtfToken a' tkn + pure () + +testNtfTokenServerRestartReregisterTimeout :: ATransport -> APNSMockServer -> IO () +testNtfTokenServerRestartReregisterTimeout t APNSMockServer {apnsQ} = do + let tkn = DeviceToken PPApnsTest "abcd" + withAgent 1 agentCfg initAgentServers testDB $ \a@AgentClient {agentEnv = Env {store}} -> do + withNtfServerStoreLog t $ \_ -> runRight $ do + NTRegistered <- registerNtfToken a tkn NMPeriodic + APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just _}, sendApnsResponse} <- + atomically $ readTBQueue apnsQ + liftIO $ sendApnsResponse APNSRespOk + -- this emulates the situation when server registered token but the client did not receive the response + withTransaction store $ \db -> + DB.execute + db + [sql| + UPDATE ntf_tokens + SET tkn_id = NULL, tkn_dh_secret = NULL, tkn_status = ?, tkn_action = ? + WHERE provider = ? AND device_token = ? + |] + (NTNew, Just NTARegister, PPApnsTest, "abcd" :: ByteString) + Just NtfToken {ntfTokenId = Nothing, ntfTknStatus = NTNew, ntfTknAction = Just NTARegister} <- withTransaction store getSavedNtfToken + pure () + threadDelay 1000000 + withAgent 2 agentCfg initAgentServers testDB $ \a' -> + -- server stopped before token is verified, and client might have lost verification notification. + -- so that repeat registration happens when client is restarted. + withNtfServerStoreLog t $ \_ -> runRight_ $ do + NTRegistered <- registerNtfToken a' tkn NMPeriodic + APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <- + atomically $ readTBQueue apnsQ + liftIO $ sendApnsResponse APNSRespOk + verification <- ntfData .-> "verification" + nonce <- C.cbNonce <$> ntfData .-> "nonce" + verifyNtfToken a' tkn nonce verification NTActive <- checkNtfToken a' tkn pure () @@ -347,7 +477,7 @@ testNotificationSubscriptionExistingConnection APNSMockServer {apnsQ} baseId ali (bobId, aliceId, nonce, message) <- runRight $ do -- establish connection (bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe - aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe + (aliceId, _sqSecured) <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe ("", _, CONF confId _ "bob's connInfo") <- get alice allowConnection alice bobId confId "alice's connInfo" get bob ##> ("", aliceId, INFO "alice's connInfo") @@ -374,27 +504,27 @@ testNotificationSubscriptionExistingConnection APNSMockServer {apnsQ} baseId ali -- alice client already has subscription for the connection Left (CMD PROHIBITED _) <- runExceptT $ getNotificationMessage alice nonce message - threadDelay 200000 + threadDelay 500000 suspendAgent alice 0 closeSQLiteStore store - threadDelay 200000 + threadDelay 1000000 -- aliceNtf client doesn't have subscription and is allowed to get notification message withAgent 3 aliceCfg initAgentServers testDB $ \aliceNtf -> runRight_ $ do - (_, [SMPMsgMeta {msgFlags = MsgFlags True}]) <- getNotificationMessage aliceNtf nonce message + (_, Just SMPMsgMeta {msgFlags = MsgFlags True}) <- getNotificationMessage aliceNtf nonce message pure () - threadDelay 200000 + threadDelay 1000000 reopenSQLiteStore store foregroundAgent alice - threadDelay 200000 + threadDelay 500000 runRight_ $ do get alice =##> \case ("", c, Msg "hello") -> c == bobId; _ -> False ackMessage alice bobId (baseId + 1) Nothing -- delete notification subscription toggleConnectionNtfs alice bobId False - liftIO $ threadDelay 250000 + liftIO $ threadDelay 500000 -- send message 2 <- msgId <$> sendMessage bob aliceId (SMP.MsgFlags True) "hello again" get bob ##> ("", aliceId, SENT $ baseId + 2) @@ -414,7 +544,7 @@ testNotificationSubscriptionNewConnection APNSMockServer {apnsQ} baseId alice bo liftIO $ threadDelay 50000 (bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe liftIO $ threadDelay 1000000 - aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe + (aliceId, _sqSecured) <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe liftIO $ threadDelay 750000 void $ messageNotificationData alice apnsQ ("", _, CONF confId _ "bob's connInfo") <- get alice @@ -461,7 +591,8 @@ testChangeNotificationsMode APNSMockServer {apnsQ} = withAgentClients2 $ \alice bob -> runRight_ $ do -- establish connection (bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe - aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe + (aliceId, sqSecured) <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe + liftIO $ sqSecured `shouldBe` True ("", _, CONF confId _ "bob's connInfo") <- get alice allowConnection alice bobId confId "alice's connInfo" get bob ##> ("", aliceId, INFO "alice's connInfo") @@ -515,7 +646,7 @@ testChangeNotificationsMode APNSMockServer {apnsQ} = -- no notifications should follow noNotification apnsQ where - baseId = 3 + baseId = 1 msgId = subtract baseId testChangeToken :: APNSMockServer -> IO () @@ -523,7 +654,8 @@ testChangeToken APNSMockServer {apnsQ} = withAgent 1 agentCfg initAgentServers t (aliceId, bobId) <- withAgent 2 agentCfg initAgentServers testDB $ \alice -> runRight $ do -- establish connection (bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe - aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe + (aliceId, sqSecured) <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe + liftIO $ sqSecured `shouldBe` True ("", _, CONF confId _ "bob's connInfo") <- get alice allowConnection alice bobId confId "alice's connInfo" get bob ##> ("", aliceId, INFO "alice's connInfo") @@ -554,7 +686,7 @@ testChangeToken APNSMockServer {apnsQ} = withAgent 1 agentCfg initAgentServers t -- no notifications should follow noNotification apnsQ where - baseId = 3 + baseId = 1 msgId = subtract baseId testNotificationsStoreLog :: ATransport -> APNSMockServer -> IO () diff --git a/tests/CoreTests/BatchingTests.hs b/tests/CoreTests/BatchingTests.hs index caab0637a..5f6beb034 100644 --- a/tests/CoreTests/BatchingTests.hs +++ b/tests/CoreTests/BatchingTests.hs @@ -261,7 +261,7 @@ testClientStub :: IO (ProtocolClient SMPVersion ErrorType BrokerMsg) testClientStub = do g <- C.newRandom sessId <- atomically $ C.randomBytes 32 g - atomically $ smpClientStub g sessId subModeSMPVersion Nothing + smpClientStub g sessId subModeSMPVersion Nothing clientStubV7 :: IO (ProtocolClient SMPVersion ErrorType BrokerMsg) clientStubV7 = do @@ -269,7 +269,7 @@ clientStubV7 = do sessId <- atomically $ C.randomBytes 32 g (rKey, _) <- atomically $ C.generateAuthKeyPair C.SX25519 g thAuth_ <- testTHandleAuth authCmdsSMPVersion g rKey - atomically $ smpClientStub g sessId authCmdsSMPVersion thAuth_ + smpClientStub g sessId authCmdsSMPVersion thAuth_ randomSUB :: ByteString -> IO (Either TransportError (Maybe TransmissionAuth, ByteString)) randomSUB = randomSUB_ C.SEd25519 subModeSMPVersion diff --git a/tests/CoreTests/RetryIntervalTests.hs b/tests/CoreTests/RetryIntervalTests.hs index 7097df989..da96d0208 100644 --- a/tests/CoreTests/RetryIntervalTests.hs +++ b/tests/CoreTests/RetryIntervalTests.hs @@ -2,6 +2,8 @@ module CoreTests.RetryIntervalTests where +import Control.Concurrent (threadDelay) +import Control.Concurrent.Async (concurrently_) import Control.Concurrent.STM import Control.Monad (when) import Data.Time.Clock (UTCTime, diffUTCTime, getCurrentTime, nominalDiffTimeToSeconds) @@ -13,6 +15,10 @@ retryIntervalTests = do describe "Retry interval with 2 modes and lock" $ do testRetryIntervalSameMode testRetryIntervalSwitchMode + describe "Foreground retry interval" $ do + testRetryForeground + testRetryToBackground + testRetrySkipWhenForeground testRI :: RetryInterval2 testRI = @@ -23,12 +29,15 @@ testRI = increaseAfter = 40000, maxInterval = 40000 }, - riFast = - RetryInterval - { initialInterval = 10000, - increaseAfter = 20000, - maxInterval = 40000 - } + riFast = testFastRI + } + +testFastRI :: RetryInterval +testFastRI = + RetryInterval + { initialInterval = 10000, + increaseAfter = 20000, + maxInterval = 40000 } testRetryIntervalSameMode :: Spec @@ -81,6 +90,67 @@ testRetryIntervalSwitchMode = (40000, 40000) ] +testRetryForeground :: Spec +testRetryForeground = + it "should increase elapased time and interval" $ do + intervals <- newTVarIO [] + reportedIntervals <- newTVarIO [] + ts <- newTVarIO =<< getCurrentTime + let isForeground = pure True + withRetryForeground testFastRI isForeground (pure True) $ \delay loop -> do + ints <- addInterval intervals ts + atomically $ modifyTVar' reportedIntervals (delay :) + when (length ints < 8) $ loop + (reverse <$> readTVarIO intervals) `shouldReturn` [0, 1, 1, 1, 2, 3, 4, 4] + (reverse <$> readTVarIO reportedIntervals) + `shouldReturn` [ 10000, 10000, 15000, 22500, 33750, 40000, 40000, 40000] + +testRetryToBackground :: Spec +testRetryToBackground = + it "should not change interval when moving to background" $ do + intervals <- newTVarIO [] + reportedIntervals <- newTVarIO [] + ts <- newTVarIO =<< getCurrentTime + foreground <- newTVarIO True + concurrently_ + ( do + threadDelay 50000 + atomically $ writeTVar foreground False + ) + ( withRetryForeground testFastRI (readTVar foreground) (pure True) $ \delay loop -> do + ints <- addInterval intervals ts + atomically $ modifyTVar' reportedIntervals (delay :) + when (length ints < 8) $ loop + ) + (reverse <$> readTVarIO intervals) `shouldReturn` [0, 1, 1, 1, 2, 3, 4, 4] + (reverse <$> readTVarIO reportedIntervals) + `shouldReturn` [ 10000, 10000, 15000, 22500, 33750, 40000, 40000, 40000] + +testRetrySkipWhenForeground :: Spec +testRetrySkipWhenForeground = + it "should repeat loop as soon as moving to foreground" $ do + intervals <- newTVarIO [] + reportedIntervals <- newTVarIO [] + ts <- newTVarIO =<< getCurrentTime + foreground <- newTVarIO False + concurrently_ + ( do + threadDelay 65000 + atomically $ writeTVar foreground True + threadDelay 10000 + atomically $ writeTVar foreground False + threadDelay 100000 + atomically $ writeTVar foreground True + ) + ( withRetryForeground testFastRI (readTVar foreground) (pure True) $ \delay loop -> do + ints <- addInterval intervals ts + atomically $ modifyTVar' reportedIntervals (delay :) + when (length ints < 12) $ loop + ) + (reverse <$> readTVarIO intervals) `shouldReturn` [0, 1, 1, 1, 2, 0, 1, 1, 1, 2, 3, 1] + (reverse <$> readTVarIO reportedIntervals) + `shouldReturn` [ 10000, 10000, 15000, 22500, 33750, 10000, 10000, 15000, 22500, 33750, 40000, 10000] + addInterval :: TVar [Int] -> TVar UTCTime -> IO [Int] addInterval intervals ts = do ts' <- getCurrentTime diff --git a/tests/CoreTests/TRcvQueuesTests.hs b/tests/CoreTests/TRcvQueuesTests.hs index 9f7c4932e..24d54fc8e 100644 --- a/tests/CoreTests/TRcvQueuesTests.hs +++ b/tests/CoreTests/TRcvQueuesTests.hs @@ -1,6 +1,7 @@ {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} module CoreTests.TRcvQueuesTests where @@ -30,19 +31,19 @@ tRcvQueuesTests = do describe "queue transfer" $ do it "getDelSessQueues-batchAddQueues preserves total length" removeSubsTest -checkDataInvariant :: RQ.TRcvQueues -> IO Bool +checkDataInvariant :: RQ.Queue q => RQ.TRcvQueues q -> IO Bool checkDataInvariant trq = atomically $ do conns <- readTVar $ RQ.getConnections trq qs <- readTVar $ RQ.getRcvQueues trq -- three invariant checks - let inv1 = all (\cId -> (S.fromList . L.toList <$> M.lookup cId conns) == Just (M.keysSet (M.filter (\q -> connId q == cId) qs))) (M.keys conns) - inv2 = all (\(k, q) -> maybe False ((k `elem`) . L.toList) (M.lookup (connId q) conns)) (M.assocs qs) + let inv1 = all (\cId -> (S.fromList . L.toList <$> M.lookup cId conns) == Just (M.keysSet (M.filter (\q -> RQ.connId' q == cId) qs))) (M.keys conns) + inv2 = all (\(k, q) -> maybe False ((k `elem`) . L.toList) (M.lookup (RQ.connId' q) conns)) (M.assocs qs) inv3 = all (\(k, q) -> RQ.qKey q == k) (M.assocs qs) pure $ inv1 && inv2 && inv3 hasConnTest :: IO () hasConnTest = do - trq <- atomically RQ.empty + trq <- RQ.empty atomically $ RQ.addQueue (dummyRQ 0 "smp://1234-w==@alpha" "c1") trq checkDataInvariant trq `shouldReturn` True atomically $ RQ.addQueue (dummyRQ 0 "smp://1234-w==@alpha" "c2") trq @@ -56,7 +57,7 @@ hasConnTest = do hasConnTestBatch :: IO () hasConnTestBatch = do - trq <- atomically RQ.empty + trq <- RQ.empty let qs = [dummyRQ 0 "smp://1234-w==@alpha" "c1", dummyRQ 0 "smp://1234-w==@alpha" "c2", dummyRQ 0 "smp://1234-w==@beta" "c3"] atomically $ RQ.batchAddQueues trq qs checkDataInvariant trq `shouldReturn` True @@ -67,7 +68,7 @@ hasConnTestBatch = do batchIdempotentTest :: IO () batchIdempotentTest = do - trq <- atomically RQ.empty + trq <- RQ.empty let qs = [dummyRQ 0 "smp://1234-w==@alpha" "c1", dummyRQ 0 "smp://1234-w==@alpha" "c2", dummyRQ 0 "smp://1234-w==@beta" "c3"] atomically $ RQ.batchAddQueues trq qs checkDataInvariant trq `shouldReturn` True @@ -76,11 +77,11 @@ batchIdempotentTest = do atomically $ RQ.batchAddQueues trq qs checkDataInvariant trq `shouldReturn` True readTVarIO (RQ.getRcvQueues trq) `shouldReturn` qs' - fmap L.nub <$> readTVarIO (RQ.getConnections trq) `shouldReturn`cs' -- connections get duplicated, but that doesn't appear to affect anybody + fmap L.nub <$> readTVarIO (RQ.getConnections trq) `shouldReturn` cs' -- connections get duplicated, but that doesn't appear to affect anybody deleteConnTest :: IO () deleteConnTest = do - trq <- atomically RQ.empty + trq <- RQ.empty atomically $ do RQ.addQueue (dummyRQ 0 "smp://1234-w==@alpha" "c1") trq RQ.addQueue (dummyRQ 0 "smp://1234-w==@alpha" "c2") trq @@ -94,7 +95,7 @@ deleteConnTest = do getSessQueuesTest :: IO () getSessQueuesTest = do - trq <- atomically RQ.empty + trq <- RQ.empty atomically $ RQ.addQueue (dummyRQ 0 "smp://1234-w==@alpha" "c1") trq checkDataInvariant trq `shouldReturn` True atomically $ RQ.addQueue (dummyRQ 0 "smp://1234-w==@alpha" "c2") trq @@ -103,32 +104,40 @@ getSessQueuesTest = do checkDataInvariant trq `shouldReturn` True atomically $ RQ.addQueue (dummyRQ 1 "smp://1234-w==@beta" "c4") trq checkDataInvariant trq `shouldReturn` True - atomically (RQ.getSessQueues (0, "smp://1234-w==@alpha", Just "c1") trq) `shouldReturn` [dummyRQ 0 "smp://1234-w==@alpha" "c1"] - atomically (RQ.getSessQueues (1, "smp://1234-w==@alpha", Just "c1") trq) `shouldReturn` [] - atomically (RQ.getSessQueues (0, "smp://1234-w==@alpha", Just "nope") trq) `shouldReturn` [] - atomically (RQ.getSessQueues (0, "smp://1234-w==@alpha", Nothing) trq) `shouldReturn` [dummyRQ 0 "smp://1234-w==@alpha" "c2", dummyRQ 0 "smp://1234-w==@alpha" "c1"] + let tSess1 = (0, "smp://1234-w==@alpha", Just "c1") + RQ.getSessQueues tSess1 trq `shouldReturn` [dummyRQ 0 "smp://1234-w==@alpha" "c1"] + atomically (RQ.hasSessQueues tSess1 trq) `shouldReturn` True + let tSess2 = (1, "smp://1234-w==@alpha", Just "c1") + RQ.getSessQueues tSess2 trq `shouldReturn` [] + atomically (RQ.hasSessQueues tSess2 trq) `shouldReturn` False + let tSess3 = (0, "smp://1234-w==@alpha", Just "nope") + RQ.getSessQueues tSess3 trq `shouldReturn` [] + atomically (RQ.hasSessQueues tSess3 trq) `shouldReturn` False + let tSess4 = (0, "smp://1234-w==@alpha", Nothing) + RQ.getSessQueues tSess4 trq `shouldReturn` [dummyRQ 0 "smp://1234-w==@alpha" "c2", dummyRQ 0 "smp://1234-w==@alpha" "c1"] + atomically (RQ.hasSessQueues tSess4 trq) `shouldReturn`True getDelSessQueuesTest :: IO () getDelSessQueuesTest = do - trq <- atomically RQ.empty + trq <- RQ.empty let qs = - [ dummyRQ 0 "smp://1234-w==@alpha" "c1", - dummyRQ 0 "smp://1234-w==@alpha" "c2", - dummyRQ 0 "smp://1234-w==@beta" "c3", - dummyRQ 1 "smp://1234-w==@beta" "c4" + [ ("1", dummyRQ 0 "smp://1234-w==@alpha" "c1"), + ("1", dummyRQ 0 "smp://1234-w==@alpha" "c2"), + ("1", dummyRQ 0 "smp://1234-w==@beta" "c3"), + ("1", dummyRQ 1 "smp://1234-w==@beta" "c4") ] atomically $ RQ.batchAddQueues trq qs checkDataInvariant trq `shouldReturn` True -- no user - atomically (RQ.getDelSessQueues (2, "smp://1234-w==@alpha", Nothing) trq) `shouldReturn` ([], []) + atomically (RQ.getDelSessQueues (2, "smp://1234-w==@alpha", Nothing) "1" trq) `shouldReturn` ([], []) checkDataInvariant trq `shouldReturn` True -- wrong user - atomically (RQ.getDelSessQueues (1, "smp://1234-w==@alpha", Nothing) trq) `shouldReturn` ([], []) + atomically (RQ.getDelSessQueues (1, "smp://1234-w==@alpha", Nothing) "1" trq) `shouldReturn` ([], []) checkDataInvariant trq `shouldReturn` True -- connections intact atomically (RQ.hasConn "c1" trq) `shouldReturn` True atomically (RQ.hasConn "c2" trq) `shouldReturn` True - atomically (RQ.getDelSessQueues (0, "smp://1234-w==@alpha", Nothing) trq) `shouldReturn` ([dummyRQ 0 "smp://1234-w==@alpha" "c2", dummyRQ 0 "smp://1234-w==@alpha" "c1"], ["c1", "c2"]) + atomically (RQ.getDelSessQueues (0, "smp://1234-w==@alpha", Nothing) "1" trq) `shouldReturn` ([dummyRQ 0 "smp://1234-w==@alpha" "c2", dummyRQ 0 "smp://1234-w==@alpha" "c1"], ["c1", "c2"]) checkDataInvariant trq `shouldReturn` True -- connections gone atomically (RQ.hasConn "c1" trq) `shouldReturn` False @@ -139,31 +148,31 @@ getDelSessQueuesTest = do removeSubsTest :: IO () removeSubsTest = do - aq <- atomically RQ.empty + aq <- RQ.empty let qs = - [ dummyRQ 0 "smp://1234-w==@alpha" "c1", - dummyRQ 0 "smp://1234-w==@alpha" "c2", - dummyRQ 0 "smp://1234-w==@beta" "c3", - dummyRQ 1 "smp://1234-w==@beta" "c4" + [ ("1", dummyRQ 0 "smp://1234-w==@alpha" "c1"), + ("1", dummyRQ 0 "smp://1234-w==@alpha" "c2"), + ("1", dummyRQ 0 "smp://1234-w==@beta" "c3"), + ("1", dummyRQ 1 "smp://1234-w==@beta" "c4") ] atomically $ RQ.batchAddQueues aq qs - pq <- atomically RQ.empty + pq <- RQ.empty atomically (totalSize aq pq) `shouldReturn` (4, 4) - atomically $ RQ.getDelSessQueues (0, "smp://1234-w==@alpha", Nothing) aq >>= RQ.batchAddQueues pq . fst + atomically $ RQ.getDelSessQueues (0, "smp://1234-w==@alpha", Nothing) "1" aq >>= RQ.batchAddQueues pq . map ("1",) . fst atomically (totalSize aq pq) `shouldReturn` (4, 4) - atomically $ RQ.getDelSessQueues (0, "smp://1234-w==@beta", Just "non-existent") aq >>= RQ.batchAddQueues pq . fst + atomically $ RQ.getDelSessQueues (0, "smp://1234-w==@beta", Just "non-existent") "1" aq >>= RQ.batchAddQueues pq . map ("1",) . fst atomically (totalSize aq pq) `shouldReturn` (4, 4) - atomically $ RQ.getDelSessQueues (0, "smp://1234-w==@localhost", Nothing) aq >>= RQ.batchAddQueues pq . fst + atomically $ RQ.getDelSessQueues (0, "smp://1234-w==@localhost", Nothing) "1" aq >>= RQ.batchAddQueues pq . map ("1",) . fst atomically (totalSize aq pq) `shouldReturn` (4, 4) - atomically $ RQ.getDelSessQueues (0, "smp://1234-w==@beta", Just "c3") aq >>= RQ.batchAddQueues pq . fst + atomically $ RQ.getDelSessQueues (0, "smp://1234-w==@beta", Just "c3") "1" aq >>= RQ.batchAddQueues pq . map ("1",) . fst atomically (totalSize aq pq) `shouldReturn` (4, 4) -totalSize :: RQ.TRcvQueues -> RQ.TRcvQueues -> STM (Int, Int) +totalSize :: RQ.TRcvQueues q -> RQ.TRcvQueues q -> STM (Int, Int) totalSize a b = do qsizeA <- M.size <$> readTVar (RQ.getRcvQueues a) qsizeB <- M.size <$> readTVar (RQ.getRcvQueues b) diff --git a/tests/SMPAgentClient.hs b/tests/SMPAgentClient.hs index 7cb2a88c5..0bb050cbe 100644 --- a/tests/SMPAgentClient.hs +++ b/tests/SMPAgentClient.hs @@ -11,6 +11,7 @@ module SMPAgentClient where import Data.List.NonEmpty (NonEmpty) +import qualified Data.List.NonEmpty as L import Data.Map.Strict (Map) import qualified Data.Map.Strict as M import NtfClient (ntfTestPort) @@ -20,7 +21,7 @@ import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Agent.RetryInterval import Simplex.Messaging.Client (ProtocolClientConfig (..), SMPProxyFallback, SMPProxyMode, defaultNetworkConfig, defaultSMPClientConfig) import Simplex.Messaging.Notifications.Client (defaultNTFClientConfig) -import Simplex.Messaging.Protocol (NtfServer, ProtoServerWithAuth) +import Simplex.Messaging.Protocol (NtfServer, ProtoServerWithAuth (..), ProtocolServer) import Simplex.Messaging.Transport import XFTPClient (testXFTPServer) @@ -48,14 +49,14 @@ testNtfServer2 = "ntf://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:6 initAgentServers :: InitialAgentServers initAgentServers = InitialAgentServers - { smp = userServers [noAuthSrv testSMPServer], + { smp = userServers [testSMPServer], ntf = [testNtfServer], - xftp = userServers [noAuthSrv testXFTPServer], + xftp = userServers [testXFTPServer], netCfg = defaultNetworkConfig {tcpTimeout = 500_000, tcpConnectTimeout = 500_000} } initAgentServers2 :: InitialAgentServers -initAgentServers2 = initAgentServers {smp = userServers [noAuthSrv testSMPServer, noAuthSrv testSMPServer2]} +initAgentServers2 = initAgentServers {smp = userServers [testSMPServer, testSMPServer2]} initAgentServersProxy :: SMPProxyMode -> SMPProxyFallback -> InitialAgentServers initAgentServersProxy smpProxyMode smpProxyFallback = @@ -71,8 +72,6 @@ agentCfg = ntfCfg = defaultNTFClientConfig {qSize = 1, defaultTransport = (ntfTestPort, transport @TLS), networkConfig}, reconnectInterval = fastRetryInterval, persistErrorInterval = 1, - ntfWorkerDelay = 100, - ntfSMPWorkerDelay = 100, caCertificateFile = "tests/fixtures/ca.crt", privateKeyFile = "tests/fixtures/server.key", certificateFile = "tests/fixtures/server.crt" @@ -89,5 +88,11 @@ fastRetryInterval = defaultReconnectInterval {initialInterval = 50_000} fastMessageRetryInterval :: RetryInterval2 fastMessageRetryInterval = RetryInterval2 {riFast = fastRetryInterval, riSlow = fastRetryInterval} -userServers :: NonEmpty (ProtoServerWithAuth p) -> Map UserId (NonEmpty (ProtoServerWithAuth p)) -userServers srvs = M.fromList [(1, srvs)] +userServers :: NonEmpty (ProtocolServer p) -> Map UserId (NonEmpty (ServerCfg p)) +userServers = userServers' . L.map noAuthSrv + +userServers' :: NonEmpty (ProtoServerWithAuth p) -> Map UserId (NonEmpty (ServerCfg p)) +userServers' srvs = M.fromList [(1, L.map (presetServerCfg True) srvs)] + +noAuthSrvCfg :: ProtocolServer p -> ServerCfg p +noAuthSrvCfg = presetServerCfg True . noAuthSrv diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index 144ad8b10..736016b3b 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -100,7 +100,6 @@ cfg = { transports = [], smpHandshakeTimeout = 60000000, tbqSize = 1, - -- serverTbqSize = 1, msgQueueQuota = 4, queueIdBytes = 24, msgIdBytes = 24, diff --git a/tests/SMPProxyTests.hs b/tests/SMPProxyTests.hs index e05ff884d..8044d23f7 100644 --- a/tests/SMPProxyTests.hs +++ b/tests/SMPProxyTests.hs @@ -34,7 +34,8 @@ import Simplex.Messaging.Client import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.Ratchet (pattern PQSupportOn) import qualified Simplex.Messaging.Crypto.Ratchet as CR -import Simplex.Messaging.Protocol as SMP +import Simplex.Messaging.Protocol (EncRcvMsgBody (..), MsgBody, RcvMessage (..), SubscriptionMode (..), maxMessageLength, noMsgFlags) +import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Server.Env.STM (ServerConfig (..)) import Simplex.Messaging.Transport import Simplex.Messaging.Util (bshow, tshow) @@ -101,27 +102,29 @@ smpProxyTests = do it "500x20" . twoServersFirstProxy $ 500 `inParrallel` deliver 20 describe "agent API" $ do describe "one server" $ do - xit "always via proxy" . oneServer $ + it "always via proxy" . oneServer $ agentDeliverMessageViaProxy ([srv1], SPMAlways, True) ([srv1], SPMAlways, True) C.SEd448 "hello 1" "hello 2" 1 - xit "without proxy" . oneServer $ + it "without proxy" . oneServer $ agentDeliverMessageViaProxy ([srv1], SPMNever, False) ([srv1], SPMNever, False) C.SEd448 "hello 1" "hello 2" 1 describe "two servers" $ do - xit "always via proxy" . twoServers $ + it "always via proxy" . twoServers $ agentDeliverMessageViaProxy ([srv1], SPMAlways, True) ([srv2], SPMAlways, True) C.SEd448 "hello 1" "hello 2" 1 - xit "both via proxy" . twoServers $ + it "both via proxy" . twoServers $ agentDeliverMessageViaProxy ([srv1], SPMUnknown, True) ([srv2], SPMUnknown, True) C.SEd448 "hello 1" "hello 2" 1 - xit "first via proxy" . twoServers $ + it "first via proxy" . twoServers $ agentDeliverMessageViaProxy ([srv1], SPMUnknown, True) ([srv2], SPMNever, False) C.SEd448 "hello 1" "hello 2" 1 - xit "without proxy" . twoServers $ + it "without proxy" . twoServers $ agentDeliverMessageViaProxy ([srv1], SPMNever, False) ([srv2], SPMNever, False) C.SEd448 "hello 1" "hello 2" 1 - xit "first via proxy for unknown" . twoServers $ + it "first via proxy for unknown" . twoServers $ agentDeliverMessageViaProxy ([srv1], SPMUnknown, True) ([srv1, srv2], SPMUnknown, False) C.SEd448 "hello 1" "hello 2" 1 it "without proxy with fallback" . twoServers_ proxyCfg cfgV7 $ agentDeliverMessageViaProxy ([srv1], SPMUnknown, False) ([srv2], SPMUnknown, False) C.SEd448 "hello 1" "hello 2" 3 it "fails when fallback is prohibited" . twoServers_ proxyCfg cfgV7 $ agentViaProxyVersionError - xit "retries sending when destination or proxy relay is offline" $ + it "retries sending when destination or proxy relay is offline" $ agentViaProxyRetryOffline + it "retries sending when destination relay session disconnects in proxy" $ + agentViaProxyRetryNoSession describe "stress test 1k" $ do let deliver nAgents nMsgs = agentDeliverMessagesViaProxyConc (replicate nAgents [srv1]) (map bshow [1 :: Int .. nMsgs]) it "2 agents, 250 messages" . oneServer $ deliver 2 250 @@ -157,7 +160,7 @@ deliverMessagesViaProxy proxyServ relayServ alg unsecuredMsgs securedMsgs = do -- prepare receiving queue (rPub, rPriv) <- atomically $ C.generateAuthKeyPair alg g (rdhPub, rdhPriv :: C.PrivateKeyX25519) <- atomically $ C.generateKeyPair g - QIK {rcvId, sndId, rcvPublicDhKey = srvDh} <- runExceptT' $ createSMPQueue rc (rPub, rPriv) rdhPub (Just "correct") SMSubscribe False + SMP.QIK {rcvId, sndId, rcvPublicDhKey = srvDh} <- runExceptT' $ createSMPQueue rc (rPub, rPriv) rdhPub (Just "correct") SMSubscribe False let dec = decryptMsgV3 $ C.dh' srvDh rdhPriv -- get proxy session sess0 <- runExceptT' $ connectSMPProxiedRelay pc relayServ (Just "correct") @@ -204,7 +207,8 @@ agentDeliverMessageViaProxy aTestCfg@(aSrvs, _, aViaProxy) bTestCfg@(bSrvs, _, b withAgent 1 aCfg (servers aTestCfg) testDB $ \alice -> withAgent 2 aCfg (servers bTestCfg) testDB2 $ \bob -> runRight_ $ do (bobId, qInfo) <- A.createConnection alice 1 True SCMInvitation Nothing (CR.IKNoPQ PQSupportOn) SMSubscribe - aliceId <- A.joinConnection bob 1 Nothing True qInfo "bob's connInfo" PQSupportOn SMSubscribe + (aliceId, sqSecured) <- A.joinConnection bob 1 Nothing True qInfo "bob's connInfo" PQSupportOn SMSubscribe + liftIO $ sqSecured `shouldBe` True ("", _, A.CONF confId pqSup' _ "bob's connInfo") <- get alice liftIO $ pqSup' `shouldBe` PQSupportOn allowConnection alice bobId confId "alice's connInfo" @@ -234,7 +238,7 @@ agentDeliverMessageViaProxy aTestCfg@(aSrvs, _, aViaProxy) bTestCfg@(bSrvs, _, b where msgId = subtract baseId . fst aCfg = agentCfg {sndAuthAlg = C.AuthAlg alg, rcvAuthAlg = C.AuthAlg alg} - servers (srvs, smpProxyMode, _) = (initAgentServersProxy smpProxyMode SPFAllow) {smp = userServers $ L.map noAuthSrv srvs} + servers (srvs, smpProxyMode, _) = (initAgentServersProxy smpProxyMode SPFAllow) {smp = userServers srvs} agentDeliverMessagesViaProxyConc :: [NonEmpty SMPServer] -> [MsgBody] -> IO () agentDeliverMessagesViaProxyConc agentServers msgs = @@ -258,7 +262,8 @@ agentDeliverMessagesViaProxyConc agentServers msgs = -- otherwise the CONF messages would get mixed with MSG prePair alice bob = do (bobId, qInfo) <- runExceptT' $ A.createConnection alice 1 True SCMInvitation Nothing (CR.IKNoPQ PQSupportOn) SMSubscribe - aliceId <- runExceptT' $ A.joinConnection bob 1 Nothing True qInfo "bob's connInfo" PQSupportOn SMSubscribe + (aliceId, sqSecured) <- runExceptT' $ A.joinConnection bob 1 Nothing True qInfo "bob's connInfo" PQSupportOn SMSubscribe + liftIO $ sqSecured `shouldBe` True confId <- get alice >>= \case ("", _, A.CONF confId pqSup' _ "bob's connInfo") -> do @@ -299,7 +304,7 @@ agentDeliverMessagesViaProxyConc agentServers msgs = logDebug "run finished" pqEnc = CR.PQEncOn aCfg = agentCfg {sndAuthAlg = C.AuthAlg C.SEd448, rcvAuthAlg = C.AuthAlg C.SEd448} - servers srvs = (initAgentServersProxy SPMAlways SPFAllow) {smp = userServers $ L.map noAuthSrv srvs} + servers srvs = (initAgentServersProxy SPMAlways SPFAllow) {smp = userServers srvs} agentViaProxyVersionError :: IO () agentViaProxyVersionError = @@ -310,7 +315,7 @@ agentViaProxyVersionError = A.joinConnection bob 1 Nothing True qInfo "bob's connInfo" PQSupportOn SMSubscribe pure () where - servers srvs = (initAgentServersProxy SPMUnknown SPFProhibit) {smp = userServers $ L.map noAuthSrv srvs} + servers srvs = (initAgentServersProxy SPMUnknown SPFProhibit) {smp = userServers srvs} agentViaProxyRetryOffline :: IO () agentViaProxyRetryOffline = do @@ -326,7 +331,8 @@ agentViaProxyRetryOffline = do withServer $ \_ -> do (aliceId, bobId) <- withServer2 $ \_ -> runRight $ do (bobId, qInfo) <- A.createConnection alice 1 True SCMInvitation Nothing (CR.IKNoPQ PQSupportOn) SMSubscribe - aliceId <- A.joinConnection bob 1 Nothing True qInfo "bob's connInfo" PQSupportOn SMSubscribe + (aliceId, sqSecured) <- A.joinConnection bob 1 Nothing True qInfo "bob's connInfo" PQSupportOn SMSubscribe + liftIO $ sqSecured `shouldBe` True ("", _, A.CONF confId pqSup' _ "bob's connInfo") <- get alice liftIO $ pqSup' `shouldBe` PQSupportOn allowConnection alice bobId confId "alice's connInfo" @@ -355,11 +361,15 @@ agentViaProxyRetryOffline = do -- proxy relay down 4 <- msgId <$> A.sendMessage bob aliceId pqEnc noMsgFlags msg2 bob `down` aliceId - withServer2 $ \_ -> runRight_ $ do - bob `up` aliceId - get bob ##> ("", aliceId, A.SENT (baseId + 4) bProxySrv) - get alice =##> \case ("", c, Msg' _ pq msg2') -> c == bobId && pq == pqEnc && msg2 == msg2'; _ -> False - ackMessage alice bobId (baseId + 4) Nothing + withServer2 $ \_ -> do + getInAnyOrder + bob + [ \case ("", "", AEvt SAENone (UP _ [c])) -> c == aliceId; _ -> False, + \case ("", c, AEvt SAEConn (A.SENT mId srv)) -> c == aliceId && mId == baseId + 4 && srv == bProxySrv; _ -> False + ] + runRight_ $ do + get alice =##> \case ("", c, Msg' _ pq msg2') -> c == bobId && pq == pqEnc && msg2 == msg2'; _ -> False + ackMessage alice bobId (baseId + 4) Nothing where withServer :: (ThreadId -> IO a) -> IO a withServer = withServer_ testStoreLogFile testStoreMsgsFile testPort @@ -370,22 +380,42 @@ agentViaProxyRetryOffline = do a `up` cId = nGet a =##> \case ("", "", UP _ [c]) -> c == cId; _ -> False a `down` cId = nGet a =##> \case ("", "", DOWN _ [c]) -> c == cId; _ -> False aCfg = agentCfg {messageRetryInterval = fastMessageRetryInterval} - baseId = 3 + baseId = 1 msgId = subtract baseId . fst - servers srv = (initAgentServersProxy SPMAlways SPFProhibit) {smp = userServers $ L.map noAuthSrv [srv]} + servers srv = (initAgentServersProxy SPMAlways SPFProhibit) {smp = userServers [srv]} + +agentViaProxyRetryNoSession :: IO () +agentViaProxyRetryNoSession = do + let srv1 = SMPServer testHost testPort testKeyHash + srv2 = SMPServer testHost testPort2 testKeyHash + withAgent 1 agentCfg (servers srv1) testDB $ \a -> + withAgent 2 agentCfg (servers srv2) testDB2 $ \b -> do + withSmpServerConfigOn (transport @TLS) proxyCfg testPort $ \_ -> do + (aId, _) <- withServer2 $ \_ -> runRight $ makeConnection a b + nGet b =##> \case ("", "", DOWN _ [c]) -> c == aId; _ -> False + withServer2 $ \_ -> do + nGet b =##> \case ("", "", UP _ [c]) -> c == aId; _ -> False + -- to test retry in case of NO_SESSION error, + -- the client using server 1 as proxy and server 2 as destination + -- should be joining the connection, so the order is swapped here. + _ <- runRight $ makeConnection b a + pure () + where + withServer2 = withSmpServerConfigOn (transport @TLS) proxyCfg {storeLogFile = Just testStoreLogFile2, storeMsgsFile = Just testStoreMsgsFile2} testPort2 + servers srv = (initAgentServersProxy SPMAlways SPFProhibit) {smp = userServers [srv]} testNoProxy :: IO () testNoProxy = do withSmpServerConfigOn (transport @TLS) cfg testPort2 $ \_ -> do testSMPClient_ "127.0.0.1" testPort2 proxyVRangeV8 $ \(th :: THandleSMP TLS 'TClient) -> do - (_, _, (_corrId, _entityId, reply)) <- sendRecv th (Nothing, "0", "", PRXY testSMPServer Nothing) + (_, _, (_corrId, _entityId, reply)) <- sendRecv th (Nothing, "0", "", SMP.PRXY testSMPServer Nothing) reply `shouldBe` Right (SMP.ERR $ SMP.PROXY SMP.BASIC_AUTH) testProxyAuth :: IO () testProxyAuth = do withSmpServerConfigOn (transport @TLS) proxyCfgAuth testPort $ \_ -> do testSMPClient_ "127.0.0.1" testPort proxyVRangeV8 $ \(th :: THandleSMP TLS 'TClient) -> do - (_, _s, (_corrId, _entityId, reply)) <- sendRecv th (Nothing, "0", "", PRXY testSMPServer2 $ Just "wrong") + (_, _s, (_corrId, _entityId, reply)) <- sendRecv th (Nothing, "0", "", SMP.PRXY testSMPServer2 $ Just "wrong") reply `shouldBe` Right (SMP.ERR $ SMP.PROXY SMP.BASIC_AUTH) where proxyCfgAuth = proxyCfg {newQueueBasicAuth = Just "correct"} diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index 1fa76dfaa..60aa1dd1c 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -509,19 +509,21 @@ testWithStoreLog at@(ATransport t) = writeTVar senderId1 sId1 writeTVar notifierId nId Resp "dabc" _ OK <- signSendRecv h1 nKey ("dabc", nId, NSUB) - signSendRecv h sKey1 ("bcda", sId1, _SEND' "hello") >>= \case - Resp "bcda" _ OK -> pure () - r -> unexpected r - Resp "" _ (Msg mId1 msg1) <- tGet1 h + (mId1, msg1) <- + signSendRecv h sKey1 ("bcda", sId1, _SEND' "hello") >>= \case + Resp "" _ (Msg mId1 msg1) -> pure (mId1, msg1) + r -> error $ "unexpected response " <> take 100 (show r) + Resp "bcda" _ OK <- tGet1 h (decryptMsgV3 dhShared mId1 msg1, Right "hello") #== "delivered from queue 1" Resp "" _ (NMSG _ _) <- tGet1 h1 (sId2, rId2, rKey2, dhShared2) <- createAndSecureQueue h sPub2 atomically $ writeTVar senderId2 sId2 - signSendRecv h sKey2 ("cdab", sId2, _SEND "hello too") >>= \case - Resp "cdab" _ OK -> pure () - r -> unexpected r - Resp "" _ (Msg mId2 msg2) <- tGet1 h + (mId2, msg2) <- + signSendRecv h sKey2 ("cdab", sId2, _SEND "hello too") >>= \case + Resp "" _ (Msg mId2 msg2) -> pure (mId2, msg2) + r -> error $ "unexpected response " <> take 100 (show r) + Resp "cdab" _ OK <- tGet1 h (decryptMsgV3 dhShared2 mId2 msg2, Right "hello too") #== "delivered from queue 2" Resp "dabc" _ OK <- signSendRecv h rKey2 ("dabc", rId2, DEL) @@ -608,7 +610,7 @@ testRestoreMessages at@(ATransport t) = logSize testStoreLogFile `shouldReturn` 2 logSize testStoreMsgsFile `shouldReturn` 5 - logSize testServerStatsBackupFile `shouldReturn` 55 + logSize testServerStatsBackupFile `shouldReturn` 71 Right stats1 <- strDecode <$> B.readFile testServerStatsBackupFile checkStats stats1 [rId] 5 1 @@ -626,7 +628,7 @@ testRestoreMessages at@(ATransport t) = logSize testStoreLogFile `shouldReturn` 1 -- the last message is not removed because it was not ACK'd logSize testStoreMsgsFile `shouldReturn` 3 - logSize testServerStatsBackupFile `shouldReturn` 55 + logSize testServerStatsBackupFile `shouldReturn` 71 Right stats2 <- strDecode <$> B.readFile testServerStatsBackupFile checkStats stats2 [rId] 5 3 @@ -645,7 +647,7 @@ testRestoreMessages at@(ATransport t) = logSize testStoreLogFile `shouldReturn` 1 logSize testStoreMsgsFile `shouldReturn` 0 - logSize testServerStatsBackupFile `shouldReturn` 55 + logSize testServerStatsBackupFile `shouldReturn` 71 Right stats3 <- strDecode <$> B.readFile testServerStatsBackupFile checkStats stats3 [rId] 5 5 @@ -884,7 +886,7 @@ testMsgExpireOnInterval t = testSMPClient @c $ \sh -> do (sId, rId, rKey, _) <- testSMPClient @c $ \rh -> createAndSecureQueue rh sPub Resp "1" _ OK <- signSendRecv sh sKey ("1", sId, _SEND "hello (should expire)") - threadDelay 2500000 + threadDelay 3000000 testSMPClient @c $ \rh -> do signSendRecv rh rKey ("2", rId, SUB) >>= \case Resp "2" _ OK -> pure ()