Compare commits

..

17 Commits

Author SHA1 Message Date
Ginger 8c1ca272de fix: Config file formatting 2026-07-01 17:04:13 -04:00
Ginger a5ee1fb009 feat: Add support for requesting additional scopes 2026-07-01 16:46:58 -04:00
Ginger f57485b8b7 chore: News fragment 2026-07-01 15:55:19 -04:00
Ginger c3e156e78d feat: Add support for importing email addresses from the IDP 2026-07-01 15:55:19 -04:00
Ginger 30b15bd7be fix: Miscellaneous fixes 2026-07-01 15:51:07 -04:00
Ginger c61473e8ee fix: Hide bare oidc config option 2026-07-01 15:51:07 -04:00
Ginger 96c597138b fix: Adjust OIDC config section comment position 2026-07-01 15:51:07 -04:00
Ginger 586fb2102a feat: Add support for importing profile data from claims 2026-07-01 15:51:06 -04:00
Ginger 5dead6621e refactor: Move profile field setting logic into users service 2026-07-01 15:51:06 -04:00
Ginger 67f5c3e595 fix: Hide password change link when OIDC is enabled 2026-07-01 15:51:06 -04:00
Ginger 51bb90250f feat: Send account selection prompt to IDP when account switch link is clicked 2026-07-01 15:51:06 -04:00
Ginger a9668f30a3 feat: Speedbump when logging in with OIDC with no next target 2026-07-01 15:51:06 -04:00
Ginger 076046171a feat: Allow existing legacy accounts to be linked interactively 2026-07-01 15:51:06 -04:00
Ginger c55db6f9bc chore: Update admin command docs 2026-07-01 15:51:06 -04:00
Ginger 870eeffe93 feat: Implement !admin oidc unlink 2026-07-01 15:51:06 -04:00
Ginger 9db5d1646d feat: Initial implementation of OIDC 2026-07-01 15:51:06 -04:00
Ginger 3069194ffe refactor: Split remote and deactivated users into their own columns 2026-07-01 15:51:06 -04:00
75 changed files with 2659 additions and 758 deletions
Generated
+610 -44
View File
File diff suppressed because it is too large Load Diff
+3
View File
@@ -403,6 +403,9 @@ default-features = false
version = "0.11.0"
default-features = false
[workspace.dependencies.openidconnect]
version = "4.0.1"
# optional opentelemetry, performance measurements, flamegraphs, etc for performance measurements and monitoring
[workspace.dependencies.opentelemetry]
version = "0.32.0"
+1
View File
@@ -0,0 +1 @@
Added support for linking an external identity provider with OIDC. Contributed by @ginger.
+86
View File
@@ -2028,3 +2028,89 @@
# legacy authentication will be unable to log in.
#
#compatibility_mode = "hybrid"
#[global.oauth.oidc]
# Uncommenting this section will enable Continuwuity's support for
# authenticating users using an OpenID Connect-compatible identity provider.
# This is referred to as "delegated authentication".
#
# IMPORTANT NOTE: When delegated authentication is active, Continuwuity will behave as if
# the `global.oauth.compatibility_mode` setting is set to `exclusive`.
# Matrix clients which do not support OAuth login (also referred to as "next-gen auth") will NOT be able
# to log in while delegated authentication is active.
# The OIDC issuer URL. Continuwuity will use OpenID Connect Discovery to
# automatically fetch the identity provider's metadata from this URL.
# Generally you should set this to the base domain your identity provider
# runs on.
#
#discovery_url =
# The OAuth client ID for Continuwuity to use when communicating with the
# identity provider.
#
#client_id =
# The OAuth client secret for Continuwuity to use when communicating with
# the identity provider.
#
#client_secret =
# Additional scopes Continuwuity should request from the IDP. This may be
# necessary to access certain claims. Continuwuity always requests the
# `openid` scope.
#
#additional_scopes = []
# Whether the user should be prompted to choose a localpart
# when signing in for the first time. If this is `false`, Continuwuity
# will attempt to use the value of the `preferred_username_claim`
# (see below) as the user's localpart. Authentication will
# fail if this claim is missing or is not a valid localpart.
#
#prompt_for_localpart = true
# The claim to use for the user's localpart, if `prompt_for_localpart` is
# false.
#
#preferred_username_claim = "preferred_username"
# The claim which will be used to set the user's email address,
# either on initial registration or on every login depending on
# the value of `profile_key_import_mode`. Continuwuity assumes that
# the IDP has taken care of verifying that the user controls the email
# address it provides.
#
# This option does nothing if SMTP is not configured.
#
# If this option is set, and `profile_key_import_mode` is `on_login`,
# users will not be able to change their email addresses themselves.
#
#email_claim = "email"
# Defines how claims returned from the IDP should be mapped to a user's
# profile data. The profile field named in each key will be set from the
# claim named in the corresponding value when the user first registers,
# and possibly on subsequent logins as well, depending on the value of
# `profile_key_import_mode` (see below).
#
# Per-room overrides to the user's display name or avatar will be
# preserved by the import process.
#
# SECURITY NOTE: If the `avatar_url` field is set, Continuwuity will
# perform a HTTP GET to the URL in the mapped claim and use the returned
# file as the user's profile picture. Make sure your users are not able
# to set the value of the mapped claim to an arbitrary URL.
#
#profile_key_map = { displayname = "name" }
# When profile keys should be imported from the IDP's claims.
#
# - "on_registration": Listed keys will be imported once, when the user
# logs in for the first time and their shadow account is created.
# - "on_login": Listed keys will be imported every time the user logs in.
# Additionally, users will not be able to manually edit any listed keys
# through their Matrix client.
#
#profile_key_import_mode = "on_registration"
+13 -1
View File
@@ -10,7 +10,13 @@ ## `!admin debug echo`
## `!admin debug get-auth-chain`
Get the auth_chain of a PDU
Loads the auth_chain of a PDU, reporting how long it took
## `!admin debug show-auth-chain`
Walks & displays the auth_chain of a PDU in a mermaid graph format.
This is useless to basically anyone but developers, and is also probably slow and memory hungry.
## `!admin debug parse-pdu`
@@ -44,6 +50,12 @@ ## `!admin debug get-room-state`
Of course the check is still done on the actual client API.
## `!admin debug get-state-at`
Gets all the room state events at the specified event.
State at event might not be available for some PDUs, such as rejected ones.
## `!admin debug get-signing-keys`
Get and display signing keys from local cache or remote server
+1
View File
@@ -14,6 +14,7 @@ ## Categories
- [`!admin appservices`](appservices/): Commands for managing appservices
- [`!admin users`](users/): Commands for managing local users
- [`!admin token`](token/): Commands for managing registration tokens
- [`!admin oidc`](oidc/): Commands for managing OIDC
- [`!admin rooms`](rooms/): Commands for managing rooms
- [`!admin federation`](federation/): Commands for managing federation
- [`!admin server`](server/): Commands for managing the server
+13
View File
@@ -0,0 +1,13 @@
<!-- This file is generated by `cargo xtask generate-docs`. Do not edit. -->
# `!admin oidc`
Commands for managing OIDC
## `!admin oidc link`
Link a user ID to the given subject claim
## `!admin oidc unlink`
Unlink the given subject claim from its associated user ID
+8 -4
View File
@@ -12,10 +12,6 @@ ## `!admin users reset-password`
Reset user password
## `!admin users issue-password-reset-link`
Issue a self-service password reset link for a user
## `!admin users get-email`
Get a user's associated email address
@@ -96,6 +92,14 @@ ## `!admin users list-users`
List local users in the database
## `!admin users list-invited-rooms`
Lists all the rooms (local and remote) that the specified user is invited to
## `!admin users reject-all-invites`
Manually make a user reject all current invites
## `!admin users list-joined-rooms`
Lists all the rooms (local and remote) that the specified user is joined in
+1 -1
View File
@@ -12,7 +12,7 @@
target:
target.fromToolchainName {
name = (lib.importTOML "${inputs.self}/rust-toolchain.toml").toolchain.channel;
sha256 = "sha256-h+t2xTBz5yt2YIO+1VMIIGlCU7gyp2LYOFvaV1nwOXU=";
sha256 = "sha256-mvUGEOHYJpn3ikC5hckneuGixaC+yGrkMM/liDIDgoU=";
};
in
{
+1 -1
View File
@@ -10,7 +10,7 @@
[toolchain]
profile = "minimal"
channel = "1.96.1"
channel = "1.96.0"
components = [
# For rust-analyzer
"rust-src",
+25 -10
View File
@@ -1,5 +1,5 @@
use clap::Parser;
use conduwuit::Result;
use conduwuit::{Err, Result};
use crate::{
appservice::{self, AppserviceCommand},
@@ -8,6 +8,7 @@
debug::{self, DebugCommand},
federation::{self, FederationCommand},
media::{self, MediaCommand},
oidc::{self, OidcCommand},
query::{self, QueryCommand},
room::{self, RoomCommand},
server::{self, ServerCommand},
@@ -18,44 +19,48 @@
#[derive(Debug, Parser)]
#[command(name = conduwuit_core::BRANDING, version = conduwuit_core::version())]
pub enum AdminCommand {
#[command(subcommand)]
/// Commands for managing appservices
#[command(subcommand)]
Appservices(AppserviceCommand),
#[command(subcommand)]
/// Commands for managing local users
#[command(subcommand)]
Users(UserCommand),
#[command(subcommand)]
/// Commands for managing registration tokens
#[command(subcommand)]
Token(TokenCommand),
/// Commands for managing OIDC
#[command(subcommand)]
Oidc(OidcCommand),
/// Commands for managing rooms
#[command(subcommand)]
Rooms(RoomCommand),
#[command(subcommand)]
/// Commands for managing federation
#[command(subcommand)]
Federation(FederationCommand),
#[command(subcommand)]
/// Commands for managing the server
#[command(subcommand)]
Server(ServerCommand),
#[command(subcommand)]
/// Commands for managing media
#[command(subcommand)]
Media(MediaCommand),
#[command(subcommand)]
/// Commands for checking integrity
#[command(subcommand)]
Check(CheckCommand),
#[command(subcommand)]
/// Commands for debugging things
#[command(subcommand)]
Debug(DebugCommand),
#[command(subcommand)]
/// Low-level queries for database getters and iterators
#[command(subcommand)]
Query(QueryCommand),
}
@@ -80,6 +85,16 @@ pub(super) async fn process(command: AdminCommand, context: &Context<'_>) -> Res
context.bail_restricted()?;
token::process(command, context).await
},
| Oidc(command) => {
// OIDC commands are all restricted
context.bail_restricted()?;
if !context.services.oidc.enabled() {
return Err!("OIDC is not configured");
}
oidc::process(command, context).await
},
| Rooms(command) => room::process(command, context).await,
| Federation(command) => federation::process(command, context).await,
| Server(command) => server::process(command, context).await,
+6 -1
View File
@@ -6,7 +6,12 @@
impl Context<'_> {
pub(super) async fn check_all_users(&self) -> Result {
let timer = tokio::time::Instant::now();
let users = self.services.users.stream().collect::<Vec<_>>().await;
let users = self
.services
.users
.stream_local_users()
.collect::<Vec<_>>()
.await;
let query_time = timer.elapsed();
let total = users.len();
+1 -1
View File
@@ -612,7 +612,7 @@ pub(super) async fn force_device_list_updates(&self) -> Result {
// Force E2EE device list updates for all users
self.services
.users
.stream()
.stream_local_users()
.for_each(async |user_id| self.services.users.mark_device_key_update(&user_id).await)
.await;
+1
View File
@@ -16,6 +16,7 @@
pub(crate) mod debug;
pub(crate) mod federation;
pub(crate) mod media;
pub(crate) mod oidc;
pub(crate) mod query;
pub(crate) mod room;
pub(crate) mod server;
+25
View File
@@ -0,0 +1,25 @@
use conduwuit::Result;
use crate::utils::parse_active_local_user_id;
impl crate::Context<'_> {
pub(super) async fn oidc_link(&self, user_id: String, subject: String) -> Result {
let user_id = parse_active_local_user_id(self.services, &user_id).await?;
self.services.oidc.link_user(&user_id, &subject);
self.write_str(&format!("Subject `{subject}` linked to account `{user_id}`."))
.await?;
Ok(())
}
pub(super) async fn oidc_unlink(&self, subject: String) -> Result {
self.services.oidc.unlink_user(&subject);
self.write_str(&format!("Subject `{subject}` unlinked."))
.await?;
Ok(())
}
}
+22
View File
@@ -0,0 +1,22 @@
mod commands;
use clap::Subcommand;
use conduwuit::Result;
use conduwuit_macros::admin_command_dispatch;
#[admin_command_dispatch]
#[derive(Debug, Subcommand)]
pub enum OidcCommand {
/// Link a user ID to the given subject claim.
#[clap(name = "link")]
OidcLink {
user_id: String,
subject: String,
},
/// Unlink the given subject claim from its associated user ID.
#[clap(name = "unlink")]
OidcUnlink {
subject: String,
},
}
+8 -3
View File
@@ -191,8 +191,13 @@ async fn get_latest_backup(&self, user_id: OwnedUserId) -> Result {
async fn iter_users(&self) -> Result {
let timer = tokio::time::Instant::now();
let result: Vec<OwnedUserId> =
self.services.users.stream().map(Into::into).collect().await;
let result: Vec<OwnedUserId> = self
.services
.users
.stream_local_users()
.map(Into::into)
.collect()
.await;
let query_time = timer.elapsed();
@@ -202,7 +207,7 @@ async fn iter_users(&self) -> Result {
async fn iter_users2(&self) -> Result {
let timer = tokio::time::Instant::now();
let result: Vec<_> = self.services.users.stream().collect().await;
let result: Vec<_> = self.services.users.stream_local_users().collect().await;
let result: Vec<_> = result
.into_iter()
.map(|user_id| String::from_utf8_lossy(user_id.as_bytes()).into_owned())
+1 -1
View File
@@ -44,7 +44,7 @@ pub(super) async fn issue_token(&self, expires: super::TokenExpires) -> Result {
.services
.config
.oauth
.compatibility_mode
.compatibility_mode()
.oauth_available()
{
self.write_str(&format!(
+49 -59
View File
@@ -20,7 +20,7 @@
tag::{TagEvent, TagEventContent, TagInfo},
},
};
use service::users::HashedPassword;
use service::users::{AccountStatus, HashedPassword};
use crate::{
get_room_info,
@@ -60,8 +60,8 @@ pub(super) async fn create_user(&self, username: String, password: Option<String
self.services
.users
.create_local_account(&user_id, password, None)
.await;
.create_local_account(&user_id, Some(password), None)
.await?;
self.write_str(&format!("Created user {user_id}")).await
}
@@ -103,15 +103,12 @@ pub(super) async fn deactivate(&self, no_leave_rooms: bool, user_id: String) ->
pub(super) async fn suspend(&self, user_id: String) -> Result {
self.bail_restricted()?;
let user_id = parse_local_user_id(self.services, &user_id)?;
let user_id = parse_active_local_user_id(self.services, &user_id).await?;
if user_id == self.services.globals.server_user {
return Err!("Not allowed to suspend the server service account.",);
}
if !self.services.users.exists(&user_id).await {
return Err!("User {user_id} does not exist.");
}
if self.services.users.is_admin(&user_id).await {
return Err!("Admin users cannot be suspended.");
}
@@ -127,15 +124,12 @@ pub(super) async fn suspend(&self, user_id: String) -> Result {
pub(super) async fn unsuspend(&self, user_id: String) -> Result {
self.bail_restricted()?;
let user_id = parse_local_user_id(self.services, &user_id)?;
let user_id = parse_active_local_user_id(self.services, &user_id).await?;
if user_id == self.services.globals.server_user {
return Err!("Not allowed to unsuspend the server service account.",);
}
if !self.services.users.exists(&user_id).await {
return Err!("User {user_id} does not exist.");
}
self.services.users.unsuspend_account(&user_id).await;
self.write_str(&format!("User {user_id} has been unsuspended."))
@@ -147,6 +141,7 @@ pub(super) async fn reset_password(
logout: bool,
username: String,
password: Option<String>,
convert_to_local_account: bool,
) -> Result {
let user_id = parse_local_user_id(self.services, &username)?;
@@ -159,15 +154,37 @@ pub(super) async fn reset_password(
let new_password =
password.unwrap_or_else(|| utils::random_string(AUTO_GEN_PASSWORD_LENGTH));
let new_password_hash = HashedPassword::new(&new_password)?;
self.services
.users
.set_password(&user_id, Some(HashedPassword::new(&new_password)?));
if convert_to_local_account {
self.services
.users
.convert_to_local_account(&user_id, new_password_hash)
.await?;
} else {
match self.services.users.status(&user_id).await {
| AccountStatus::Active if !self.services.users.is_shadow(&user_id).await => {
self.services
.users
.set_password(&user_id, new_password_hash)
.await?;
},
| AccountStatus::NotFound => {
return Err!("The provided user does not exist.");
},
| _ => {
return Err!(
"The provided user is a shadow or deactivated account. To convert it to \
a local account, pass the --convert-to-local-account flag."
);
},
}
self.write_str(&format!(
"Successfully reset the password for user {user_id}: `{new_password}`"
))
.await?;
self.write_str(&format!(
"Successfully reset the password for user {user_id}: `{new_password}`"
))
.await?;
}
if logout {
self.services
@@ -919,21 +936,16 @@ pub(super) async fn force_leave_remote_room(
pub(super) async fn lock(&self, user_id: String) -> Result {
self.bail_restricted()?;
let user_id = parse_local_user_id(self.services, &user_id)?;
assert!(
self.services.globals.user_is_local(&user_id),
"Parsed user_id must be a local user"
);
let user_id = parse_active_local_user_id(self.services, &user_id).await?;
if user_id == self.services.globals.server_user {
return Err!("Not allowed to lock the server service account.",);
}
if !self.services.users.exists(&user_id).await {
return Err!("User {user_id} does not exist.");
}
if self.services.users.is_admin(&user_id).await {
return Err!("Admin users cannot be locked.");
}
self.services
.users
.lock_account(&user_id, self.sender_or_service_user())
@@ -945,11 +957,8 @@ pub(super) async fn lock(&self, user_id: String) -> Result {
pub(super) async fn unlock(&self, user_id: String) -> Result {
self.bail_restricted()?;
let user_id = parse_local_user_id(self.services, &user_id)?;
assert!(
self.services.globals.user_is_local(&user_id),
"Parsed user_id must be a local user"
);
let user_id = parse_active_local_user_id(self.services, &user_id).await?;
self.services.users.unlock_account(&user_id).await;
self.write_str(&format!("User {user_id} has been unlocked."))
@@ -958,21 +967,16 @@ pub(super) async fn unlock(&self, user_id: String) -> Result {
pub(super) async fn logout(&self, user_id: String) -> Result {
self.bail_restricted()?;
let user_id = parse_local_user_id(self.services, &user_id)?;
assert!(
self.services.globals.user_is_local(&user_id),
"Parsed user_id must be a local user"
);
let user_id = parse_active_local_user_id(self.services, &user_id).await?;
if user_id == self.services.globals.server_user {
return Err!("Not allowed to log out the server service account.",);
}
if !self.services.users.exists(&user_id).await {
return Err!("User {user_id} does not exist.");
}
if self.services.users.is_admin(&user_id).await {
return Err!("You cannot forcefully log out admin users.");
}
self.services
.users
.all_device_ids(&user_id)
@@ -989,18 +993,12 @@ pub(super) async fn logout(&self, user_id: String) -> Result {
pub(super) async fn disable_login(&self, user_id: String) -> Result {
self.bail_restricted()?;
let user_id = parse_local_user_id(self.services, &user_id)?;
assert!(
self.services.globals.user_is_local(&user_id),
"Parsed user_id must be a local user"
);
let user_id = parse_active_local_user_id(self.services, &user_id).await?;
if user_id == self.services.globals.server_user {
return Err!("Not allowed to disable login for the server service account.",);
}
if !self.services.users.exists(&user_id).await {
return Err!("User {user_id} does not exist.");
}
if self.services.users.is_admin(&user_id).await {
return Err!("Admin users cannot have their login disallowed.");
}
@@ -1014,14 +1012,8 @@ pub(super) async fn disable_login(&self, user_id: String) -> Result {
pub(super) async fn enable_login(&self, user_id: String) -> Result {
self.bail_restricted()?;
let user_id = parse_local_user_id(self.services, &user_id)?;
assert!(
self.services.globals.user_is_local(&user_id),
"Parsed user_id must be a local user"
);
if !self.services.users.exists(&user_id).await {
return Err!("User {user_id} does not exist.");
}
let user_id = parse_active_local_user_id(self.services, &user_id).await?;
self.services.users.enable_login(&user_id);
self.write_str(&format!("{user_id} can now log in.")).await
@@ -1129,10 +1121,8 @@ pub(super) async fn change_email(&self, user_id: String, email: Option<String>)
}
pub(super) async fn reset_push_rules(&self, user_id: String) -> Result {
let user_id = parse_local_user_id(self.services, &user_id)?;
if !self.services.users.is_active(&user_id).await {
return Err!("User is not active.");
}
let user_id = parse_active_local_user_id(self.services, &user_id).await?;
recreate_push_rules_and_return(self.services, &user_id).await?;
self.write_str("Reset user's push rules to the server default.")
.await
+3
View File
@@ -27,6 +27,9 @@ pub enum UserCommand {
username: String,
/// New password for the user, if unspecified one is generated
password: Option<String>,
#[arg(long)]
convert_to_local_account: bool,
},
/// Get a user's associated email address.
+1 -8
View File
@@ -54,14 +54,7 @@ pub(crate) async fn parse_active_local_user_id(
user_id: &str,
) -> Result<OwnedUserId> {
let user_id = parse_local_user_id(services, user_id)?;
if !services.users.exists(&user_id).await {
return Err!("User {user_id:?} does not exist on this server.");
}
if services.users.is_deactivated(&user_id).await? {
return Err!("User {user_id:?} is deactivated.");
}
services.users.status(&user_id).await.ensure_active()?;
Ok(user_id)
}
+2 -1
View File
@@ -128,7 +128,8 @@ pub(crate) async fn change_password_route(
services
.users
.set_password(&sender_user, Some(HashedPassword::new(&body.new_password)?));
.set_password(&sender_user, HashedPassword::new(&body.new_password)?)
.await?;
if body.logout_devices {
// Logout all devices except the current one
+4 -4
View File
@@ -72,7 +72,7 @@ pub(crate) async fn register_route(
.determine_registration_user_id(body.username.clone(), None, Some(appservice_info))
.await?;
services.users.create(&user_id, None)?;
services.users.create_shadow_account(&user_id).await?;
user_id
} else {
@@ -97,8 +97,8 @@ pub(crate) async fn register_route(
services
.users
.create_local_account(&user_id, password, identity.email)
.await;
.create_local_account(&user_id, Some(password), identity.email)
.await?;
user_id
};
@@ -106,7 +106,7 @@ pub(crate) async fn register_route(
let (token, device) = if !body.inhibit_login {
// If UIAA is disabled, we can't create a device. In that case only appservices
// can reach this point in the first place, so we return an error for them.
if !services.config.oauth.compatibility_mode.uiaa_available() {
if !services.config.oauth.compatibility_mode().uiaa_available() {
return Err!(Request(AppserviceLoginUnsupported(
"User-interactive appservice registration is not available on this server."
)));
+13 -16
View File
@@ -12,20 +12,18 @@ pub(crate) async fn get_lock_status(
State(services): State<crate::State>,
body: Ruma<is_user_locked::v1::Request>,
) -> Result<is_user_locked::v1::Response> {
let (admin, active) = join(
let (admin, status) = join(
services.users.is_admin(body.identity.expect_sender_user()?),
services.users.is_active(&body.user_id),
services.users.status(&body.user_id),
)
.await;
if !admin {
return Err!(Request(Forbidden("Only server administrators can use this endpoint")));
}
if !services.globals.user_is_local(&body.user_id) {
return Err!(Request(InvalidParam("Can only check the lock status of local users")));
}
if !active {
return Err!(Request(NotFound("Unknown user")));
}
status.ensure_active()?;
Ok(is_user_locked::v1::Response::new(
services.users.is_locked(&body.user_id).await?,
))
@@ -40,9 +38,9 @@ pub(crate) async fn put_lock_status(
) -> Result<lock_user::v1::Response> {
let sender_user = body.identity.expect_sender_user()?;
let (sender_admin, active, target_admin) = join3(
let (sender_admin, status, target_admin) = join3(
services.users.is_admin(sender_user),
services.users.is_active(&body.user_id),
services.users.status(&body.user_id),
services.users.is_admin(&body.user_id),
)
.await;
@@ -50,18 +48,17 @@ pub(crate) async fn put_lock_status(
if !sender_admin {
return Err!(Request(Forbidden("Only server administrators can use this endpoint")));
}
if !services.globals.user_is_local(&body.user_id) {
return Err!(Request(InvalidParam("Can only set the lock status of local users")));
}
if !active {
return Err!(Request(NotFound("Unknown user")));
}
status.ensure_active()?;
if body.user_id == *sender_user {
return Err!(Request(Forbidden("You cannot lock yourself")));
}
if target_admin {
return Err!(Request(Forbidden("You cannot lock another server administrator")));
}
if services.users.is_locked(&body.user_id).await? == body.locked {
// No change
return Ok(lock_user::v1::Response::new(body.locked));
+13 -16
View File
@@ -12,20 +12,18 @@ pub(crate) async fn get_suspended_status(
State(services): State<crate::State>,
body: Ruma<is_user_suspended::v1::Request>,
) -> Result<is_user_suspended::v1::Response> {
let (admin, active) = join(
let (admin, status) = join(
services.users.is_admin(body.identity.expect_sender_user()?),
services.users.is_active(&body.user_id),
services.users.status(&body.user_id),
)
.await;
if !admin {
return Err!(Request(Forbidden("Only server administrators can use this endpoint")));
}
if !services.globals.user_is_local(&body.user_id) {
return Err!(Request(InvalidParam("Can only check the suspended status of local users")));
}
if !active {
return Err!(Request(NotFound("Unknown user")));
}
status.ensure_active()?;
Ok(is_user_suspended::v1::Response::new(
services.users.is_suspended(&body.user_id).await?,
))
@@ -40,9 +38,9 @@ pub(crate) async fn put_suspended_status(
) -> Result<suspend_user::v1::Response> {
let sender_user = body.identity.expect_sender_user()?;
let (sender_admin, active, target_admin) = join3(
let (sender_admin, status, target_admin) = join3(
services.users.is_admin(sender_user),
services.users.is_active(&body.user_id),
services.users.status(&body.user_id),
services.users.is_admin(&body.user_id),
)
.await;
@@ -50,18 +48,17 @@ pub(crate) async fn put_suspended_status(
if !sender_admin {
return Err!(Request(Forbidden("Only server administrators can use this endpoint")));
}
if !services.globals.user_is_local(&body.user_id) {
return Err!(Request(InvalidParam("Can only set the suspended status of local users")));
}
if !active {
return Err!(Request(NotFound("Unknown user")));
}
status.ensure_active()?;
if body.user_id == *sender_user {
return Err!(Request(Forbidden("You cannot suspend yourself")));
}
if target_admin {
return Err!(Request(Forbidden("You cannot suspend another server administrator")));
}
if services.users.is_suspended(&body.user_id).await? == body.suspended {
// No change
return Ok(suspend_user::v1::Response::new(body.suspended));
+7 -2
View File
@@ -7,10 +7,11 @@
api::client::discovery::get_capabilities::{
self,
v3::{
Capabilities, GetLoginTokenCapability, RoomVersionStability, RoomVersionsCapability,
ThirdPartyIdChangesCapability,
Capabilities, GetLoginTokenCapability, ProfileFieldsCapability, RoomVersionStability,
RoomVersionsCapability, ThirdPartyIdChangesCapability,
},
},
assign,
};
use crate::Ruma;
@@ -50,5 +51,9 @@ pub(crate) async fn get_capabilities_route(
capabilities.account_moderation.suspend = true;
}
capabilities.profile_fields = Some(
assign!(ProfileFieldsCapability::new(true), { disallowed: Some(services.oidc.restricted_profile_fields()) }),
);
Ok(get_capabilities::v3::Response::new(capabilities))
}
+1 -1
View File
@@ -37,7 +37,7 @@ pub(crate) fn router(state: crate::State) -> Router<crate::State> {
.layer(middleware::from_fn_with_state(
state,
async |State(state): State<crate::State>, request: Request, next: Next| -> Response {
if state.config.oauth.compatibility_mode.oauth_available() {
if state.config.oauth.compatibility_mode().oauth_available() {
next.run(request).await
} else {
(StatusCode::NOT_FOUND, "OAuth is unavailable on this server").into_response()
+1 -1
View File
@@ -21,7 +21,7 @@ pub(crate) async fn get_authorization_server_metadata_route(
State(services): State<crate::State>,
_body: Ruma<get_authorization_server_metadata::v1::Request>,
) -> Result<get_authorization_server_metadata::v1::Response> {
if !services.config.oauth.compatibility_mode.oauth_available() {
if !services.config.oauth.compatibility_mode().oauth_available() {
return Err!(Request(Unrecognized("OAuth is unavailable on this server")));
}
+58 -289
View File
@@ -1,9 +1,8 @@
use std::collections::BTreeMap;
use axum::extract::State;
use conduwuit::{Err, Result, matrix::pdu::PartialPdu, utils::to_canonical_object};
use conduwuit::{Err, Result};
use conduwuit_service::Services;
use futures::StreamExt;
use ruma::{
UserId,
api::{
@@ -13,11 +12,10 @@
federation,
},
assign,
events::room::member::MembershipState,
presence::PresenceState,
profile::{ProfileFieldName, ProfileFieldValue},
};
use serde_json::{Value, to_value};
use serde_json::Value;
use service::users::ProfileFieldChange;
use crate::Ruma;
@@ -65,13 +63,24 @@ pub(crate) async fn set_profile_field_route(
return Err!(Request(InvalidParam("You may not change a remote user's profile data.")));
}
set_profile_field(
&services,
&body.user_id,
ProfileFieldChange::Set(body.value.clone()),
body.propagate_to.clone(),
)
.await?;
if services
.oidc
.restricted_profile_fields()
.contains(&body.value.field_name())
{
return Err!(Request(Forbidden(
"This profile field is controlled by your identity provider."
)));
}
services
.users
.set_profile_field(
&body.user_id,
ProfileFieldChange::Set(body.value.clone()),
body.propagate_to.clone(),
)
.await?;
Ok(set_profile_field::v3::Response::new())
}
@@ -94,13 +103,24 @@ pub(crate) async fn delete_profile_field_route(
return Err!(Request(InvalidParam("You may not change a remote user's profile data.")));
}
set_profile_field(
&services,
&body.user_id,
ProfileFieldChange::Delete(body.field.clone()),
body.propagate_to.clone(),
)
.await?;
if services
.oidc
.restricted_profile_fields()
.contains(&body.field)
{
return Err!(Request(Forbidden(
"This profile field is controlled by your identity provider."
)));
}
services
.users
.set_profile_field(
&body.user_id,
ProfileFieldChange::Delete(body.field.clone()),
body.propagate_to.clone(),
)
.await?;
Ok(delete_profile_field::v3::Response::new())
}
@@ -110,8 +130,8 @@ async fn fetch_full_profile(
user_id: &UserId,
) -> Option<BTreeMap<String, Value>> {
// If the user exists locally, fetch their local profile
if services.users.exists(user_id).await {
return Some(get_local_profile(services, user_id).await);
if services.users.status(user_id).await.is_found() {
return Some(services.users.get_local_profile(user_id).await);
}
// Otherwise ask their homeserver
@@ -135,13 +155,10 @@ async fn fetch_full_profile(
continue;
};
let _ = set_profile_field(
services,
user_id,
ProfileFieldChange::Set(value),
PropagateTo::None,
)
.await;
let _ = services
.users
.set_profile_field(user_id, ProfileFieldChange::Set(value), PropagateTo::None)
.await;
}
Some(BTreeMap::from_iter(response))
@@ -154,7 +171,7 @@ async fn fetch_profile_field(
) -> Result<Option<ProfileFieldValue>> {
// If the user exists locally, fetch their local profile field
if services.globals.user_is_local(user_id) {
return Ok(get_local_profile_field(services, user_id, field).await);
return Ok(services.users.get_local_profile_field(user_id, field).await);
}
// Otherwise ask their homeserver
@@ -175,13 +192,14 @@ async fn fetch_profile_field(
if let Some(value) = response.get(field.as_str()).map(ToOwned::to_owned) {
if let Ok(value) = ProfileFieldValue::new(field.as_str(), value) {
let _ = set_profile_field(
services,
user_id,
ProfileFieldChange::Set(value.clone()),
PropagateTo::None,
)
.await;
let _ = services
.users
.set_profile_field(
user_id,
ProfileFieldChange::Set(value.clone()),
PropagateTo::None,
)
.await;
Ok(Some(value))
} else {
@@ -190,260 +208,11 @@ async fn fetch_profile_field(
)))
}
} else {
let _ = set_profile_field(
services,
user_id,
ProfileFieldChange::Delete(field),
PropagateTo::None,
)
.await;
let _ = services
.users
.set_profile_field(user_id, ProfileFieldChange::Delete(field), PropagateTo::None)
.await;
Ok(None)
}
}
pub(crate) async fn get_local_profile(
services: &Services,
user_id: &UserId,
) -> BTreeMap<String, Value> {
let mut profile = BTreeMap::new();
// Get displayname and avatar_url independently because `all_profile_keys`
// doesn't include them
for field in [ProfileFieldName::AvatarUrl, ProfileFieldName::DisplayName] {
let key = field.as_str().to_owned();
if let Some(value) = get_local_profile_field(services, user_id, field).await {
profile.insert(key, value.value().into_owned());
}
}
// Insert all other profile fields
let mut all_fields = services.users.all_profile_keys(user_id);
while let Some((key, value)) = all_fields.next().await {
profile.insert(key, value);
}
profile
}
pub(crate) async fn get_local_profile_field(
services: &Services,
user_id: &UserId,
field: ProfileFieldName,
) -> Option<ProfileFieldValue> {
let value = match field.clone() {
| ProfileFieldName::AvatarUrl => services
.users
.avatar_url(user_id)
.await
.ok()
.map(to_value)
.transpose()
.expect("converting avatar url to value should succeed"),
| ProfileFieldName::DisplayName => services
.users
.displayname(user_id)
.await
.ok()
.map(to_value)
.transpose()
.expect("converting displayname to value should succeed"),
| other => services
.users
.profile_key(user_id, other.as_str())
.await
.ok(),
}?;
Some(
ProfileFieldValue::new(field.as_str(), value)
.expect("local profile field should be valid"),
)
}
enum ProfileFieldChange {
Set(ProfileFieldValue),
Delete(ProfileFieldName),
}
impl ProfileFieldChange {
fn field_name(&self) -> ProfileFieldName {
match self {
| &Self::Delete(ref name) => name.clone(),
| &Self::Set(ref value) => value.field_name(),
}
}
fn value(&self) -> Option<Value> {
if let Self::Set(value) = self {
Some(value.value().into_owned())
} else {
None
}
}
}
async fn set_profile_field(
services: &Services,
user_id: &UserId,
change: ProfileFieldChange,
propagate_to: PropagateTo,
) -> Result<()> {
const MAX_KEY_LENGTH_BYTES: usize = 255;
const MAX_PROFILE_LENGTH_BYTES: usize = 65536;
let field_name = change.field_name();
// TODO: The spec mentions special error codes (M_PROFILE_TOO_LARGE,
// M_KEY_TOO_LARGE) for profile field size limits, but they're not in its list
// of error codes and Ruma doesn't have them. Should we return those, or is
// M_TOO_LARGE okay?
if field_name.as_str().len() > MAX_KEY_LENGTH_BYTES {
return Err!(Request(TooLarge(
"Individual profile keys must not exceed {MAX_KEY_LENGTH_BYTES} bytes in length."
)));
}
// Serialize the entire profile as canonical JSON, including the new change,
// to check if it exceeds 64 KiB
{
let mut full_profile = get_local_profile(services, user_id).await;
match &change {
| ProfileFieldChange::Set(value) => {
full_profile.insert(
value.field_name().as_str().to_owned(),
value.value().clone().into_owned(),
);
},
| ProfileFieldChange::Delete(key) => {
full_profile.remove(key.as_str());
},
}
if let Ok(canonical_profile) = to_canonical_object(full_profile) {
if serde_json::to_string(&canonical_profile)
.expect("should be able to serialize to string")
.len() > MAX_PROFILE_LENGTH_BYTES
{
return Err!(
"Profile data must not exceed {MAX_PROFILE_LENGTH_BYTES} bytes in length."
);
}
} else {
return Err!(Request(BadJson("Failed to canonicalize profile.")));
}
}
// If the user is local and changed their displayname or avatar_url, update it
// in all their joined rooms. This is done before updating their profile data
// so we can check the old value of the field if `propagate_to` is `unchanged`.
if matches!(field_name, ProfileFieldName::AvatarUrl | ProfileFieldName::DisplayName)
&& matches!(propagate_to, PropagateTo::All | PropagateTo::Unchanged)
&& services.globals.user_is_local(user_id)
{
let current_displayname = services.users.displayname(user_id).await.ok();
let current_avatar_url = services.users.avatar_url(user_id).await.ok();
let mut all_joined_rooms = services.rooms.state_cache.rooms_joined(user_id);
while let Some(room_id) = all_joined_rooms.next().await {
// TODO: this clobbers any custom fields on the event content
let mut current_membership = services
.rooms
.state_accessor
.get_member(&room_id, user_id)
.await
.expect("should be able to fetch membership event for joined room");
assert_eq!(
current_membership.membership,
MembershipState::Join,
"user should be joined"
);
// If `propagate_to` is `unchanged`, and the current value of the field we're
// updating was changed from its global value in this room, skip it.
if matches!(propagate_to, PropagateTo::Unchanged) {
let field_changed_from_global = match field_name {
| ProfileFieldName::AvatarUrl =>
current_membership.avatar_url.as_ref() != current_avatar_url.as_ref(),
| ProfileFieldName::DisplayName =>
current_membership.displayname.as_ref() != current_displayname.as_ref(),
| _ => unreachable!(),
};
if field_changed_from_global {
continue;
}
}
let state_lock = services.rooms.state.mutex.lock(room_id.as_str()).await;
// Preserve keys in accordance with the key copying rules
current_membership.reason = None;
current_membership.join_authorized_via_users_server = None;
match &change {
| ProfileFieldChange::Set(ProfileFieldValue::AvatarUrl(avatar_url)) => {
current_membership.avatar_url = Some(avatar_url.clone());
},
| ProfileFieldChange::Set(ProfileFieldValue::DisplayName(displayname)) => {
current_membership.displayname = Some(displayname.clone());
},
| ProfileFieldChange::Delete(ProfileFieldName::AvatarUrl) => {
current_membership.avatar_url = None;
},
| ProfileFieldChange::Delete(ProfileFieldName::DisplayName) => {
current_membership.displayname = None;
},
| _ => unreachable!(),
}
let _ = services
.rooms
.timeline
.build_and_append_pdu(
PartialPdu::state(user_id.to_string(), &current_membership),
user_id,
Some(&room_id),
&state_lock,
)
.await;
}
if services.config.allow_local_presence {
// Send a presence EDU to indicate the profile changed
let _ = services
.presence
.ping_presence(user_id, &PresenceState::Online)
.await;
}
}
match change {
| ProfileFieldChange::Set(ProfileFieldValue::DisplayName(displayname)) => {
services
.users
.set_displayname(user_id, Some(displayname).filter(|dn| !dn.is_empty()));
},
| ProfileFieldChange::Set(ProfileFieldValue::AvatarUrl(avatar_url)) => {
services
.users
.set_avatar_url(user_id, Some(avatar_url).filter(|av| av.is_valid()));
},
| ProfileFieldChange::Delete(ProfileFieldName::DisplayName) => {
services.users.set_displayname(user_id, None);
},
| ProfileFieldChange::Delete(ProfileFieldName::AvatarUrl) => {
services.users.set_avatar_url(user_id, None);
},
| other =>
services
.users
.set_profile_key(user_id, other.field_name().as_str(), other.value()),
}
Ok(())
}
+1 -1
View File
@@ -149,7 +149,7 @@ pub(crate) async fn report_user_route(
delay_response().await;
if !services.users.is_active_local(&body.user_id).await {
if !services.users.status(&body.user_id).await.is_found() {
// return 200 as to not reveal if the user exists. Recommended by spec.
return Ok(report_user::v3::Response::new());
}
+2 -10
View File
@@ -43,7 +43,7 @@ pub(crate) async fn get_login_types_route(
ClientIp(client): ClientIp,
_body: Ruma<get_login_types::v3::Request>,
) -> Result<get_login_types::v3::Response> {
if !services.config.oauth.compatibility_mode.uiaa_available() {
if !services.config.oauth.compatibility_mode().uiaa_available() {
return Err!(Request(Unrecognized(
"User-interactive authentication is not available on this server."
)));
@@ -88,14 +88,6 @@ pub async fn handle_login(
UserId::parse_with_server_name(user_id_or_localpart, &services.config.server_name)
.map_err(|_| err!(Request(InvalidUsername("User ID is malformed"))))?;
if !services.globals.user_is_local(&user_id) {
return Err!(Request(InvalidParam("User ID does not belong to this homeserver")));
}
if services.users.is_deactivated(&user_id).await? {
return Err!(Request(UserDeactivated("This account has been deactivated.")));
}
if services.users.is_locked(&user_id).await? {
return Err!(Request(UserLocked("This account has been locked.")));
}
@@ -128,7 +120,7 @@ pub(crate) async fn login_route(
ClientIp(client): ClientIp,
body: Ruma<login::v3::Request>,
) -> Result<login::v3::Response> {
if !services.config.oauth.compatibility_mode.uiaa_available() {
if !services.config.oauth.compatibility_mode().uiaa_available() {
return match body.login_info {
| LoginInfo::ApplicationService(_) => {
Err!(Request(AppserviceLoginUnsupported(
+30 -26
View File
@@ -32,22 +32,26 @@ pub(crate) async fn search_users_route(
.min(LIMIT_MAX);
let search_term = body.search_term.to_lowercase();
let mut users = services.users.stream().broad_filter_map(async |user_id| {
let display_name = services.users.displayname(&user_id).await.ok();
let user_id_matches = user_id.as_str().to_lowercase().contains(&search_term);
let mut users = services
.users
.stream_local_users()
.chain(services.users.stream_remote_users())
.broad_filter_map(async |user_id| {
let display_name = services.users.displayname(&user_id).await.ok();
let display_name_matches = display_name
.as_deref()
.map(str::to_lowercase)
.is_some_and(|display_name| display_name.contains(&search_term));
let user_id_matches = user_id.as_str().to_lowercase().contains(&search_term);
if !user_id_matches && !display_name_matches {
return None;
}
let display_name_matches = display_name
.as_deref()
.map(str::to_lowercase)
.is_some_and(|display_name| display_name.contains(&search_term));
let user_in_public_room =
services
if !user_id_matches && !display_name_matches {
return None;
}
let user_in_public_room = services
.rooms
.state_cache
.rooms_joined(&user_id)
@@ -60,22 +64,22 @@ pub(crate) async fn search_users_route(
.await
});
let user_sees_user = services
.rooms
.state_cache
.user_sees_user(sender_user, &user_id);
let user_sees_user = services
.rooms
.state_cache
.user_sees_user(sender_user, &user_id);
pin_mut!(user_in_public_room, user_sees_user);
pin_mut!(user_in_public_room, user_sees_user);
if user_in_public_room.or(user_sees_user).await {
Some(assign!(search_users::v3::User::new(user_id.clone()), {
display_name,
avatar_url: services.users.avatar_url(&user_id).await.ok(),
}))
} else {
None
}
});
if user_in_public_room.or(user_sees_user).await {
Some(assign!(search_users::v3::User::new(user_id.clone()), {
display_name,
avatar_url: services.users.avatar_url(&user_id).await.ok(),
}))
} else {
None
}
});
let results = users.by_ref().take(limit).collect().await;
let limited = users.next().await.is_some();
+8 -7
View File
@@ -7,10 +7,7 @@
api::federation::query::{get_profile_information, get_room_information},
};
use crate::{
Ruma,
client::{get_local_profile, get_local_profile_field},
};
use crate::Ruma;
/// # `GET /_matrix/federation/v1/query/directory`
///
@@ -75,15 +72,19 @@ pub(crate) async fn get_profile_information_route(
let response = if let Some(field) = &body.field {
let mut response = get_profile_information::v1::Response::new();
if let Some(value) =
get_local_profile_field(&services, &body.user_id, field.to_owned()).await
if let Some(value) = services
.users
.get_local_profile_field(&body.user_id, field.to_owned())
.await
{
response.set(value.field_name().as_str().to_owned(), value.value().into_owned());
}
response
} else {
get_local_profile(&services, &body.user_id)
services
.users
.get_local_profile(&body.user_id)
.await
.into_iter()
.collect()
+2 -1
View File
@@ -620,8 +620,9 @@ async fn handle_edu_direct_to_device(
.broad_filter_map(|(target_user_id, map)| async move {
services
.users
.is_active_local(&target_user_id)
.status(&target_user_id)
.await
.is_active()
.then_some((target_user_id, map))
})
.for_each_concurrent(automatic_width(), |(target_user_id, map)| {
+1
View File
@@ -117,6 +117,7 @@ url.workspace = true
parking_lot.workspace = true
lock_api.workspace = true
hyper-util.workspace = true
openidconnect.workspace = true
[target.'cfg(unix)'.dependencies]
nix.workspace = true
+134 -3
View File
@@ -4,7 +4,7 @@
pub mod proxy;
use std::{
collections::{BTreeMap, BTreeSet},
collections::{BTreeMap, BTreeSet, HashMap},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
path::PathBuf,
};
@@ -17,10 +17,12 @@
use figment::providers::{Env, Format, Toml};
pub use figment::{Figment, value::Value as FigmentValue};
use lettre::message::Mailbox;
use openidconnect::{ClientId, ClientSecret, Scope};
use regex::RegexSet;
use ruma::{
OwnedRoomId, OwnedRoomOrAliasId, OwnedServerName, OwnedUserId, RoomVersionId,
api::client::{discovery::discover_support::ContactRole, rtc::RtcTransport},
profile::ProfileFieldName,
serde::Base64,
};
use serde::{Deserialize, Serialize, de::IgnoredAny};
@@ -2419,10 +2421,24 @@ pub struct OauthConfig {
/// legacy authentication will be unable to log in.
///
/// default: "hybrid"
pub compatibility_mode: OAuthMode,
compatibility_mode: OAuthMode,
/// display: hidden
pub oidc: Option<OidcConfig>,
}
#[derive(Clone, Debug, Default, Deserialize)]
impl OauthConfig {
#[must_use]
pub fn compatibility_mode(&self) -> OAuthMode {
if self.oidc.is_some() {
OAuthMode::Exclusive
} else {
self.compatibility_mode
}
}
}
#[derive(Clone, Copy, Debug, Default, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum OAuthMode {
Disabled,
@@ -2439,6 +2455,115 @@ pub fn uiaa_available(&self) -> bool { matches!(self, Self::Disabled | Self::Hyb
pub fn oauth_available(&self) -> bool { matches!(self, Self::Hybrid | Self::Exclusive) }
}
#[derive(Clone, Debug, Deserialize)]
#[config_example_generator(
filename = "conduwuit-example.toml",
section = "global.oauth.oidc",
optional = "true",
subheader = "\
# Uncommenting this section will enable Continuwuity's support for
# authenticating users using an OpenID Connect-compatible identity provider.
# This is referred to as \"delegated authentication\".
#
# IMPORTANT NOTE: When delegated authentication is active, Continuwuity will behave as if
# the `global.oauth.compatibility_mode` setting is set to `exclusive`.
# Matrix clients which do not support OAuth login (also referred to as \"next-gen auth\") will \
NOT be able
# to log in while delegated authentication is active."
)]
pub struct OidcConfig {
/// The OIDC issuer URL. Continuwuity will use OpenID Connect Discovery to
/// automatically fetch the identity provider's metadata from this URL.
/// Generally you should set this to the base domain your identity provider
/// runs on.
pub discovery_url: Url,
/// The OAuth client ID for Continuwuity to use when communicating with the
/// identity provider.
pub client_id: ClientId,
/// The OAuth client secret for Continuwuity to use when communicating with
/// the identity provider.
pub client_secret: ClientSecret,
/// Additional scopes Continuwuity should request from the IDP. This may be
/// necessary to access certain claims. Continuwuity always requests the
/// `openid` scope.
///
/// default: []
#[serde(default)]
pub additional_scopes: Vec<Scope>,
/// Whether the user should be prompted to choose a localpart
/// when signing in for the first time. If this is `false`, Continuwuity
/// will attempt to use the value of the `preferred_username_claim`
/// (see below) as the user's localpart. Authentication will
/// fail if this claim is missing or is not a valid localpart.
///
/// default: true
#[serde(default = "true_fn")]
pub prompt_for_localpart: bool,
/// The claim to use for the user's localpart, if `prompt_for_localpart` is
/// false.
///
/// default: "preferred_username"
#[serde(default = "default_preferred_username_claim")]
pub preferred_username_claim: String,
/// The claim which will be used to set the user's email address,
/// either on initial registration or on every login depending on
/// the value of `profile_key_import_mode`. Continuwuity assumes that
/// the IDP has taken care of verifying that the user controls the email
/// address it provides.
///
/// This option does nothing if SMTP is not configured.
///
/// If this option is set, and `profile_key_import_mode` is `on_login`,
/// users will not be able to change their email addresses themselves.
///
/// default: "email"
pub email_claim: Option<String>,
/// Defines how claims returned from the IDP should be mapped to a user's
/// profile data. The profile field named in each key will be set from the
/// claim named in the corresponding value when the user first registers,
/// and possibly on subsequent logins as well, depending on the value of
/// `profile_key_import_mode` (see below).
///
/// Per-room overrides to the user's display name or avatar will be
/// preserved by the import process.
///
/// SECURITY NOTE: If the `avatar_url` field is set, Continuwuity will
/// perform a HTTP GET to the URL in the mapped claim and use the returned
/// file as the user's profile picture. Make sure your users are not able
/// to set the value of the mapped claim to an arbitrary URL.
///
/// default: { displayname = "name" }
#[serde(default = "default_profile_key_map")]
pub profile_key_map: HashMap<String, String>,
/// When profile keys should be imported from the IDP's claims.
///
/// - "on_registration": Listed keys will be imported once, when the user
/// logs in for the first time and their shadow account is created.
/// - "on_login": Listed keys will be imported every time the user logs in.
/// Additionally, users will not be able to manually edit any listed keys
/// through their Matrix client.
///
/// default: "on_registration"
#[serde(default)]
pub profile_key_import_mode: OidcProfileKeyImportMode,
}
#[derive(Clone, Debug, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum OidcProfileKeyImportMode {
#[default]
OnRegistration,
OnLogin,
}
const DEPRECATED_KEYS: &[&str] = &[
"cache_capacity",
"conduit_cache_capacity_modifier",
@@ -2823,3 +2948,9 @@ fn default_client_shutdown_timeout() -> u64 { 15 }
fn default_sender_shutdown_timeout() -> u64 { 5 }
fn default_terms_language() -> String { "en".to_owned() }
fn default_preferred_username_claim() -> String { "preferred_username".to_owned() }
fn default_profile_key_map() -> HashMap<String, String> {
HashMap::from_iter([(ProfileFieldName::DisplayName.to_string(), "name".to_owned())])
}
+16
View File
@@ -124,6 +124,14 @@ pub(super) fn open_list(db: &Arc<Engine>, maps: &[Descriptor]) -> Result<Maps> {
name: "onetimekeyid_onetimekeys",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "openidsubject_localpart",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "openidsubject_currentpictureurl",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "fallbackkeyid_fallbackkey",
..descriptor::RANDOM_SMALL
@@ -169,6 +177,10 @@ pub(super) fn open_list(db: &Arc<Engine>, maps: &[Descriptor]) -> Result<Maps> {
name: "registrationtoken_info",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "remoteuserid_remoteuser",
..descriptor::RANDOM
},
Descriptor {
name: "roomid_invitedcount",
..descriptor::RANDOM_SMALL
@@ -398,6 +410,10 @@ pub(super) fn open_list(db: &Arc<Engine>, maps: &[Descriptor]) -> Result<Maps> {
name: "userid_blurhash",
..descriptor::DROPPED
},
Descriptor {
name: "userid_deactivated",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "userid_dehydrateddevice",
..descriptor::RANDOM_SMALL
+6
View File
@@ -81,6 +81,12 @@ fn generate_example(input: &ItemStruct, args: &[Meta], write: bool) -> Result<To
};
file.write_fmt(format_args!("{section_header}"))
.expect("written to config file");
if let Some(subheader) = settings.get("subheader") {
file.write_all(subheader.as_bytes())
.expect("written to config file");
file.write_all(b"\n\n").expect("written to config file");
}
}
let mut summary: Vec<TokenStream2> = Vec::new();
+1
View File
@@ -120,6 +120,7 @@ reqwest_recaptcha = { package = "reqwest", version = "0.12.28", default-features
yansi.workspace = true
lettre.workspace = true
serde_urlencoded.workspace = true
openidconnect.workspace = true
[target.'cfg(all(unix, target_os = "linux"))'.dependencies]
sd-notify.workspace = true
+1 -1
View File
@@ -48,7 +48,7 @@ pub async fn create_admin_room(services: &Services) -> Result {
// Create a user for the server
let server_user = services.globals.server_user.as_ref();
services.users.create(server_user, None)?;
services.users.create_shadow_account(server_user).await?;
let mut create_content = if room_version_rules.authorization.use_room_create_sender {
RoomCreateEventContent::new_v1(server_user.into())
+7 -7
View File
@@ -108,17 +108,17 @@ async fn start_appservice(&self, id: String, registration: Registration) -> Resu
self.services.globals.server_name(),
)?;
if !self.services.users.exists(&appservice_user_id).await {
self.services.users.create(&appservice_user_id, None)?;
} else if self
if !self
.services
.users
.is_deactivated(&appservice_user_id)
.status(&appservice_user_id)
.await
.unwrap_or(false)
.is_found()
{
// Reactivate the appservice user if it was accidentally deactivated
self.services.users.set_password(&appservice_user_id, None);
self.services
.users
.create_shadow_account(&appservice_user_id)
.await?;
}
self.registration_info
+16 -9
View File
@@ -54,15 +54,22 @@ impl Service {
async fn set_emergency_access(&self) -> Result {
let server_user = &self.services.globals.server_user;
self.services.users.set_password(
server_user,
self.services
.config
.emergency_password
.as_deref()
.map(HashedPassword::new)
.transpose()?,
);
match &self.services.config.emergency_password {
| Some(emergency_password) => {
let emergency_password = HashedPassword::new(emergency_password)?;
self.services
.users
.convert_to_local_account(server_user, emergency_password)
.await?;
},
| None => {
self.services
.users
.convert_to_shadow_account(server_user)
.await?;
},
}
let (ruleset, pwd_set) = match self.services.config.emergency_password {
| Some(_) => (Ruleset::server_default(server_user), true),
-6
View File
@@ -211,7 +211,6 @@ pub async fn download_audio(
Ok(preview_data)
}
#[cfg(feature = "url_preview")]
pub async fn download_media(&self, url: &str) -> Result<(OwnedMxcUri, usize)> {
use conduwuit::utils::random_string;
use http::header::CONTENT_TYPE;
@@ -268,11 +267,6 @@ pub async fn download_audio(
Err!(FeatureDisabled("url_preview"))
}
#[cfg(not(feature = "url_preview"))]
pub async fn download_media(&self, _url: &str) -> Result<UrlPreviewData> {
Err!(FeatureDisabled("url_preview"))
}
#[cfg(feature = "url_preview")]
async fn download_html(&self, url: &str) -> Result<UrlPreviewData> {
use webpage::HTML;
+57 -4
View File
@@ -41,7 +41,7 @@ pub(crate) async fn migrations(services: &Services) -> Result<()> {
// requires recreating the database from scratch.
if users_count > 0 {
let server_user = &services.globals.server_user;
if !services.users.exists(server_user).await {
if !services.users.status(server_user).await.is_found() {
error!("The {server_user} server user does not exist, and the database is not new.");
return Err!(Database(
"Cannot reuse an existing database after changing the server name, please \
@@ -228,6 +228,18 @@ async fn migrate(services: &Services) -> Result<()> {
.map_err(|e| err!("Failed to run 'fix_local_invite_state' migration': {e}"))?;
}
if services.globals.db.database_version().await < 18 {
services.globals.db.bump_database_version(18);
info!("Migration: Bumped database version to 18");
}
if db["global"].get(SPLIT_USERID_PASSWORD).await.is_not_found() {
info!("Running migration 'split_userid_password'");
split_userid_password(services)
.await
.map_err(|e| err!("Failed to run 'split_userid_password' migration': {e}"))?;
}
assert_eq!(
services.globals.db.database_version().await,
DATABASE_VERSION,
@@ -242,9 +254,9 @@ async fn migrate(services: &Services) -> Result<()> {
if !patterns.is_empty() {
services
.users
.stream()
.stream_local_users()
.filter_map(async |user_id| {
if services.users.is_active_local(&user_id).await {
if services.users.status(&user_id).await.is_found() {
Some(user_id)
} else {
None
@@ -774,7 +786,7 @@ async fn fix_local_invite_state(services: &Services) -> Result {
let db = &services.db;
let cork = db.cork_and_sync();
let userroomid_invitestate = services.db["userroomid_invitestate"].clone();
let userroomid_invitestate = db["userroomid_invitestate"].clone();
// for each user invited to a room
let fixed = userroomid_invitestate.stream()
@@ -818,3 +830,44 @@ async fn fix_local_invite_state(services: &Services) -> Result {
db.db.sort()?;
Ok(())
}
const SPLIT_USERID_PASSWORD: &str = "split_userid_password";
async fn split_userid_password(services: &Services) -> Result {
// Split remote and deactivated users out from the `userid_password` table
let db = &services.db;
let cork = db.cork_and_sync();
let userid_password = db["userid_password"].clone();
let remoteuserid_remoteuser = db["remoteuserid_remoteuser"].clone();
let userid_deactivated = db["userid_deactivated"].clone();
let remote_users = userid_password
.stream::<OwnedUserId, String>()
.ignore_err()
.fold(0_usize, async |mut remote_users, (user_id, hash)| {
if !services.globals.user_is_local(&user_id) {
assert!(hash.is_empty(), "non-empty hash {hash} for remote user {user_id}");
remoteuserid_remoteuser.insert(&user_id, "");
userid_password.remove(&user_id);
remote_users = remote_users.saturating_add(1);
} else if hash.is_empty() {
if !(services.appservice.is_exclusive_user_id(&user_id).await
|| user_id == services.globals.server_user)
{
info!("Marking {user_id} as deactivated");
userid_deactivated.insert(&user_id, "");
}
}
remote_users
})
.await;
drop(cork);
info!(?remote_users, "Split userid_password.");
db["global"].insert(FIXED_LOCAL_INVITE_STATE_MARKER, []);
db.db.sort()?;
Ok(())
}
+1
View File
@@ -28,6 +28,7 @@
pub mod media;
pub mod moderation;
pub mod oauth;
pub mod oidc;
pub mod presence;
pub mod pusher;
pub mod registration_tokens;
+1 -1
View File
@@ -160,7 +160,7 @@ pub enum ErrorCode {
InvalidClientMetadata,
}
#[derive(Serialize)]
#[derive(Serialize, Deserialize)]
pub struct AuthorizationCodeResponse {
pub state: String,
pub code: String,
+471
View File
@@ -0,0 +1,471 @@
use std::{collections::HashMap, str::FromStr, sync::Arc, time::Duration};
use async_trait::async_trait;
use conduwuit::{
Result,
config::{OidcConfig, OidcProfileKeyImportMode},
debug, err, error, info, warn,
};
use database::{Deserialized, Map};
use lettre::Address;
use openidconnect::{
AdditionalClaims, AuthorizationCode, CsrfToken, EmptyExtraTokenFields, EndpointMaybeSet,
EndpointNotSet, EndpointSet, IdTokenClaims, IdTokenFields, IssuerUrl, Nonce,
PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, StandardErrorResponse,
StandardTokenResponse, TokenResponse,
core::{
CoreAuthDisplay, CoreAuthPrompt, CoreAuthenticationFlow, CoreErrorResponseType,
CoreGenderClaim, CoreJsonWebKey, CoreJweContentEncryptionAlgorithm,
CoreJwsSigningAlgorithm, CoreProviderMetadata, CoreRevocableToken,
CoreRevocationErrorResponse, CoreTokenIntrospectionResponse, CoreTokenType,
},
reqwest,
};
use ruma::{
OwnedUserId, UserId,
api::client::profile::PropagateTo,
profile::{ProfileFieldName, ProfileFieldValue},
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::{runtime, sync::SetOnce};
use url::Url;
use crate::{
Dep, config, globals, media,
oauth::grant::AuthorizationCodeResponse,
threepid,
users::{self, AccountStatus, ProfileFieldChange},
};
pub struct Service {
services: Services,
runtime: runtime::Handle,
db: Data,
client: Option<OidcClient>,
}
struct Data {
openidsubject_localpart: Arc<Map>,
openidsubject_currentpictureurl: Arc<Map>,
}
struct Services {
config: Dep<config::Service>,
globals: Dep<globals::Service>,
media: Dep<media::Service>,
threepid: Dep<threepid::Service>,
users: Dep<users::Service>,
}
struct OidcClient {
config: OidcConfig,
machine: SetOnce<OidcClientMachine>,
client: reqwest::Client,
}
type OidcClientMachine = openidconnect::Client<
AllClaims,
CoreAuthDisplay,
CoreGenderClaim,
CoreJweContentEncryptionAlgorithm,
CoreJsonWebKey,
CoreAuthPrompt,
StandardErrorResponse<CoreErrorResponseType>,
StandardTokenResponse<
IdTokenFields<
AllClaims,
EmptyExtraTokenFields,
CoreGenderClaim,
CoreJweContentEncryptionAlgorithm,
CoreJwsSigningAlgorithm,
>,
CoreTokenType,
>,
CoreTokenIntrospectionResponse,
CoreRevocableToken,
CoreRevocationErrorResponse,
EndpointSet,
EndpointNotSet,
EndpointNotSet,
EndpointNotSet,
EndpointMaybeSet,
EndpointMaybeSet,
>;
pub type Claims = IdTokenClaims<AllClaims, CoreGenderClaim>;
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct AllClaims {
#[serde(flatten)]
pub claims: HashMap<String, Value>,
}
impl AdditionalClaims for AllClaims {}
#[derive(Debug, Deserialize, Serialize)]
pub struct PendingSession {
pkce_verifier: PkceCodeVerifier,
nonce: Nonce,
csrf_token: CsrfToken,
}
pub enum SessionCompletionStatus {
NeedsUserId,
Complete(OwnedUserId),
}
#[async_trait]
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
services: Services {
config: args.depend::<config::Service>("config"),
globals: args.depend::<globals::Service>("globals"),
media: args.depend::<media::Service>("media"),
threepid: args.depend::<threepid::Service>("threepid"),
users: args.depend::<users::Service>("users"),
},
runtime: args.server.runtime().clone(),
db: Data {
openidsubject_localpart: args.db["openidsubject_localpart"].clone(),
openidsubject_currentpictureurl: args.db["openidsubject_currentpictureurl"].clone(),
},
client: args.server.config.oauth.oidc.as_ref().map(|config| OidcClient {
config: config.clone(),
machine: SetOnce::new(),
// This isn't in the client service because it has to use the `reqwest` shipped by `openidconnect`
client: reqwest::ClientBuilder::new()
.connect_timeout(Duration::from_secs(args.server.config.request_conn_timeout))
.read_timeout(Duration::from_secs(args.server.config.request_timeout))
.timeout(Duration::from_secs(args.server.config.request_total_timeout))
.pool_idle_timeout(Duration::from_secs(args.server.config.request_idle_timeout))
.pool_max_idle_per_host(args.server.config.request_idle_per_host.into())
.user_agent(conduwuit::user_agent())
.redirect(reqwest::redirect::Policy::none())
.danger_accept_invalid_certs(args.server.config.allow_invalid_tls_certificates_yes_i_know_what_the_fuck_i_am_doing_with_this_and_i_know_this_is_insecure)
.build()
.expect("client should build")
}),
}))
}
async fn worker(self: Arc<Self>) -> Result {
if let Some(OidcClient { config, machine, client }) = &self.client {
let redirect_url = self
.services
.config
.get_client_domain()
.join(&format!("{}/oidc/complete", conduwuit::ROUTE_PREFIX))
.expect("redirect url should be valid");
let provider_metadata = CoreProviderMetadata::discover_async(
IssuerUrl::from_url(config.discovery_url.clone()),
client,
)
.await
.map_err(|err| err!("Failed to discover OIDC provider metadata: {err}"))?;
machine
.set(
OidcClientMachine::from_provider_metadata(
provider_metadata,
config.client_id.clone(),
Some(config.client_secret.clone()),
)
.set_redirect_uri(RedirectUrl::from_url(redirect_url)),
)
.expect("machine should be empty");
}
Ok(())
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
impl Service {
const SERVER_MISCONFIGURED: &str =
"Identity server is misconfigured. Contact your homeserver's administrator.";
pub fn enabled(&self) -> bool { self.client.is_some() }
pub fn restricted_profile_fields(&self) -> Vec<ProfileFieldName> {
if let Some(config) = self.client.as_ref().map(|client| &client.config)
&& matches!(config.profile_key_import_mode, OidcProfileKeyImportMode::OnLogin)
{
config
.profile_key_map
.keys()
.map(|key| ProfileFieldName::from(key.as_str()))
.collect()
} else {
vec![]
}
}
pub async fn begin_session(&self, prompt: Option<CoreAuthPrompt>) -> (PendingSession, Url) {
let OidcClient { machine, config, .. } =
self.client.as_ref().expect("oidc should be configured");
let machine = machine.wait().await;
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let mut auth_url = machine
.authorize_url(
CoreAuthenticationFlow::AuthorizationCode,
CsrfToken::new_random,
Nonce::new_random,
)
.add_scopes(config.additional_scopes.iter().cloned())
.set_pkce_challenge(pkce_challenge);
if let Some(prompt) = prompt {
auth_url = auth_url.add_prompt(prompt);
}
let (auth_url, csrf_token, nonce) = auth_url.url();
(PendingSession { pkce_verifier, nonce, csrf_token }, auth_url)
}
pub async fn exchange_code(
&self,
session: PendingSession,
response: AuthorizationCodeResponse,
) -> Result<Claims, &'static str> {
let Some(OidcClient { machine, client, .. }) = self.client.as_ref() else {
return Err("Delegated authentication is not enabled on this server.");
};
let machine = machine.wait().await;
if session.csrf_token.into_secret() != response.state {
return Err("State mismatch.");
}
let token_response = machine
.exchange_code(AuthorizationCode::new(response.code))
.expect("machine should be configured correctly")
.set_pkce_verifier(session.pkce_verifier)
.request_async(client)
.await
.map_err(|err| {
error!("Failed to exchange OIDC authorization code: {err}");
"Code exchange failed."
})?;
let Some(id_token) = token_response.id_token() else {
error!("Identity server did not return an id token");
return Err(Self::SERVER_MISCONFIGURED);
};
let claims = id_token
.claims(&machine.id_token_verifier(), &session.nonce)
.map_err(|err| {
error!("Failed to verify id token claims: {err}");
Self::SERVER_MISCONFIGURED
})?
.to_owned();
info!(subject = claims.subject().as_str(), "Authenticated subject");
Ok(claims)
}
#[tracing::instrument(skip(self, claims), fields(subject = claims.subject().to_string()))]
pub async fn complete_session(
&self,
claims: &Claims,
supplied_user_id: Option<OwnedUserId>,
) -> Result<SessionCompletionStatus, &'static str> {
let Some(OidcClient { config, .. }) = self.client.as_ref() else {
return Err("Delegated authentication is not enabled on this server.");
};
// this is a truly awful hack but we really need all the claims in a map
let all_claims = serde_json::to_value(claims)
.expect("should be able to serialize claims")
.as_object()
.expect("claims should be an object")
.to_owned();
debug!(?all_claims);
let subject = claims.subject().as_str();
let user_id = if let Ok(localpart) = self
.db
.openidsubject_localpart
.get(subject)
.await
.deserialized::<String>()
{
UserId::parse(format!("@{localpart}:{}", self.services.globals.server_name()))
.expect("saved localpart should be valid")
} else if config.prompt_for_localpart {
if let Some(supplied_user_id) = supplied_user_id {
supplied_user_id
} else {
return Ok(SessionCompletionStatus::NeedsUserId);
}
} else if let Some(preferred_username) = all_claims
.get(&config.preferred_username_claim)
.and_then(|claim| claim.as_str())
{
self.services
.users
.determine_registration_user_id(Some(preferred_username.to_owned()), None, None)
.await
.map_err(|err| {
error!("Preferred username claim is not a valid localpart: {err}");
"Your preferred username is not a valid Matrix user ID localpart. Contact \
your homeserver's administrator."
})?
} else {
error!("Preferred username claim was not present or was not a string");
return Err(Self::SERVER_MISCONFIGURED);
};
info!(?subject, ?user_id, "User {user_id} successfully authorized with OIDC");
// Create a shadow account for the user if necessary
let new_account_registered = match self.services.users.status(&user_id).await {
| AccountStatus::Active => {
// Do nothing, an account already exists
false
},
| AccountStatus::NotFound => {
// Create a new shadow user
self.services
.users
.create_local_account(&user_id, None, None)
.await
.map_err(|err| {
error!("Failed to create a shadow user for {user_id}: {err}");
Self::SERVER_MISCONFIGURED
})?;
info!(?subject, ?user_id, "Shadow user created for {user_id}");
true
},
| AccountStatus::Deactivated => {
return Err("Your account has been deactivated.");
},
};
self.link_user(&user_id, subject);
// Import profile fields
if matches!(config.profile_key_import_mode, OidcProfileKeyImportMode::OnLogin)
|| (matches!(
config.profile_key_import_mode,
OidcProfileKeyImportMode::OnRegistration
) && new_account_registered)
{
if let Some(email_claim) = &config.email_claim {
if let Some(email) = claims.email().map(|email| email.as_str())
&& let Ok(address) = Address::from_str(email)
{
if let Err(err) = self
.services
.threepid
.associate_localpart_email(user_id.localpart(), &address)
.await
{
warn!(?email_claim, ?address, "Failed to associate email address: {err}");
}
} else {
warn!(
?email_claim,
"Email claim was not present or was not a valid email address"
);
}
}
let user_id = user_id.clone();
let subject = claims.subject().to_string();
let profile_key_map = config.profile_key_map.clone();
let openidsubject_currentpictureurl = self.db.openidsubject_currentpictureurl.clone();
let users = self.services.users.clone();
let media = self.services.media.clone();
let import_task = self.runtime.spawn(async move {
for (field, claim) in &profile_key_map {
let Some(value) = all_claims.get(claim).cloned() else {
warn!(?field, ?claim, "IDP provided no value for this mapped claim");
continue;
};
let value = if let Some(picture_url) = value.as_str()
&& field == ProfileFieldName::AvatarUrl.as_str()
&& openidsubject_currentpictureurl
.get(&subject)
.await
.deserialized::<String>()
.ok()
.is_none_or(|current_picture| current_picture != picture_url)
{
match media.download_media(picture_url).await {
| Ok((mxc, size)) => {
openidsubject_currentpictureurl.insert(&subject, picture_url);
info!(?picture_url, ?mxc, ?size, "Downloaded profile picture");
ProfileFieldValue::AvatarUrl(mxc)
},
| Err(err) => {
warn!(
?claim,
?picture_url,
"Failed to download profile picture: {err}"
);
continue;
},
}
} else {
match ProfileFieldValue::new(field, value.clone()) {
| Ok(value) => value,
| Err(err) => {
warn!(
?field,
?claim,
?value,
"Failed to parse claim value for profile field: {err}"
);
continue;
},
}
};
if let Err(err) = users
.set_profile_field(
&user_id,
ProfileFieldChange::Set(value),
PropagateTo::Unchanged,
)
.await
{
warn!(?field, ?claim, "Error while setting profile field: {err}");
}
}
info!("Profile import complete");
});
// Only wait for import to complete if this is a new account,
// so they see the correct profile information in the account panel
if new_account_registered {
let _ = import_task.await;
}
}
Ok(SessionCompletionStatus::Complete(user_id))
}
pub fn link_user(&self, user_id: &UserId, subject: &str) {
self.db
.openidsubject_localpart
.insert(subject, user_id.localpart());
}
pub fn unlink_user(&self, subject: &str) { self.db.openidsubject_localpart.remove(subject); }
}
+1 -1
View File
@@ -248,7 +248,7 @@ pub fn active_local_users_in_room<'a>(
) -> impl Stream<Item = OwnedUserId> + Send + 'a {
self.local_users_in_room(room_id)
.filter_map(async |user_id| {
if self.services.users.is_active(&user_id).await {
if self.services.users.status(&user_id).await.is_active() {
Some(user_id)
} else {
None
+2 -9
View File
@@ -30,15 +30,8 @@ pub async fn update_membership(
) -> Result {
let membership = pdu.get_content::<RoomMemberEventContent>()?;
// Keep track what remote users exist by adding them as "deactivated" users
//
// TODO: use futures to update remote profiles without blocking the membership
// update
#[allow(clippy::collapsible_if)]
if !self.services.globals.user_is_local(user_id)
&& !self.services.users.exists(user_id).await
{
self.services.users.create(user_id, None)?;
if !self.services.globals.user_is_local(user_id) {
self.services.users.record_remote_user(user_id);
}
match &membership.membership {
+7 -1
View File
@@ -255,7 +255,13 @@ async fn notify_local_users(&self, pdu: &PduEvent, pdu_id: &RawPduId, room_id: &
if let Some(state_key) = pdu.state_key() {
match UserId::parse(state_key) {
| Ok(target_user_id) => {
if self.services.users.is_active_local(&target_user_id).await {
if self
.services
.users
.status(&target_user_id)
.await
.is_active()
{
push_targets.insert(target_user_id.clone());
}
},
+4 -2
View File
@@ -11,8 +11,8 @@
account_data, admin, announcements, antispam, appservice, client, config, emergency,
federation, firstrun, globals, key_backups, mailer,
manager::Manager,
media, moderation, oauth, presence, pusher, registration_tokens, resolver, rooms, sending,
server_keys,
media, moderation, oauth, oidc, presence, pusher, registration_tokens, resolver, rooms,
sending, server_keys,
service::{self, Args, Map, Service},
sync, threepid, transactions, uiaa, users,
};
@@ -28,6 +28,7 @@ pub struct Services {
pub key_backups: Arc<key_backups::Service>,
pub media: Arc<media::Service>,
pub oauth: Arc<oauth::Service>,
pub oidc: Arc<oidc::Service>,
pub mailer: Arc<mailer::Service>,
pub presence: Arc<presence::Service>,
pub pusher: Arc<pusher::Service>,
@@ -85,6 +86,7 @@ macro_rules! build {
key_backups: build!(key_backups::Service),
media: build!(media::Service),
oauth: build!(oauth::Service),
oidc: build!(oidc::Service),
mailer: build!(mailer::Service),
presence: build!(presence::Service),
pusher: build!(pusher::Service),
+9 -2
View File
@@ -1,6 +1,6 @@
use std::{borrow::Cow, collections::HashMap, sync::Arc};
use conduwuit::{Err, Error, Result, result::FlatOk};
use conduwuit::{Err, Error, Result, config::OidcProfileKeyImportMode, result::FlatOk};
use database::{Deserialized, Map};
use governor::{DefaultKeyedRateLimiter, Quota, RateLimiter};
use lettre::{Address, message::Mailbox};
@@ -87,7 +87,14 @@ impl Service {
/// Check if users are required to have an email address.
pub fn email_requirement(&self) -> EmailRequirement {
if let Some(smtp) = &self.services.config.smtp {
if smtp.require_email_for_registration || smtp.require_email_for_token_registration {
if let Some(oidc) = &self.services.config.oauth.oidc
&& matches!(oidc.profile_key_import_mode, OidcProfileKeyImportMode::OnLogin)
&& oidc.email_claim.is_some()
{
EmailRequirement::Unavailable
} else if smtp.require_email_for_registration
|| smtp.require_email_for_token_registration
{
EmailRequirement::Required
} else {
EmailRequirement::Optional
+1 -1
View File
@@ -314,7 +314,7 @@ async fn create_session(
.services
.config
.oauth
.compatibility_mode
.compatibility_mode()
.uiaa_available()
{
return Err!(Request(Unrecognized(
+150 -88
View File
@@ -6,21 +6,23 @@
warn,
};
use database::{Deserialized, Json};
use futures::{FutureExt, Stream, StreamExt, TryFutureExt};
use futures::{FutureExt, Stream, StreamExt};
use lettre::Address;
use ruma::{
MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedUserId, UserId,
api::client::profile::PropagateTo,
events::{
GlobalAccountDataEventType, ignored_user_list::IgnoredUserListEvent,
push_rules::PushRulesEvent, room::message::RoomMessageEventContent,
},
profile::ProfileFieldValue,
push::Ruleset,
};
use ruminuwuity::invite_permission_config::{FilterLevel, InvitePermissionConfigEvent};
use crate::{
appservice::RegistrationInfo,
users::{HashedPassword, UserSuspension},
users::{HashedPassword, UserSuspension, profile::ProfileFieldChange},
};
/// The status of an access token.
@@ -29,6 +31,30 @@ pub enum AccessTokenStatus {
Expired,
}
/// The status of a user account.
#[derive(Clone, Copy)]
pub enum AccountStatus {
NotFound,
Active,
Deactivated,
}
impl AccountStatus {
#[must_use]
pub fn is_found(&self) -> bool { !matches!(self, Self::NotFound) }
#[must_use]
pub fn is_active(&self) -> bool { matches!(self, Self::Active) }
pub fn ensure_active(&self) -> Result<()> {
match self {
| Self::Active => Ok(()),
| Self::Deactivated => Err!(Request(UserDeactivated("This account is deactivated."))),
| Self::NotFound => Err!(Request(NotFound("This account does not exist."))),
}
}
}
impl super::Service {
/// Returns true/false based on whether the recipient/receiving user has
/// ignored the sender.
@@ -89,29 +115,33 @@ pub async fn is_admin(&self, user_id: &UserId) -> bool {
self.services.admin.user_is_admin(user_id).await
}
/// Create a new user account on this homeserver. Set the password to `None`
/// to create a non-local user. Non-local users with a password will return
/// an error.
#[inline]
pub fn create(&self, user_id: &UserId, password: Option<HashedPassword>) -> Result<()> {
if !self.services.globals.user_is_local(user_id) && password.is_some() {
return Err!("Cannot create a nonlocal user with a set password");
/// Create a new shadow user account on this server. Shadow accounts
/// have no password and cannot be logged into.
pub async fn create_shadow_account(&self, user_id: &UserId) -> Result<()> {
assert!(self.services.globals.user_is_local(user_id), "user id must be local");
if self.status(user_id).await.is_found() {
return Err!(Request(UserInUse("An account with this user ID already exists.")));
}
self.set_password(user_id, password);
self.db.userid_password.insert(user_id, "");
Ok(())
}
/// Create a new account for a local human or bot user.
/// Create a new account for a local human or bot user. If `password` is
/// None, the account will be a shadow account.
pub async fn create_local_account(
&self,
user_id: &UserId,
password: HashedPassword,
password: Option<HashedPassword>,
email: Option<Address>,
) {
self.create(user_id, Some(password))
.expect("should be able to save a new local user. what happened?");
) -> Result<()> {
self.create_shadow_account(user_id).await?;
if let Some(password) = password {
self.convert_to_local_account(user_id, password).await?;
}
// Set an initial display name
{
@@ -123,7 +153,13 @@ pub async fn create_local_account(
displayname.push_str(suffix);
}
self.set_displayname(user_id, Some(displayname));
self.set_profile_field(
user_id,
ProfileFieldChange::Set(ProfileFieldValue::DisplayName(displayname)),
PropagateTo::None,
)
.await
.expect("should be able to set display name");
};
// Set default push rules
@@ -230,6 +266,8 @@ pub async fn create_local_account(
}
info!("Created new user account for {user_id}");
Ok(())
}
pub async fn determine_registration_user_id(
@@ -268,28 +306,7 @@ pub async fn determine_registration_user_id(
&supplied_username,
self.services.globals.server_name(),
) {
| Ok(user_id) => {
if let Err(e) = user_id.validate_strict() {
// Unless we are in emergency mode, we should follow synapse's behaviour
// on not allowing things like spaces and UTF-8 characters in
// usernames
if !emergency_mode_enabled {
return Err!(Request(InvalidUsername(debug_warn!(
"Username {supplied_username} contains disallowed characters or \
spaces: {e}"
))));
}
}
// Don't allow registration with user IDs that aren't local
if !self.services.globals.user_is_local(&user_id) {
return Err!(Request(InvalidUsername(
"Username {supplied_username} is not local to this server"
)));
}
user_id
},
| Ok(user_id) => user_id,
| Err(e) => {
return Err!(Request(InvalidUsername(debug_warn!(
"Username {supplied_username} is not valid: {e}"
@@ -297,8 +314,27 @@ pub async fn determine_registration_user_id(
},
};
if self.exists(&user_id).await {
return Err!(Request(UserInUse("User ID is not available.")));
if let Err(e) = user_id.validate_strict() {
// Unless we are in emergency mode, we should follow synapse's behaviour
// on not allowing things like spaces and UTF-8 characters in
// usernames
if !emergency_mode_enabled {
return Err!(Request(InvalidUsername(debug_warn!(
"Username {supplied_username} contains disallowed characters or spaces: \
{e}"
))));
}
}
// Don't allow registration with user IDs that aren't local
if !self.services.globals.user_is_local(&user_id) {
return Err!(Request(InvalidUsername(
"Username {supplied_username} is not local to this server"
)));
}
if self.status(&user_id).await.is_found() {
return Err!(Request(UserInUse("Username is not available.")));
}
// Check that the user ID is/is not in an appservice's namespace
@@ -329,26 +365,24 @@ pub async fn determine_registration_user_id(
)
.unwrap();
if !self.exists(&user_id).await {
if !self.status(&user_id).await.is_found() {
break Ok(user_id);
}
}
}
}
/// Deactivates an account, removing all of their device IDs and unsetting
/// their password.
/// Deactivates an account, removing all of their device IDs.
pub async fn deactivate_account(&self, user_id: &UserId) -> Result<()> {
self.status(user_id).await.ensure_active()?;
// Remove all associated devices
self.all_device_ids(user_id)
.for_each(async |device_id| self.remove_device(user_id, &device_id).await)
.await;
// Set the password to "" to indicate a deactivated account. Hashes will never
// result in an empty string, so the user will not be able to log in again.
// Systems like changing the password without logging in should check if the
// account is deactivated.
self.set_password(user_id, None);
// Mark user as deactivated
self.db.userid_deactivated.insert(user_id, "");
// TODO: Unhook 3PID
Ok(())
@@ -393,25 +427,6 @@ pub async fn lock_account(&self, user_id: &UserId, locking_user: &UserId) {
/// Unlocks an account, allowing the user to log in and use it again.
pub async fn unlock_account(&self, user_id: &UserId) { self.db.userid_lock.remove(user_id); }
/// Check if the provided user ID belongs to an existing (possibly
/// deactivated) account on this homeserver.
#[inline]
pub async fn exists(&self, user_id: &UserId) -> bool {
self.services.globals.user_is_local(user_id)
&& self.db.userid_password.get(user_id).await.is_ok()
}
/// Check if account is deactivated (has an empty password). Returns a
/// NotFound error if the user does not exist.
pub async fn is_deactivated(&self, user_id: &UserId) -> Result<bool> {
self.db
.userid_password
.get(user_id)
.map_ok(|val| val.is_empty())
.map_err(|_| err!(Request(NotFound("User does not exist."))))
.await
}
/// Check if account is suspended. Returns false if the user does not exist.
pub async fn is_suspended(&self, user_id: &UserId) -> Result<bool> {
match self
@@ -469,18 +484,33 @@ pub async fn is_login_disabled(&self, user_id: &UserId) -> bool {
.is_ok()
}
/// Check if account is active (not deactivated)
pub async fn is_active(&self, user_id: &UserId) -> bool {
!self.is_deactivated(user_id).await.unwrap_or(true)
/// Checks the activation status of the provided user ID's account.
pub async fn status(&self, user_id: &UserId) -> AccountStatus {
if !self.services.globals.user_is_local(user_id) {
AccountStatus::NotFound
} else if self.db.userid_password.exists(user_id).await.is_ok() {
if self.db.userid_deactivated.exists(user_id).await.is_ok() {
AccountStatus::Deactivated
} else {
AccountStatus::Active
}
} else {
AccountStatus::NotFound
}
}
/// Check if account is a local user, and is active (not deactivated)
pub async fn is_active_local(&self, user_id: &UserId) -> bool {
self.services.globals.user_is_local(user_id) && self.is_active(user_id).await
/// Checks if a user is a shadow.
pub async fn is_shadow(&self, user_id: &UserId) -> bool {
self.db
.userid_password
.get(user_id)
.await
.deserialized::<String>()
.is_ok_and(|hash| hash.is_empty())
}
/// Returns the number of users registered on this server, including
/// deactivated users.
/// deactivated and shadow users.
#[inline]
pub async fn count(&self) -> usize { self.db.userid_password.count().await }
@@ -522,8 +552,8 @@ pub async fn find_from_token(
Some((user_id, device_id, AccessTokenStatus::Valid))
}
/// Returns an iterator over all users on this homeserver.
pub fn stream(&self) -> impl Stream<Item = OwnedUserId> + Send {
/// Returns an iterator over all local users on this homeserver.
pub fn stream_local_users(&self) -> impl Stream<Item = OwnedUserId> + Send {
self.db.userid_password.keys().ignore_err()
}
@@ -540,16 +570,49 @@ pub fn list_local_users(&self) -> impl Stream<Item = OwnedUserId> + Send + '_ {
}
/// Set a user's password.
pub fn set_password(&self, user_id: &UserId, password: Option<HashedPassword>) {
if let Some(hash) = password {
self.db.userid_password.insert(user_id, hash.0);
} else {
self.db.userid_password.insert(user_id, b"");
pub async fn set_password(&self, user_id: &UserId, password: HashedPassword) -> Result {
self.status(user_id).await.ensure_active()?;
self.db.userid_password.insert(user_id, password.0);
Ok(())
}
/// Convert an existing user to an active local account. This will turn
/// shadow accounts into normal accounts and will clear the deactivation
/// flag on the user if it is set.
pub async fn convert_to_local_account(
&self,
user_id: &UserId,
password: HashedPassword,
) -> Result {
if !self.status(user_id).await.is_found() {
return Err!(Request(NotFound("This account does not exist.")));
}
self.db.userid_deactivated.remove(user_id);
self.db.userid_password.insert(user_id, password.0);
Ok(())
}
/// Convert an existing user to a shadow account. This will clear their
/// password and reactivate them if they are deactivated.
pub async fn convert_to_shadow_account(&self, user_id: &UserId) -> Result {
if !self.status(user_id).await.is_found() {
return Err!(Request(NotFound("This account does not exist.")));
}
self.db.userid_deactivated.remove(user_id);
self.db.userid_password.insert(user_id, "");
Ok(())
}
/// Check a user's password.
pub async fn check_password(&self, user_id: &UserId, password: &str) -> Result<OwnedUserId> {
self.status(user_id).await.ensure_active()?;
let (hash, user_id): (String, OwnedUserId) =
if let Ok(hash) = self.db.userid_password.get(user_id).await.deserialized() {
(hash, user_id.to_owned())
@@ -558,21 +621,20 @@ pub async fn check_password(&self, user_id: &UserId, password: &str) -> Result<O
// better
let lowercase_user_id = UserId::parse(user_id.as_str().to_lowercase()).unwrap();
if let Ok(hash) = self
let hash = self
.db
.userid_password
.get(lowercase_user_id.as_str())
.await
.deserialized()
{
(hash, lowercase_user_id)
} else {
return Err!(Request(Forbidden("This user cannot log in with a password.")));
}
.expect("user should exist");
(hash, lowercase_user_id)
};
if hash.is_empty() {
return Err!(Request(UserDeactivated("This user is deactivated")));
// Cannot log into shadow users
return Err!(Request(Forbidden("This user cannot log in with a password.")));
}
utils::hash::verify_password(password, &hash)
+1 -4
View File
@@ -21,10 +21,7 @@ pub struct DehydratedDevice {
impl super::Service {
/// Creates or recreates the user's dehydrated device.
pub async fn set_dehydrated_device(&self, user_id: &UserId, request: Request) -> Result {
assert!(
self.exists(user_id).await,
"Tried to create dehydrated device for non-existent user"
);
self.status(user_id).await.ensure_active()?;
let existing_id = self.get_dehydrated_device_id(user_id).await;
+12 -21
View File
@@ -4,8 +4,8 @@
};
use conduwuit::{
Err, utils,
utils::{ReadyExt, stream::TryIgnore},
Err, Result,
utils::{self, ReadyExt, stream::TryIgnore},
};
use database::{Deserialized, Ignore, Interfix, Json};
use futures::{Stream, StreamExt};
@@ -18,8 +18,7 @@
use crate::users::increment;
impl super::Service {
/// Adds a new device to a user. The user must exist, otherwise InvalidParam
/// is returned.
/// Adds a new device to a user.
pub async fn create_device(
&self,
user_id: &UserId,
@@ -28,12 +27,8 @@ pub async fn create_device(
token_max_age: Option<Duration>,
initial_device_display_name: Option<String>,
client_ip: Option<String>,
) -> conduwuit::Result<()> {
if !self.exists(user_id).await {
return Err!(Request(InvalidParam(error!(
"Called create_device for non-existent user {user_id}"
))));
}
) -> Result<()> {
self.status(user_id).await.ensure_active()?;
let key = (user_id, device_id);
let mut device = Device::new(device_id.into());
@@ -50,7 +45,7 @@ pub async fn create_device(
/// Removes a device from a user.
pub async fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) {
// Remove dehydrated device if this is the dehydrated device
let _: conduwuit::Result<_> = self
let _ = self
.remove_dehydrated_device(user_id, Some(device_id))
.await;
@@ -97,11 +92,7 @@ pub fn all_device_ids<'a>(
}
/// Gets the access token associated with a device.
pub async fn get_token(
&self,
user_id: &UserId,
device_id: &DeviceId,
) -> conduwuit::Result<String> {
pub async fn get_token(&self, user_id: &UserId, device_id: &DeviceId) -> Result<String> {
let key = (user_id, device_id);
self.db.userdeviceid_token.qry(&key).await.deserialized()
}
@@ -131,7 +122,7 @@ pub async fn set_token(
device_id: &DeviceId,
token: &str,
token_max_age: Option<Duration>,
) -> conduwuit::Result<()> {
) -> Result<()> {
let key = (user_id, device_id);
if self.db.userdeviceid_metadata.qry(&key).await.is_err() {
return Err!(Database(error!(
@@ -259,7 +250,7 @@ pub async fn update_device_metadata(
user_id: &UserId,
device_id: &DeviceId,
device: &Device,
) -> conduwuit::Result<()> {
) -> Result<()> {
increment(&self.db.userid_devicelistversion, user_id.as_bytes());
self.update_device_metadata_no_increment(user_id, device_id, device)
}
@@ -273,7 +264,7 @@ fn update_device_metadata_no_increment(
user_id: &UserId,
device_id: &DeviceId,
device: &Device,
) -> conduwuit::Result<()> {
) -> Result<()> {
let key = (user_id, device_id);
self.db.userdeviceid_metadata.put(key, Json(device));
@@ -312,7 +303,7 @@ pub async fn get_device_metadata(
&self,
user_id: &UserId,
device_id: &DeviceId,
) -> conduwuit::Result<Device> {
) -> Result<Device> {
self.db
.userdeviceid_metadata
.qry(&(user_id, device_id))
@@ -321,7 +312,7 @@ pub async fn get_device_metadata(
}
/// Gets the most recent device list version for a user.
pub async fn get_devicelist_version(&self, user_id: &UserId) -> conduwuit::Result<u64> {
pub async fn get_devicelist_version(&self, user_id: &UserId) -> Result<u64> {
self.db
.userid_devicelistversion
.get(user_id)
+14 -2
View File
@@ -4,20 +4,22 @@
pub(super) mod filters;
pub(super) mod keys;
pub(super) mod profile;
pub(super) mod remote;
use std::{mem, sync::Arc};
pub use account::AccessTokenStatus;
pub use account::{AccessTokenStatus, AccountStatus};
use conduwuit::{
Err, Error, Result, err,
utils::{self},
};
use database::Map;
pub use profile::ProfileFieldChange;
use ruma::{UserId, api::error::ErrorKind, encryption::CrossSigningKey, serde::Raw};
use serde::{Deserialize, Serialize};
use crate::{
Dep, account_data, admin, appservice, config, firstrun, globals, oauth,
Dep, account_data, admin, appservice, config, firstrun, globals, oauth, presence,
rooms::{self, alias, membership},
threepid,
};
@@ -61,9 +63,12 @@ struct Services {
globals: Dep<globals::Service>,
membership: Dep<membership::Service>,
oauth: Dep<oauth::Service>,
presence: Dep<presence::Service>,
state: Dep<rooms::state::Service>,
state_accessor: Dep<rooms::state_accessor::Service>,
state_cache: Dep<rooms::state_cache::Service>,
threepid: Dep<threepid::Service>,
timeline: Dep<rooms::timeline::Service>,
}
struct Data {
@@ -75,11 +80,13 @@ struct Data {
logintoken_expiresatuserid: Arc<Map>,
todeviceid_events: Arc<Map>,
token_userdeviceid: Arc<Map>,
remoteuserid_remoteuser: Arc<Map>,
userdeviceid_tokenexpires: Arc<Map>,
userdeviceid_metadata: Arc<Map>,
userdeviceid_token: Arc<Map>,
userfilterid_filter: Arc<Map>,
userid_avatarurl: Arc<Map>,
userid_deactivated: Arc<Map>,
userid_dehydrateddevice: Arc<Map>,
userid_devicelistversion: Arc<Map>,
userid_displayname: Arc<Map>,
@@ -107,10 +114,13 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
globals: args.depend::<globals::Service>("globals"),
membership: args.depend::<membership::Service>("membership"),
oauth: args.depend::<oauth::Service>("oauth"),
presence: args.depend::<presence::Service>("presence"),
state: args.depend::<rooms::state::Service>("rooms::state"),
state_accessor: args
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
threepid: args.depend::<threepid::Service>("threepid"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
},
db: Data {
keychangeid_userid: args.db["keychangeid_userid"].clone(),
@@ -121,10 +131,12 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
logintoken_expiresatuserid: args.db["logintoken_expiresatuserid"].clone(),
todeviceid_events: args.db["todeviceid_events"].clone(),
token_userdeviceid: args.db["token_userdeviceid"].clone(),
remoteuserid_remoteuser: args.db["remoteuserid_remoteuser"].clone(),
userdeviceid_metadata: args.db["userdeviceid_metadata"].clone(),
userdeviceid_token: args.db["userdeviceid_token"].clone(),
userfilterid_filter: args.db["userfilterid_filter"].clone(),
userid_avatarurl: args.db["userid_avatarurl"].clone(),
userid_deactivated: args.db["userid_deactivated"].clone(),
userid_dehydrateddevice: args.db["userid_dehydrateddevice"].clone(),
userid_devicelistversion: args.db["userid_devicelistversion"].clone(),
userid_displayname: args.db["userid_displayname"].clone(),
+269 -14
View File
@@ -1,17 +1,276 @@
use conduwuit::utils::{ReadyExt, stream::TryIgnore};
use std::{borrow::Cow, collections::BTreeMap};
use conduwuit::{
Err, Result,
pdu::PartialPdu,
utils::{ReadyExt, stream::TryIgnore, to_canonical_object},
};
use database::{Deserialized, Ignore, Interfix, Json};
use futures::{Stream, StreamExt};
use ruma::{OwnedMxcUri, UserId};
use ruma::{
OwnedMxcUri, UserId,
api::client::profile::PropagateTo,
events::room::member::MembershipState,
presence::PresenceState,
profile::{ProfileFieldName, ProfileFieldValue},
};
use serde_json::{Value, to_value};
pub enum ProfileFieldChange {
Set(ProfileFieldValue),
Delete(ProfileFieldName),
}
impl ProfileFieldChange {
fn field_name(&self) -> ProfileFieldName {
match self {
| &Self::Delete(ref name) => name.clone(),
| &Self::Set(ref value) => value.field_name(),
}
}
fn value(&self) -> Option<Cow<'_, Value>> {
if let Self::Set(value) = self {
Some(value.value())
} else {
None
}
}
}
impl super::Service {
pub async fn set_profile_field(
&self,
user_id: &UserId,
change: ProfileFieldChange,
propagate_to: PropagateTo,
) -> Result<()> {
const MAX_KEY_LENGTH_BYTES: usize = 255;
const MAX_PROFILE_LENGTH_BYTES: usize = 65536;
let field_name = change.field_name();
// TODO: The spec mentions special error codes (M_PROFILE_TOO_LARGE,
// M_KEY_TOO_LARGE) for profile field size limits, but they're not in its list
// of error codes and Ruma doesn't have them. Should we return those, or is
// M_TOO_LARGE okay?
if field_name.as_str().len() > MAX_KEY_LENGTH_BYTES {
return Err!(Request(TooLarge(
"Individual profile keys must not exceed {MAX_KEY_LENGTH_BYTES} bytes in length."
)));
}
// Serialize the entire profile as canonical JSON, including the new change,
// to check if it exceeds 64 KiB
{
let mut full_profile = self.get_local_profile(user_id).await;
match &change {
| ProfileFieldChange::Set(value) => {
full_profile.insert(
value.field_name().as_str().to_owned(),
value.value().clone().into_owned(),
);
},
| ProfileFieldChange::Delete(key) => {
full_profile.remove(key.as_str());
},
}
if let Ok(canonical_profile) = to_canonical_object(full_profile) {
if serde_json::to_string(&canonical_profile)
.expect("should be able to serialize to string")
.len() > MAX_PROFILE_LENGTH_BYTES
{
return Err!(
"Profile data must not exceed {MAX_PROFILE_LENGTH_BYTES} bytes in \
length."
);
}
} else {
return Err!(Request(BadJson("Failed to canonicalize profile.")));
}
}
// Check if this change would be a no-op
if self
.get_local_profile_field(user_id, field_name.clone())
.await
.is_some_and(|value| Some(value.value()) == change.value())
{
return Ok(());
}
// If the user is local and changed their displayname or avatar_url, update it
// in all their joined rooms. This is done before updating their profile data
// so we can check the old value of the field if `propagate_to` is `unchanged`.
if matches!(field_name, ProfileFieldName::AvatarUrl | ProfileFieldName::DisplayName)
&& matches!(propagate_to, PropagateTo::All | PropagateTo::Unchanged)
&& self.services.globals.user_is_local(user_id)
{
let current_displayname = self.displayname(user_id).await.ok();
let current_avatar_url = self.avatar_url(user_id).await.ok();
let mut all_joined_rooms = self.services.state_cache.rooms_joined(user_id);
while let Some(room_id) = all_joined_rooms.next().await {
// TODO: this clobbers any custom fields on the event content
let mut current_membership = self
.services
.state_accessor
.get_member(&room_id, user_id)
.await
.expect("should be able to fetch membership event for joined room");
assert_eq!(
current_membership.membership,
MembershipState::Join,
"user should be joined"
);
// If `propagate_to` is `unchanged`, and the current value of the field we're
// updating was changed from its global value in this room, skip it.
if matches!(propagate_to, PropagateTo::Unchanged) {
let field_changed_from_global = match field_name {
| ProfileFieldName::AvatarUrl =>
current_membership.avatar_url.as_ref() != current_avatar_url.as_ref(),
| ProfileFieldName::DisplayName =>
current_membership.displayname.as_ref()
!= current_displayname.as_ref(),
| _ => unreachable!(),
};
if field_changed_from_global {
continue;
}
}
let state_lock = self.services.state.mutex.lock(room_id.as_str()).await;
// Preserve keys in accordance with the key copying rules
current_membership.reason = None;
current_membership.join_authorized_via_users_server = None;
match &change {
| ProfileFieldChange::Set(ProfileFieldValue::AvatarUrl(avatar_url)) => {
current_membership.avatar_url = Some(avatar_url.clone());
},
| ProfileFieldChange::Set(ProfileFieldValue::DisplayName(displayname)) => {
current_membership.displayname = Some(displayname.clone());
},
| ProfileFieldChange::Delete(ProfileFieldName::AvatarUrl) => {
current_membership.avatar_url = None;
},
| ProfileFieldChange::Delete(ProfileFieldName::DisplayName) => {
current_membership.displayname = None;
},
| _ => unreachable!(),
}
let _ = self
.services
.timeline
.build_and_append_pdu(
PartialPdu::state(user_id.to_string(), &current_membership),
user_id,
Some(&room_id),
&state_lock,
)
.await;
}
if self.services.config.allow_local_presence {
// Send a presence EDU to indicate the profile changed
let _ = self
.services
.presence
.ping_presence(user_id, &PresenceState::Online)
.await;
}
}
match change {
| ProfileFieldChange::Set(ProfileFieldValue::DisplayName(displayname)) => {
self.set_displayname(user_id, Some(displayname).filter(|dn| !dn.is_empty()));
},
| ProfileFieldChange::Set(ProfileFieldValue::AvatarUrl(avatar_url)) => {
self.set_avatar_url(user_id, Some(avatar_url).filter(|av| av.is_valid()));
},
| ProfileFieldChange::Delete(ProfileFieldName::DisplayName) => {
self.set_displayname(user_id, None);
},
| ProfileFieldChange::Delete(ProfileFieldName::AvatarUrl) => {
self.set_avatar_url(user_id, None);
},
| other => self.set_profile_key(
user_id,
other.field_name().as_str(),
other.value().map(Cow::into_owned),
),
}
Ok(())
}
pub async fn get_local_profile(&self, user_id: &UserId) -> BTreeMap<String, Value> {
let mut profile = BTreeMap::new();
// Get displayname and avatar_url independently because `all_profile_keys`
// doesn't include them
for field in [ProfileFieldName::AvatarUrl, ProfileFieldName::DisplayName] {
let key = field.as_str().to_owned();
if let Some(value) = self.get_local_profile_field(user_id, field).await {
profile.insert(key, value.value().into_owned());
}
}
// Insert all other profile fields
let mut all_fields = self.all_profile_keys(user_id);
while let Some((key, value)) = all_fields.next().await {
profile.insert(key, value);
}
profile
}
pub async fn get_local_profile_field(
&self,
user_id: &UserId,
field: ProfileFieldName,
) -> Option<ProfileFieldValue> {
let value = match field.clone() {
| ProfileFieldName::AvatarUrl => self
.avatar_url(user_id)
.await
.ok()
.map(to_value)
.transpose()
.expect("converting avatar url to value should succeed"),
| ProfileFieldName::DisplayName => self
.displayname(user_id)
.await
.ok()
.map(to_value)
.transpose()
.expect("converting displayname to value should succeed"),
| other => self.profile_key(user_id, other.as_str()).await.ok(),
}?;
Some(
ProfileFieldValue::new(field.as_str(), value)
.expect("local profile field should be valid"),
)
}
/// Returns the displayname of a user on this homeserver.
pub async fn displayname(&self, user_id: &UserId) -> conduwuit::Result<String> {
pub async fn displayname(&self, user_id: &UserId) -> Result<String> {
self.db.userid_displayname.get(user_id).await.deserialized()
}
/// Sets a new displayname or removes it if displayname is None. You still
/// need to notify all rooms of this change.
pub fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) {
fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) {
if let Some(displayname) = displayname {
self.db.userid_displayname.insert(user_id, displayname);
} else {
@@ -20,12 +279,12 @@ pub fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) {
}
/// Get the `avatar_url` of a user.
pub async fn avatar_url(&self, user_id: &UserId) -> conduwuit::Result<OwnedMxcUri> {
pub async fn avatar_url(&self, user_id: &UserId) -> Result<OwnedMxcUri> {
self.db.userid_avatarurl.get(user_id).await.deserialized()
}
/// Sets a new avatar_url or removes it if avatar_url is None.
pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<OwnedMxcUri>) {
fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<OwnedMxcUri>) {
match avatar_url {
| Some(avatar_url) => {
self.db.userid_avatarurl.insert(user_id, &avatar_url);
@@ -37,11 +296,7 @@ pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<OwnedMxcUri>)
}
/// Gets a specific user profile key
pub async fn profile_key(
&self,
user_id: &UserId,
profile_key: &str,
) -> conduwuit::Result<serde_json::Value> {
pub async fn profile_key(&self, user_id: &UserId, profile_key: &str) -> Result<Value> {
let key = (user_id, profile_key);
self.db
.useridprofilekey_value
@@ -54,7 +309,7 @@ pub async fn profile_key(
pub fn all_profile_keys<'a>(
&'a self,
user_id: &'a UserId,
) -> impl Stream<Item = (String, serde_json::Value)> + 'a + Send {
) -> impl Stream<Item = (String, Value)> + 'a + Send {
type KeyVal<'a> = ((Ignore, String), &'a [u8]);
let prefix = (user_id, Interfix);
@@ -67,11 +322,11 @@ pub fn all_profile_keys<'a>(
}
/// Sets a new profile key value, removes the key if value is None
pub fn set_profile_key(
fn set_profile_key(
&self,
user_id: &UserId,
profile_key: &str,
profile_key_value: Option<serde_json::Value>,
profile_key_value: Option<Value>,
) {
let key = (user_id, profile_key);
+17
View File
@@ -0,0 +1,17 @@
use conduwuit::utils::stream::TryIgnore;
use futures::Stream;
use ruma::{OwnedUserId, UserId};
impl super::Service {
/// Record the existence of a remote user.
pub fn record_remote_user(&self, user_id: &UserId) {
assert!(!self.services.globals.user_is_local(user_id), "user is not remote");
self.db.remoteuserid_remoteuser.insert(user_id, "");
}
/// Returns a stream over all remote users this server has ever seen.
pub fn stream_remote_users(&self) -> impl Stream<Item = OwnedUserId> + Send {
self.db.remoteuserid_remoteuser.keys().ignore_err()
}
}
+1
View File
@@ -49,6 +49,7 @@ url.workspace = true
recaptcha-verify = { version = "0.2.0", default-features = false }
reqwest_recaptcha = { package = "reqwest", version = "0.12.28", default-features = false, features = ["rustls-tls-native-roots-no-provider"] } # As long as recaptcha-verify's reqwest is outdated
form_urlencoded = "1.2.2"
openidconnect.workspace = true
[build-dependencies]
memory-serve = "2.1.0"
+1
View File
@@ -133,6 +133,7 @@ pub fn build(services: &Services) -> Router<state::State> {
.nest("/account/", account::build())
.merge(debug::build())
.nest("/oauth2/", oauth::build())
.nest("/oidc/", oidc::build())
.merge(resources::build())
.merge(threepid::build())
.fallback(async || WebError::NotFound),
+49 -17
View File
@@ -7,12 +7,14 @@
routing::{get, on},
};
use conduwuit_api::client::handle_login;
use openidconnect::core::CoreAuthPrompt;
use ruma::{
OwnedUserId,
api::client::uiaa::{EmailUserIdentifier, MatrixUserIdentifier, UserIdentifier},
};
use serde::Deserialize;
use tower_sessions::Session;
use url::Url;
use crate::{
ROUTE_PREFIX, WebError,
@@ -21,9 +23,10 @@
GET_POST, Result, TemplateContext,
account::register::{TrustedFlowStatus, UntrustedFlowStatus, registration_flow_status},
components::UserCard,
oidc::{OIDC_SESSION_ID_KEY, OidcSession, OidcSessionState},
},
response,
session::{LoginQuery, LoginTarget, User, UserSession},
session::{LoginIntent, LoginQuery, LoginTarget, User, UserSession},
template,
};
@@ -36,6 +39,7 @@ pub(crate) fn build() -> Router<crate::State> {
template! {
struct Login use "login.html.j2" {
body: LoginBody,
login_type: LoginType,
login_error: Option<String>
}
}
@@ -44,7 +48,6 @@ struct Login use "login.html.j2" {
enum LoginBody {
Unauthenticated {
server_name: String,
registration_available: bool,
next: Option<LoginTarget>,
},
Authenticated {
@@ -52,6 +55,16 @@ enum LoginBody {
},
}
#[derive(Debug)]
enum LoginType {
Interactive {
registration_available: bool,
},
Oidc {
redirect_url: Url,
},
}
#[derive(Deserialize)]
struct LoginForm {
identifier: Option<String>,
@@ -61,27 +74,46 @@ struct LoginForm {
async fn route_login(
State(services): State<crate::State>,
Extension(context): Extension<TemplateContext>,
Expect(Query(LoginQuery { next, reauthenticate })): Expect<Query<LoginQuery>>,
Expect(Query(LoginQuery { next, reauthenticate, intent })): Expect<Query<LoginQuery>>,
session_store: Session,
user: User<true>,
PostForm(form): PostForm<LoginForm>,
) -> Result {
let user_id = user.into_session().map(|session| session.user_id);
let login_type = if services.oidc.enabled() {
let (session, redirect_url) = services
.oidc
.begin_session(match intent {
| Some(LoginIntent::SwitchAccounts) => Some(CoreAuthPrompt::SelectAccount),
| _ if reauthenticate => Some(CoreAuthPrompt::Consent),
| _ => None,
})
.await;
session_store
.insert(OIDC_SESSION_ID_KEY, OidcSession {
next: next.clone().unwrap_or_default(),
state: OidcSessionState::CodeExchange { expected_user: user_id.clone(), session },
})
.await
.expect("should be able to serialize OIDC session");
LoginType::Oidc { redirect_url }
} else {
let (trusted_flow_status, untrusted_flow_status) =
registration_flow_status(&services).await;
let registration_available = matches!(trusted_flow_status, TrustedFlowStatus::Available)
|| matches!(untrusted_flow_status, UntrustedFlowStatus::Available { .. });
LoginType::Interactive { registration_available }
};
let body = match &user_id {
| None => {
let (trusted_flow_status, untrusted_flow_status) =
registration_flow_status(&services).await;
let registration_available =
matches!(trusted_flow_status, TrustedFlowStatus::Available)
|| matches!(untrusted_flow_status, UntrustedFlowStatus::Available { .. });
LoginBody::Unauthenticated {
server_name: services.globals.server_name().to_string(),
registration_available,
next: next.clone(),
}
| None => LoginBody::Unauthenticated {
server_name: services.globals.server_name().to_string(),
next: next.clone(),
},
| Some(user_id) => {
if !reauthenticate {
@@ -94,7 +126,7 @@ async fn route_login(
},
};
let mut template = Login::new(context, body, None);
let mut template = Login::new(context, body, login_type, None);
if let Some(form) = form {
let login_result = match (user_id, form.identifier) {
+2
View File
@@ -66,6 +66,7 @@ struct Account use "account.html.j2" {
enum AccountBody {
Unlocked {
suspended: bool,
oidc_enabled: bool,
email_requirement: EmailRequirement,
email: Option<String>,
devices: Vec<DeviceCard>,
@@ -128,6 +129,7 @@ async fn get_account(
response!(Account::new(context, user_card, AccountBody::Unlocked {
suspended,
oidc_enabled: services.oidc.enabled(),
email_requirement,
email,
devices: device_cards
+13 -6
View File
@@ -4,6 +4,7 @@
use validator::{Validate, ValidationError, ValidationErrors};
use crate::{
WebError,
extract::PostForm,
form,
pages::{
@@ -65,11 +66,17 @@ async fn route_change_password(
user: User,
PostForm(form): PostForm<ChangePasswordForm>,
) -> Result {
if services.oidc.enabled() {
return Err(WebError::BadRequest(
"Password changing is not available on this server".to_owned(),
));
}
let user_id = user.expect(LoginTarget::ChangePassword)?;
let user_card = UserCard::for_local_user(&services, user_id.clone()).await;
let body = if let Some(form) = form {
match change_password(&services, &user_id, form).await {
match change_password(&services, &user_id, form).await? {
| Ok(()) => ChangePasswordBody::Success,
| Err(errors) =>
ChangePasswordBody::Form(ChangePasswordForm::with_errors(context.clone(), errors)),
@@ -85,7 +92,7 @@ async fn change_password(
services: &crate::State,
user_id: &UserId,
form: ChangePasswordForm,
) -> Result<(), ValidationErrors> {
) -> Result<Result<(), ValidationErrors>> {
form.validate()?;
if services
@@ -100,12 +107,12 @@ async fn change_password(
ValidationError::new("wrong").with_message("Incorrect password".into()),
);
return Err(errors);
return Ok(Err(errors));
}
match HashedPassword::new(&form.new_password) {
| Ok(hash) => {
services.users.set_password(user_id, Some(hash));
services.users.set_password(user_id, hash).await?;
},
| Err(err) => {
let mut errors = ValidationErrors::new();
@@ -114,9 +121,9 @@ async fn change_password(
ValidationError::new("malformed").with_message(err.message().into()),
);
return Err(errors);
return Ok(Err(errors));
},
}
Ok(())
Ok(Ok(()))
}
+8 -1
View File
@@ -65,6 +65,13 @@ async fn route_reset_password(
return response!(ResetPassword::new(context, ResetPasswordBody::Unavailable));
}
// Check if OIDC is enabled
if services.oidc.enabled() {
return Err(WebError::BadRequest(
"Password resets are not available on this server".to_owned(),
));
}
let Some(form) = form else {
// For GET requests return the reset request form
return response!(ResetPassword::new(
@@ -214,7 +221,7 @@ async fn route_reset_password_validate(
| Ok(hash) => {
let _ = session.consume();
services.users.set_password(&user_id, Some(hash));
services.users.set_password(&user_id, hash).await?;
ResetPasswordValidateBody::ResetSuccess { user_card }
},
+6 -6
View File
@@ -310,7 +310,7 @@ async fn get_register_email_validate(
response!(
complete_registration(&services, session_store, completed_registration, Some(email))
.await
.await?
)
}
@@ -502,7 +502,7 @@ async fn begin_registration(
} else {
// If email isn't required we can immediately complete registration
Ok(response!(
complete_registration(services, session_store, completed_registration, None).await
complete_registration(services, session_store, completed_registration, None).await?
))
}
}
@@ -517,11 +517,11 @@ async fn complete_registration(
next,
}: CompletedRegistration,
email: Option<Address>,
) -> Redirect {
) -> Result<Redirect> {
services
.users
.create_local_account(&user_id, password_hash, email)
.await;
.create_local_account(&user_id, Some(password_hash), email)
.await?;
if let Some(registration_token) = registration_token {
services
@@ -536,7 +536,7 @@ async fn complete_registration(
.await
.expect("should be able to serialize user session");
Redirect::to(&next.unwrap_or_default().target_path())
Ok(Redirect::to(&next.unwrap_or_default().target_path()))
}
pub(super) async fn registration_flow_status(
+1
View File
@@ -17,6 +17,7 @@
pub(super) mod debug;
pub(super) mod index;
pub(super) mod oauth;
pub(super) mod oidc;
pub(super) mod resources;
pub(super) mod threepid;
+3 -2
View File
@@ -17,7 +17,7 @@
components::{Avatar, AvatarType, ClientScopes},
},
response,
session::{LoginQuery, LoginTarget, User},
session::{LoginIntent, LoginQuery, LoginTarget, User},
template,
};
@@ -129,7 +129,8 @@ async fn route_authorization_code(
context,
serde_urlencoded::to_string(LoginQuery {
next: Some(LoginTarget::AuthorizationCode(query)),
reauthenticate: false,
intent: Some(LoginIntent::SwitchAccounts),
..Default::default()
})
.unwrap(),
user_id,
+196
View File
@@ -0,0 +1,196 @@
use std::time::SystemTime;
use axum::{
Extension, Router,
extract::{Query, State},
response::Redirect,
routing::on,
};
use conduwuit_service::{oauth::grant::AuthorizationCodeResponse, oidc::SessionCompletionStatus};
use futures::FutureExt;
use ruma::{OwnedServerName, UserId};
use serde::{Deserialize, de::IgnoredAny};
use tower_sessions::Session;
use crate::{
WebError,
extract::{Expect, PostForm},
pages::{
GET_POST, Result, TemplateContext,
components::UserCard,
oidc::{OIDC_SESSION_ID_KEY, OidcSession, OidcSessionState},
},
response,
session::{User, UserSession},
template,
};
pub(crate) fn build() -> Router<crate::State> {
Router::new().route("/", on(GET_POST, route_complete))
}
template! {
struct OidcComplete use "oidc_complete.html.j2" {
body: OidcCompleteBody
}
}
#[derive(Debug)]
enum OidcCompleteBody {
UsernamePrompt {
server_name: OwnedServerName,
username_error: Option<String>,
},
PasswordPrompt {
username: String,
user_card: UserCard,
password_error: bool,
},
}
#[derive(Deserialize)]
struct LoginForm {
username: String,
password: Option<String>,
}
async fn route_complete(
State(services): State<crate::State>,
Extension(context): Extension<TemplateContext>,
Expect(Query(query)): Expect<Query<AuthorizationCodeResponse>>,
session_store: Session,
user: User<true>,
PostForm(form): PostForm<LoginForm>,
) -> Result {
let user_id = user.into_session().map(|session| session.user_id);
let Some(session) = session_store
.get::<OidcSession>(OIDC_SESSION_ID_KEY)
.await
.expect("should be able to deserialize oidc session")
else {
return response!(WebError::BadRequest(
"No OIDC session found. What are you doing here?".to_owned()
));
};
let session_completion_status = match session.state {
| OidcSessionState::CodeExchange { expected_user, session: pending_session } => {
if let (Some(user_id), Some(expected_user)) = (&user_id, &expected_user)
&& user_id != expected_user
{
return response!(WebError::BadRequest(
"Identity mismatch. You may have switched accounts at your identity \
provider. Please log out and back in to continue."
.to_owned()
));
}
let claims = services
.oidc
.exchange_code(pending_session, query)
.boxed()
.await
.map_err(|err| WebError::BadRequest(err.to_owned()))?;
session_store
.insert(OIDC_SESSION_ID_KEY, OidcSession {
next: session.next.clone(),
state: OidcSessionState::Authorized { claims: Box::new(claims.clone()) },
})
.await
.expect("Should be able to serialize oidc session");
services.oidc.complete_session(&claims, None).await
},
| OidcSessionState::Authorized { claims } => {
let supplied_user_id = if let Some(form) = form {
if let Ok(user_id) = UserId::parse(format!(
"@{}:{}",
&form.username,
services.globals.server_name()
)) && services.users.status(&user_id).await.is_active()
{
let user_card = UserCard::for_local_user(&services, user_id.clone()).await;
if let Some(password) = form.password {
if services
.users
.check_password(&user_id, &password)
.await
.is_ok()
{
Some(user_id)
} else {
return response!(OidcComplete::new(
context,
OidcCompleteBody::PasswordPrompt {
username: form.username,
user_card,
password_error: true
}
));
}
} else {
return response!(OidcComplete::new(
context,
OidcCompleteBody::PasswordPrompt {
username: form.username,
user_card,
password_error: false,
}
));
}
} else {
match services
.users
.determine_registration_user_id(Some(form.username), None, None)
.await
{
| Ok(user_id) => Some(user_id),
| Err(err) => {
return response!(OidcComplete::new(
context,
OidcCompleteBody::UsernamePrompt {
server_name: services.globals.server_name().to_owned(),
username_error: Some(err.message()),
}
));
},
}
}
} else {
None
};
services
.oidc
.complete_session(&claims, supplied_user_id)
.await
},
}
.map_err(|err| WebError::BadRequest(err.to_owned()))?;
match session_completion_status {
| SessionCompletionStatus::Complete(user_id) => {
let _ = session_store
.remove::<IgnoredAny>(OIDC_SESSION_ID_KEY)
.await;
let user_session = UserSession { user_id, last_login: SystemTime::now() };
session_store
.insert(User::KEY, user_session)
.await
.expect("should be able to serialize user session");
response!(Redirect::to(&session.next.target_path()))
},
| SessionCompletionStatus::NeedsUserId => {
response!(OidcComplete::new(context, OidcCompleteBody::UsernamePrompt {
server_name: services.globals.server_name().to_owned(),
username_error: None
}))
},
}
}
+34
View File
@@ -0,0 +1,34 @@
use axum::Router;
use conduwuit_service::oidc::{self, Claims};
use ruma::OwnedUserId;
use serde::{Deserialize, Serialize};
use crate::session::LoginTarget;
mod complete;
pub(crate) const OIDC_SESSION_ID_KEY: &str = "oidc_session";
#[derive(Debug, Deserialize, Serialize)]
pub(crate) struct OidcSession {
pub next: LoginTarget,
pub state: OidcSessionState,
}
#[derive(Debug, Deserialize, Serialize)]
pub(crate) enum OidcSessionState {
CodeExchange {
expected_user: Option<OwnedUserId>,
session: oidc::PendingSession,
},
Authorized {
claims: Box<Claims>,
},
}
pub(crate) fn build() -> Router<crate::State> {
#[allow(clippy::wildcard_imports)]
use self::*;
Router::new().nest("/complete", complete::build())
}
+1
View File
@@ -33,6 +33,7 @@
padding: 0.5em;
margin-bottom: 0.5em;
line-height: 1;
align-items: baseline;
border-radius: var(--border-radius-sm);
border: 2px solid var(--secondary);
+6 -4
View File
@@ -9,7 +9,7 @@ Your account
<h1>Manage your account</h1>
{{ user_card }}
{% match body %}
{% when AccountBody::Unlocked { suspended, email_requirement, email, devices } %}
{% when AccountBody::Unlocked { suspended, email_requirement, email, devices, oidc_enabled } %}
{% if suspended %}
<p class="card danger">
⚠️ Your account has been suspended by your homeserver's administrator.
@@ -27,9 +27,11 @@ Your account
<a href="email/change/">Change your email</a>
</p>
{% endif %}
<p>
<a href="password/change">Change your password</a>
</p>
{% if !oidc_enabled %}
<p>
<a href="password/change">Change your password</a>
</p>
{% endif %}
</section>
<section>
+46 -31
View File
@@ -9,9 +9,14 @@ Log in
{%- endblock -%}
{%- block content -%}
{% match login_type %}
{% when LoginType::Interactive { .. } %}
<div class="panel narrow">
{% when LoginType::Oidc { .. } %}
<div class="panel narrow middle"/>
{% endmatch %}
{% match body %}
{% when LoginBody::Unauthenticated { server_name, registration_available, next } %}
{% when LoginBody::Unauthenticated { server_name, next } %}
<h1 class="with-matrix-icon">
{% if next.is_some() %}
Log in to continue
@@ -25,39 +30,49 @@ Log in
<p>
You're about to log in to your account on <em>{{ server_name }}</em>
</p>
<hr>
<form method="post">
<p>
<label for="identifier">Username or email address</label>
<input type="text" id="identifier" name="identifier" autocomplete="username">
</p>
<p>
<label for="password">Password</label>
<input type="password" id="password" name="password" autocomplete="current-password">
</p>
<button type="submit">Log in</button>
</form>
<div class="centered-links">
{% if registration_available %}
{% let query = next.as_ref().map(serde_urlencoded::to_string).transpose().unwrap().unwrap_or_default() %}
<a href="{{ crate::ROUTE_PREFIX }}/account/register/?{{ query }}">Sign up</a>
{% endif %}
<a href="{{ crate::ROUTE_PREFIX }}/account/password/reset/">Forgot your password?</a>
</div>
{% match login_type %}
{% when LoginType::Interactive { registration_available } %}
<hr>
<form method="post">
<p>
<label for="identifier">Username or email address</label>
<input type="text" id="identifier" name="identifier" autocomplete="username">
</p>
<p>
<label for="password">Password</label>
<input type="password" id="password" name="password" autocomplete="current-password">
</p>
<button type="submit">Log in</button>
</form>
<div class="centered-links">
{% if registration_available %}
{% let query = next.as_ref().map(serde_urlencoded::to_string).transpose().unwrap().unwrap_or_default() %}
<a href="{{ crate::ROUTE_PREFIX }}/account/register/?{{ query }}">Sign up</a>
{% endif %}
<a href="{{ crate::ROUTE_PREFIX }}/account/password/reset/">Forgot your password?</a>
</div>
{% when LoginType::Oidc { redirect_url } %}
<a class="button" href="{{ redirect_url }}">Continue</a>
{% endmatch %}
{% when LoginBody::Authenticated { user_card } %}
<h1>Confirm your identity</h1>
{{ user_card }}
<p>Enter your password to continue.</p>
<form method="post">
<p>
<label for="password">Password</label>
<input type="password" id="password" name="password" autocomplete="current-password">
</p>
<button type="submit">Continue</button>
</form>
<div class="centered-links">
<a href="{{ crate::ROUTE_PREFIX }}/account/password/reset/">Forgot your password?</a>
</div>
{% match login_type %}
{% when LoginType::Interactive { .. } %}
<p>Enter your password to continue.</p>
<form method="post">
<p>
<label for="password">Password</label>
<input type="password" id="password" name="password" autocomplete="current-password">
</p>
<button type="submit">Continue</button>
</form>
<div class="centered-links">
<a href="{{ crate::ROUTE_PREFIX }}/account/password/reset/">Forgot your password?</a>
</div>
{% when LoginType::Oidc { redirect_url } %}
<a class="button" href="{{ redirect_url }}">Continue</a>
{% endmatch %}
{% endmatch %}
{% if let Some(error) = login_error %}
<small class="error">{{ error }}</small>
@@ -0,0 +1,59 @@
{% extends "_layout.html.j2" %}
{% import "_components/form.html.j2" as form %}
{%- block head -%}
<link rel="stylesheet" href="{{ crate::ROUTE_PREFIX }}/resources/login.css">
{%- endblock -%}
{%- block title -%}
Link your account
{%- endblock -%}
{%- block content -%}
<div class="panel narrow">
<h1 class="with-matrix-icon">
Link your account
<a href="https://matrix.org" target="_blank" noreferer>
<img class="matrix-icon" alt="Matrix logo" aria-ignore src="{{ crate::ROUTE_PREFIX }}/resources/matrix-icon.svg">
</a>
</h1>
{% match body %}
{% when OidcCompleteBody::UsernamePrompt { server_name, username_error } %}
<form method="post">
<p>
To finish linking your account to Matrix, choose a username.
<br>If you have an existing Matrix account, enter its user ID to link it.
</p>
<p>
<label for="username">Username</label>
<span class="username-input">
<span>@</span>
<input type="text" id="username" name="username" autocomplete="username" required>
<span>:{{ server_name }}</span>
</span>
{% if let Some(username_error) = username_error %}
<small class="error">
{{ username_error }}
</small>
{% endif %}
<small>Your username cannot be changed after you link your account.</small>
</p>
<button type="submit">Continue</button>
</form>
{% when OidcCompleteBody::PasswordPrompt { username, user_card, password_error } %}
{{ user_card }}
<form method="post">
<p>To link this legacy account, enter the password you used to use when logging in.</p>
<p>
<label for="password">Password</label>
<input type="password" id="password" name="password" autocomplete="current-password">
{% if password_error %}
<small class="error">Incorrect password</small>
{% endif %}
</p>
<input type="hidden" name="username" value="{{ username }}">
<button type="submit">Continue</button>
</form>
{% endmatch %}
</div>
{% endblock %}
+13 -6
View File
@@ -24,6 +24,8 @@ pub(crate) struct LoginQuery {
pub next: Option<LoginTarget>,
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
pub reauthenticate: bool,
#[serde(default)]
pub intent: Option<LoginIntent>,
}
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
@@ -65,6 +67,12 @@ pub(crate) fn target_path(&self) -> String {
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub(crate) enum LoginIntent {
SwitchAccounts,
}
/// An extractor that fetches the authenticated user.
pub(crate) struct User<const ALLOW_LOCKED: bool = false>(Option<UserSession>);
@@ -107,7 +115,7 @@ pub(crate) fn expect(self, or_else: LoginTarget) -> Result<OwnedUserId, WebError
} else {
Err(WebError::LoginRequired(LoginQuery {
next: Some(or_else),
reauthenticate: false,
..Default::default()
}))
}
}
@@ -122,12 +130,13 @@ pub(crate) fn expect_recent(self, or_else: LoginTarget) -> Result<OwnedUserId, W
Err(WebError::LoginRequired(LoginQuery {
next: Some(or_else),
reauthenticate: true,
..Default::default()
}))
}
} else {
Err(WebError::LoginRequired(LoginQuery {
next: Some(or_else),
reauthenticate: false,
..Default::default()
}))
}
}
@@ -162,10 +171,8 @@ pub(crate) async fn require_active(
user_id: &UserId,
allow_locked: bool,
) -> Result<(), Response> {
if !services.users.is_active(user_id).await {
return Err(
WebError::Forbidden("Your account is deactivated.".to_owned()).into_response()
);
if let Err(err) = services.users.status(user_id).await.ensure_active() {
return Err(WebError::Forbidden(err.message()).into_response());
}
if !allow_locked