diff --git a/crates/cli/src/commands/debug.rs b/crates/cli/src/commands/debug.rs index 8768ae7b1..a82d8f059 100644 --- a/crates/cli/src/commands/debug.rs +++ b/crates/cli/src/commands/debug.rs @@ -1,4 +1,4 @@ -// Copyright 2024 New Vector Ltd. +// Copyright 2024, 2025 New Vector Ltd. // Copyright 2022-2024 The Matrix.org Foundation C.I.C. // // SPDX-License-Identifier: AGPL-3.0-only @@ -8,10 +8,14 @@ use std::process::ExitCode; use clap::Parser; use figment::Figment; -use mas_config::{ConfigurationSection, ConfigurationSectionExt, MatrixConfig, PolicyConfig}; +use mas_config::{ + ConfigurationSection, ConfigurationSectionExt, DatabaseConfig, MatrixConfig, PolicyConfig, +}; use tracing::{info, info_span}; -use crate::util::policy_factory_from_config; +use crate::util::{ + database_pool_from_config, load_policy_factory_dynamic_data, policy_factory_from_config, +}; #[derive(Parser, Debug)] pub(super) struct Options { @@ -22,7 +26,11 @@ pub(super) struct Options { #[derive(Parser, Debug)] enum Subcommand { /// Check that the policies compile - Policy, + Policy { + /// With dynamic data loaded + #[arg(long)] + with_dynamic_data: bool, + }, } impl Options { @@ -30,13 +38,19 @@ impl Options { pub async fn run(self, figment: &Figment) -> anyhow::Result { use Subcommand as SC; match self.subcommand { - SC::Policy => { + SC::Policy { with_dynamic_data } => { let _span = info_span!("cli.debug.policy").entered(); let config = PolicyConfig::extract_or_default(figment)?; let matrix_config = MatrixConfig::extract(figment)?; info!("Loading and compiling the policy module"); let policy_factory = policy_factory_from_config(&config, &matrix_config).await?; + if with_dynamic_data { + let database_config = DatabaseConfig::extract(figment)?; + let pool = database_pool_from_config(&database_config).await?; + load_policy_factory_dynamic_data(&policy_factory, &pool).await?; + } + let _instance = policy_factory.instantiate().await?; } } diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index d58fcb9da..811027594 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -26,9 +26,9 @@ use crate::{ app_state::AppState, lifecycle::LifecycleManager, util::{ - database_pool_from_config, mailer_from_config, password_manager_from_config, - policy_factory_from_config, site_config_from_config, templates_from_config, - test_mailer_in_background, + database_pool_from_config, load_policy_factory_dynamic_data_continuously, + mailer_from_config, password_manager_from_config, policy_factory_from_config, + site_config_from_config, templates_from_config, test_mailer_in_background, }, }; @@ -130,6 +130,14 @@ impl Options { let policy_factory = policy_factory_from_config(&config.policy, &config.matrix).await?; let policy_factory = Arc::new(policy_factory); + load_policy_factory_dynamic_data_continuously( + &policy_factory, + &pool, + shutdown.soft_shutdown_token(), + shutdown.task_tracker(), + ) + .await?; + let url_builder = UrlBuilder::new( config.http.public_base.clone(), config.http.issuer.clone(), diff --git a/crates/cli/src/util.rs b/crates/cli/src/util.rs index 0f155afc7..118cc4a1b 100644 --- a/crates/cli/src/util.rs +++ b/crates/cli/src/util.rs @@ -4,7 +4,7 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. -use std::time::Duration; +use std::{sync::Arc, time::Duration}; use anyhow::Context; use mas_config::{ @@ -17,11 +17,14 @@ use mas_email::{MailTransport, Mailer}; use mas_handlers::passwords::PasswordManager; use mas_policy::PolicyFactory; use mas_router::UrlBuilder; +use mas_storage::RepositoryAccess; +use mas_storage_pg::PgRepository; use mas_templates::{SiteConfigExt, TemplateLoadingError, Templates}; use sqlx::{ ConnectOptions, PgConnection, PgPool, postgres::{PgConnectOptions, PgPoolOptions}, }; +use tokio_util::{sync::CancellationToken, task::TaskTracker}; use tracing::{Instrument, log::LevelFilter}; pub async fn password_manager_from_config( @@ -346,6 +349,66 @@ pub async fn database_connection_from_config( .context("could not connect to the database") } +/// Update the policy factory dynamic data from the database and spawn a task to +/// periodically update it +// XXX: this could be put somewhere else? +pub async fn load_policy_factory_dynamic_data_continuously( + policy_factory: &Arc, + pool: &PgPool, + cancellation_token: CancellationToken, + task_tracker: &TaskTracker, +) -> Result<(), anyhow::Error> { + let policy_factory = policy_factory.clone(); + let pool = pool.clone(); + + load_policy_factory_dynamic_data(&policy_factory, &pool).await?; + + task_tracker.spawn(async move { + let mut interval = tokio::time::interval(Duration::from_secs(60)); + + loop { + tokio::select! { + () = cancellation_token.cancelled() => { + return; + } + _ = interval.tick() => {} + } + + if let Err(err) = load_policy_factory_dynamic_data(&policy_factory, &pool).await { + tracing::error!( + error = ?err, + "Failed to load policy factory dynamic data" + ); + cancellation_token.cancel(); + return; + } + } + }); + + Ok(()) +} + +/// Update the policy factory dynamic data from the database +#[tracing::instrument(name = "policy.load_dynamic_data", skip_all, err(Debug))] +pub async fn load_policy_factory_dynamic_data( + policy_factory: &PolicyFactory, + pool: &PgPool, +) -> Result<(), anyhow::Error> { + let mut repo = PgRepository::from_pool(pool) + .await + .context("Failed to acquire database connection")?; + + if let Some(data) = repo.policy_data().get().await? { + let id = data.id; + let updated = policy_factory.set_dynamic_data(data).await?; + if updated { + tracing::info!(policy_data.id = %id, "Loaded dynamic policy data from the database"); + } + } + + Ok(()) +} + #[cfg(test)] mod tests { use rand::SeedableRng;