Add experimental and preliminary policy-driven session limiting when logging in OAuth 2 sessions. (#5221)

This commit is contained in:
Olivier 'reivilibre
2025-11-25 15:24:02 +00:00
committed by GitHub
18 changed files with 309 additions and 23 deletions

View File

@@ -9,7 +9,8 @@ use std::process::ExitCode;
use clap::Parser;
use figment::Figment;
use mas_config::{
ConfigurationSection, ConfigurationSectionExt, DatabaseConfig, MatrixConfig, PolicyConfig,
ConfigurationSection, ConfigurationSectionExt, DatabaseConfig, ExperimentalConfig,
MatrixConfig, PolicyConfig,
};
use mas_storage_pg::PgRepositoryFactory;
use tracing::{info, info_span};
@@ -45,8 +46,12 @@ impl Options {
PolicyConfig::extract_or_default(figment).map_err(anyhow::Error::from_boxed)?;
let matrix_config =
MatrixConfig::extract(figment).map_err(anyhow::Error::from_boxed)?;
let experimental_config =
ExperimentalConfig::extract(figment).map_err(anyhow::Error::from_boxed)?;
info!("Loading and compiling the policy module");
let policy_factory = policy_factory_from_config(&config, &matrix_config).await?;
let policy_factory =
policy_factory_from_config(&config, &matrix_config, &experimental_config)
.await?;
if with_dynamic_data {
let database_config =

View File

@@ -132,7 +132,9 @@ impl Options {
// Load and compile the WASM policies (and fallback to the default embedded one)
info!("Loading and compiling the policy module");
let policy_factory = policy_factory_from_config(&config.policy, &config.matrix).await?;
let policy_factory =
policy_factory_from_config(&config.policy, &config.matrix, &config.experimental)
.await?;
let policy_factory = Arc::new(policy_factory);
load_policy_factory_dynamic_data_continuously(

View File

@@ -13,7 +13,7 @@ use mas_config::{
PolicyConfig, TemplatesConfig,
};
use mas_context::LogContext;
use mas_data_model::{SessionExpirationConfig, SiteConfig};
use mas_data_model::{SessionExpirationConfig, SessionLimitConfig, SiteConfig};
use mas_email::{MailTransport, Mailer};
use mas_handlers::passwords::PasswordManager;
use mas_matrix::{HomeserverConnection, ReadOnlyHomeserverConnection};
@@ -135,6 +135,7 @@ pub fn test_mailer_in_background(mailer: &Mailer, timeout: Duration) {
pub async fn policy_factory_from_config(
config: &PolicyConfig,
matrix_config: &MatrixConfig,
experimental_config: &ExperimentalConfig,
) -> Result<PolicyFactory, anyhow::Error> {
let policy_file = tokio::fs::File::open(&config.wasm_module)
.await
@@ -147,8 +148,17 @@ pub async fn policy_factory_from_config(
email: config.email_entrypoint.clone(),
};
let data =
mas_policy::Data::new(matrix_config.homeserver.clone()).with_rest(config.data.clone());
let session_limit_config =
experimental_config
.session_limit
.as_ref()
.map(|c| SessionLimitConfig {
soft_limit: c.soft_limit,
hard_limit: c.hard_limit,
});
let data = mas_policy::Data::new(matrix_config.homeserver.clone(), session_limit_config)
.with_rest(config.data.clone());
PolicyFactory::load(policy_file, data, entrypoints)
.await
@@ -225,6 +235,13 @@ pub fn site_config_from_config(
session_expiration,
login_with_email_allowed: account_config.login_with_email_allowed,
plan_management_iframe_uri: experimental_config.plan_management_iframe_uri.clone(),
session_limit: experimental_config
.session_limit
.as_ref()
.map(|c| SessionLimitConfig {
soft_limit: c.soft_limit,
hard_limit: c.hard_limit,
}),
})
}

View File

@@ -4,6 +4,8 @@
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::num::NonZeroU64;
use chrono::Duration;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@@ -81,6 +83,13 @@ pub struct ExperimentalConfig {
/// validation.
#[serde(skip_serializing_if = "Option::is_none")]
pub plan_management_iframe_uri: Option<String>,
/// Experimental feature to limit the number of application sessions per
/// user.
///
/// Disabled by default.
#[serde(skip_serializing_if = "Option::is_none")]
pub session_limit: Option<SessionLimitConfig>,
}
impl Default for ExperimentalConfig {
@@ -90,6 +99,7 @@ impl Default for ExperimentalConfig {
compat_token_ttl: default_token_ttl(),
inactive_session_expiration: None,
plan_management_iframe_uri: None,
session_limit: None,
}
}
}
@@ -100,9 +110,17 @@ impl ExperimentalConfig {
&& is_default_token_ttl(&self.compat_token_ttl)
&& self.inactive_session_expiration.is_none()
&& self.plan_management_iframe_uri.is_none()
&& self.session_limit.is_none()
}
}
impl ConfigurationSection for ExperimentalConfig {
const PATH: Option<&'static str> = Some("experimental");
}
/// Configuration options for the session limit feature
#[derive(Clone, Debug, Deserialize, JsonSchema, Serialize)]
pub struct SessionLimitConfig {
pub soft_limit: NonZeroU64,
pub hard_limit: NonZeroU64,
}

View File

@@ -39,7 +39,9 @@ pub use self::{
DeviceCodeGrantState, InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session, SessionState,
},
policy_data::PolicyData,
site_config::{CaptchaConfig, CaptchaService, SessionExpirationConfig, SiteConfig},
site_config::{
CaptchaConfig, CaptchaService, SessionExpirationConfig, SessionLimitConfig, SiteConfig,
},
tokens::{
AccessToken, AccessTokenState, RefreshToken, RefreshTokenState, TokenFormatError, TokenType,
},

View File

@@ -4,7 +4,10 @@
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::num::NonZeroU64;
use chrono::Duration;
use serde::Serialize;
use url::Url;
/// Which Captcha service is being used
@@ -36,6 +39,12 @@ pub struct SessionExpirationConfig {
pub compat_session_inactivity_ttl: Option<Duration>,
}
#[derive(Serialize, Debug, Clone)]
pub struct SessionLimitConfig {
pub soft_limit: NonZeroU64,
pub hard_limit: NonZeroU64,
}
/// Random site configuration we want accessible in various places.
#[allow(clippy::struct_excessive_bools)]
#[derive(Debug, Clone)]
@@ -99,4 +108,7 @@ pub struct SiteConfig {
/// The iframe URL to show in the plan tab of the UI
pub plan_management_iframe_uri: Option<String>,
/// Limits on the number of application sessions that each user can have
pub session_limit: Option<SessionLimitConfig>,
}

View File

@@ -32,7 +32,7 @@ use super::callback::CallbackDestination;
use crate::{
BoundActivityTracker, PreferredLanguage, impl_from_error_for_route,
oauth2::generate_id_token,
session::{SessionOrFallback, load_session_or_fallback},
session::{SessionOrFallback, count_user_sessions_for_limiting, load_session_or_fallback},
};
#[derive(Debug, Error)]
@@ -136,10 +136,13 @@ pub(crate) async fn get(
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?;
let res = policy
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
user: Some(&session.user),
client: &client,
session_counts: Some(session_counts),
scope: &grant.scope,
grant_type: mas_policy::GrantType::AuthorizationCode,
requester: mas_policy::Requester {
@@ -235,10 +238,13 @@ pub(crate) async fn post(
return Err(RouteError::GrantNotPending(grant.id));
}
let session_counts = count_user_sessions_for_limiting(&mut repo, &browser_session.user).await?;
let res = policy
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
user: Some(&browser_session.user),
client: &client,
session_counts: Some(session_counts),
scope: &grant.scope,
grant_type: mas_policy::GrantType::AuthorizationCode,
requester: mas_policy::Requester {

View File

@@ -27,7 +27,7 @@ use ulid::Ulid;
use crate::{
BoundActivityTracker, PreferredLanguage,
session::{SessionOrFallback, load_session_or_fallback},
session::{SessionOrFallback, count_user_sessions_for_limiting, load_session_or_fallback},
};
#[derive(Deserialize, Debug)]
@@ -103,11 +103,14 @@ pub(crate) async fn get(
.context("Client not found")
.map_err(InternalError::from_anyhow)?;
let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?;
// Evaluate the policy
let res = policy
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
grant_type: mas_policy::GrantType::DeviceCode,
client: &client,
session_counts: Some(session_counts),
scope: &grant.scope,
user: Some(&session.user),
requester: mas_policy::Requester {
@@ -205,11 +208,14 @@ pub(crate) async fn post(
.context("Client not found")
.map_err(InternalError::from_anyhow)?;
let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?;
// Evaluate the policy
let res = policy
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
grant_type: mas_policy::GrantType::DeviceCode,
client: &client,
session_counts: Some(session_counts),
scope: &grant.scope,
user: Some(&session.user),
requester: mas_policy::Requester {

View File

@@ -781,6 +781,7 @@ async fn client_credentials_grant(
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
user: None,
client,
session_counts: None,
scope: &scope,
grant_type: mas_policy::GrantType::ClientCredentials,
requester: mas_policy::Requester {

View File

@@ -8,9 +8,13 @@
use axum::response::{Html, IntoResponse as _, Response};
use mas_axum_utils::{SessionInfoExt, cookies::CookieJar, csrf::CsrfExt};
use mas_data_model::{BrowserSession, Clock};
use mas_data_model::{BrowserSession, Clock, User};
use mas_i18n::DataLocale;
use mas_storage::{BoxRepository, RepositoryError};
use mas_policy::model::SessionCounts;
use mas_storage::{
BoxRepository, RepositoryError, compat::CompatSessionFilter, oauth2::OAuth2SessionFilter,
personal::PersonalSessionFilter,
};
use mas_templates::{AccountInactiveContext, TemplateContext, Templates};
use rand::RngCore;
use thiserror::Error;
@@ -102,3 +106,62 @@ pub async fn load_session_or_fallback(
maybe_session: Some(session),
})
}
/// Get a count of sessions for the given user, for the purposes of session
/// limiting.
///
/// Includes:
/// - OAuth 2 sessions
/// - Compatibility sessions
/// - Personal sessions (unless owned by a different user)
///
/// # Backstory
///
/// Originally, we were only intending to count sessions with devices in this
/// result, because those are the entries that are expensive for Synapse and
/// also would not hinder use of deviceless clients (like Element Admin, an
/// admin dashboard).
///
/// However, to do so, we would need to count only sessions including device
/// scopes. To do this efficiently, we'd need a partial index on sessions
/// including device scopes.
///
/// It turns out that this can't be done cleanly (as we need to, in Postgres,
/// match scope lists where one of the scopes matches one of 2 known prefixes),
/// at least not without somewhat uncomfortable stored functions.
///
/// So for simplicity's sake, we now count all sessions.
/// For practical use cases, it's not likely to make a noticeable difference
/// (and maybe it's good that there's an overall limit).
pub(crate) async fn count_user_sessions_for_limiting(
repo: &mut BoxRepository,
user: &User,
) -> Result<SessionCounts, RepositoryError> {
let oauth2 = repo
.oauth2_session()
.count(OAuth2SessionFilter::new().active_only().for_user(user))
.await? as u64;
let compat = repo
.compat_session()
.count(CompatSessionFilter::new().active_only().for_user(user))
.await? as u64;
// Only include self-owned personal sessions, not administratively-owned ones
let personal = repo
.personal_session()
.count(
PersonalSessionFilter::new()
.active_only()
.for_actor_user(user)
.for_owner_user(user),
)
.await? as u64;
Ok(SessionCounts {
total: oauth2 + compat + personal,
oauth2,
compat,
personal,
})
}

View File

@@ -85,7 +85,7 @@ pub(crate) async fn policy_factory(
email: "email/violation".to_owned(),
};
let data = mas_policy::Data::new(server_name.to_owned()).with_rest(data);
let data = mas_policy::Data::new(server_name.to_owned(), None).with_rest(data);
let policy_factory = PolicyFactory::load(file, data, entrypoints).await?;
let policy_factory = Arc::new(policy_factory);
@@ -148,6 +148,7 @@ pub fn test_site_config() -> SiteConfig {
session_expiration: None,
login_with_email_allowed: true,
plan_management_iframe_uri: None,
session_limit: None,
}
}

View File

@@ -9,11 +9,12 @@ pub mod model;
use std::sync::Arc;
use arc_swap::ArcSwap;
use mas_data_model::Ulid;
use mas_data_model::{SessionLimitConfig, Ulid};
use opa_wasm::{
Runtime,
wasmtime::{Config, Engine, Module, OptLevel, Store},
};
use serde::Serialize;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt};
@@ -87,16 +88,29 @@ impl Entrypoints {
#[derive(Debug)]
pub struct Data {
base: BaseData,
// We will merge this in a custom way, so don't emit as part of the base
rest: Option<serde_json::Value>,
}
#[derive(Serialize, Debug)]
struct BaseData {
server_name: String,
rest: Option<serde_json::Value>,
/// Limits on the number of application sessions that each user can have
session_limit: Option<SessionLimitConfig>,
}
impl Data {
#[must_use]
pub fn new(server_name: String) -> Self {
pub fn new(server_name: String, session_limit: Option<SessionLimitConfig>) -> Self {
Self {
server_name,
base: BaseData {
server_name,
session_limit,
},
rest: None,
}
}
@@ -108,9 +122,7 @@ impl Data {
}
fn to_value(&self) -> Result<serde_json::Value, anyhow::Error> {
let base = serde_json::json!({
"server_name": self.server_name,
});
let base = serde_json::to_value(&self.base)?;
if let Some(rest) = &self.rest {
merge_data(base, rest.clone())
@@ -458,7 +470,7 @@ mod tests {
#[tokio::test]
async fn test_register() {
let data = Data::new("example.com".to_owned()).with_rest(serde_json::json!({
let data = Data::new("example.com".to_owned(), None).with_rest(serde_json::json!({
"allowed_domains": ["element.io", "*.element.io"],
"banned_domains": ["staging.element.io"],
}));
@@ -528,7 +540,7 @@ mod tests {
#[tokio::test]
async fn test_dynamic_data() {
let data = Data::new("example.com".to_owned());
let data = Data::new("example.com".to_owned(), None);
#[allow(clippy::disallowed_types)]
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
@@ -597,7 +609,7 @@ mod tests {
#[tokio::test]
async fn test_big_dynamic_data() {
let data = Data::new("example.com".to_owned());
let data = Data::new("example.com".to_owned(), None);
#[allow(clippy::disallowed_types)]
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))

View File

@@ -49,6 +49,9 @@ pub enum Code {
/// The email address is banned.
EmailBanned,
/// The user has reached their session limit.
TooManySessions,
}
impl Code {
@@ -66,6 +69,7 @@ impl Code {
Self::EmailDomainBanned => "email-domain-banned",
Self::EmailNotAllowed => "email-not-allowed",
Self::EmailBanned => "email-banned",
Self::TooManySessions => "too-many-sessions",
}
}
}
@@ -168,6 +172,10 @@ pub struct AuthorizationGrantInput<'a> {
#[schemars(with = "Option<std::collections::HashMap<String, serde_json::Value>>")]
pub user: Option<&'a User>,
/// How many sessions the user has.
/// Not populated if it's not a user logging in.
pub session_counts: Option<SessionCounts>,
#[schemars(with = "std::collections::HashMap<String, serde_json::Value>")]
pub client: &'a Client,
@@ -179,6 +187,16 @@ pub struct AuthorizationGrantInput<'a> {
pub requester: Requester,
}
/// Information about how many sessions the user has
#[derive(Serialize, Debug, JsonSchema)]
pub struct SessionCounts {
pub total: u64,
pub oauth2: u64,
pub compat: u64,
pub personal: u64,
}
/// Input for the email add policy.
#[derive(Serialize, Debug, JsonSchema)]
#[serde(rename_all = "snake_case")]

View File

@@ -2830,6 +2830,17 @@
"string",
"null"
]
},
"session_limit": {
"description": "Experimental feature to limit the number of application sessions per\n user.\n\n Disabled by default.",
"anyOf": [
{
"$ref": "#/definitions/SessionLimitConfig"
},
{
"type": "null"
}
]
}
}
},
@@ -2863,6 +2874,26 @@
"required": [
"ttl"
]
},
"SessionLimitConfig": {
"description": "Configuration options for the session limit feature",
"type": "object",
"properties": {
"soft_limit": {
"type": "integer",
"format": "uint64",
"minimum": 1
},
"hard_limit": {
"type": "integer",
"format": "uint64",
"minimum": 1
}
},
"required": [
"soft_limit",
"hard_limit"
]
}
}
}

View File

@@ -153,3 +153,20 @@ violation contains {"msg": sprintf(
)} if {
common.requester_banned(input.requester, data.requester)
}
violation contains {
"code": "too-many-sessions",
"msg": "user has too many active sessions",
} if {
# Only apply if session limits are enabled in the config
data.session_limit != null
# Only apply if it's a user logging in (who therefore has countable sessions)
input.session_counts != null
# For OAuth 2 login, a violation occurs when the soft limit has already been
# reached or exceeded.
# We use the soft limit because the user will be able to interactively remove
# sessions to return under the limit.
data.session_limit.soft_limit <= input.session_counts.total
}

View File

@@ -222,3 +222,35 @@ test_mas_scopes if {
with input.grant_type as "authorization_code"
with input.scope as "urn:mas:admin"
}
test_session_limiting if {
authorization_grant.allow with input.user as user
with input.session_counts as {"total": 1}
with data.session_limit as {"soft_limit": 32, "hard_limit": 64}
authorization_grant.allow with input.user as user
with input.session_counts as {"total": 31}
with data.session_limit as {"soft_limit": 32, "hard_limit": 64}
not authorization_grant.allow with input.user as user
with input.session_counts as {"total": 32}
with data.session_limit as {"soft_limit": 32, "hard_limit": 64}
not authorization_grant.allow with input.user as user
with input.session_counts as {"total": 42}
with data.session_limit as {"soft_limit": 32, "hard_limit": 64}
not authorization_grant.allow with input.user as user
with input.session_counts as {"total": 65}
with data.session_limit as {"soft_limit": 32, "hard_limit": 64}
# No limit configured
authorization_grant.allow with input.user as user
with input.session_counts as {"total": 1}
with data.session_limit as null
# Client credentials grant
authorization_grant.allow with input.user as user
with input.session_counts as null
with data.session_limit as {"soft_limit": 32, "hard_limit": 64}
}

View File

@@ -11,6 +11,17 @@
],
"additionalProperties": true
},
"session_counts": {
"description": "How many sessions the user has.\n Not populated if it's not a user logging in.",
"anyOf": [
{
"$ref": "#/definitions/SessionCounts"
},
{
"type": "null"
}
]
},
"client": {
"type": "object",
"additionalProperties": true
@@ -32,6 +43,38 @@
"requester"
],
"definitions": {
"SessionCounts": {
"description": "Information about how many sessions the user has",
"type": "object",
"properties": {
"total": {
"type": "integer",
"format": "uint64",
"minimum": 0
},
"oauth2": {
"type": "integer",
"format": "uint64",
"minimum": 0
},
"compat": {
"type": "integer",
"format": "uint64",
"minimum": 0
},
"personal": {
"type": "integer",
"format": "uint64",
"minimum": 0
}
},
"required": [
"total",
"oauth2",
"compat",
"personal"
]
},
"GrantType": {
"type": "string",
"enum": [

View File

@@ -499,7 +499,7 @@
"context": "pages/policy_violation.html:19:25-62",
"description": "Displayed when an authorization request is denied by the policy"
},
"heading": "The authorization request was denied the policy enforced by this service",
"heading": "The authorization request was denied by the policy enforced by this service",
"@heading": {
"context": "pages/policy_violation.html:18:27-60",
"description": "Displayed when an authorization request is denied by the policy"