diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index 4b3a842c8..0605d6cd6 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -392,6 +392,11 @@ where get(self::views::register::steps::verify_email::get) .post(self::views::register::steps::verify_email::post), ) + .route( + mas_router::RegisterToken::route(), + get(self::views::register::steps::registration_token::get) + .post(self::views::register::steps::registration_token::post), + ) .route( mas_router::RegisterDisplayName::route(), get(self::views::register::steps::display_name::get) diff --git a/crates/handlers/src/views/register/steps/mod.rs b/crates/handlers/src/views/register/steps/mod.rs index 1b090abb9..ae57f5a0c 100644 --- a/crates/handlers/src/views/register/steps/mod.rs +++ b/crates/handlers/src/views/register/steps/mod.rs @@ -5,4 +5,5 @@ pub(crate) mod display_name; pub(crate) mod finish; +pub(crate) mod registration_token; pub(crate) mod verify_email; diff --git a/crates/handlers/src/views/register/steps/registration_token.rs b/crates/handlers/src/views/register/steps/registration_token.rs new file mode 100644 index 000000000..eacf343a3 --- /dev/null +++ b/crates/handlers/src/views/register/steps/registration_token.rs @@ -0,0 +1,201 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only +// Please see LICENSE in the repository root for full details. + +use anyhow::Context as _; +use axum::{ + Form, + extract::{Path, State}, + response::{Html, IntoResponse, Response}, +}; +use mas_axum_utils::{ + InternalError, + cookies::CookieJar, + csrf::{CsrfExt as _, ProtectedForm}, +}; +use mas_router::{PostAuthAction, UrlBuilder}; +use mas_storage::{BoxClock, BoxRepository, BoxRng}; +use mas_templates::{ + FieldError, RegisterStepsRegistrationTokenContext, RegisterStepsRegistrationTokenFormField, + TemplateContext as _, Templates, ToFormState, +}; +use serde::{Deserialize, Serialize}; +use ulid::Ulid; + +use crate::{PreferredLanguage, views::shared::OptionalPostAuthAction}; + +#[derive(Deserialize, Serialize)] +pub(crate) struct RegistrationTokenForm { + #[serde(default)] + token: String, +} + +impl ToFormState for RegistrationTokenForm { + type Field = mas_templates::RegisterStepsRegistrationTokenFormField; +} + +#[tracing::instrument( + name = "handlers.views.register.steps.registration_token.get", + fields(user_registration.id = %id), + skip_all, +)] +pub(crate) async fn get( + mut rng: BoxRng, + clock: BoxClock, + PreferredLanguage(locale): PreferredLanguage, + State(templates): State, + State(url_builder): State, + mut repo: BoxRepository, + Path(id): Path, + cookie_jar: CookieJar, +) -> Result { + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); + + let registration = repo + .user_registration() + .lookup(id) + .await? + .context("Could not find user registration") + .map_err(InternalError::from_anyhow)?; + + // If the registration is completed, we can go to the registration destination + if registration.completed_at.is_some() { + let post_auth_action: Option = registration + .post_auth_action + .map(serde_json::from_value) + .transpose()?; + + return Ok(( + cookie_jar, + OptionalPostAuthAction::from(post_auth_action) + .go_next(&url_builder) + .into_response(), + ) + .into_response()); + } + + // If the registration already has a token, skip this step + if registration.user_registration_token_id.is_some() { + let destination = mas_router::RegisterDisplayName::new(registration.id); + return Ok((cookie_jar, url_builder.redirect(&destination)).into_response()); + } + + let ctx = RegisterStepsRegistrationTokenContext::new() + .with_csrf(csrf_token.form_value()) + .with_language(locale); + + let content = templates.render_register_steps_registration_token(&ctx)?; + + Ok((cookie_jar, Html(content)).into_response()) +} + +#[tracing::instrument( + name = "handlers.views.register.steps.registration_token.post", + fields(user_registration.id = %id), + skip_all, +)] +pub(crate) async fn post( + mut rng: BoxRng, + clock: BoxClock, + PreferredLanguage(locale): PreferredLanguage, + State(templates): State, + State(url_builder): State, + mut repo: BoxRepository, + Path(id): Path, + cookie_jar: CookieJar, + Form(form): Form>, +) -> Result { + let registration = repo + .user_registration() + .lookup(id) + .await? + .context("Could not find user registration") + .map_err(InternalError::from_anyhow)?; + + // If the registration is completed, we can go to the registration destination + if registration.completed_at.is_some() { + let post_auth_action: Option = registration + .post_auth_action + .map(serde_json::from_value) + .transpose()?; + + return Ok(( + cookie_jar, + OptionalPostAuthAction::from(post_auth_action) + .go_next(&url_builder) + .into_response(), + ) + .into_response()); + } + + let form = cookie_jar.verify_form(&clock, form)?; + + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); + + // Validate the token + let token = form.token.trim(); + if token.is_empty() { + let ctx = RegisterStepsRegistrationTokenContext::new() + .with_form_state(form.to_form_state().with_error_on_field( + RegisterStepsRegistrationTokenFormField::Token, + FieldError::Required, + )) + .with_csrf(csrf_token.form_value()) + .with_language(locale); + + return Ok(( + cookie_jar, + Html(templates.render_register_steps_registration_token(&ctx)?), + ) + .into_response()); + } + + // Look up the token + let Some(registration_token) = repo.user_registration_token().find_by_token(token).await? + else { + let ctx = RegisterStepsRegistrationTokenContext::new() + .with_form_state(form.to_form_state().with_error_on_field( + RegisterStepsRegistrationTokenFormField::Token, + FieldError::Invalid, + )) + .with_csrf(csrf_token.form_value()) + .with_language(locale); + + return Ok(( + cookie_jar, + Html(templates.render_register_steps_registration_token(&ctx)?), + ) + .into_response()); + }; + + // Check if the token is still valid + if !registration_token.is_valid(clock.now()) { + tracing::warn!("Registration token isn't valid (expired or already used)"); + let ctx = RegisterStepsRegistrationTokenContext::new() + .with_form_state(form.to_form_state().with_error_on_field( + RegisterStepsRegistrationTokenFormField::Token, + FieldError::Invalid, + )) + .with_csrf(csrf_token.form_value()) + .with_language(locale); + + return Ok(( + cookie_jar, + Html(templates.render_register_steps_registration_token(&ctx)?), + ) + .into_response()); + } + + // Associate the token with the registration + let registration = repo + .user_registration() + .set_registration_token(registration, ®istration_token) + .await?; + + repo.save().await?; + + // Continue to the next step + let destination = mas_router::RegisterFinish::new(registration.id); + Ok((cookie_jar, url_builder.redirect(&destination)).into_response()) +} diff --git a/crates/router/src/endpoints.rs b/crates/router/src/endpoints.rs index a7efeade9..896f17a52 100644 --- a/crates/router/src/endpoints.rs +++ b/crates/router/src/endpoints.rs @@ -382,6 +382,30 @@ impl From> for PasswordRegister { } } +/// `GET|POST /register/steps/{id}/token` +#[derive(Debug, Clone)] +pub struct RegisterToken { + id: Ulid, +} + +impl RegisterToken { + #[must_use] + pub fn new(id: Ulid) -> Self { + Self { id } + } +} + +impl Route for RegisterToken { + type Query = (); + fn route() -> &'static str { + "/register/steps/{id}/token" + } + + fn path(&self) -> std::borrow::Cow<'static, str> { + format!("/register/steps/{}/token", self.id).into() + } +} + /// `GET|POST /register/steps/{id}/display-name` #[derive(Debug, Clone)] pub struct RegisterDisplayName { diff --git a/crates/templates/src/context.rs b/crates/templates/src/context.rs index 54b2f193d..c21096ee6 100644 --- a/crates/templates/src/context.rs +++ b/crates/templates/src/context.rs @@ -1068,6 +1068,61 @@ impl TemplateContext for RegisterStepsDisplayNameContext { } } +/// Fields of the registration token form +#[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum RegisterStepsRegistrationTokenFormField { + /// The registration token + Token, +} + +impl FormField for RegisterStepsRegistrationTokenFormField { + fn keep(&self) -> bool { + match self { + Self::Token => true, + } + } +} + +/// The registration token page context +#[derive(Serialize, Default)] +pub struct RegisterStepsRegistrationTokenContext { + form: FormState, +} + +impl RegisterStepsRegistrationTokenContext { + /// Constructs a context for the registration token page + #[must_use] + pub fn new() -> Self { + Self::default() + } + + /// Set the form state + #[must_use] + pub fn with_form_state( + mut self, + form_state: FormState, + ) -> Self { + self.form = form_state; + self + } +} + +impl TemplateContext for RegisterStepsRegistrationTokenContext { + fn sample( + _now: chrono::DateTime, + _rng: &mut impl Rng, + _locales: &[DataLocale], + ) -> Vec + where + Self: Sized, + { + vec![Self { + form: FormState::default(), + }] + } +} + /// Fields of the account recovery start form #[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)] #[serde(rename_all = "snake_case")] diff --git a/crates/templates/src/lib.rs b/crates/templates/src/lib.rs index 431b1f52b..88a72225a 100644 --- a/crates/templates/src/lib.rs +++ b/crates/templates/src/lib.rs @@ -42,7 +42,8 @@ pub use self::{ RecoveryExpiredContext, RecoveryFinishContext, RecoveryFinishFormField, RecoveryProgressContext, RecoveryStartContext, RecoveryStartFormField, RegisterContext, RegisterFormField, RegisterStepsDisplayNameContext, RegisterStepsDisplayNameFormField, - RegisterStepsEmailInUseContext, RegisterStepsVerifyEmailContext, + RegisterStepsEmailInUseContext, RegisterStepsRegistrationTokenContext, + RegisterStepsRegistrationTokenFormField, RegisterStepsVerifyEmailContext, RegisterStepsVerifyEmailFormField, SiteBranding, SiteConfigExt, SiteFeatures, TemplateContext, UpstreamExistingLinkContext, UpstreamRegister, UpstreamRegisterFormField, UpstreamSuggestLink, WithCaptcha, WithCsrf, WithLanguage, WithOptionalSession, WithSession, @@ -340,6 +341,9 @@ register_templates! { /// Render the display name page pub fn render_register_steps_display_name(WithLanguage>) { "pages/register/steps/display_name.html" } + /// Render the registration token page + pub fn render_register_steps_registration_token(WithLanguage>) { "pages/register/steps/registration_token.html" } + /// Render the client consent page pub fn render_consent(WithLanguage>>) { "pages/consent.html" } @@ -444,6 +448,7 @@ impl Templates { check::render_register_steps_verify_email(self, now, rng)?; check::render_register_steps_email_in_use(self, now, rng)?; check::render_register_steps_display_name(self, now, rng)?; + check::render_register_steps_registration_token(self, now, rng)?; check::render_consent(self, now, rng)?; check::render_policy_violation(self, now, rng)?; check::render_sso_login(self, now, rng)?; diff --git a/templates/pages/register/steps/registration_token.html b/templates/pages/register/steps/registration_token.html new file mode 100644 index 000000000..d58c82e7f --- /dev/null +++ b/templates/pages/register/steps/registration_token.html @@ -0,0 +1,44 @@ +{# +Copyright 2025 New Vector Ltd. + +SPDX-License-Identifier: AGPL-3.0-only +Please see LICENSE in the repository root for full details. +-#} + +{% extends "base.html" %} + +{% block content %} +
+
+ {{ icon.key_solid() }} +
+
+

{{ _("mas.registration_token.headline") }}

+

{{ _("mas.registration_token.description") }}

+
+
+ +
+
+ {% if form.errors is not empty %} + {% for error in form.errors %} +
+ {{ errors.form_error_message(error=error) }} +
+ {% endfor %} + {% endif %} + + + + {% call(f) field.field(label=_("mas.registration_token.field"), name="token", form_state=form, class="mb-4") %} + + {% endcall %} + + {{ button.button(text=_("action.continue")) }} +
+
+{% endblock content %} diff --git a/translations/en.json b/translations/en.json index 5b2a5ad04..d17e09338 100644 --- a/translations/en.json +++ b/translations/en.json @@ -10,7 +10,7 @@ }, "continue": "Continue", "@continue": { - "context": "form_post.html:25:28-48, pages/consent.html:57:28-48, pages/device_consent.html:124:13-33, pages/device_link.html:40:26-46, pages/login.html:68:30-50, pages/reauth.html:32:28-48, pages/recovery/start.html:38:26-46, pages/register/password.html:74:26-46, pages/register/steps/display_name.html:43:28-48, pages/register/steps/verify_email.html:51:26-46, pages/sso.html:37:28-48" + "context": "form_post.html:25:28-48, pages/consent.html:57:28-48, pages/device_consent.html:124:13-33, pages/device_link.html:40:26-46, pages/login.html:68:30-50, pages/reauth.html:32:28-48, pages/recovery/start.html:38:26-46, pages/register/password.html:74:26-46, pages/register/steps/display_name.html:43:28-48, pages/register/steps/registration_token.html:41:28-48, pages/register/steps/verify_email.html:51:26-46, pages/sso.html:37:28-48" }, "create_account": "Create Account", "@create_account": { @@ -635,6 +635,20 @@ "context": "pages/register/password.html:51:35-95, pages/upstream_oauth2/do_register.html:179:35-95" } }, + "registration_token": { + "description": "Enter a registration token provided by the homeserver administrator.", + "@description": { + "context": "pages/register/steps/registration_token.html:17:25-64" + }, + "field": "Registration token", + "@field": { + "context": "pages/register/steps/registration_token.html:33:35-68" + }, + "headline": "Registration token", + "@headline": { + "context": "pages/register/steps/registration_token.html:16:27-63" + } + }, "scope": { "edit_profile": "Edit your profile and contact details", "@edit_profile": {