diff --git a/Cargo.lock b/Cargo.lock index ce1a9c354..83b9fd35a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3268,14 +3268,12 @@ dependencies = [ "base64ct", "chrono", "crc", - "lettre", "mas-iana", "mas-jose", "oauth2-types", "rand 0.8.5", "rand_chacha 0.3.1", "regex", - "ruma-common", "serde", "serde_json", "thiserror 2.0.17", @@ -3352,6 +3350,7 @@ dependencies = [ "rand 0.8.5", "rand_chacha 0.3.1", "reqwest", + "ruma-common", "rustls", "schemars 0.9.0", "sentry", diff --git a/crates/data-model/Cargo.toml b/crates/data-model/Cargo.toml index da7021b11..7e9be8282 100644 --- a/crates/data-model/Cargo.toml +++ b/crates/data-model/Cargo.toml @@ -29,8 +29,6 @@ rand.workspace = true rand_chacha.workspace = true regex.workspace = true woothee.workspace = true -ruma-common.workspace = true -lettre.workspace = true mas-iana.workspace = true mas-jose.workspace = true diff --git a/crates/data-model/src/oauth2/authorization_grant.rs b/crates/data-model/src/oauth2/authorization_grant.rs index 738277b84..2ad0488ce 100644 --- a/crates/data-model/src/oauth2/authorization_grant.rs +++ b/crates/data-model/src/oauth2/authorization_grant.rs @@ -4,8 +4,6 @@ // SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial // Please see LICENSE files in the repository root for full details. -use std::str::FromStr as _; - use chrono::{DateTime, Utc}; use mas_iana::oauth::PkceCodeChallengeMethod; use oauth2_types::{ @@ -17,7 +15,6 @@ use rand::{ RngCore, distributions::{Alphanumeric, DistString}, }; -use ruma_common::UserId; use serde::Serialize; use ulid::Ulid; use url::Url; @@ -142,12 +139,6 @@ impl AuthorizationGrantStage { } } -pub enum LoginHint<'a> { - MXID(&'a UserId), - Email(lettre::Address), - None, -} - #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct AuthorizationGrant { pub id: Ulid, @@ -175,31 +166,6 @@ impl std::ops::Deref for AuthorizationGrant { } impl AuthorizationGrant { - /// Parse a `login_hint` - /// - /// Returns `LoginHint::MXID` for valid mxid 'mxid:@john.doe:example.com' - /// - /// Returns `LoginHint::Email` for valid email 'john.doe@example.com' - /// - /// Otherwise returns `LoginHint::None` - #[must_use] - pub fn parse_login_hint(&self, homeserver: &str) -> LoginHint<'_> { - let Some(login_hint) = &self.login_hint else { - return LoginHint::None; - }; - - if let Some(value) = login_hint.strip_prefix("mxid:") - && let Ok(mxid) = <&UserId>::try_from(value) - && mxid.server_name() == homeserver - { - LoginHint::MXID(mxid) - } else if let Ok(email) = lettre::Address::from_str(login_hint) { - LoginHint::Email(email) - } else { - LoginHint::None - } - } - /// Mark the authorization grant as exchanged. /// /// # Errors @@ -266,101 +232,3 @@ impl AuthorizationGrant { } } } - -#[cfg(test)] -mod tests { - use rand::SeedableRng; - - use super::*; - use crate::clock::{Clock, MockClock}; - - #[test] - fn no_login_hint() { - let now = MockClock::default().now(); - let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42); - - let grant = AuthorizationGrant { - login_hint: None, - ..AuthorizationGrant::sample(now, &mut rng) - }; - - let hint = grant.parse_login_hint("example.com"); - - assert!(matches!(hint, LoginHint::None)); - } - - #[test] - fn valid_login_hint() { - let now = MockClock::default().now(); - let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42); - - let grant = AuthorizationGrant { - login_hint: Some(String::from("mxid:@example-user:example.com")), - ..AuthorizationGrant::sample(now, &mut rng) - }; - - let hint = grant.parse_login_hint("example.com"); - - assert!(matches!(hint, LoginHint::MXID(mxid) if mxid.localpart() == "example-user")); - } - - #[test] - fn valid_login_hint_with_email() { - let now = MockClock::default().now(); - let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42); - - let grant = AuthorizationGrant { - login_hint: Some(String::from("example@user")), - ..AuthorizationGrant::sample(now, &mut rng) - }; - - let hint = grant.parse_login_hint("example.com"); - - assert!(matches!(hint, LoginHint::Email(email) if email.to_string() == "example@user")); - } - - #[test] - fn invalid_login_hint() { - let now = MockClock::default().now(); - let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42); - - let grant = AuthorizationGrant { - login_hint: Some(String::from("example-user")), - ..AuthorizationGrant::sample(now, &mut rng) - }; - - let hint = grant.parse_login_hint("example.com"); - - assert!(matches!(hint, LoginHint::None)); - } - - #[test] - fn valid_login_hint_for_wrong_homeserver() { - let now = MockClock::default().now(); - let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42); - - let grant = AuthorizationGrant { - login_hint: Some(String::from("mxid:@example-user:matrix.org")), - ..AuthorizationGrant::sample(now, &mut rng) - }; - - let hint = grant.parse_login_hint("example.com"); - - assert!(matches!(hint, LoginHint::None)); - } - - #[test] - fn unknown_login_hint_type() { - let now = MockClock::default().now(); - let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42); - - let grant = AuthorizationGrant { - login_hint: Some(String::from("something:anything")), - ..AuthorizationGrant::sample(now, &mut rng) - }; - - let hint = grant.parse_login_hint("example.com"); - - assert!(matches!(hint, LoginHint::None)); - } -} diff --git a/crates/data-model/src/oauth2/mod.rs b/crates/data-model/src/oauth2/mod.rs index 6221a32fc..798e07e54 100644 --- a/crates/data-model/src/oauth2/mod.rs +++ b/crates/data-model/src/oauth2/mod.rs @@ -10,9 +10,7 @@ mod device_code_grant; mod session; pub use self::{ - authorization_grant::{ - AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, LoginHint, Pkce, - }, + authorization_grant::{AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Pkce}, client::{Client, InvalidRedirectUriError, JwksOrJwksUri}, device_code_grant::{DeviceCodeGrant, DeviceCodeGrantState}, session::{Session, SessionState}, diff --git a/crates/handlers/Cargo.toml b/crates/handlers/Cargo.toml index 57f391854..97710930b 100644 --- a/crates/handlers/Cargo.toml +++ b/crates/handlers/Cargo.toml @@ -49,6 +49,7 @@ psl.workspace = true rand_chacha.workspace = true rand.workspace = true reqwest.workspace = true +ruma-common.workspace = true rustls.workspace = true schemars.workspace = true sentry.workspace = true diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index 6037bd9ad..a3988b0ce 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -278,9 +278,15 @@ pub(crate) async fn get( // Other cases where we don't have a session, ask for a login repo.save().await?; - url_builder - .redirect(&mas_router::Login::and_then(continue_grant)) - .into_response() + let mut url = mas_router::Login::and_then(continue_grant); + + url = if let Some(login_hint) = grant.login_hint { + url.with_login_hint(login_hint) + } else { + url + }; + + url_builder.redirect(&url).into_response() } Some(user_session) => { diff --git a/crates/handlers/src/views/app.rs b/crates/handlers/src/views/app.rs index 4ae5f5222..36c9cd4f9 100644 --- a/crates/handlers/src/views/app.rs +++ b/crates/handlers/src/views/app.rs @@ -25,6 +25,9 @@ use crate::{ pub struct Params { #[serde(default, flatten)] action: Option, + + #[serde(rename = "org.matrix.msc4198.login_hint")] + unstable_login_hint: Option, } #[tracing::instrument(name = "handlers.views.app.get", skip_all)] @@ -33,7 +36,10 @@ pub async fn get( State(templates): State, activity_tracker: BoundActivityTracker, State(url_builder): State, - Query(Params { action }): Query, + Query(Params { + action, + unstable_login_hint, + }): Query, mut repo: BoxRepository, clock: BoxClock, mut rng: BoxRng, @@ -54,13 +60,13 @@ pub async fn get( // TODO: keep the full path, not just the action let Some(session) = maybe_session else { - return Ok(( - cookie_jar, - url_builder.redirect(&mas_router::Login::and_then( - PostAuthAction::manage_account(action), - )), - ) - .into_response()); + let mut url = mas_router::Login::and_then(PostAuthAction::manage_account(action)); + + if let Some(login_hint) = unstable_login_hint { + url = url.with_login_hint(login_hint); + } + + return Ok((cookie_jar, url_builder.redirect(&url)).into_response()); }; activity_tracker diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index 72e1566fe..a4fef8eba 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -17,7 +17,7 @@ use mas_axum_utils::{ cookies::CookieJar, csrf::{CsrfExt, ProtectedForm}, }; -use mas_data_model::{BoxClock, BoxRng, Clock, oauth2::LoginHint}; +use mas_data_model::{BoxClock, BoxRng, Clock}; use mas_i18n::DataLocale; use mas_matrix::HomeserverConnection; use mas_router::{UpstreamOAuth2Authorize, UrlBuilder}; @@ -28,14 +28,14 @@ use mas_storage::{ }; use mas_templates::{ AccountInactiveContext, FieldError, FormError, FormState, LoginContext, LoginFormField, - PostAuthContext, PostAuthContextInner, TemplateContext, Templates, ToFormState, + TemplateContext, Templates, ToFormState, }; use opentelemetry::{Key, KeyValue, metrics::Counter}; use rand::Rng; use serde::{Deserialize, Serialize}; use zeroize::Zeroizing; -use super::shared::OptionalPostAuthAction; +use super::shared::{LoginHint, OptionalPostAuthAction, QueryLoginHint}; use crate::{ BoundActivityTracker, Limiter, METER, PreferredLanguage, RequesterFingerprint, SiteConfig, passwords::{PasswordManager, PasswordVerificationResult}, @@ -73,6 +73,7 @@ pub(crate) async fn get( mut repo: BoxRepository, activity_tracker: BoundActivityTracker, Query(query): Query, + Query(query_login_hint): Query, cookie_jar: CookieJar, ) -> Result { let (cookie_jar, maybe_session) = match load_session_or_fallback( @@ -124,6 +125,7 @@ pub(crate) async fn get( &templates, &homeserver, &site_config, + query_login_hint, ) .await } @@ -142,7 +144,7 @@ pub(crate) async fn post( mut repo: BoxRepository, activity_tracker: BoundActivityTracker, requester: RequesterFingerprint, - Query(query): Query, + (Query(query), Query(query_login_hint)): (Query, Query), cookie_jar: CookieJar, user_agent: Option>, Form(form): Form>, @@ -180,6 +182,7 @@ pub(crate) async fn post( &templates, &homeserver, &site_config, + query_login_hint, ) .await; } @@ -206,6 +209,7 @@ pub(crate) async fn post( &templates, &homeserver, &site_config, + query_login_hint, ) .await; }; @@ -226,6 +230,7 @@ pub(crate) async fn post( &templates, &homeserver, &site_config, + query_login_hint, ) .await; } @@ -248,6 +253,7 @@ pub(crate) async fn post( &templates, &homeserver, &site_config, + query_login_hint, ) .await; }; @@ -293,6 +299,7 @@ pub(crate) async fn post( &templates, &homeserver, &site_config, + query_login_hint, ) .await; } @@ -375,7 +382,7 @@ async fn get_user_by_email_or_by_username( fn handle_login_hint( mut ctx: LoginContext, - next: &PostAuthContext, + query_login_hint: &QueryLoginHint, homeserver: &dyn HomeserverConnection, site_config: &SiteConfig, ) -> LoginContext { @@ -386,16 +393,12 @@ fn handle_login_hint( return ctx; } - if let PostAuthContextInner::ContinueAuthorizationGrant { ref grant } = next.ctx { - let value = match grant.parse_login_hint(homeserver.homeserver()) { - LoginHint::MXID(mxid) => Some(mxid.localpart().to_owned()), - LoginHint::Email(email) if site_config.login_with_email_allowed => { - Some(email.to_string()) - } - _ => None, - }; - form_state.set_value(LoginFormField::Username, value); - } + let value = match query_login_hint.parse_login_hint(homeserver.homeserver()) { + LoginHint::Mxid(mxid) => Some(mxid.localpart().to_owned()), + LoginHint::Email(email) if site_config.login_with_email_allowed => Some(email.to_string()), + _ => None, + }; + form_state.set_value(LoginFormField::Username, value); ctx } @@ -411,6 +414,7 @@ async fn render( templates: &Templates, homeserver: &dyn HomeserverConnection, site_config: &SiteConfig, + query_login_hint: QueryLoginHint, ) -> Result { let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock, rng); let providers = repo.upstream_oauth_provider().all_enabled().await?; @@ -419,12 +423,13 @@ async fn render( .with_form_state(form_state) .with_upstream_providers(providers); + let ctx = handle_login_hint(ctx, &query_login_hint, homeserver, site_config); + let next = action .load_context(repo) .await .map_err(InternalError::from_anyhow)?; let ctx = if let Some(next) = next { - let ctx = handle_login_hint(ctx, &next, homeserver, site_config); ctx.with_post_action(next) } else { ctx diff --git a/crates/handlers/src/views/shared.rs b/crates/handlers/src/views/shared.rs index 85edf299f..b0c7b8ac1 100644 --- a/crates/handlers/src/views/shared.rs +++ b/crates/handlers/src/views/shared.rs @@ -4,6 +4,8 @@ // SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial // Please see LICENSE files in the repository root for full details. +use std::str::FromStr as _; + use anyhow::Context; use mas_router::{PostAuthAction, Route, UrlBuilder}; use mas_storage::{ @@ -13,6 +15,7 @@ use mas_storage::{ upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository}, }; use mas_templates::{PostAuthContext, PostAuthContextInner}; +use ruma_common::UserId; use serde::{Deserialize, Serialize}; use tracing::warn; @@ -107,3 +110,109 @@ impl OptionalPostAuthAction { })) } } + +pub enum LoginHint<'a> { + Mxid(&'a UserId), + Email(lettre::Address), + None, +} + +#[derive(Debug, Deserialize)] +pub(crate) struct QueryLoginHint { + login_hint: Option, +} + +impl QueryLoginHint { + /// Parse a `login_hint` + /// + /// Returns `LoginHint::MXID` for valid mxid 'mxid:@john.doe:example.com' + /// + /// Returns `LoginHint::Email` for valid email 'john.doe@example.com' + /// + /// Otherwise returns `LoginHint::None` + pub fn parse_login_hint(&self, homeserver: &str) -> LoginHint<'_> { + let Some(login_hint) = &self.login_hint else { + return LoginHint::None; + }; + + if let Some(value) = login_hint.strip_prefix("mxid:") + && let Ok(mxid) = <&UserId>::try_from(value) + && mxid.server_name() == homeserver + { + LoginHint::Mxid(mxid) + } else if let Ok(email) = lettre::Address::from_str(login_hint) { + LoginHint::Email(email) + } else { + LoginHint::None + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn no_login_hint() { + let query_login_hint = QueryLoginHint { login_hint: None }; + + let hint = query_login_hint.parse_login_hint("example.com"); + + assert!(matches!(hint, LoginHint::None)); + } + + #[test] + fn valid_login_hint() { + let query_login_hint = QueryLoginHint { + login_hint: Some(String::from("mxid:@example-user:example.com")), + }; + + let hint = query_login_hint.parse_login_hint("example.com"); + + assert!(matches!(hint, LoginHint::Mxid(mxid) if mxid.localpart() == "example-user")); + } + + #[test] + fn valid_login_hint_with_email() { + let query_login_hint = QueryLoginHint { + login_hint: Some(String::from("example@user")), + }; + + let hint = query_login_hint.parse_login_hint("example.com"); + + assert!(matches!(hint, LoginHint::Email(email) if email.to_string() == "example@user")); + } + + #[test] + fn invalid_login_hint() { + let query_login_hint = QueryLoginHint { + login_hint: Some(String::from("example-user")), + }; + + let hint = query_login_hint.parse_login_hint("example.com"); + + assert!(matches!(hint, LoginHint::None)); + } + + #[test] + fn valid_login_hint_for_wrong_homeserver() { + let query_login_hint = QueryLoginHint { + login_hint: Some(String::from("mxid:@example-user:matrix.org")), + }; + + let hint = query_login_hint.parse_login_hint("example.com"); + + assert!(matches!(hint, LoginHint::None)); + } + + #[test] + fn unknown_login_hint_type() { + let query_login_hint = QueryLoginHint { + login_hint: Some(String::from("something:anything")), + }; + + let hint = query_login_hint.parse_login_hint("example.com"); + + assert!(matches!(hint, LoginHint::None)); + } +} diff --git a/crates/router/src/endpoints.rs b/crates/router/src/endpoints.rs index 6aa18f13d..239f24b25 100644 --- a/crates/router/src/endpoints.rs +++ b/crates/router/src/endpoints.rs @@ -172,20 +172,23 @@ impl SimpleRoute for Healthcheck { } /// `GET|POST /login` -#[derive(Default, Debug, Clone)] +#[derive(Default, Debug, Clone, Serialize, Deserialize)] pub struct Login { + #[serde(flatten)] post_auth_action: Option, + + login_hint: Option, } impl Route for Login { - type Query = PostAuthAction; + type Query = Self; fn route() -> &'static str { "/login" } fn query(&self) -> Option<&Self::Query> { - self.post_auth_action.as_ref() + Some(self) } } @@ -194,6 +197,7 @@ impl Login { pub const fn and_then(action: PostAuthAction) -> Self { Self { post_auth_action: Some(action), + login_hint: None, } } @@ -201,6 +205,7 @@ impl Login { pub const fn and_continue_grant(id: Ulid) -> Self { Self { post_auth_action: Some(PostAuthAction::continue_grant(id)), + login_hint: None, } } @@ -208,6 +213,7 @@ impl Login { pub const fn and_continue_device_code_grant(id: Ulid) -> Self { Self { post_auth_action: Some(PostAuthAction::continue_device_code_grant(id)), + login_hint: None, } } @@ -215,6 +221,7 @@ impl Login { pub const fn and_continue_compat_sso_login(id: Ulid) -> Self { Self { post_auth_action: Some(PostAuthAction::continue_compat_sso_login(id)), + login_hint: None, } } @@ -222,9 +229,16 @@ impl Login { pub const fn and_link_upstream(id: Ulid) -> Self { Self { post_auth_action: Some(PostAuthAction::link_upstream(id)), + login_hint: None, } } + #[must_use] + pub fn with_login_hint(mut self, login_hint: String) -> Self { + self.login_hint = Some(login_hint); + self + } + /// Get a reference to the login's post auth action. #[must_use] pub fn post_auth_action(&self) -> Option<&PostAuthAction> { @@ -241,7 +255,10 @@ impl Login { impl From> for Login { fn from(post_auth_action: Option) -> Self { - Self { post_auth_action } + Self { + post_auth_action, + login_hint: None, + } } }