Support MSC4198 login_hint in account management uri (#5516)

This commit is contained in:
Quentin Gliech
2026-03-03 16:25:35 +01:00
committed by GitHub
10 changed files with 177 additions and 170 deletions

3
Cargo.lock generated
View File

@@ -3268,14 +3268,12 @@ dependencies = [
"base64ct",
"chrono",
"crc",
"lettre",
"mas-iana",
"mas-jose",
"oauth2-types",
"rand 0.8.5",
"rand_chacha 0.3.1",
"regex",
"ruma-common",
"serde",
"serde_json",
"thiserror 2.0.17",
@@ -3352,6 +3350,7 @@ dependencies = [
"rand 0.8.5",
"rand_chacha 0.3.1",
"reqwest",
"ruma-common",
"rustls",
"schemars 0.9.0",
"sentry",

View File

@@ -29,8 +29,6 @@ rand.workspace = true
rand_chacha.workspace = true
regex.workspace = true
woothee.workspace = true
ruma-common.workspace = true
lettre.workspace = true
mas-iana.workspace = true
mas-jose.workspace = true

View File

@@ -4,8 +4,6 @@
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::str::FromStr as _;
use chrono::{DateTime, Utc};
use mas_iana::oauth::PkceCodeChallengeMethod;
use oauth2_types::{
@@ -17,7 +15,6 @@ use rand::{
RngCore,
distributions::{Alphanumeric, DistString},
};
use ruma_common::UserId;
use serde::Serialize;
use ulid::Ulid;
use url::Url;
@@ -142,12 +139,6 @@ impl AuthorizationGrantStage {
}
}
pub enum LoginHint<'a> {
MXID(&'a UserId),
Email(lettre::Address),
None,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct AuthorizationGrant {
pub id: Ulid,
@@ -175,31 +166,6 @@ impl std::ops::Deref for AuthorizationGrant {
}
impl AuthorizationGrant {
/// Parse a `login_hint`
///
/// Returns `LoginHint::MXID` for valid mxid 'mxid:@john.doe:example.com'
///
/// Returns `LoginHint::Email` for valid email 'john.doe@example.com'
///
/// Otherwise returns `LoginHint::None`
#[must_use]
pub fn parse_login_hint(&self, homeserver: &str) -> LoginHint<'_> {
let Some(login_hint) = &self.login_hint else {
return LoginHint::None;
};
if let Some(value) = login_hint.strip_prefix("mxid:")
&& let Ok(mxid) = <&UserId>::try_from(value)
&& mxid.server_name() == homeserver
{
LoginHint::MXID(mxid)
} else if let Ok(email) = lettre::Address::from_str(login_hint) {
LoginHint::Email(email)
} else {
LoginHint::None
}
}
/// Mark the authorization grant as exchanged.
///
/// # Errors
@@ -266,101 +232,3 @@ impl AuthorizationGrant {
}
}
}
#[cfg(test)]
mod tests {
use rand::SeedableRng;
use super::*;
use crate::clock::{Clock, MockClock};
#[test]
fn no_login_hint() {
let now = MockClock::default().now();
let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
let grant = AuthorizationGrant {
login_hint: None,
..AuthorizationGrant::sample(now, &mut rng)
};
let hint = grant.parse_login_hint("example.com");
assert!(matches!(hint, LoginHint::None));
}
#[test]
fn valid_login_hint() {
let now = MockClock::default().now();
let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
let grant = AuthorizationGrant {
login_hint: Some(String::from("mxid:@example-user:example.com")),
..AuthorizationGrant::sample(now, &mut rng)
};
let hint = grant.parse_login_hint("example.com");
assert!(matches!(hint, LoginHint::MXID(mxid) if mxid.localpart() == "example-user"));
}
#[test]
fn valid_login_hint_with_email() {
let now = MockClock::default().now();
let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
let grant = AuthorizationGrant {
login_hint: Some(String::from("example@user")),
..AuthorizationGrant::sample(now, &mut rng)
};
let hint = grant.parse_login_hint("example.com");
assert!(matches!(hint, LoginHint::Email(email) if email.to_string() == "example@user"));
}
#[test]
fn invalid_login_hint() {
let now = MockClock::default().now();
let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
let grant = AuthorizationGrant {
login_hint: Some(String::from("example-user")),
..AuthorizationGrant::sample(now, &mut rng)
};
let hint = grant.parse_login_hint("example.com");
assert!(matches!(hint, LoginHint::None));
}
#[test]
fn valid_login_hint_for_wrong_homeserver() {
let now = MockClock::default().now();
let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
let grant = AuthorizationGrant {
login_hint: Some(String::from("mxid:@example-user:matrix.org")),
..AuthorizationGrant::sample(now, &mut rng)
};
let hint = grant.parse_login_hint("example.com");
assert!(matches!(hint, LoginHint::None));
}
#[test]
fn unknown_login_hint_type() {
let now = MockClock::default().now();
let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
let grant = AuthorizationGrant {
login_hint: Some(String::from("something:anything")),
..AuthorizationGrant::sample(now, &mut rng)
};
let hint = grant.parse_login_hint("example.com");
assert!(matches!(hint, LoginHint::None));
}
}

View File

@@ -10,9 +10,7 @@ mod device_code_grant;
mod session;
pub use self::{
authorization_grant::{
AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, LoginHint, Pkce,
},
authorization_grant::{AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Pkce},
client::{Client, InvalidRedirectUriError, JwksOrJwksUri},
device_code_grant::{DeviceCodeGrant, DeviceCodeGrantState},
session::{Session, SessionState},

View File

@@ -49,6 +49,7 @@ psl.workspace = true
rand_chacha.workspace = true
rand.workspace = true
reqwest.workspace = true
ruma-common.workspace = true
rustls.workspace = true
schemars.workspace = true
sentry.workspace = true

View File

@@ -278,9 +278,15 @@ pub(crate) async fn get(
// Other cases where we don't have a session, ask for a login
repo.save().await?;
url_builder
.redirect(&mas_router::Login::and_then(continue_grant))
.into_response()
let mut url = mas_router::Login::and_then(continue_grant);
url = if let Some(login_hint) = grant.login_hint {
url.with_login_hint(login_hint)
} else {
url
};
url_builder.redirect(&url).into_response()
}
Some(user_session) => {

View File

@@ -25,6 +25,9 @@ use crate::{
pub struct Params {
#[serde(default, flatten)]
action: Option<mas_router::AccountAction>,
#[serde(rename = "org.matrix.msc4198.login_hint")]
unstable_login_hint: Option<String>,
}
#[tracing::instrument(name = "handlers.views.app.get", skip_all)]
@@ -33,7 +36,10 @@ pub async fn get(
State(templates): State<Templates>,
activity_tracker: BoundActivityTracker,
State(url_builder): State<UrlBuilder>,
Query(Params { action }): Query<Params>,
Query(Params {
action,
unstable_login_hint,
}): Query<Params>,
mut repo: BoxRepository,
clock: BoxClock,
mut rng: BoxRng,
@@ -54,13 +60,13 @@ pub async fn get(
// TODO: keep the full path, not just the action
let Some(session) = maybe_session else {
return Ok((
cookie_jar,
url_builder.redirect(&mas_router::Login::and_then(
PostAuthAction::manage_account(action),
)),
)
.into_response());
let mut url = mas_router::Login::and_then(PostAuthAction::manage_account(action));
if let Some(login_hint) = unstable_login_hint {
url = url.with_login_hint(login_hint);
}
return Ok((cookie_jar, url_builder.redirect(&url)).into_response());
};
activity_tracker

View File

@@ -17,7 +17,7 @@ use mas_axum_utils::{
cookies::CookieJar,
csrf::{CsrfExt, ProtectedForm},
};
use mas_data_model::{BoxClock, BoxRng, Clock, oauth2::LoginHint};
use mas_data_model::{BoxClock, BoxRng, Clock};
use mas_i18n::DataLocale;
use mas_matrix::HomeserverConnection;
use mas_router::{UpstreamOAuth2Authorize, UrlBuilder};
@@ -28,14 +28,14 @@ use mas_storage::{
};
use mas_templates::{
AccountInactiveContext, FieldError, FormError, FormState, LoginContext, LoginFormField,
PostAuthContext, PostAuthContextInner, TemplateContext, Templates, ToFormState,
TemplateContext, Templates, ToFormState,
};
use opentelemetry::{Key, KeyValue, metrics::Counter};
use rand::Rng;
use serde::{Deserialize, Serialize};
use zeroize::Zeroizing;
use super::shared::OptionalPostAuthAction;
use super::shared::{LoginHint, OptionalPostAuthAction, QueryLoginHint};
use crate::{
BoundActivityTracker, Limiter, METER, PreferredLanguage, RequesterFingerprint, SiteConfig,
passwords::{PasswordManager, PasswordVerificationResult},
@@ -73,6 +73,7 @@ pub(crate) async fn get(
mut repo: BoxRepository,
activity_tracker: BoundActivityTracker,
Query(query): Query<OptionalPostAuthAction>,
Query(query_login_hint): Query<QueryLoginHint>,
cookie_jar: CookieJar,
) -> Result<Response, InternalError> {
let (cookie_jar, maybe_session) = match load_session_or_fallback(
@@ -124,6 +125,7 @@ pub(crate) async fn get(
&templates,
&homeserver,
&site_config,
query_login_hint,
)
.await
}
@@ -142,7 +144,7 @@ pub(crate) async fn post(
mut repo: BoxRepository,
activity_tracker: BoundActivityTracker,
requester: RequesterFingerprint,
Query(query): Query<OptionalPostAuthAction>,
(Query(query), Query(query_login_hint)): (Query<OptionalPostAuthAction>, Query<QueryLoginHint>),
cookie_jar: CookieJar,
user_agent: Option<TypedHeader<headers::UserAgent>>,
Form(form): Form<ProtectedForm<LoginForm>>,
@@ -180,6 +182,7 @@ pub(crate) async fn post(
&templates,
&homeserver,
&site_config,
query_login_hint,
)
.await;
}
@@ -206,6 +209,7 @@ pub(crate) async fn post(
&templates,
&homeserver,
&site_config,
query_login_hint,
)
.await;
};
@@ -226,6 +230,7 @@ pub(crate) async fn post(
&templates,
&homeserver,
&site_config,
query_login_hint,
)
.await;
}
@@ -248,6 +253,7 @@ pub(crate) async fn post(
&templates,
&homeserver,
&site_config,
query_login_hint,
)
.await;
};
@@ -293,6 +299,7 @@ pub(crate) async fn post(
&templates,
&homeserver,
&site_config,
query_login_hint,
)
.await;
}
@@ -375,7 +382,7 @@ async fn get_user_by_email_or_by_username<R: RepositoryAccess>(
fn handle_login_hint(
mut ctx: LoginContext,
next: &PostAuthContext,
query_login_hint: &QueryLoginHint,
homeserver: &dyn HomeserverConnection,
site_config: &SiteConfig,
) -> LoginContext {
@@ -386,16 +393,12 @@ fn handle_login_hint(
return ctx;
}
if let PostAuthContextInner::ContinueAuthorizationGrant { ref grant } = next.ctx {
let value = match grant.parse_login_hint(homeserver.homeserver()) {
LoginHint::MXID(mxid) => Some(mxid.localpart().to_owned()),
LoginHint::Email(email) if site_config.login_with_email_allowed => {
Some(email.to_string())
}
_ => None,
};
form_state.set_value(LoginFormField::Username, value);
}
let value = match query_login_hint.parse_login_hint(homeserver.homeserver()) {
LoginHint::Mxid(mxid) => Some(mxid.localpart().to_owned()),
LoginHint::Email(email) if site_config.login_with_email_allowed => Some(email.to_string()),
_ => None,
};
form_state.set_value(LoginFormField::Username, value);
ctx
}
@@ -411,6 +414,7 @@ async fn render(
templates: &Templates,
homeserver: &dyn HomeserverConnection,
site_config: &SiteConfig,
query_login_hint: QueryLoginHint,
) -> Result<Response, InternalError> {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock, rng);
let providers = repo.upstream_oauth_provider().all_enabled().await?;
@@ -419,12 +423,13 @@ async fn render(
.with_form_state(form_state)
.with_upstream_providers(providers);
let ctx = handle_login_hint(ctx, &query_login_hint, homeserver, site_config);
let next = action
.load_context(repo)
.await
.map_err(InternalError::from_anyhow)?;
let ctx = if let Some(next) = next {
let ctx = handle_login_hint(ctx, &next, homeserver, site_config);
ctx.with_post_action(next)
} else {
ctx

View File

@@ -4,6 +4,8 @@
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::str::FromStr as _;
use anyhow::Context;
use mas_router::{PostAuthAction, Route, UrlBuilder};
use mas_storage::{
@@ -13,6 +15,7 @@ use mas_storage::{
upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository},
};
use mas_templates::{PostAuthContext, PostAuthContextInner};
use ruma_common::UserId;
use serde::{Deserialize, Serialize};
use tracing::warn;
@@ -107,3 +110,109 @@ impl OptionalPostAuthAction {
}))
}
}
pub enum LoginHint<'a> {
Mxid(&'a UserId),
Email(lettre::Address),
None,
}
#[derive(Debug, Deserialize)]
pub(crate) struct QueryLoginHint {
login_hint: Option<String>,
}
impl QueryLoginHint {
/// Parse a `login_hint`
///
/// Returns `LoginHint::MXID` for valid mxid 'mxid:@john.doe:example.com'
///
/// Returns `LoginHint::Email` for valid email 'john.doe@example.com'
///
/// Otherwise returns `LoginHint::None`
pub fn parse_login_hint(&self, homeserver: &str) -> LoginHint<'_> {
let Some(login_hint) = &self.login_hint else {
return LoginHint::None;
};
if let Some(value) = login_hint.strip_prefix("mxid:")
&& let Ok(mxid) = <&UserId>::try_from(value)
&& mxid.server_name() == homeserver
{
LoginHint::Mxid(mxid)
} else if let Ok(email) = lettre::Address::from_str(login_hint) {
LoginHint::Email(email)
} else {
LoginHint::None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn no_login_hint() {
let query_login_hint = QueryLoginHint { login_hint: None };
let hint = query_login_hint.parse_login_hint("example.com");
assert!(matches!(hint, LoginHint::None));
}
#[test]
fn valid_login_hint() {
let query_login_hint = QueryLoginHint {
login_hint: Some(String::from("mxid:@example-user:example.com")),
};
let hint = query_login_hint.parse_login_hint("example.com");
assert!(matches!(hint, LoginHint::Mxid(mxid) if mxid.localpart() == "example-user"));
}
#[test]
fn valid_login_hint_with_email() {
let query_login_hint = QueryLoginHint {
login_hint: Some(String::from("example@user")),
};
let hint = query_login_hint.parse_login_hint("example.com");
assert!(matches!(hint, LoginHint::Email(email) if email.to_string() == "example@user"));
}
#[test]
fn invalid_login_hint() {
let query_login_hint = QueryLoginHint {
login_hint: Some(String::from("example-user")),
};
let hint = query_login_hint.parse_login_hint("example.com");
assert!(matches!(hint, LoginHint::None));
}
#[test]
fn valid_login_hint_for_wrong_homeserver() {
let query_login_hint = QueryLoginHint {
login_hint: Some(String::from("mxid:@example-user:matrix.org")),
};
let hint = query_login_hint.parse_login_hint("example.com");
assert!(matches!(hint, LoginHint::None));
}
#[test]
fn unknown_login_hint_type() {
let query_login_hint = QueryLoginHint {
login_hint: Some(String::from("something:anything")),
};
let hint = query_login_hint.parse_login_hint("example.com");
assert!(matches!(hint, LoginHint::None));
}
}

View File

@@ -172,20 +172,23 @@ impl SimpleRoute for Healthcheck {
}
/// `GET|POST /login`
#[derive(Default, Debug, Clone)]
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
pub struct Login {
#[serde(flatten)]
post_auth_action: Option<PostAuthAction>,
login_hint: Option<String>,
}
impl Route for Login {
type Query = PostAuthAction;
type Query = Self;
fn route() -> &'static str {
"/login"
}
fn query(&self) -> Option<&Self::Query> {
self.post_auth_action.as_ref()
Some(self)
}
}
@@ -194,6 +197,7 @@ impl Login {
pub const fn and_then(action: PostAuthAction) -> Self {
Self {
post_auth_action: Some(action),
login_hint: None,
}
}
@@ -201,6 +205,7 @@ impl Login {
pub const fn and_continue_grant(id: Ulid) -> Self {
Self {
post_auth_action: Some(PostAuthAction::continue_grant(id)),
login_hint: None,
}
}
@@ -208,6 +213,7 @@ impl Login {
pub const fn and_continue_device_code_grant(id: Ulid) -> Self {
Self {
post_auth_action: Some(PostAuthAction::continue_device_code_grant(id)),
login_hint: None,
}
}
@@ -215,6 +221,7 @@ impl Login {
pub const fn and_continue_compat_sso_login(id: Ulid) -> Self {
Self {
post_auth_action: Some(PostAuthAction::continue_compat_sso_login(id)),
login_hint: None,
}
}
@@ -222,9 +229,16 @@ impl Login {
pub const fn and_link_upstream(id: Ulid) -> Self {
Self {
post_auth_action: Some(PostAuthAction::link_upstream(id)),
login_hint: None,
}
}
#[must_use]
pub fn with_login_hint(mut self, login_hint: String) -> Self {
self.login_hint = Some(login_hint);
self
}
/// Get a reference to the login's post auth action.
#[must_use]
pub fn post_auth_action(&self) -> Option<&PostAuthAction> {
@@ -241,7 +255,10 @@ impl Login {
impl From<Option<PostAuthAction>> for Login {
fn from(post_auth_action: Option<PostAuthAction>) -> Self {
Self { post_auth_action }
Self {
post_auth_action,
login_hint: None,
}
}
}