Pass in session_limit_config directly to policy

Revert changes from
https://github.com/element-hq/matrix-authentication-service/pull/5221. I
assume it was done that way as the "session_limit_config" doesn't change
after the server is created. But this makes downstream usage complicated as
you whenever you create `SiteConfig`, you also have to make sure to configure
whatever else is necessary.

Easier to just pass in `session_limit_config` as necessary whenever
we evaluate the policy
This commit is contained in:
Eric Eastwood
2026-04-06 18:01:39 -05:00
parent c77afd5243
commit 724e0cf5ca
15 changed files with 76 additions and 66 deletions
+2 -7
View File
@@ -9,8 +9,7 @@ use std::process::ExitCode;
use clap::Parser;
use figment::Figment;
use mas_config::{
ConfigurationSection, ConfigurationSectionExt, DatabaseConfig, ExperimentalConfig,
MatrixConfig, PolicyConfig,
ConfigurationSection, ConfigurationSectionExt, DatabaseConfig, MatrixConfig, PolicyConfig,
};
use mas_storage_pg::PgRepositoryFactory;
use tracing::{info, info_span};
@@ -46,12 +45,8 @@ impl Options {
PolicyConfig::extract_or_default(figment).map_err(anyhow::Error::from_boxed)?;
let matrix_config =
MatrixConfig::extract(figment).map_err(anyhow::Error::from_boxed)?;
let experimental_config =
ExperimentalConfig::extract(figment).map_err(anyhow::Error::from_boxed)?;
info!("Loading and compiling the policy module");
let policy_factory =
policy_factory_from_config(&config, &matrix_config, &experimental_config)
.await?;
let policy_factory = policy_factory_from_config(&config, &matrix_config).await?;
if with_dynamic_data {
let database_config =
+1 -3
View File
@@ -127,9 +127,7 @@ impl Options {
// Load and compile the WASM policies (and fallback to the default embedded one)
info!("Loading and compiling the policy module");
let policy_factory =
policy_factory_from_config(&config.policy, &config.matrix, &config.experimental)
.await?;
let policy_factory = policy_factory_from_config(&config.policy, &config.matrix).await?;
let policy_factory = Arc::new(policy_factory);
load_policy_factory_dynamic_data_continuously(
+2 -13
View File
@@ -135,7 +135,6 @@ pub fn test_mailer_in_background(mailer: &Mailer, timeout: Duration) {
pub async fn policy_factory_from_config(
config: &PolicyConfig,
matrix_config: &MatrixConfig,
experimental_config: &ExperimentalConfig,
) -> Result<PolicyFactory, anyhow::Error> {
let policy_file = tokio::fs::File::open(&config.wasm_module)
.await
@@ -149,18 +148,8 @@ pub async fn policy_factory_from_config(
email: config.email_entrypoint.clone(),
};
let session_limit_config =
experimental_config
.session_limit
.as_ref()
.map(|c| SessionLimitConfig {
soft_limit: c.soft_limit,
hard_limit: c.hard_limit,
hard_limit_eviction: c.hard_limit_eviction,
});
let data = mas_policy::Data::new(matrix_config.homeserver.clone(), session_limit_config)
.with_rest(config.data.clone());
let data =
mas_policy::Data::new(matrix_config.homeserver.clone()).with_rest(config.data.clone());
PolicyFactory::load(policy_file, data, entrypoints)
.await
+2
View File
@@ -684,6 +684,7 @@ async fn token_login(
let res = policy
.evaluate_compat_login(mas_policy::CompatLoginInput {
user: &browser_session.user,
session_limit: session_limit_config,
login: CompatLogin::Token,
session_replaced,
session_counts,
@@ -811,6 +812,7 @@ async fn user_password_login(
let res = policy
.evaluate_compat_login(mas_policy::CompatLoginInput {
user: &user,
session_limit: session_limit_config,
login: CompatLogin::Password,
session_replaced,
session_counts,
@@ -19,7 +19,7 @@ use mas_axum_utils::{
cookies::CookieJar,
csrf::{CsrfExt, ProtectedForm},
};
use mas_data_model::{BoxClock, BoxRng, Clock, MatrixUser};
use mas_data_model::{BoxClock, BoxRng, Clock, MatrixUser, SiteConfig};
use mas_matrix::HomeserverConnection;
use mas_policy::{Policy, model::CompatLogin};
use mas_router::{CompatLoginSsoAction, UrlBuilder};
@@ -53,6 +53,7 @@ pub async fn get(
State(templates): State<Templates>,
State(url_builder): State<UrlBuilder>,
State(homeserver): State<Arc<dyn HomeserverConnection>>,
State(site_config): State<SiteConfig>,
mut policy: Policy,
activity_tracker: BoundActivityTracker,
user_agent: Option<TypedHeader<headers::UserAgent>>,
@@ -114,9 +115,12 @@ pub async fn get(
// We can close the repository early, we don't need it at this point
repo.save().await?;
let session_limit_config = site_config.session_limit.as_ref();
let res = policy
.evaluate_compat_login(mas_policy::CompatLoginInput {
user: &session.user,
session_limit: session_limit_config,
login: CompatLogin::Sso {
redirect_uri: login.redirect_uri.to_string(),
},
@@ -193,6 +197,7 @@ pub async fn post(
PreferredLanguage(locale): PreferredLanguage,
State(templates): State<Templates>,
State(url_builder): State<UrlBuilder>,
State(site_config): State<SiteConfig>,
mut policy: Policy,
activity_tracker: BoundActivityTracker,
user_agent: Option<TypedHeader<headers::UserAgent>>,
@@ -262,9 +267,12 @@ pub async fn post(
let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?;
let session_limit_config = site_config.session_limit.as_ref();
let res = policy
.evaluate_compat_login(mas_policy::CompatLoginInput {
user: &session.user,
session_limit: session_limit_config,
login: CompatLogin::Sso {
redirect_uri: login.redirect_uri.to_string(),
},
@@ -17,7 +17,7 @@ use mas_axum_utils::{
cookies::CookieJar,
csrf::{CsrfExt, ProtectedForm},
};
use mas_data_model::{AuthorizationGrantStage, BoxClock, BoxRng, MatrixUser};
use mas_data_model::{AuthorizationGrantStage, BoxClock, BoxRng, MatrixUser, SiteConfig};
use mas_keystore::Keystore;
use mas_matrix::HomeserverConnection;
use mas_policy::Policy;
@@ -91,6 +91,7 @@ pub(crate) async fn get(
State(templates): State<Templates>,
State(url_builder): State<UrlBuilder>,
State(homeserver): State<Arc<dyn HomeserverConnection>>,
State(site_config): State<SiteConfig>,
mut policy: Policy,
mut repo: BoxRepository,
activity_tracker: BoundActivityTracker,
@@ -145,9 +146,12 @@ pub(crate) async fn get(
// We can close the repository early, we don't need it at this point
repo.save().await?;
let session_limit_config = site_config.session_limit.as_ref();
let res = policy
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
user: Some(&session.user),
session_limit: session_limit_config,
client: &client,
session_counts: Some(session_counts),
scope: &grant.scope,
@@ -220,6 +224,7 @@ pub(crate) async fn post(
PreferredLanguage(locale): PreferredLanguage,
State(templates): State<Templates>,
State(key_store): State<Keystore>,
State(site_config): State<SiteConfig>,
mut policy: Policy,
mut repo: BoxRepository,
activity_tracker: BoundActivityTracker,
@@ -275,11 +280,13 @@ pub(crate) async fn post(
return Err(RouteError::GrantNotPending(grant.id));
}
let session_limit_config = site_config.session_limit.as_ref();
let session_counts = count_user_sessions_for_limiting(&mut repo, &browser_session.user).await?;
let res = policy
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
user: Some(&browser_session.user),
session_limit: session_limit_config,
client: &client,
session_counts: Some(session_counts),
scope: &grant.scope,
+7 -1
View File
@@ -18,7 +18,7 @@ use mas_axum_utils::{
cookies::CookieJar,
csrf::{CsrfExt, ProtectedForm},
};
use mas_data_model::{BoxClock, BoxRng, MatrixUser};
use mas_data_model::{BoxClock, BoxRng, MatrixUser, SiteConfig};
use mas_matrix::HomeserverConnection;
use mas_policy::Policy;
use mas_router::UrlBuilder;
@@ -53,6 +53,7 @@ pub(crate) async fn get(
State(templates): State<Templates>,
State(url_builder): State<UrlBuilder>,
State(homeserver): State<Arc<dyn HomeserverConnection>>,
State(site_config): State<SiteConfig>,
mut repo: BoxRepository,
mut policy: Policy,
activity_tracker: BoundActivityTracker,
@@ -107,6 +108,7 @@ pub(crate) async fn get(
.context("Client not found")
.map_err(InternalError::from_anyhow)?;
let session_limit_config = site_config.session_limit.as_ref();
let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?;
// We can close the repository early, we don't need it at this point
@@ -117,6 +119,7 @@ pub(crate) async fn get(
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
grant_type: mas_policy::GrantType::DeviceCode,
client: &client,
session_limit: session_limit_config,
session_counts: Some(session_counts),
scope: &grant.scope,
user: Some(&session.user),
@@ -191,6 +194,7 @@ pub(crate) async fn post(
State(templates): State<Templates>,
State(url_builder): State<UrlBuilder>,
State(homeserver): State<Arc<dyn HomeserverConnection>>,
State(site_config): State<SiteConfig>,
mut repo: BoxRepository,
mut policy: Policy,
activity_tracker: BoundActivityTracker,
@@ -246,6 +250,7 @@ pub(crate) async fn post(
.context("Client not found")
.map_err(InternalError::from_anyhow)?;
let session_limit_config = site_config.session_limit.as_ref();
let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?;
// Evaluate the policy
@@ -253,6 +258,7 @@ pub(crate) async fn post(
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
grant_type: mas_policy::GrantType::DeviceCode,
client: &client,
session_limit: session_limit_config,
session_counts: Some(session_counts),
scope: &grant.scope,
user: Some(&session.user),
+3
View File
@@ -792,10 +792,13 @@ async fn client_credentials_grant(
.clone()
.unwrap_or_else(|| std::iter::empty::<ScopeToken>().collect());
let session_limit_config = site_config.session_limit.as_ref();
// Make the request go through the policy engine
let res = policy
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
user: None,
session_limit: session_limit_config,
client,
session_counts: None,
scope: &scope,
+1 -1
View File
@@ -86,7 +86,7 @@ pub(crate) async fn policy_factory(
email: "email/violation".to_owned(),
};
let data = mas_policy::Data::new(server_name.to_owned(), None).with_rest(data);
let data = mas_policy::Data::new(server_name.to_owned()).with_rest(data);
let policy_factory = PolicyFactory::load(file, data, entrypoints).await?;
let policy_factory = Arc::new(policy_factory);
+6 -12
View File
@@ -9,7 +9,7 @@ pub mod model;
use std::sync::Arc;
use arc_swap::ArcSwap;
use mas_data_model::{SessionLimitConfig, Ulid};
use mas_data_model::Ulid;
use opa_wasm::{
Runtime,
wasmtime::{Config, Engine, Module, OptLevel, Store},
@@ -100,19 +100,13 @@ pub struct Data {
#[derive(Serialize, Debug)]
struct BaseData {
server_name: String,
/// Limits on the number of application sessions that each user can have
session_limit: Option<SessionLimitConfig>,
}
impl Data {
#[must_use]
pub fn new(server_name: String, session_limit: Option<SessionLimitConfig>) -> Self {
pub fn new(server_name: String) -> Self {
Self {
base: BaseData {
server_name,
session_limit,
},
base: BaseData { server_name },
rest: None,
}
@@ -507,7 +501,7 @@ mod tests {
#[tokio::test]
async fn test_register() {
let data = Data::new("example.com".to_owned(), None).with_rest(serde_json::json!({
let data = Data::new("example.com".to_owned()).with_rest(serde_json::json!({
"allowed_domains": ["element.io", "*.element.io"],
"banned_domains": ["staging.element.io"],
}));
@@ -572,7 +566,7 @@ mod tests {
#[tokio::test]
async fn test_dynamic_data() {
let data = Data::new("example.com".to_owned(), None);
let data = Data::new("example.com".to_owned());
#[allow(clippy::disallowed_types)]
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
@@ -636,7 +630,7 @@ mod tests {
#[tokio::test]
async fn test_big_dynamic_data() {
let data = Data::new("example.com".to_owned(), None);
let data = Data::new("example.com".to_owned());
#[allow(clippy::disallowed_types)]
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
+9 -1
View File
@@ -11,7 +11,7 @@
use std::net::IpAddr;
use mas_data_model::{Client, User};
use mas_data_model::{Client, SessionLimitConfig, User};
use oauth2_types::{registration::VerifiedClientMetadata, scope::Scope};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@@ -179,6 +179,10 @@ pub struct AuthorizationGrantInput<'a> {
#[schemars(with = "Option<std::collections::HashMap<String, serde_json::Value>>")]
pub user: Option<&'a User>,
/// Limits on the number of application sessions that each user can have
#[schemars(with = "std::collections::HashMap<String, serde_json::Value>")]
pub session_limit: Option<&'a SessionLimitConfig>,
/// How many sessions the user has.
/// Not populated if it's not a user logging in.
pub session_counts: Option<SessionCounts>,
@@ -201,6 +205,10 @@ pub struct CompatLoginInput<'a> {
#[schemars(with = "std::collections::HashMap<String, serde_json::Value>")]
pub user: &'a User,
/// Limits on the number of application sessions that each user can have
#[schemars(with = "std::collections::HashMap<String, serde_json::Value>")]
pub session_limit: Option<&'a SessionLimitConfig>,
/// How many sessions the user has.
pub session_counts: SessionCounts,
@@ -159,7 +159,7 @@ violation contains {
"msg": "user has too many active sessions",
} if {
# Only apply if session limits are enabled in the config
data.session_limit != null
input.session_limit != null
# Only apply if it's a user logging in (who therefore has countable sessions)
input.session_counts != null
@@ -168,5 +168,5 @@ violation contains {
# reached or exceeded.
# We use the soft limit because the user will be able to interactively remove
# sessions to return under the limit.
data.session_limit.soft_limit <= input.session_counts.total
input.session_limit.soft_limit <= input.session_counts.total
}
@@ -226,31 +226,31 @@ test_mas_scopes if {
test_session_limiting if {
authorization_grant.allow with input.user as user
with input.session_counts as {"total": 1}
with data.session_limit as {"soft_limit": 32, "hard_limit": 64}
with input.session_limit as {"soft_limit": 32, "hard_limit": 64}
authorization_grant.allow with input.user as user
with input.session_counts as {"total": 31}
with data.session_limit as {"soft_limit": 32, "hard_limit": 64}
with input.session_limit as {"soft_limit": 32, "hard_limit": 64}
not authorization_grant.allow with input.user as user
with input.session_counts as {"total": 32}
with data.session_limit as {"soft_limit": 32, "hard_limit": 64}
with input.session_limit as {"soft_limit": 32, "hard_limit": 64}
not authorization_grant.allow with input.user as user
with input.session_counts as {"total": 42}
with data.session_limit as {"soft_limit": 32, "hard_limit": 64}
with input.session_limit as {"soft_limit": 32, "hard_limit": 64}
not authorization_grant.allow with input.user as user
with input.session_counts as {"total": 65}
with data.session_limit as {"soft_limit": 32, "hard_limit": 64}
with input.session_limit as {"soft_limit": 32, "hard_limit": 64}
# No limit configured
authorization_grant.allow with input.user as user
with input.session_counts as {"total": 1}
with data.session_limit as null
with input.session_limit as null
# Client credentials grant
authorization_grant.allow with input.user as user
with input.session_counts as null
with data.session_limit as {"soft_limit": 32, "hard_limit": 64}
with input.session_limit as {"soft_limit": 32, "hard_limit": 64}
}
+4 -4
View File
@@ -30,7 +30,7 @@ violation contains {
"msg": "user has too many active sessions (soft limit)",
} if {
# Only apply if session limits are enabled in the config
data.session_limit != null
input.session_limit != null
# This is a web-based interactive login
is_interactive
@@ -43,7 +43,7 @@ violation contains {
# reached or exceeded.
# We use the soft limit because the user will be able to interactively remove
# sessions to return under the limit.
data.session_limit.soft_limit <= input.session_counts.total
input.session_limit.soft_limit <= input.session_counts.total
}
violation contains {
@@ -51,7 +51,7 @@ violation contains {
"msg": "user has too many active sessions (hard limit)",
} if {
# Only apply if session limits are enabled in the config
data.session_limit != null
input.session_limit != null
# This is not a web-based interactive login
not is_interactive
@@ -64,7 +64,7 @@ violation contains {
# reached or exceeded.
# We don't use the soft limit because the user won't be able to interactively remove
# sessions to return under the limit.
data.session_limit.hard_limit <= input.session_counts.total
input.session_limit.hard_limit <= input.session_counts.total
}
is_interactive if {
+13 -13
View File
@@ -16,38 +16,38 @@ test_session_limiting_sso if {
with input.session_counts as {"total": 1}
with input.login as {"type": "m.login.sso"}
with input.session_replaced as false
with data.session_limit as {"soft_limit": 32, "hard_limit": 64}
with input.session_limit as {"soft_limit": 32, "hard_limit": 64}
compat_login.allow with input.user as user
with input.session_counts as {"total": 31}
with input.login as {"type": "m.login.sso"}
with input.session_replaced as false
with data.session_limit as {"soft_limit": 32, "hard_limit": 64}
with input.session_limit as {"soft_limit": 32, "hard_limit": 64}
not compat_login.allow with input.user as user
with input.session_counts as {"total": 32}
with input.login as {"type": "m.login.sso"}
with input.session_replaced as false
with data.session_limit as {"soft_limit": 32, "hard_limit": 64}
with input.session_limit as {"soft_limit": 32, "hard_limit": 64}
not compat_login.allow with input.user as user
with input.session_counts as {"total": 42}
with input.login as {"type": "m.login.sso"}
with input.session_replaced as false
with data.session_limit as {"soft_limit": 32, "hard_limit": 64}
with input.session_limit as {"soft_limit": 32, "hard_limit": 64}
not compat_login.allow with input.user as user
with input.session_counts as {"total": 65}
with input.login as {"type": "m.login.sso"}
with input.session_replaced as false
with data.session_limit as {"soft_limit": 32, "hard_limit": 64}
with input.session_limit as {"soft_limit": 32, "hard_limit": 64}
# No limit configured
compat_login.allow with input.user as user
with input.session_counts as {"total": 1}
with input.login as {"type": "m.login.sso"}
with input.session_replaced as false
with data.session_limit as null
with input.session_limit as null
}
# Test session limiting when using `m.login.password`
@@ -56,32 +56,32 @@ test_session_limiting_password if {
with input.session_counts as {"total": 1}
with input.login as {"type": "m.login.password"}
with input.session_replaced as false
with data.session_limit as {"soft_limit": 32, "hard_limit": 64}
with input.session_limit as {"soft_limit": 32, "hard_limit": 64}
compat_login.allow with input.user as user
with input.session_counts as {"total": 63}
with input.login as {"type": "m.login.password"}
with input.session_replaced as false
with data.session_limit as {"soft_limit": 32, "hard_limit": 64}
with input.session_limit as {"soft_limit": 32, "hard_limit": 64}
not compat_login.allow with input.user as user
with input.session_counts as {"total": 64}
with input.login as {"type": "m.login.password"}
with input.session_replaced as false
with data.session_limit as {"soft_limit": 32, "hard_limit": 64}
with input.session_limit as {"soft_limit": 32, "hard_limit": 64}
not compat_login.allow with input.user as user
with input.session_counts as {"total": 65}
with input.login as {"type": "m.login.password"}
with input.session_replaced as false
with data.session_limit as {"soft_limit": 32, "hard_limit": 64}
with input.session_limit as {"soft_limit": 32, "hard_limit": 64}
# No limit configured
compat_login.allow with input.user as user
with input.session_counts as {"total": 1}
with input.login as {"type": "m.login.password"}
with input.session_replaced as false
with data.session_limit as null
with input.session_limit as null
}
test_no_session_limiting_upon_replacement if {
@@ -89,11 +89,11 @@ test_no_session_limiting_upon_replacement if {
with input.session_counts as {"total": 65}
with input.login as {"type": "m.login.password"}
with input.session_replaced as false
with data.session_limit as {"soft_limit": 32, "hard_limit": 64}
with input.session_limit as {"soft_limit": 32, "hard_limit": 64}
not compat_login.allow with input.user as user
with input.session_counts as {"total": 65}
with input.login as {"type": "m.login.sso"}
with input.session_replaced as false
with data.session_limit as {"soft_limit": 32, "hard_limit": 64}
with input.session_limit as {"soft_limit": 32, "hard_limit": 64}
}