Compare commits

...

13 Commits

Author SHA1 Message Date
Jacob Taylor b89c6bc2ef fix: Re-Add query mutual rooms advertisement 2026-06-27 06:43:55 -07:00
Jacob Taylor 7a0f710843 feat: Merge nex/feat/advertise-spec118 2026-06-27 06:42:35 -07:00
timedout ca8147db1b perf: Throttle dummy events to prevent stampeding 2026-06-27 06:42:35 -07:00
Jacob Taylor c3d193fb38 feat: More vibrant communication 2026-06-27 06:42:35 -07:00
Jacob Taylor 208d156261 perf: Merge nex/perf/get-missing-events 2026-06-27 06:42:35 -07:00
Jacob Taylor 0eee5089ea feat: Merge ginger/oauth 2026-06-27 06:42:35 -07:00
Jacob Taylor 14a288cc0a fix: Delete silly 2026-06-27 06:42:35 -07:00
Jacob Taylor d7b5938fd9 fix: Pre-Commit Lint Compliance Maneuver 2026-06-27 06:42:35 -07:00
Jacob Taylor 2e2b725f59 feat: Bump one cache a bit 2026-06-27 06:42:35 -07:00
Jacob Taylor fb51bd3cf3 upgrade some logs to info 2026-06-27 06:42:35 -07:00
Jacob Taylor 6bf6d4eaff exponential backoff is now just bees. did you want bees? no? well you have them now. congrats 2026-06-27 06:42:35 -07:00
Jacob Taylor 75b1aec33f enable converged 6g at the edge in continuwuity
sender_workers scaling. this time, with feeling!
2026-06-27 06:42:35 -07:00
Jacob Taylor bb04a81394 bump the number of allowed immutable memtables by 1, to allow for greater flood protection
this should probably not be applied if you have rocksdb_atomic_flush = false (the default)
2026-06-27 06:42:35 -07:00
149 changed files with 8136 additions and 2008 deletions
+1 -1
View File
@@ -51,7 +51,7 @@ repos:
hooks:
- id: cargo-clippy
name: cargo clippy
entry: cargo clippy -- -D warnings
entry: cargo clippy --
language: system
pass_filenames: false
types: [rust]
Generated
+69
View File
@@ -1088,6 +1088,7 @@ dependencies = [
"serde",
"serde-saphyr",
"serde_json",
"serde_urlencoded",
"sha2 0.11.0",
"termimad",
"tokio",
@@ -1107,18 +1108,29 @@ dependencies = [
"axum",
"axum-extra",
"base64 0.22.1",
"conduwuit_api",
"conduwuit_build_metadata",
"conduwuit_core",
"conduwuit_database",
"conduwuit_service",
"form_urlencoded",
"futures",
"lettre",
"memory-serve",
"rand 0.10.1",
"recaptcha-verify",
"reqwest 0.12.28",
"ruma",
"serde",
"serde_json",
"serde_urlencoded",
"thiserror",
"tower-http 0.7.0",
"tower-sec-fetch",
"tower-sessions",
"tower-sessions-core",
"tracing",
"url",
"validator",
]
@@ -1534,6 +1546,9 @@ name = "deranged"
version = "0.5.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7cd812cc2bc1d69d4764bd80df88b4317eaef9e773c75226407d9bc0876b211c"
dependencies = [
"serde_core",
]
[[package]]
name = "derive_more"
@@ -5616,6 +5631,22 @@ dependencies = [
"tracing",
]
[[package]]
name = "tower-cookies"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "151b5a3e3c45df17466454bb74e9ecedecc955269bdedbf4d150dfa393b55a36"
dependencies = [
"axum-core",
"cookie",
"futures-util",
"http",
"parking_lot",
"pin-project-lite",
"tower-layer",
"tower-service",
]
[[package]]
name = "tower-http"
version = "0.6.11"
@@ -5687,6 +5718,44 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3"
[[package]]
name = "tower-sessions"
version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "518dca34b74a17cadfcee06e616a09d2bd0c3984eff1769e1e76d58df978fc78"
dependencies = [
"async-trait",
"http",
"time",
"tokio",
"tower-cookies",
"tower-layer",
"tower-service",
"tower-sessions-core",
"tracing",
]
[[package]]
name = "tower-sessions-core"
version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "568531ec3dfcf3ffe493de1958ae5662a0284ac5d767476ecdb6a34ff8c6b06c"
dependencies = [
"async-trait",
"axum-core",
"base64 0.22.1",
"futures",
"http",
"parking_lot",
"rand 0.9.4",
"serde",
"serde_json",
"thiserror",
"time",
"tokio",
"tracing",
]
[[package]]
name = "tracing"
version = "0.1.44"
+3
View File
@@ -559,6 +559,9 @@ features = ["std"]
[workspace.dependencies.nonzero_ext]
version = "0.3.0"
[workspace.dependencies.serde_urlencoded]
version = "0.7.1"
#
# Patches
#
+1
View File
@@ -0,0 +1 @@
Users may now be forbidden from deactivating their own accounts with the new `allow_deactivation` config option. Contributed by @ginger.
+1
View File
@@ -0,0 +1 @@
Added support for authenticating clients using the new OAuth 2.0 login API. Contributed by @ginger.
+2
View File
@@ -0,0 +1,2 @@
Improved the performance and reliability of fetching missing events, improving network partition recovery. Contributed
by @nex.
+48 -14
View File
@@ -297,7 +297,7 @@
# This item is undocumented. Please contribute documentation for it.
#
#max_fetch_prev_events = 192
#max_fetch_prev_events = 1024
# How many incoming federation transactions the server is willing to be
# processing at any given time before it becomes overloaded and starts
@@ -521,17 +521,15 @@
#
#recaptcha_private_site_key =
# Policy documents, such as terms and conditions or a privacy policy,
# which users must agree to when registering an account.
# Controls whether users are allowed to deactivate their own accounts
# through the account management panel or their Matrix clients. Server
# admins can always deactivate users using the relevant admin commands.
#
# Example:
# ```ignore
# [global.registration_terms.privacy_policy]
# en = { name = "Privacy Policy", url = "https://homeserver.example/en/privacy_policy.html" }
# es = { name = "Política de Privacidad", url = "https://homeserver.example/es/privacy_policy.html" }
# ```
# Note that, in some jurisdictions, you may be legally required to honor
# users who request to deactivate their accounts if you set this option
# to `false`.
#
#registration_terms = {}
#allow_deactivation = true
# Controls whether encrypted rooms and events are allowed.
#
@@ -645,6 +643,14 @@
#
#default_room_acl_deny =
# The number of forward extremities to tolerate in a room before
# attempting to manually squash them with a "dummy event". Setting this
# above 20 will hinder its efficacy, and setting it below 5 will cause
# more dummy events to be sent than necessary (which increases federation
# traffic).
#
#dummy_event_threshold = 10
# Enable OpenTelemetry OTLP tracing export. This replaces the deprecated
# Jaeger exporter. Traces will be sent via OTLP to a collector (such as
# Jaeger) that supports the OpenTelemetry Protocol.
@@ -1795,11 +1801,9 @@
#stream_amplification = 1024
# Number of sender task workers; determines sender parallelism. Default is
# '0' which means the value is determined internally, likely matching the
# number of tokio worker-threads or number of cores, etc. Override by
# setting a non-zero value.
# core count. Override by setting a different value.
#
#sender_workers = 0
#sender_workers = core count
# Enables listener sockets; can be set to false to disable listening. This
# option is intended for developer/diagnostic purposes only.
@@ -1987,3 +1991,33 @@
# `require_email_for_registration`.
#
#require_email_for_token_registration = false
#[global.registration_terms]
# The language code to provide to clients along with the policy documents.
#
#language = "en"
# Policy documents, such as terms and conditions or a privacy policy,
# which users must agree to when registering an account.
#
# Example:
# ```ignore
# [global.registration_terms.documents]
# privacy_policy = { name = "Privacy Policy", url = "https://homeserver.example/en/privacy_policy.html" }
# ```
#
#documents =
#[global.oauth]
# The compatibility mode to use for OAuth.
#
# - "disabled": OAuth will be unavailable. Users will only be able to log
# in using legacy authentication.
# - "hybrid": OAuth and legacy authentication will both be available. Some
# clients may only use one or the other.
# - "exclusive": Only OAuth will be available. Clients which require
# legacy authentication will be unable to log in.
#
#compatibility_mode = "hybrid"
+1 -1
View File
@@ -16,7 +16,7 @@
};
#[derive(Debug, Parser)]
#[command(name = conduwuit_core::name(), version = conduwuit_core::version())]
#[command(name = conduwuit_core::BRANDING, version = conduwuit_core::version())]
pub enum AdminCommand {
#[command(subcommand)]
/// Commands for managing appservices
+51 -1
View File
@@ -31,7 +31,7 @@
};
use tracing_subscriber::EnvFilter;
use crate::admin_command;
use crate::{PAGE_SIZE, admin_command};
#[admin_command]
pub(super) async fn echo(&self, message: Vec<String>) -> Result {
@@ -1176,3 +1176,53 @@ pub(super) async fn send_test_email(&self) -> Result {
Ok(())
}
#[admin_command]
pub(super) async fn rooms_by_extremity_count(&self, page: Option<usize>) -> Result {
let page = page.unwrap_or(1);
// My Giant Chain:tm:
let mapped: HashMap<OwnedRoomId, u64> = self
.services
.rooms
.state
.all_forward_extremities()
.ready_fold(HashMap::new(), move |mut map, (room_id, _)| {
let count: u64 = map.get(&room_id).copied().unwrap_or(0);
map.insert(room_id, count.saturating_add(1));
map
})
.await
.into_iter()
.filter_map(|(room_id, count)| (count >= 2).then_some((room_id, count)))
.collect();
if mapped.is_empty() {
return Err!("No more rooms.");
}
let mut rooms = mapped.keys().collect::<Vec<_>>();
rooms.sort_by_key(|room_id| {
mapped
.get(*room_id)
.copied()
.expect("keys must have values")
});
rooms.reverse();
let body = rooms
.into_iter()
.stream()
.skip(page.saturating_sub(1).saturating_mul(PAGE_SIZE))
.take(PAGE_SIZE)
.map(|room_id| {
format!("{room_id}: {}", mapped.get(room_id).copied().expect("keys must have values"))
})
.collect::<Vec<_>>()
.await;
self.write_str(&format!(
"Rooms by extremity count ({}):\n```\n{}\n```",
body.len(),
body.join("\n")
))
.await
}
+5
View File
@@ -245,6 +245,11 @@ pub enum DebugCommand {
/// Send a test email to the invoking admin's email address
SendTestEmail,
/// Lists room IDs by forward extremity count in descending order
RoomsByExtremityCount {
page: Option<usize>,
},
/// Developer test stubs
#[command(subcommand)]
#[allow(non_snake_case)]
+25 -2
View File
@@ -30,14 +30,37 @@ pub(super) async fn issue_token(&self, expires: super::TokenExpires) -> Result {
.issue_token(self.sender_or_service_user().into(), expires);
self.write_str(&format!(
"New registration token issued: `{token}`. {}.",
"New registration token issued: `{token}` . {}.",
if let Some(expires) = info.expires {
format!("{expires}")
} else {
"Never expires".to_owned()
}
))
.await
.await?;
if self
.services
.config
.oauth
.compatibility_mode
.oauth_available()
{
self.write_str(&format!(
"\nInvite link using this token: {}",
self.services
.config
.get_client_domain()
.join(&format!(
"{}/account/register/?flow=trusted&token={token}",
conduwuit::ROUTE_PREFIX
))
.unwrap()
))
.await?;
}
Ok(())
}
#[admin_command]
+13 -149
View File
@@ -1,13 +1,10 @@
use std::{
collections::{BTreeMap, HashSet},
fmt::Write as _,
};
use std::collections::{BTreeMap, HashSet};
use api::client::{
full_user_deactivate, leave_room, recreate_push_rules_and_return, remote_leave_room,
};
use conduwuit::{
Err, Result, debug_warn, error, info,
Err, Result, debug_warn, info,
matrix::{Event, pdu::PartialPdu},
utils::{self, ReadyExt},
warn,
@@ -53,130 +50,22 @@ pub(super) async fn list_users(&self) -> Result {
#[admin_command]
pub(super) async fn create_user(&self, username: String, password: Option<String>) -> Result {
// Validate user id
let user_id = parse_local_user_id(self.services, &username)?;
if let Err(e) = user_id.validate_strict() {
if self.services.config.emergency_password.is_none() {
return Err!("Username {user_id} contains disallowed characters or spaces: {e}");
}
}
if self.services.users.exists(&user_id).await {
return Err!("User {user_id} already exists");
}
let password = password.unwrap_or_else(|| utils::random_string(AUTO_GEN_PASSWORD_LENGTH));
// Create user
self.services
.users
.create(&user_id, Some(HashedPassword::new(&password)?))
.await?;
// Default to pretty displayname
let mut displayname = user_id.localpart().to_owned();
// If `new_user_displayname_suffix` is set, registration will push whatever
// content is set to the user's display name with a space before it
if !self
let user_id = self
.services
.server
.config
.new_user_displayname_suffix
.is_empty()
{
write!(displayname, " {}", self.services.server.config.new_user_displayname_suffix)?;
}
.users
.determine_registration_user_id(Some(username), None, None)
.await?;
let password = HashedPassword::new(
&password.unwrap_or_else(|| utils::random_string(AUTO_GEN_PASSWORD_LENGTH)),
)?;
self.services
.users
.set_displayname(&user_id, Some(displayname));
.create_local_account(&user_id, password, None)
.await;
// Initial account data
self.services
.account_data
.update(
None,
&user_id,
ruma::events::GlobalAccountDataEventType::PushRules
.to_string()
.into(),
&serde_json::to_value(ruma::events::push_rules::PushRulesEvent::new(
ruma::events::push_rules::PushRulesEventContent::new(
ruma::push::Ruleset::server_default(&user_id),
),
))
.unwrap(),
)
.await?;
if !self.services.server.config.auto_join_rooms.is_empty() {
for room in &self.services.server.config.auto_join_rooms {
let Ok(room_id) = self.services.rooms.alias.resolve(room).await else {
error!(
%user_id,
"Failed to resolve room alias to room ID when attempting to auto join {room}, skipping"
);
continue;
};
if !self
.services
.rooms
.state_cache
.server_in_room(self.services.globals.server_name(), &room_id)
.await
{
warn!(
"Skipping room {room} to automatically join as we have never joined before."
);
continue;
}
if let Some(room_server_name) = room.server_name() {
match self
.services
.rooms
.membership
.join_room(
&user_id,
&room_id,
Some("Automatically joining this room upon registration".to_owned()),
&[
self.services.globals.server_name().to_owned(),
room_server_name.to_owned(),
],
)
.await
{
| Ok(_response) => {
info!("Automatically joined room {room} for user {user_id}");
},
| Err(e) => {
// don't return this error so we don't fail registrations
error!(
"Failed to automatically join room {room} for user {user_id}: {e}"
);
self.services
.admin
.send_text(&format!(
"Failed to automatically join room {room} for user {user_id}: \
{e}"
))
.await;
},
}
}
}
}
// we dont add a device since we're not the user, just the creator
// Make the first user to register an administrator and disable first-run mode.
self.services.firstrun.empower_first_user(&user_id).await?;
self.write_str(&format!("Created user with user_id: {user_id} and password: `{password}`"))
.await
self.write_str(&format!("Created user {user_id}")).await
}
#[admin_command]
@@ -302,31 +191,6 @@ pub(super) async fn reset_password(
Ok(())
}
#[admin_command]
pub(super) async fn issue_password_reset_link(&self, username: String) -> Result {
use conduwuit_service::password_reset::{PASSWORD_RESET_PATH, RESET_TOKEN_QUERY_PARAM};
self.bail_restricted()?;
let mut reset_url = self
.services
.config
.get_client_domain()
.join(PASSWORD_RESET_PATH)
.unwrap();
let user_id = parse_local_user_id(self.services, &username)?;
let token = self.services.password_reset.issue_token(user_id).await?;
reset_url
.query_pairs_mut()
.append_pair(RESET_TOKEN_QUERY_PARAM, &token.token);
self.write_str(&format!("Password reset link issued for {username}: {reset_url}"))
.await?;
Ok(())
}
#[admin_command]
pub(super) async fn deactivate_all(&self, no_leave_rooms: bool, force: bool) -> Result {
if self.body.len() < 2
-6
View File
@@ -29,12 +29,6 @@ pub enum UserCommand {
password: Option<String>,
},
/// Issue a self-service password reset link for a user.
IssuePasswordResetLink {
/// Username of the user who may use the link
username: String,
},
/// Get a user's associated email address.
GetEmail {
user_id: String,
+28 -45
View File
@@ -24,7 +24,7 @@
power_levels::RoomPowerLevelsEventContent,
},
};
use service::{mailer::messages, uiaa::Identity, users::HashedPassword};
use service::{mailer::messages, uiaa::UiaaInitiator, users::HashedPassword};
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
use crate::{Ruma, router::ClientIdentity};
@@ -49,39 +49,16 @@ pub(crate) async fn get_register_available_route(
ClientIp(client): ClientIp,
body: Ruma<get_username_availability::v3::Request>,
) -> Result<get_username_availability::v3::Response> {
// Validate user id
let user_id =
match UserId::parse_with_server_name(&body.username, services.globals.server_name()) {
| Ok(user_id) => {
if let Err(e) = user_id.validate_strict() {
return Err!(Request(InvalidUsername(debug_warn!(
"Username {} contains disallowed characters or spaces: {e}",
body.username
))));
}
user_id
},
| Err(e) => {
return Err!(Request(InvalidUsername(debug_warn!(
"Username {} is not valid: {e}",
body.username
))));
},
};
// Check if username is creative enough
if services.users.exists(&user_id).await {
return Err!(Request(UserInUse("User ID is not available.")));
}
if let Some(ClientIdentity::Appservice { appservice_info, .. }) = &body.identity
&& !appservice_info.is_user_match(&user_id)
{
return Err!(Request(Exclusive("Username is not in an appservice namespace.")));
} else if services.appservice.is_exclusive_user_id(&user_id).await {
return Err!(Request(Exclusive("Username is reserved by an appservice.")));
}
let _ = services
.users
.determine_registration_user_id(
Some(body.username.clone()),
None,
body.identity
.as_ref()
.and_then(ClientIdentity::appservice_info),
)
.await?;
Ok(get_username_availability::v3::Response::new(true))
}
@@ -109,12 +86,7 @@ pub(crate) async fn change_password_route(
ClientIp(client): ClientIp,
body: Ruma<change_password::v3::Request>,
) -> Result<change_password::v3::Response> {
let identity = if let Some(user_id) = body
.identity
.as_ref()
.map(ClientIdentity::expect_sender_user)
.transpose()?
{
let identity = if let Some(identity) = body.identity.as_ref() {
// A signed-in user is trying to change their password, prompt them for their
// existing one
@@ -124,7 +96,10 @@ pub(crate) async fn change_password_route(
&body.auth,
vec![AuthFlow::new(vec![AuthType::Password])],
Box::default(),
Some(Identity::from_user_id(user_id)),
Some(UiaaInitiator::new(
identity.expect_sender_user()?,
identity.sender_device(),
)),
)
.await?
} else {
@@ -280,16 +255,24 @@ pub(crate) async fn deactivate_route(
) -> Result<deactivate::v3::Response> {
// Authentication for this endpoint is technically optional,
// but we require the user to be logged in
let sender_user = body
let identity = body
.identity
.as_ref()
.map(ClientIdentity::expect_sender_user)
.ok_or_else(|| err!(Request(MissingToken("Missing access token."))))??;
.ok_or_else(|| err!(Request(MissingToken("Missing access token."))))?;
let sender_user = identity.expect_sender_user()?;
if !services.config.allow_deactivation {
return Err!(Request(Unauthorized(
"You may not deactivate your own account. Contact your server's administrator for \
assistance."
)));
}
// Prompt the user to confirm with their password using UIAA
let _ = services
.uiaa
.authenticate_password(&body.auth, Some(Identity::from_user_id(sender_user)))
.authenticate_password(&body.auth, sender_user, identity.sender_device(), None)
.await?;
// Remove profile pictures and display name
+56 -292
View File
@@ -1,17 +1,15 @@
use std::{collections::HashMap, fmt::Write};
use std::collections::HashMap;
use axum::extract::State;
use axum_client_ip::ClientIp;
use conduwuit::{
Err, Result, debug_info, error, info,
Err, Result, debug_info, info,
utils::{self},
warn,
};
use conduwuit_service::Services;
use futures::{FutureExt, StreamExt};
use futures::StreamExt;
use lettre::{Address, message::Mailbox};
use ruma::{
OwnedUserId, UserId,
api::client::{
account::{
register::{self, LoginType, RegistrationKind},
@@ -20,11 +18,6 @@
uiaa::{AuthFlow, AuthType},
},
assign,
events::{
GlobalAccountDataEventType, push_rules::PushRulesEvent,
room::message::RoomMessageEventContent,
},
push,
};
use serde_json::value::RawValue;
use service::{mailer::messages, users::HashedPassword};
@@ -32,8 +25,6 @@
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
use crate::Ruma;
const RANDOM_USER_ID_LENGTH: usize = 10;
/// # `POST /_matrix/client/v3/register`
///
/// Register an account on this homeserver.
@@ -52,8 +43,6 @@ pub(crate) async fn register_route(
return Err!(Request(GuestAccessForbidden("Guests may not register on this server.")));
}
let emergency_mode_enabled = services.config.emergency_password.is_some();
// Allow registration if it's enabled in the config file or if this is the first
// run (so the first user account can be created)
let allow_registration =
@@ -71,99 +60,59 @@ pub(crate) async fn register_route(
)));
}
let identity = if body.identity.is_some() {
// Appservices can skip auth
None
let user_id = if body.body.login_type == Some(LoginType::ApplicationService) {
let Some(appservice_info) = &body.identity else {
return Err!(Request(Forbidden(
"Only appservices can use the appservice login type."
)));
};
let user_id = services
.users
.determine_registration_user_id(body.username.clone(), None, Some(appservice_info))
.await?;
services.users.create(&user_id, None).await?;
user_id
} else {
// Perform UIAA to determine the user's identity
let (flows, params) = create_registration_uiaa_session(&services).await?;
Some(
services
.uiaa
.authenticate(&body.auth, flows, params, None)
.await?,
)
};
// If the user didn't supply a username but did supply an email, use
// the email's user as their initial localpart to avoid falling back to
// a randomly generated localpart
let supplied_username = body.username.clone().or_else(|| {
if let Some(identity) = &identity
&& let Some(email) = &identity.email
{
Some(email.user().to_owned())
} else {
None
}
});
let user_id =
determine_registration_user_id(&services, supplied_username, emergency_mode_enabled)
let identity = services
.uiaa
.authenticate(&body.auth, flows, params, None)
.await?;
if body.body.login_type == Some(LoginType::ApplicationService) {
// For appservice logins, make sure that the user ID is in the appservice's
// namespace
let password = if let Some(password) = &body.password {
HashedPassword::new(password)?
} else {
return Err!(Request(InvalidParam("A password must be provided.")));
};
match body.identity {
| Some(ref info) =>
if !info.is_user_match(&user_id) && !emergency_mode_enabled {
return Err!(Request(Exclusive(
"Username is not in an appservice namespace."
)));
},
| _ => {
return Err!(Request(MissingToken("Missing appservice token.")));
},
}
} else if services.appservice.is_exclusive_user_id(&user_id).await && !emergency_mode_enabled
{
// For non-appservice logins, ban user IDs which are in an appservice's
// namespace (unless emergency mode is enabled)
return Err!(Request(Exclusive("Username is reserved by an appservice.")));
}
let user_id = services
.users
.determine_registration_user_id(body.username.clone(), identity.email.as_ref(), None)
.await?;
let password = if body.identity.is_some() {
None
} else if let Some(password) = body.password.as_deref() {
Some(HashedPassword::new(password)?)
} else {
return Err!(Request(InvalidParam("A password must be provided")));
services
.users
.create_local_account(&user_id, password, identity.email)
.await;
user_id
};
// Create user
services.users.create(&user_id, password).await?;
// Set an initial display name
let mut displayname = user_id.localpart().to_owned();
// Apply the new user displayname suffix, if it's set
if !services.globals.new_user_displayname_suffix().is_empty() && body.identity.is_none() {
write!(displayname, " {}", services.server.config.new_user_displayname_suffix)?;
}
services
.users
.set_displayname(&user_id, Some(displayname.clone()));
// Initial account data
services
.account_data
.update(
None,
&user_id,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(PushRulesEvent::new(
push::Ruleset::server_default(&user_id).into(),
))
.expect("should be able to serialize push rules"),
)
.await?;
// Generate new device id if the user didn't specify one
let (token, device) = if !body.inhibit_login {
// If UIAA is disabled, we can't create a device. In that case only appservices
// can reach this point in the first place, so we return an error for them.
if !services.config.oauth.compatibility_mode.uiaa_available() {
return Err!(Request(AppserviceLoginUnsupported(
"User-interactive appservice registration is not available on this server."
)));
}
// Generate new device id if the user didn't specify one
let device_id = body
.device_id
.clone()
@@ -179,6 +128,7 @@ pub(crate) async fn register_route(
&user_id,
&device_id,
&new_token,
None,
body.initial_device_display_name.clone(),
Some(client.to_string()),
)
@@ -189,118 +139,7 @@ pub(crate) async fn register_route(
(None, None)
};
debug_info!(%user_id, ?device, "User account was created");
// If the user registered with an email, associate it with their account.
if let Some(identity) = identity
&& let Some(email) = identity.email
{
// This may fail if the email is already in use, but we already check for that
// in `/requestToken`, so ignoring the error is acceptable here in the rare case
// that an email is sniped by another user between the `/requestToken` request
// and the `/register` request.
let _ = services
.threepid
.associate_localpart_email(user_id.localpart(), &email)
.await;
}
let device_display_name = body.initial_device_display_name.as_deref().unwrap_or("");
if body.identity.is_none() {
if !device_display_name.is_empty() {
let notice = format!(
"New user \"{user_id}\" registered on this server from IP {client} and device \
display name \"{device_display_name}\""
);
info!("{notice}");
if services.server.config.admin_room_notices {
services.admin.notice(&notice).await;
}
} else {
let notice = format!("New user \"{user_id}\" registered on this server.");
info!("{notice}");
if services.server.config.admin_room_notices {
services.admin.notice(&notice).await;
}
}
}
// Make the first user to register an administrator and disable first-run mode.
let was_first_user = services.firstrun.empower_first_user(&user_id).await?;
// If the registering user was not the first and we're suspending users on
// register, suspend them.
if !was_first_user && services.config.suspend_on_register {
// Note that we can still do auto joins for suspended users
services
.users
.suspend_account(&user_id, &services.globals.server_user)
.await;
// And send an @room notice to the admin room, to prompt admins to review the
// new user and ideally unsuspend them if deemed appropriate.
if services.server.config.admin_room_notices {
services
.admin
.send_loud_message(RoomMessageEventContent::text_plain(format!(
"User {user_id} has been suspended as they are not the first user on this \
server. Please review and unsuspend them if appropriate."
)))
.await
.ok();
}
}
if body.identity.is_none() && !services.server.config.auto_join_rooms.is_empty() {
for room in &services.server.config.auto_join_rooms {
let Ok(room_id) = services.rooms.alias.resolve(room).await else {
error!(
"Failed to resolve room alias to room ID when attempting to auto join \
{room}, skipping"
);
continue;
};
if !services
.rooms
.state_cache
.server_in_room(services.globals.server_name(), &room_id)
.await
{
warn!(
"Skipping room {room} to automatically join as we have never joined before."
);
continue;
}
if let Some(room_server_name) = room.server_name() {
match services
.rooms
.membership
.join_room(
&user_id,
&room_id,
Some("Automatically joining this room upon registration".to_owned()),
&[services.globals.server_name().to_owned(), room_server_name.to_owned()],
)
.boxed()
.await
{
| Err(e) => {
// don't return this error so we don't fail registrations
error!(
"Failed to automatically join room {room} for user {user_id}: {e}"
);
},
| _ => {
info!("Automatically joined room {room} for user {user_id}");
},
}
}
}
}
debug_info!(%user_id, ?device, "New account created via legacy registration");
Ok(assign!(register::v3::Response::new(user_id), {
access_token: token,
@@ -372,21 +211,21 @@ async fn create_registration_uiaa_session(
// Require all users to agree to the terms and conditions, if configured
let terms = &services.config.registration_terms;
if !terms.is_empty() {
let mut terms =
serde_json::to_value(terms.clone()).expect("failed to serialize terms");
if !terms.documents.is_empty() {
let mut terms_map = HashMap::new();
// Insert a dummy `version` field
for (_, documents) in terms.as_object_mut().unwrap() {
let documents = documents.as_object_mut().unwrap();
documents.insert("version".to_owned(), "latest".into());
for (id, document) in &terms.documents {
terms_map.insert(id.to_owned(), serde_json::json!({
terms.language.clone(): serde_json::to_value(document).expect("should be able to serialize document")
}));
}
terms_map.insert("version".to_owned(), "latest".into());
params.insert(
AuthType::Terms.as_str().to_owned(),
serde_json::json!({
"policies": terms,
"policies": terms_map,
}),
);
@@ -419,81 +258,6 @@ async fn create_registration_uiaa_session(
Ok((flows, params))
}
async fn determine_registration_user_id(
services: &Services,
supplied_username: Option<String>,
emergency_mode_enabled: bool,
) -> Result<OwnedUserId> {
if let Some(supplied_username) = supplied_username {
// The user gets to pick their username. Do some validation to make sure it's
// acceptable.
// Don't allow registration with forbidden usernames.
if services
.globals
.forbidden_usernames()
.is_match(&supplied_username)
&& !emergency_mode_enabled
{
return Err!(Request(Forbidden("Username is forbidden")));
}
// Create and validate the user ID
let user_id = match UserId::parse_with_server_name(
&supplied_username,
services.globals.server_name(),
) {
| Ok(user_id) => {
if let Err(e) = user_id.validate_strict() {
// Unless we are in emergency mode, we should follow synapse's behaviour on
// not allowing things like spaces and UTF-8 characters in usernames
if !emergency_mode_enabled {
return Err!(Request(InvalidUsername(debug_warn!(
"Username {supplied_username} contains disallowed characters or \
spaces: {e}"
))));
}
}
// Don't allow registration with user IDs that aren't local
if !services.globals.user_is_local(&user_id) {
return Err!(Request(InvalidUsername(
"Username {supplied_username} is not local to this server"
)));
}
user_id
},
| Err(e) => {
return Err!(Request(InvalidUsername(debug_warn!(
"Username {supplied_username} is not valid: {e}"
))));
},
};
if services.users.exists(&user_id).await {
return Err!(Request(UserInUse("User ID is not available.")));
}
Ok(user_id)
} else {
// The user didn't specify a username. Generate a username for
// them.
loop {
let user_id = UserId::parse_with_server_name(
utils::random_string(RANDOM_USER_ID_LENGTH).to_lowercase(),
services.globals.server_name(),
)
.unwrap();
if !services.users.exists(&user_id).await {
break Ok(user_id);
}
}
}
}
/// # `POST /_matrix/client/v3/register/email/requestToken`
///
/// Requests a validation email for the purpose of registering a new account.
+7 -4
View File
@@ -11,7 +11,7 @@
},
thirdparty::{Medium, ThirdPartyIdentifierInit},
};
use service::{mailer::messages, uiaa::Identity};
use service::mailer::messages;
use crate::{Ruma, router::ClientIdentity};
@@ -124,15 +124,18 @@ pub(crate) async fn add_3pid_route(
.uiaa
.authenticate_password(
&body.auth,
Some(Identity::from_user_id(body.identity.expect_sender_user()?)),
body.identity.expect_sender_user()?,
body.identity.sender_device(),
None,
)
.await?;
let email = services
.threepid
.consume_valid_session(&body.sid, &body.client_secret)
.get_valid_session(&body.sid, &body.client_secret)
.await
.map_err(|message| err!(Request(ThreepidAuthFailed("{message}"))))?;
.map_err(|message| err!(Request(ThreepidAuthFailed("{message}"))))?
.consume();
services
.threepid
+90
View File
@@ -0,0 +1,90 @@
use axum::extract::State;
use conduwuit::{Err, Result};
use futures::future::{join, join3};
use ruma::api::client::admin::{is_user_locked, lock_user};
use crate::Ruma;
/// # `GET /_matrix/client/v1/admin/lock/{userId}`
///
/// Check the account lock status of a target user
pub(crate) async fn get_lock_status(
State(services): State<crate::State>,
body: Ruma<is_user_locked::v1::Request>,
) -> Result<is_user_locked::v1::Response> {
let (admin, active) = join(
services.users.is_admin(body.identity.expect_sender_user()?),
services.users.is_active(&body.user_id),
)
.await;
if !admin {
return Err!(Request(Forbidden("Only server administrators can use this endpoint")));
}
if !services.globals.user_is_local(&body.user_id) {
return Err!(Request(InvalidParam("Can only check the lock status of local users")));
}
if !active {
return Err!(Request(NotFound("Unknown user")));
}
Ok(is_user_locked::v1::Response::new(
services.users.is_locked(&body.user_id).await?,
))
}
/// # `PUT /_matrix/client/v1/admin/lock/{userId}`
///
/// Set the account lock status of a target user
pub(crate) async fn put_lock_status(
State(services): State<crate::State>,
body: Ruma<lock_user::v1::Request>,
) -> Result<lock_user::v1::Response> {
let sender_user = body.identity.expect_sender_user()?;
let (sender_admin, active, target_admin) = join3(
services.users.is_admin(sender_user),
services.users.is_active(&body.user_id),
services.users.is_admin(&body.user_id),
)
.await;
if !sender_admin {
return Err!(Request(Forbidden("Only server administrators can use this endpoint")));
}
if !services.globals.user_is_local(&body.user_id) {
return Err!(Request(InvalidParam("Can only set the lock status of local users")));
}
if !active {
return Err!(Request(NotFound("Unknown user")));
}
if body.user_id == *sender_user {
return Err!(Request(Forbidden("You cannot lock yourself")));
}
if target_admin {
return Err!(Request(Forbidden("You cannot lock another server administrator")));
}
if services.users.is_locked(&body.user_id).await? == body.locked {
// No change
return Ok(lock_user::v1::Response::new(body.locked));
}
let action = if body.locked {
services
.users
.suspend_account(&body.user_id, sender_user)
.await;
"locked"
} else {
services.users.unsuspend_account(&body.user_id).await;
"unlocked"
};
if services.config.admin_room_notices {
// Notify the admin room that an account has been un/suspended
services
.admin
.send_text(&format!("{} has been {} by {}.", body.user_id, action, sender_user))
.await;
}
Ok(lock_user::v1::Response::new(body.locked))
}
+2 -1
View File
@@ -1,3 +1,4 @@
mod lock;
mod suspend;
pub(crate) use self::suspend::*;
pub(crate) use self::{lock::*, suspend::*};
+8 -8
View File
@@ -1,7 +1,7 @@
use axum::extract::State;
use conduwuit::{Err, Result};
use futures::future::{join, join3};
use ruminuwuity::admin::{get_suspended, set_suspended};
use ruma::api::client::admin::{is_user_suspended, suspend_user};
use crate::Ruma;
@@ -10,8 +10,8 @@
/// Check the suspension status of a target user
pub(crate) async fn get_suspended_status(
State(services): State<crate::State>,
body: Ruma<get_suspended::v1::Request>,
) -> Result<get_suspended::v1::Response> {
body: Ruma<is_user_suspended::v1::Request>,
) -> Result<is_user_suspended::v1::Response> {
let (admin, active) = join(
services.users.is_admin(body.identity.expect_sender_user()?),
services.users.is_active(&body.user_id),
@@ -26,7 +26,7 @@ pub(crate) async fn get_suspended_status(
if !active {
return Err!(Request(NotFound("Unknown user")));
}
Ok(get_suspended::v1::Response::new(
Ok(is_user_suspended::v1::Response::new(
services.users.is_suspended(&body.user_id).await?,
))
}
@@ -36,8 +36,8 @@ pub(crate) async fn get_suspended_status(
/// Set the suspension status of a target user
pub(crate) async fn put_suspended_status(
State(services): State<crate::State>,
body: Ruma<set_suspended::v1::Request>,
) -> Result<set_suspended::v1::Response> {
body: Ruma<suspend_user::v1::Request>,
) -> Result<suspend_user::v1::Response> {
let sender_user = body.identity.expect_sender_user()?;
let (sender_admin, active, target_admin) = join3(
@@ -64,7 +64,7 @@ pub(crate) async fn put_suspended_status(
}
if services.users.is_suspended(&body.user_id).await? == body.suspended {
// No change
return Ok(set_suspended::v1::Response::new(body.suspended));
return Ok(suspend_user::v1::Response::new(body.suspended));
}
let action = if body.suspended {
@@ -86,5 +86,5 @@ pub(crate) async fn put_suspended_status(
.await;
}
Ok(set_suspended::v1::Response::new(body.suspended))
Ok(suspend_user::v1::Response::new(body.suspended))
}
+3 -10
View File
@@ -12,7 +12,6 @@
},
},
};
use serde_json::json;
use crate::Ruma;
@@ -40,21 +39,15 @@ pub(crate) async fn get_capabilities_route(
capabilities.get_login_token =
GetLoginTokenCapability::new(services.server.config.login_via_existing_session);
// MSC4133 capability
capabilities.set("uk.tcpip.msc4133.profile_fields", json!({"enabled": true}))?;
capabilities.set(
"org.matrix.msc4267.forget_forced_upon_leave",
json!({"enabled": services.config.forget_forced_upon_leave}),
)?;
capabilities.forget_forced_upon_leave.enabled = true;
if services
.users
.is_admin(body.identity.expect_sender_user()?)
.await
{
// Advertise suspension API
capabilities.set("uk.timedout.msc4323", json!({"suspend": true, "lock": false}))?;
capabilities.account_moderation.lock = true;
capabilities.account_moderation.suspend = true;
}
Ok(get_capabilities::v3::Response::new(capabilities))
+5 -7
View File
@@ -8,7 +8,6 @@
self, delete_device, delete_devices, get_device, get_devices, update_device,
},
};
use service::uiaa::Identity;
use crate::{Ruma, client::DEVICE_ID_LENGTH};
@@ -95,6 +94,7 @@ pub(crate) async fn update_device_route(
&device_id,
&appservice.registration.as_token,
None,
None,
Some(client.to_string()),
)
.await?;
@@ -119,14 +119,13 @@ pub(crate) async fn delete_device_route(
body: Ruma<delete_device::v3::Request>,
) -> Result<delete_device::v3::Response> {
let sender_user = body.identity.expect_sender_user()?;
let appservice = body.identity.appservice_info();
// Appservices get to skip UIAA for this endpoint
if appservice.is_none() {
if let Some(sender_device) = body.identity.sender_device() {
// Prompt the user to confirm with their password using UIAA
let _ = services
.uiaa
.authenticate_password(&body.auth, Some(Identity::from_user_id(sender_user)))
.authenticate_password(&body.auth, sender_user, Some(sender_device), None)
.await?;
}
@@ -155,14 +154,13 @@ pub(crate) async fn delete_devices_route(
body: Ruma<delete_devices::v3::Request>,
) -> Result<delete_devices::v3::Response> {
let sender_user = body.identity.expect_sender_user()?;
let appservice = body.identity.appservice_info();
// Appservices get to skip UIAA for this endpoint
if appservice.is_none() {
if let Some(sender_device) = body.identity.sender_device() {
// Prompt the user to confirm with their password using UIAA
let _ = services
.uiaa
.authenticate_password(&body.auth, Some(Identity::from_user_id(sender_user)))
.authenticate_password(&body.auth, sender_user, Some(sender_device), None)
.await?;
}
+7 -2
View File
@@ -26,7 +26,7 @@
serde::Raw,
};
use serde_json::json;
use service::uiaa::Identity;
use service::oauth::OAuthTicket;
use crate::Ruma;
@@ -205,7 +205,12 @@ pub(crate) async fn upload_signing_keys_route(
{
let _ = services
.uiaa
.authenticate_password(&body.auth, Some(Identity::from_user_id(sender_user)))
.authenticate_password(
&body.auth,
sender_user,
body.identity.sender_device(),
Some(OAuthTicket::CrossSigningReset),
)
.await?;
}
+3
View File
@@ -16,6 +16,7 @@
pub(super) mod membership;
pub(super) mod message;
pub(super) mod mutual_rooms;
pub(super) mod oauth;
pub(super) mod openid;
pub(super) mod presence;
pub(super) mod profile;
@@ -61,6 +62,7 @@
pub use membership::{leave_all_rooms, leave_room, remote_leave_room};
pub(super) use message::*;
pub(super) use mutual_rooms::*;
pub(super) use oauth::*;
pub(super) use openid::*;
pub(super) use presence::*;
pub(super) use profile::*;
@@ -73,6 +75,7 @@
pub(super) use room::*;
pub(super) use search::*;
pub(super) use send::*;
pub use session::handle_login;
pub(super) use session::*;
pub(super) use space::*;
pub(super) use state::*;
+56
View File
@@ -0,0 +1,56 @@
use axum::{
Json, Router,
extract::{Request, State},
middleware::{self, Next},
response::{IntoResponse, Response},
routing::method_routing::{get, post},
};
use const_str::concat;
use http::StatusCode;
use serde_json::json;
pub(crate) use server_metadata::*;
mod register_client;
mod server_metadata;
mod token;
const BASE_PATH: &str = concat!(conduwuit_core::ROUTE_PREFIX, "/oauth2/");
const AUTH_CODE_PATH: &str = "grant/authorization_code";
const JWKS_URI_PATH: &str = "client/keys.json";
const CLIENT_REGISTER_PATH: &str = "client/register";
const TOKEN_REVOKE_PATH: &str = "client/revoke";
const TOKEN_PATH: &str = "grant/token";
const ACCOUNT_MANAGEMENT_PATH: &str = concat!(conduwuit_core::ROUTE_PREFIX, "/account/deeplink");
pub(crate) fn router(state: crate::State) -> Router<crate::State> {
Router::new()
.nest(BASE_PATH, oauth_router())
.route(
"/.well-known/openid-configuration",
get(
// TODO(unspecced): used by old versions of the matrix-js-sdk
async |State(services): State<crate::State>| {
Json(authorization_server_metadata(&services).await)
},
),
)
.layer(middleware::from_fn_with_state(
state,
async |State(state): State<crate::State>, request: Request, next: Next| -> Response {
if state.config.oauth.compatibility_mode.oauth_available() {
next.run(request).await
} else {
(StatusCode::NOT_FOUND, "OAuth is unavailable on this server").into_response()
}
},
))
}
fn oauth_router() -> Router<crate::State> {
Router::new()
.route(concat!("/", CLIENT_REGISTER_PATH), post(register_client::register_client_route))
// TODO(unspecced): used by old versions of the matrix-js-sdk
.route(concat!("/", JWKS_URI_PATH), get(async || Json(json!({"keys": []}))))
.route(concat!("/", TOKEN_PATH), post(token::token_route))
.route(concat!("/", TOKEN_REVOKE_PATH), post(token::revoke_token_route))
}
+28
View File
@@ -0,0 +1,28 @@
use axum::{
Json,
extract::State,
response::{IntoResponse, Response},
};
use http::StatusCode;
use serde::Serialize;
use service::oauth::client_metadata::ClientMetadata;
#[derive(Serialize)]
struct RegisteredClient {
client_id: String,
#[serde(flatten)]
metadata: ClientMetadata,
}
pub(crate) async fn register_client_route(
State(services): State<crate::State>,
Json(metadata): Json<ClientMetadata>,
) -> Result<Response, Response> {
let client_id = services
.oauth
.register_client(&metadata)
.await
.map_err(|err| (StatusCode::BAD_REQUEST, Json(err)).into_response())?;
Ok(Json(RegisteredClient { client_id, metadata }).into_response())
}
+62
View File
@@ -0,0 +1,62 @@
use axum::extract::State;
use conduwuit::{Err, Result};
use ruma::{
api::client::discovery::get_authorization_server_metadata::{
self, v1::AccountManagementAction,
},
serde::Raw,
};
use serde_json::{Value, json};
use service::Services;
use crate::{
Ruma,
client::oauth::{
ACCOUNT_MANAGEMENT_PATH, AUTH_CODE_PATH, CLIENT_REGISTER_PATH, JWKS_URI_PATH, TOKEN_PATH,
TOKEN_REVOKE_PATH,
},
};
pub(crate) async fn get_authorization_server_metadata_route(
State(services): State<crate::State>,
_body: Ruma<get_authorization_server_metadata::v1::Request>,
) -> Result<get_authorization_server_metadata::v1::Response> {
if !services.config.oauth.compatibility_mode.oauth_available() {
return Err!(Request(Unrecognized("OAuth is unavailable on this server")));
}
let metadata = Raw::new(&authorization_server_metadata(&services).await).unwrap();
Ok(get_authorization_server_metadata::v1::Response::new(metadata.cast_unchecked()))
}
pub(crate) async fn authorization_server_metadata(services: &Services) -> Value {
let endpoint_base = services
.config
.get_client_domain()
.join(super::BASE_PATH)
.unwrap();
json!({
"account_management_uri": endpoint_base.join(ACCOUNT_MANAGEMENT_PATH).unwrap(),
"account_management_actions_supported": [
AccountManagementAction::AccountDeactivate,
AccountManagementAction::CrossSigningReset,
AccountManagementAction::DeviceDelete,
AccountManagementAction::DeviceView,
AccountManagementAction::DevicesList,
AccountManagementAction::Profile,
],
"authorization_endpoint": endpoint_base.join(AUTH_CODE_PATH).unwrap(),
"code_challenge_methods_supported": ["S256"],
"grant_types_supported": ["authorization_code", "refresh_token"],
"issuer": services.config.get_client_domain(),
"jwks_uri": endpoint_base.join(JWKS_URI_PATH).unwrap(),
"prompt_values_supported": ["create"],
"registration_endpoint": endpoint_base.join(CLIENT_REGISTER_PATH).unwrap(),
"response_modes_supported": ["query", "fragment"],
"response_types_supported": ["code"],
"revocation_endpoint": endpoint_base.join(TOKEN_REVOKE_PATH).unwrap(),
"token_endpoint": endpoint_base.join(TOKEN_PATH).unwrap(),
})
}
+23
View File
@@ -0,0 +1,23 @@
use axum::{Form, Json, extract::State, response::IntoResponse};
use http::StatusCode;
use service::oauth::grant::{RevokeTokenRequest, TokenRequest};
pub(crate) async fn token_route(
State(services): State<crate::State>,
Form(request): Form<TokenRequest>,
) -> impl IntoResponse {
match services.oauth.issue_token(request).await {
| Ok(response) => Ok(Json(response)),
| Err(err) => Err((StatusCode::BAD_REQUEST, Json(err))),
}
}
pub(crate) async fn revoke_token_route(
State(services): State<crate::State>,
Form(request): Form<RevokeTokenRequest>,
) -> impl IntoResponse {
match services.oauth.revoke_token(request.token).await {
| Ok(()) => Ok(StatusCode::OK),
| Err(err) => Err((StatusCode::BAD_REQUEST, Json(err))),
}
}
+31 -24
View File
@@ -21,7 +21,7 @@
},
login::{
self,
v3::{DiscoveryInfo, HomeserverInfo},
v3::{DiscoveryInfo, HomeserverInfo, LoginInfo},
},
logout, logout_all,
},
@@ -29,7 +29,6 @@
},
assign,
};
use service::uiaa::Identity;
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
use crate::Ruma;
@@ -44,6 +43,12 @@ pub(crate) async fn get_login_types_route(
ClientIp(client): ClientIp,
_body: Ruma<get_login_types::v3::Request>,
) -> Result<get_login_types::v3::Response> {
if !services.config.oauth.compatibility_mode.uiaa_available() {
return Err!(Request(Unrecognized(
"User-interactive authentication is not available on this server."
)));
}
Ok(get_login_types::v3::Response::new(vec![
get_login_types::v3::LoginType::Password(PasswordLoginType::default()),
get_login_types::v3::LoginType::ApplicationService(ApplicationServiceLoginType::default()),
@@ -53,7 +58,7 @@ pub(crate) async fn get_login_types_route(
]))
}
pub(crate) async fn handle_login(
pub async fn handle_login(
services: &Services,
identifier: Option<&UserIdentifier>,
password: &str,
@@ -87,14 +92,6 @@ pub(crate) async fn handle_login(
return Err!(Request(InvalidParam("User ID does not belong to this homeserver")));
}
if services.users.is_deactivated(&user_id).await? {
return Err!(Request(UserDeactivated("This account has been deactivated.")));
}
if services.users.is_locked(&user_id).await? {
return Err!(Request(UserLocked("This account has been locked.")));
}
if services.users.is_login_disabled(&user_id).await {
warn!(%user_id, "user attempted to log in with a login-disabled account");
return Err!(Request(Forbidden("This account is not permitted to log in.")));
@@ -123,19 +120,29 @@ pub(crate) async fn login_route(
ClientIp(client): ClientIp,
body: Ruma<login::v3::Request>,
) -> Result<login::v3::Response> {
if !services.config.oauth.compatibility_mode.uiaa_available() {
return match body.login_info {
| LoginInfo::ApplicationService(_) => {
Err!(Request(AppserviceLoginUnsupported(
"User-interactive appservice login is not available on this server."
)))
},
| _ => {
Err!(Request(Unrecognized(
"User-interactive authentication is not available on this server."
)))
},
};
}
let emergency_mode_enabled = services.config.emergency_password.is_some();
// Validate login method
// TODO: Other login methods
let user_id = match &body.login_info {
#[allow(deprecated)]
| login::v3::LoginInfo::Password(login::v3::Password {
identifier,
password,
user,
..
}) => handle_login(&services, identifier.as_ref(), password, user.as_ref()).await?,
| login::v3::LoginInfo::Token(login::v3::Token { token, .. }) => {
| LoginInfo::Password(login::v3::Password { identifier, password, user, .. }) =>
handle_login(&services, identifier.as_ref(), password, user.as_ref()).await?,
| LoginInfo::Token(login::v3::Token { token, .. }) => {
debug!("Got token login type");
if !services.server.config.login_via_existing_session {
return Err!(Request(Unknown("Token login is not enabled.")));
@@ -143,7 +150,7 @@ pub(crate) async fn login_route(
services.users.find_from_login_token(token).await?
},
#[allow(deprecated)]
| login::v3::LoginInfo::ApplicationService(login::v3::ApplicationService {
| LoginInfo::ApplicationService(login::v3::ApplicationService {
identifier,
user,
..
@@ -177,7 +184,6 @@ pub(crate) async fn login_route(
user_id
},
| _ => {
debug!("/login json_body: {:?}", &body.json_body);
return Err!(Request(Unknown(
debug_warn!(?body.login_info, "Invalid or unsupported login type")
)));
@@ -207,7 +213,7 @@ pub(crate) async fn login_route(
if device_exists {
services
.users
.set_token(&user_id, &device_id, &token)
.set_token(&user_id, &device_id, &token, None)
.await?;
} else {
services
@@ -216,6 +222,7 @@ pub(crate) async fn login_route(
&user_id,
&device_id,
&token,
None,
body.initial_device_display_name.clone(),
Some(client.to_string()),
)
@@ -254,7 +261,7 @@ pub(crate) async fn login_token_route(
ClientIp(client): ClientIp,
body: Ruma<get_login_token::v1::Request>,
) -> Result<get_login_token::v1::Response> {
if !services.server.config.login_via_existing_session {
if !services.config.login_via_existing_session {
return Err!(Request(Forbidden("Login via an existing session is not enabled")));
}
@@ -263,7 +270,7 @@ pub(crate) async fn login_token_route(
// Prompt the user to confirm with their password using UIAA
let _ = services
.uiaa
.authenticate_password(&body.auth, Some(Identity::from_user_id(sender_user)))
.authenticate_password(&body.auth, sender_user, body.identity.sender_device(), None)
.await?;
let login_token = utils::random_string(TOKEN_LENGTH);
-1
View File
@@ -70,7 +70,6 @@ pub(crate) async fn sync_events_v5_route(
ClientIp(client_ip): ClientIp,
body: Ruma<sync_events::v5::Request>,
) -> Result<sync_events::v5::Response> {
debug_assert!(DEFAULT_BUMP_TYPES.is_sorted(), "DEFAULT_BUMP_TYPES is not sorted");
let sender_user = body.identity.expect_sender_user()?;
let sender_device = body.identity.expect_sender_device()?;
+2 -2
View File
@@ -35,8 +35,8 @@ pub(crate) async fn get_supported_versions_route(
/// `/_matrix/federation/v1/version`
pub(crate) async fn conduwuit_server_version() -> Result<impl IntoResponse> {
Ok(Json(serde_json::json!({
"name": conduwuit::version::name(),
"version": conduwuit::version::version(),
"name": conduwuit::BRANDING,
"version": conduwuit::version(),
})))
}
+2 -42
View File
@@ -3,8 +3,7 @@
use ruma::{
api::client::discovery::{
discover_homeserver::{self, HomeserverInfo},
discover_policy_server,
discover_support::{self, Contact, ContactRole},
discover_policy_server, discover_support,
},
assign,
};
@@ -67,46 +66,7 @@ pub(crate) async fn well_known_support(
.as_ref()
.map(ToString::to_string);
let email_address = services.config.well_known.support_email.clone();
let matrix_id = services.config.well_known.support_mxid.clone();
let pgp_key = services.config.well_known.support_pgp_key.clone();
// TODO: support defining multiple contacts in the config
let mut contacts: Vec<Contact> = vec![];
let role = services
.config
.well_known
.support_role
.clone()
.unwrap_or(ContactRole::Admin);
// Add configured contact if at least one contact method is specified
let configured_contact = match (matrix_id, email_address) {
| (Some(matrix_id), email_address) =>
Some(assign!(Contact::with_matrix_id(role, matrix_id), { email_address })),
| (None, Some(email_address)) => Some(Contact::with_email_address(role, email_address)),
| (None, None) => None,
};
if let Some(mut configured_contact) = configured_contact {
configured_contact.pgp_key = pgp_key;
contacts.push(configured_contact);
}
// Try to add admin users as contacts if no contacts are configured
if contacts.is_empty() {
let admin_users = services.admin.get_admins().await;
for user_id in &admin_users {
if *user_id == services.globals.server_user {
continue;
}
contacts.push(Contact::with_matrix_id(ContactRole::Admin, user_id.to_owned()));
}
}
let contacts = services.admin.get_support_contacts().await;
if contacts.is_empty() && support_page.is_none() {
// No admin room, no configured contacts, and no support page
+1
View File
@@ -1,4 +1,5 @@
#![type_length_limit = "16384"] //TODO: reduce me
#![recursion_limit = "256"] // My Giant Async Function
#![allow(clippy::toplevel_ref_arg)]
extern crate conduwuit_core as conduwuit;
+7 -3
View File
@@ -10,7 +10,7 @@
response::{IntoResponse, Redirect},
routing::{any, get, post},
};
use conduwuit::{Server, err};
use conduwuit::err;
pub(super) use conduwuit_service::state::State;
use http::{Uri, uri};
@@ -18,8 +18,8 @@
pub(super) use self::{args::Args as Ruma, auth::ClientIdentity, response::RumaResponse};
use crate::{admin, client, server};
pub fn build(router: Router<State>, server: &Server) -> Router<State> {
let config = &server.config;
pub fn build(router: Router<State>, state: State) -> Router<State> {
let config = &state.server.config;
let mut router = router
.ruma_route(&client::appservice_ping)
.ruma_route(&client::get_supported_versions_route)
@@ -181,11 +181,15 @@ pub fn build(router: Router<State>, server: &Server) -> Router<State> {
.ruma_route(&client::get_room_summary)
.ruma_route(&client::get_suspended_status)
.ruma_route(&client::put_suspended_status)
.ruma_route(&client::get_lock_status)
.ruma_route(&client::put_lock_status)
.ruma_route(&client::well_known_support)
.ruma_route(&client::well_known_client)
.ruma_route(&client::well_known_policy_server)
.ruma_route(&client::get_rtc_transports)
.ruma_route(&client::room_initial_sync_route)
.ruma_route(&client::get_authorization_server_metadata_route)
.merge(client::oauth::router(state))
.route("/_conduwuit/server_version", get(client::conduwuit_server_version))
.route("/_continuwuity/server_version", get(client::conduwuit_server_version))
.ruma_route(&admin::rooms::ban::ban_room)
+25 -10
View File
@@ -1,6 +1,7 @@
use std::any::{Any, TypeId};
use conduwuit::{Err, Result, err};
use conduwuit::{Err, Error, Result, err};
use http::StatusCode;
use ruma::{
DeviceId, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId,
api::{
@@ -10,12 +11,15 @@
AuthScheme, NoAccessToken, NoAuthentication,
},
client,
error::{ErrorKind, UnknownTokenErrorData},
federation::authentication::ServerSignatures,
},
assign,
};
use service::{
Services,
server_keys::{PubKeyMap, PubKeys},
users::AccessTokenStatus,
};
use crate::{router::args::AuthQueryParams, service::appservice::RegistrationInfo};
@@ -162,10 +166,20 @@ async fn verify<B: AsRef<[u8]> + Sync>(
query: AuthQueryParams,
route: TypeId,
) -> Result<Self::Identity> {
if output.is_empty() {
return Err!(Request(Unauthorized("Missing access token.")));
}
if let Ok((sender_user, sender_device)) = services.users.find_from_token(&output).await {
if let Some((sender_user, sender_device, status)) =
services.users.find_from_token(&output).await
{
// If the token is expired we return a soft logout
if matches!(status, AccessTokenStatus::Expired) {
return Err(Error::Request(
ErrorKind::UnknownToken(
assign!(UnknownTokenErrorData::new(), { soft_logout: true }),
),
"This token has expired".into(),
StatusCode::UNAUTHORIZED,
));
}
// Locked users can only use /logout and /logout/all
if services
.users
@@ -176,7 +190,7 @@ async fn verify<B: AsRef<[u8]> + Sync>(
if !(route == TypeId::of::<client::session::logout::v3::Request>()
|| route == TypeId::of::<client::session::logout_all::v3::Request>())
{
return Err!(Request(Unauthorized("Your account is locked.")));
return Err!(Request(UserLocked("Your account is locked.")));
}
}
@@ -227,7 +241,11 @@ async fn verify<B: AsRef<[u8]> + Sync>(
appservice_info: Box::new(appservice_info),
})
} else {
Err!(Request(Unauthorized("Invalid access token.")))
Err(Error::Request(
ErrorKind::UnknownToken(UnknownTokenErrorData::new()),
"Invalid token".into(),
StatusCode::UNAUTHORIZED,
))
}
}
}
@@ -262,9 +280,6 @@ async fn verify<B: AsRef<[u8]> + Sync>(
_query: AuthQueryParams,
_route: TypeId,
) -> Result<Self::Identity> {
if output.is_empty() {
return Err!(Request(Unauthorized("Missing access token.")));
}
let Ok(appservice_info) = services.appservice.find_from_token(&output).await else {
return Err!(Request(Unauthorized("Invalid appservice token.")));
};
+3 -7
View File
@@ -4,15 +4,11 @@
use conduwuit::{Err, Event, Result, debug, info, trace, utils::to_canonical_object, warn};
use ruma::{OwnedEventId, api::federation::event::get_missing_events};
use serde_json::{json, value::RawValue};
use service::rooms::event_handler::GET_MISSING_EVENTS_MAX_BATCH_SIZE;
use super::AccessCheck;
use crate::Ruma;
/// arbitrary number but synapse's is 20 and we can handle lots of these anyways
const LIMIT_MAX: usize = 50;
/// spec says default is 10
const LIMIT_DEFAULT: usize = 10;
/// # `POST /_matrix/federation/v1/get_missing_events/{roomId}`
///
/// Retrieves events that the sender is missing.
@@ -45,8 +41,8 @@ pub(crate) async fn get_missing_events_route(
let limit = body
.limit
.try_into()
.unwrap_or(LIMIT_DEFAULT)
.min(LIMIT_MAX);
.unwrap_or(10)
.min(GET_MISSING_EVENTS_MAX_BATCH_SIZE);
let room_version = services.rooms.state.get_room_version(&body.room_id).await?;
+11 -3
View File
@@ -7,7 +7,7 @@
use axum::extract::State;
use axum_client_ip::ClientIp;
use conduwuit::{
Err, Error, Result, debug, debug_warn, err, error,
Err, Error, Result, debug, debug_error, debug_warn, err, error,
result::LogErr,
state_res::lexicographical_topological_sort,
trace,
@@ -133,6 +133,7 @@ async fn wait_for_result(
}
#[instrument(
name="transaction"
skip_all,
fields(
id = ?body.transaction_id.as_str(),
@@ -174,8 +175,14 @@ async fn process_inbound_transaction(
for (id, result) in &results {
if let Err(e) = result {
if matches!(e, Error::BadRequest(ErrorKind::NotFound, _)) {
debug_warn!("Incoming PDU failed {id}: {e:?}");
match e {
| Error::BadRequest(
ErrorKind::Forbidden | ErrorKind::InvalidParam | ErrorKind::BadJson,
..,
) => {
debug_warn!("Incoming PDU {id} failed: {e:?}");
},
| _ => debug_error!("Incoming PDU {id} failed: {e:?}"),
}
}
}
@@ -381,6 +388,7 @@ async fn handle_room(
.rooms
.event_handler
.handle_incoming_pdu(origin, room_id, &event_id, value, true)
.boxed()
.await
.map(|_| ());
results.push((event_id, result));
+2 -2
View File
@@ -11,8 +11,8 @@ pub(crate) async fn get_server_version_route(
) -> Result<get_server_version::v1::Response> {
Ok(assign!(get_server_version::v1::Response::new(), {
server: Some(assign!(get_server_version::v1::Server::new(), {
name: Some(conduwuit::version::name().into()),
version: Some(conduwuit::version::version().into()),
name: Some(conduwuit::BRANDING.into()),
version: Some(conduwuit::version().into()),
})),
}))
}
+132 -39
View File
@@ -4,7 +4,7 @@
pub mod proxy;
use std::{
collections::{BTreeMap, BTreeSet, HashMap},
collections::{BTreeMap, BTreeSet},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
path::PathBuf,
};
@@ -375,7 +375,7 @@ pub struct Config {
#[serde(default = "default_max_request_size")]
pub max_request_size: usize,
/// default: 192
/// default: 1024
#[serde(default = "default_max_fetch_prev_events")]
pub max_fetch_prev_events: u16,
@@ -655,19 +655,25 @@ pub struct Config {
/// even if `recaptcha_site_key` is set.
pub recaptcha_private_site_key: Option<String>,
/// Policy documents, such as terms and conditions or a privacy policy,
/// which users must agree to when registering an account.
///
/// Example:
/// ```ignore
/// [global.registration_terms.privacy_policy]
/// en = { name = "Privacy Policy", url = "https://homeserver.example/en/privacy_policy.html" }
/// es = { name = "Política de Privacidad", url = "https://homeserver.example/es/privacy_policy.html" }
/// ```
///
/// default: {}
/// display: nested
#[serde(default)]
pub registration_terms: HashMap<String, HashMap<String, TermsDocument>>,
pub registration_terms: RegistrationTerms,
/// display: nested
#[serde(default)]
pub oauth: OauthConfig,
/// Controls whether users are allowed to deactivate their own accounts
/// through the account management panel or their Matrix clients. Server
/// admins can always deactivate users using the relevant admin commands.
///
/// Note that, in some jurisdictions, you may be legally required to honor
/// users who request to deactivate their accounts if you set this option
/// to `false`.
///
/// default: true
#[serde(default = "true_fn")]
pub allow_deactivation: bool,
/// Controls whether encrypted rooms and events are allowed.
#[serde(default = "true_fn")]
@@ -781,6 +787,16 @@ pub struct Config {
/// a substitute for moderation bots.
pub default_room_acl_deny: Option<Vec<String>>,
/// The number of forward extremities to tolerate in a room before
/// attempting to manually squash them with a "dummy event". Setting this
/// above 20 will hinder its efficacy, and setting it below 5 will cause
/// more dummy events to be sent than necessary (which increases federation
/// traffic).
///
/// default: 10
#[serde(default = "default_extremity_threshold")]
pub dummy_event_threshold: u8,
/// display: nested
#[serde(default)]
pub well_known: WellKnownConfig,
@@ -2071,12 +2087,10 @@ pub struct Config {
pub stream_amplification: usize,
/// Number of sender task workers; determines sender parallelism. Default is
/// '0' which means the value is determined internally, likely matching the
/// number of tokio worker-threads or number of cores, etc. Override by
/// setting a non-zero value.
/// core count. Override by setting a different value.
///
/// default: 0
#[serde(default)]
/// default: core count
#[serde(default = "default_sender_workers")]
pub sender_workers: usize,
/// Enables listener sockets; can be set to false to disable listening. This
@@ -2351,6 +2365,29 @@ pub struct SmtpConfig {
pub require_email_for_token_registration: bool,
}
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
#[config_example_generator(
filename = "conduwuit-example.toml",
section = "global.registration_terms",
optional = "true"
)]
pub struct RegistrationTerms {
/// The language code to provide to clients along with the policy documents.
///
/// default: "en"
#[serde(default = "default_terms_language")]
pub language: String,
/// Policy documents, such as terms and conditions or a privacy policy,
/// which users must agree to when registering an account.
///
/// Example:
/// ```ignore
/// [global.registration_terms.documents]
/// privacy_policy = { name = "Privacy Policy", url = "https://homeserver.example/en/privacy_policy.html" }
/// ```
pub documents: BTreeMap<String, TermsDocument>,
}
/// A policy document for use with a m.login.terms stage.
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct TermsDocument {
@@ -2358,6 +2395,43 @@ pub struct TermsDocument {
pub url: String,
}
#[derive(Clone, Debug, Default, Deserialize)]
#[config_example_generator(
filename = "conduwuit-example.toml",
section = "global.oauth",
optional = "true"
)]
pub struct OauthConfig {
/// The compatibility mode to use for OAuth.
///
/// - "disabled": OAuth will be unavailable. Users will only be able to log
/// in using legacy authentication.
/// - "hybrid": OAuth and legacy authentication will both be available. Some
/// clients may only use one or the other.
/// - "exclusive": Only OAuth will be available. Clients which require
/// legacy authentication will be unable to log in.
///
/// default: "hybrid"
pub compatibility_mode: OAuthMode,
}
#[derive(Clone, Debug, Default, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum OAuthMode {
Disabled,
#[default]
Hybrid,
Exclusive,
}
impl OAuthMode {
#[must_use]
pub fn uiaa_available(&self) -> bool { matches!(self, Self::Disabled | Self::Hybrid) }
#[must_use]
pub fn oauth_available(&self) -> bool { matches!(self, Self::Hybrid | Self::Exclusive) }
}
const DEPRECATED_KEYS: &[&str] = &[
"cache_capacity",
"conduit_cache_capacity_modifier",
@@ -2455,45 +2529,47 @@ fn default_database_backups_to_keep() -> i16 { 1 }
fn default_db_write_buffer_capacity_mb() -> f64 { 48.0 + parallelism_scaled_f64(4.0) }
fn default_db_cache_capacity_mb() -> f64 { 128.0 + parallelism_scaled_f64(64.0) }
fn default_db_cache_capacity_mb() -> f64 { 512.0 + parallelism_scaled_f64(512.0) }
fn default_pdu_cache_capacity() -> u32 { parallelism_scaled_u32(10_000).saturating_add(100_000) }
fn default_pdu_cache_capacity() -> u32 { parallelism_scaled_u32(50_000).saturating_add(100_000) }
fn default_cache_capacity_modifier() -> f64 { 1.0 }
fn default_auth_chain_cache_capacity() -> u32 {
parallelism_scaled_u32(10_000).saturating_add(100_000)
}
fn default_shorteventid_cache_capacity() -> u32 {
parallelism_scaled_u32(50_000).saturating_add(100_000)
}
fn default_shorteventid_cache_capacity() -> u32 {
parallelism_scaled_u32(100_000).saturating_add(100_000)
}
fn default_eventidshort_cache_capacity() -> u32 {
parallelism_scaled_u32(25_000).saturating_add(100_000)
parallelism_scaled_u32(50_000).saturating_add(100_000)
}
fn default_eventid_pdu_cache_capacity() -> u32 {
parallelism_scaled_u32(25_000).saturating_add(100_000)
parallelism_scaled_u32(50_000).saturating_add(100_000)
}
fn default_shortstatekey_cache_capacity() -> u32 {
parallelism_scaled_u32(10_000).saturating_add(100_000)
parallelism_scaled_u32(100_000).saturating_add(100_000)
}
fn default_statekeyshort_cache_capacity() -> u32 {
parallelism_scaled_u32(10_000).saturating_add(100_000)
parallelism_scaled_u32(50_000).saturating_add(100_000)
}
fn default_servernameevent_data_cache_capacity() -> u32 {
parallelism_scaled_u32(100_000).saturating_add(500_000)
parallelism_scaled_u32(100_000).saturating_add(100_000)
}
fn default_stateinfo_cache_capacity() -> u32 { parallelism_scaled_u32(100) }
fn default_stateinfo_cache_capacity() -> u32 { parallelism_scaled_u32(500).clamp(100, 12000) }
fn default_roomid_spacehierarchy_cache_capacity() -> u32 { parallelism_scaled_u32(1000) }
fn default_roomid_spacehierarchy_cache_capacity() -> u32 {
parallelism_scaled_u32(500).clamp(100, 12000)
}
fn default_dns_cache_entries() -> u32 { 32768 }
fn default_dns_cache_entries() -> u32 { 327_680 }
fn default_dns_min_ttl() -> u64 { 60 * 180 }
@@ -2549,7 +2625,7 @@ fn default_pusher_timeout() -> u64 { 60 }
fn default_pusher_idle_timeout() -> u64 { 15 }
fn default_max_fetch_prev_events() -> u16 { 192_u16 }
fn default_max_fetch_prev_events() -> u16 { 1024 }
fn default_max_concurrent_inbound_transactions() -> usize { 150 }
@@ -2652,6 +2728,8 @@ fn default_rocksdb_stats_level() -> u8 { 1 }
#[inline]
pub fn default_default_room_version() -> RoomVersionId { RoomVersionId::V12 }
fn default_extremity_threshold() -> u8 { 10 }
fn default_ip_range_denylist() -> Vec<String> {
vec![
"127.0.0.0/8".to_owned(),
@@ -2701,15 +2779,26 @@ fn default_admin_log_capture() -> String {
fn default_admin_room_tag() -> String { "m.server_notice".to_owned() }
#[must_use]
#[allow(clippy::as_conversions, clippy::cast_precision_loss)]
fn parallelism_scaled_f64(val: f64) -> f64 { val * (sys::available_parallelism() as f64) }
pub fn parallelism_scaled_f64(val: f64) -> f64 { val * (sys::available_parallelism() as f64) }
fn parallelism_scaled_u32(val: u32) -> u32 {
let val = val.try_into().expect("failed to cast u32 to usize");
parallelism_scaled(val).try_into().unwrap_or(u32::MAX)
#[must_use]
#[allow(clippy::as_conversions, clippy::cast_possible_truncation)]
pub fn parallelism_scaled_u32(val: u32) -> u32 {
val.saturating_mul(sys::available_parallelism() as u32)
}
fn parallelism_scaled(val: usize) -> usize { val.saturating_mul(sys::available_parallelism()) }
#[must_use]
#[allow(clippy::as_conversions, clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
pub fn parallelism_scaled_i32(val: i32) -> i32 {
val.saturating_mul(sys::available_parallelism() as i32)
}
#[must_use]
pub fn parallelism_scaled(val: usize) -> usize {
val.saturating_mul(sys::available_parallelism())
}
fn default_trusted_server_batch_size() -> usize { 256 }
@@ -2729,6 +2818,8 @@ fn default_stream_width_scale() -> f32 { 1.0 }
fn default_stream_amplification() -> usize { 1024 }
fn default_sender_workers() -> usize { parallelism_scaled(1) }
fn default_client_receive_timeout() -> u64 { 75 }
fn default_client_request_timeout() -> u64 { 180 }
@@ -2738,3 +2829,5 @@ fn default_client_response_timeout() -> u64 { 120 }
fn default_client_shutdown_timeout() -> u64 { 15 }
fn default_sender_shutdown_timeout() -> u64 { 5 }
fn default_terms_language() -> String { "en".to_owned() }
+1
View File
@@ -161,6 +161,7 @@ pub fn message(&self) -> String {
match self {
| Self::Federation(origin, error) => format!("Answer from {origin}: {error}"),
| Self::Ruma(error) => response::ruma_error_message(error),
| Self::Request(_, message, _) => message.clone().into_owned(),
| _ => format!("{self}"),
}
}
+1 -4
View File
@@ -73,11 +73,8 @@ pub(super) fn bad_request_code(kind: &ErrorKind) -> StatusCode {
// 413
| TooLarge => StatusCode::PAYLOAD_TOO_LARGE,
// 405
| Unrecognized => StatusCode::METHOD_NOT_ALLOWED,
// 404
| NotFound => StatusCode::NOT_FOUND,
| Unrecognized | NotFound => StatusCode::NOT_FOUND,
// 403
| GuestAccessForbidden
+6 -9
View File
@@ -7,19 +7,16 @@
use std::sync::OnceLock;
static BRANDING: &str = "continuwuity";
static WEBSITE: &str = "https://continuwuity.org";
static SEMANTIC: &str = env!("CARGO_PKG_VERSION");
pub const BRANDING: &str = "continuwuity";
pub const ROUTE_PREFIX: &str = "/_continuwuity";
pub const WEBSITE: &str = "https://continuwuity.org";
pub const SEMANTIC: &str = env!("CARGO_PKG_VERSION");
static VERSION: OnceLock<String> = OnceLock::new();
static VERSION_UA: OnceLock<String> = OnceLock::new();
static USER_AGENT: OnceLock<String> = OnceLock::new();
static USER_AGENT_MEDIA: OnceLock<String> = OnceLock::new();
#[inline]
#[must_use]
pub fn name() -> &'static str { BRANDING }
#[inline]
pub fn version() -> &'static str { VERSION.get_or_init(init_version) }
@@ -32,10 +29,10 @@ pub fn user_agent() -> &'static str { USER_AGENT.get_or_init(init_user_agent) }
#[inline]
pub fn user_agent_media() -> &'static str { USER_AGENT_MEDIA.get_or_init(init_user_agent_media) }
fn init_user_agent() -> String { format!("{}/{} (bot; +{WEBSITE})", name(), version_ua()) }
fn init_user_agent() -> String { format!("{BRANDING}/{} (bot; +{WEBSITE})", version_ua()) }
fn init_user_agent_media() -> String {
format!("{}/{} (embedbot; facebookexternalhit/1.1; +{WEBSITE})", name(), version_ua())
format!("{BRANDING}/{} (embedbot; facebookexternalhit/1.1; +{WEBSITE})", version_ua())
}
fn init_version_ua() -> String {
+1 -1
View File
@@ -62,7 +62,7 @@ impl Default for PartialPdu {
fn default() -> Self {
Self {
event_type: "m.room.message".into(),
content: Box::<RawJsonValue>::default(),
content: to_raw_value("{}").unwrap(),
unsigned: None,
state_key: None,
redacts: None,
+4 -13
View File
@@ -22,28 +22,19 @@ pub fn versions() -> Vec<String> {
"v1.13".to_owned(),
"v1.14".to_owned(),
"v1.16".to_owned(),
// "v1.17".to_owned(),
// v1.17 requires: MSC4326 (AS device masquerading), MSC4312 (m.auth), MSC4190 (AS oauth
// user registration).
"v1.18".to_owned(),
]
}
#[must_use]
pub fn unstable_features() -> BTreeMap<String, bool> {
BTreeMap::from_iter([
("org.matrix.e2e_cross_signing".to_owned(), true),
("org.matrix.msc2285.stable".to_owned(), true), /* private read receipts (https://github.com/matrix-org/matrix-spec-proposals/pull/2285) */
("uk.half-shot.msc2666.query_mutual_rooms".to_owned(), true), /* query mutual rooms (https://github.com/matrix-org/matrix-spec-proposals/pull/2666) */
("org.matrix.msc2836".to_owned(), true), /* threading/threads (https://github.com/matrix-org/matrix-spec-proposals/pull/2836) */
("org.matrix.msc2946".to_owned(), true), /* spaces/hierarchy summaries (https://github.com/matrix-org/matrix-spec-proposals/pull/2946) */
("org.matrix.msc3026.busy_presence".to_owned(), true), /* busy presence status (https://github.com/matrix-org/matrix-spec-proposals/pull/3026) */
("org.matrix.msc3827".to_owned(), true), /* filtering of /publicRooms by room type (https://github.com/matrix-org/matrix-spec-proposals/pull/3827) */
("org.matrix.msc3952_intentional_mentions".to_owned(), true), /* intentional mentions (https://github.com/matrix-org/matrix-spec-proposals/pull/3952) */
("org.matrix.msc3916.stable".to_owned(), true), /* authenticated media (https://github.com/matrix-org/matrix-spec-proposals/pull/3916) */
("org.matrix.msc4180".to_owned(), true), /* stable flag for 3916 (https://github.com/matrix-org/matrix-spec-proposals/pull/4180) */
("uk.tcpip.msc4133".to_owned(), true), /* Extending User Profile API with Key:Value Pairs (https://github.com/matrix-org/matrix-spec-proposals/pull/4133) */
("us.cloke.msc4175".to_owned(), true), /* Profile field for user time zone (https://github.com/matrix-org/matrix-spec-proposals/pull/4175) */
("org.matrix.simplified_msc3575".to_owned(), true), /* Simplified Sliding sync (https://github.com/matrix-org/matrix-spec-proposals/pull/4186) */
("uk.timedout.msc4323".to_owned(), true), /* agnostic suspend (https://github.com/matrix-org/matrix-spec-proposals/pull/4323) */
("org.matrix.msc4155".to_owned(), true), /* invite filtering (https://github.com/matrix-org/matrix-spec-proposals/pull/4155) */
("computer.gingershaped.msc4466".to_owned(), true), /* profile change propagation (https://github.com/matrix-org/matrix-spec-proposals/pull/4466) */
("org.matrix.msc4380.stable".to_owned(), true),
])
}
+1 -4
View File
@@ -34,10 +34,7 @@ macro_rules! mod_dtor {
pub use conduwuit_build_metadata as build_metadata;
pub use config::Config;
pub use error::Error;
pub use info::{
version,
version::{name, version},
};
pub use info::version::*;
pub use matrix::{Event, EventTypeExt, Pdu, PduCount, PduEvent, PduId, pdu, state_res};
pub use parking_lot::{Mutex as SyncMutex, RwLock as SyncRwLock};
pub use server::Server;
+3
View File
@@ -5,6 +5,7 @@
/// Sha256 hash (input gather joined by 0xFF bytes)
#[must_use]
#[tracing::instrument(skip(inputs), level = "trace")]
#[allow(clippy::unnecessary_fallible_conversions)]
pub fn delimited<'a, T, I>(mut inputs: I) -> DigestOut
where
I: Iterator<Item = T> + 'a,
@@ -25,6 +26,7 @@ pub fn delimited<'a, T, I>(mut inputs: I) -> DigestOut
/// Sha256 hash (input gather)
#[must_use]
#[tracing::instrument(skip(inputs), level = "trace")]
#[allow(clippy::unnecessary_fallible_conversions)]
pub fn concat<'a, T, I>(inputs: I) -> DigestOut
where
I: Iterator<Item = T> + 'a,
@@ -43,6 +45,7 @@ pub fn concat<'a, T, I>(inputs: I) -> DigestOut
#[inline]
#[must_use]
#[tracing::instrument(skip(input), level = "trace")]
#[allow(clippy::unnecessary_fallible_conversions)]
pub fn hash<T>(input: T) -> DigestOut
where
T: AsRef<[u8]>,
+16 -10
View File
@@ -61,17 +61,23 @@ pub fn format(ts: SystemTime, str: &str) -> String {
pub fn pretty(d: Duration) -> String {
use Unit::*;
let fmt = |w, f, u| format!("{w}.{f} {u}");
let gen64 = |w, f, u| fmt(w, (f * 100.0) as u32, u);
let gen128 = |w, f, u| gen64(u64::try_from(w).expect("u128 to u64"), f, u);
let fmt = |w, u| {
if w == 1 {
format!("{w} {u}")
} else {
format!("{w} {u}s")
}
};
let gen64 = |w, u| fmt(w, u);
let gen128 = |w, u| gen64(u64::try_from(w).expect("u128 to u64"), u);
match whole_and_frac(d) {
| (Days(whole), frac) => gen64(whole, frac, "days"),
| (Hours(whole), frac) => gen64(whole, frac, "hours"),
| (Mins(whole), frac) => gen64(whole, frac, "minutes"),
| (Secs(whole), frac) => gen64(whole, frac, "seconds"),
| (Millis(whole), frac) => gen128(whole, frac, "milliseconds"),
| (Micros(whole), frac) => gen128(whole, frac, "microseconds"),
| (Nanos(whole), frac) => gen128(whole, frac, "nanoseconds"),
| (Days(whole), _) => gen64(whole, "day"),
| (Hours(whole), _) => gen64(whole, "hour"),
| (Mins(whole), _) => gen64(whole, "minute"),
| (Secs(whole), _) => gen64(whole, "second"),
| (Millis(whole), _) => gen128(whole, "millisecond"),
| (Micros(whole), _) => gen128(whole, "microsecond"),
| (Nanos(whole), _) => gen128(whole, "nanosecond"),
}
}
+1 -1
View File
@@ -29,7 +29,7 @@ fn descriptor_cf_options(
set_table_options(&mut opts, &desc, cache)?;
opts.set_min_write_buffer_number(1);
opts.set_max_write_buffer_number(2);
opts.set_max_write_buffer_number(3);
opts.set_write_buffer_size(desc.write_size);
opts.set_target_file_size_base(desc.file_size);
+24
View File
@@ -49,6 +49,10 @@ pub(super) fn open_list(db: &Arc<Engine>, maps: &[Descriptor]) -> Result<Maps> {
name: "bannedroomids",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "clientid_clientmetadata",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "disabledroomids",
..descriptor::RANDOM_SMALL
@@ -157,6 +161,10 @@ pub(super) fn open_list(db: &Arc<Engine>, maps: &[Descriptor]) -> Result<Maps> {
name: "referencedevents",
..descriptor::RANDOM
},
Descriptor {
name: "refreshtoken_refreshtokeninfo",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "registrationtoken_info",
..descriptor::RANDOM_SMALL
@@ -187,6 +195,10 @@ pub(super) fn open_list(db: &Arc<Engine>, maps: &[Descriptor]) -> Result<Maps> {
val_size_hint: Some(8),
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "roomid_mindepth",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "roomserverids",
..descriptor::RANDOM_SMALL
@@ -366,6 +378,14 @@ pub(super) fn open_list(db: &Arc<Engine>, maps: &[Descriptor]) -> Result<Maps> {
name: "userdevicetxnid_response",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "userdeviceid_oauthsessioninfo",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "userdeviceid_tokenexpires",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "userfilterid_filter",
..descriptor::RANDOM_SMALL
@@ -470,4 +490,8 @@ pub(super) fn open_list(db: &Arc<Engine>, maps: &[Descriptor]) -> Result<Maps> {
name: "userroomid_invitesender",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "websessionid_session",
..descriptor::RANDOM_SMALL
},
];
+1 -1
View File
@@ -15,7 +15,7 @@
#[clap(
about,
long_about = None,
name = conduwuit_core::name(),
name = conduwuit_core::BRANDING,
version = conduwuit_core::version(),
)]
pub struct Args {
+1 -1
View File
@@ -110,7 +110,7 @@ pub(crate) fn init(
.with_batch_exporter(exporter)
.build();
let tracer = provider.tracer(conduwuit_core::name());
let tracer = provider.tracer(conduwuit_core::BRANDING);
let telemetry = tracing_opentelemetry::layer().with_tracer(tracer);
+1 -1
View File
@@ -47,7 +47,7 @@ fn options(config: &Config) -> ClientOptions {
traces_sample_rate: config.sentry_traces_sample_rate,
debug: cfg!(debug_assertions),
release: release_name(),
user_agent: conduwuit_core::version::user_agent().into(),
user_agent: conduwuit_core::user_agent().into(),
attach_stacktrace: config.sentry_attach_stacktrace,
before_send: Some(Arc::new(before_send)),
before_breadcrumb: Some(Arc::new(before_breadcrumb)),
+7 -5
View File
@@ -8,7 +8,7 @@
extract::State,
response::{IntoResponse, Response},
};
use conduwuit::{Result, debug, debug_error, debug_warn, err, error, trace};
use conduwuit::{Result, debug_warn, err, error, info, trace};
use conduwuit_service::Services;
use futures::FutureExt;
use http::{Method, StatusCode, Uri};
@@ -102,17 +102,19 @@ fn handle_result(method: &Method, uri: &Uri, result: Response) -> Result<Respons
let reason = status.canonical_reason().unwrap_or("Unknown Reason");
if status.is_server_error() {
error!(%method, %uri, "{code} {reason}");
info!(%method, %uri, "{code} {reason}");
} else if status.is_client_error() {
debug_error!(%method, %uri, "{code} {reason}");
info!(%method, %uri, "{code} {reason}");
} else if status.is_redirection() {
debug!(%method, %uri, "{code} {reason}");
trace!(%method, %uri, "{code} {reason}");
} else {
trace!(%method, %uri, "{code} {reason}");
}
if status == StatusCode::METHOD_NOT_ALLOWED {
return Ok(err!(Request(Unrecognized("Method Not Allowed"))).into_response());
return Ok(
err!(Request(Unrecognized("Method not allowed"), METHOD_NOT_ALLOWED)).into_response()
);
}
Ok(result)
+2 -2
View File
@@ -9,8 +9,8 @@
pub(crate) fn build(services: &Arc<Services>) -> (Router, Guard) {
let router = Router::<state::State>::new();
let (state, guard) = state::create(services.clone());
let router = conduwuit_api::router::build(router, &services.server)
.merge(conduwuit_web::build())
let router = conduwuit_api::router::build(router, state)
.merge(conduwuit_web::build(services))
.fallback(not_found)
.with_state(state);
+1
View File
@@ -119,6 +119,7 @@ recaptcha-verify = { version = "0.2.0", default-features = false }
reqwest_recaptcha = { package = "reqwest", version = "0.12.28", default-features = false, features = ["rustls-tls-native-roots-no-provider"] } # As long as recaptcha-verify's reqwest is outdated
yansi.workspace = true
lettre.workspace = true
serde_urlencoded.workspace = true
[target.'cfg(all(unix, target_os = "linux"))'.dependencies]
sd-notify.workspace = true
+53 -1
View File
@@ -18,6 +18,8 @@
use loole::{Receiver, Sender};
use ruma::{
OwnedEventId, OwnedMxcUri, OwnedRoomId, OwnedUserId, RoomId, UInt, UserId,
api::client::discovery::discover_support::{Contact, ContactRole},
assign,
events::{
Mentions,
room::message::{
@@ -28,7 +30,7 @@
use tokio::sync::RwLock;
use crate::{
Dep, account_data, globals,
Dep, account_data, config, globals,
media::{MXC_LENGTH, mxc::Mxc},
rooms::{self, state::RoomMutexGuard},
};
@@ -44,6 +46,7 @@ pub struct Service {
struct Services {
server: Arc<Server>,
config: Dep<config::Service>,
globals: Dep<globals::Service>,
alias: Dep<rooms::alias::Service>,
timeline: Dep<rooms::timeline::Service>,
@@ -115,6 +118,7 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
services: Services {
server: args.server.clone(),
config: args.depend::<config::Service>("config"),
globals: args.depend::<globals::Service>("globals"),
alias: args.depend::<rooms::alias::Service>("rooms::alias"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
@@ -623,4 +627,52 @@ pub(super) fn set_services(&self, services: Option<&Arc<crate::Services>>) {
let weak = services.map(Arc::downgrade);
*receiver = weak;
}
/// Get the server's configured support contacts.
pub async fn get_support_contacts(&self) -> Vec<Contact> {
let email_address = self.services.config.well_known.support_email.clone();
let matrix_id = self.services.config.well_known.support_mxid.clone();
let pgp_key = self.services.config.well_known.support_pgp_key.clone();
// TODO: support defining multiple contacts in the config
let mut contacts: Vec<Contact> = vec![];
let role = self
.services
.config
.well_known
.support_role
.clone()
.unwrap_or(ContactRole::Admin);
// Add configured contact if at least one contact method is specified
let configured_contact = match (matrix_id, email_address) {
| (Some(matrix_id), email_address) =>
Some(assign!(Contact::with_matrix_id(role, matrix_id), { email_address })),
| (None, Some(email_address)) =>
Some(Contact::with_email_address(role, email_address)),
| (None, None) => None,
};
if let Some(mut configured_contact) = configured_contact {
configured_contact.pgp_key = pgp_key;
contacts.push(configured_contact);
}
// Try to add admin users as contacts if no contacts are configured
if contacts.is_empty() {
let admin_users = self.get_admins().await;
for user_id in &admin_users {
if *user_id == self.services.globals.server_user {
continue;
}
contacts.push(Contact::with_matrix_id(ContactRole::Admin, user_id.to_owned()));
}
}
contacts
}
}
+2 -2
View File
@@ -67,7 +67,7 @@ async fn worker(self: Arc<Self>) -> Result {
for (id, registration) in appservices {
// During startup, resolve any token collisions in favour of appservices
// by logging out conflicting user devices
if let Ok((user_id, device_id)) = self
if let Some((user_id, device_id, _)) = self
.services
.users
.find_from_token(&registration.as_token)
@@ -158,7 +158,7 @@ pub async fn register_appservice(
.users
.find_from_token(&registration.as_token)
.await
.is_ok()
.is_some()
{
return Err(err!(Request(InvalidParam(
"Cannot register appservice: The provided token is already in use by a user \
+2 -2
View File
@@ -39,7 +39,7 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
let url_preview_user_agent = config
.url_preview_user_agent
.clone()
.unwrap_or_else(|| conduwuit::version::user_agent_media().to_owned());
.unwrap_or_else(|| conduwuit::user_agent_media().to_owned());
Ok(Arc::new(Self {
default: base(config)?
@@ -149,7 +149,7 @@ fn base(config: &Config) -> Result<reqwest::ClientBuilder> {
.timeout(Duration::from_secs(config.request_total_timeout))
.pool_idle_timeout(Duration::from_secs(config.request_idle_timeout))
.pool_max_idle_per_host(config.request_idle_per_host.into())
.user_agent(conduwuit::version::user_agent())
.user_agent(conduwuit::user_agent())
.redirect(redirect::Policy::limited(6))
.danger_accept_invalid_certs(config.allow_invalid_tls_certificates_yes_i_know_what_the_fuck_i_am_doing_with_this_and_i_know_this_is_insecure)
.connection_verbose(cfg!(debug_assertions));
+12 -7
View File
@@ -6,7 +6,7 @@
use askama::Template;
use async_trait::async_trait;
use conduwuit::{Result, info, utils::ReadyExt};
use futures::{FutureExt, StreamExt};
use futures::StreamExt;
use ruma::{UserId, events::room::message::RoomMessageEventContent};
use crate::{
@@ -120,7 +120,7 @@ fn disable_first_run(&self) -> bool {
///
/// Returns Ok(true) if the specified user was the first user, and Ok(false)
/// if they were not.
pub async fn empower_first_user(&self, user: &UserId) -> Result<bool> {
pub async fn empower_first_user(&self, user: &UserId) -> bool {
#[derive(Template)]
#[template(path = "welcome.md")]
struct WelcomeMessage<'a> {
@@ -130,10 +130,14 @@ struct WelcomeMessage<'a> {
// If first run mode isn't active, do nothing.
if !self.disable_first_run() {
return Ok(false);
return false;
}
self.services.admin.make_user_admin(user).boxed().await?;
self.services
.admin
.make_user_admin(user)
.await
.expect("should have been able to empower the first user");
// Send the welcome message
let welcome_message = WelcomeMessage {
@@ -146,11 +150,12 @@ struct WelcomeMessage<'a> {
self.services
.admin
.send_loud_message(RoomMessageEventContent::text_markdown(welcome_message))
.await?;
.await
.expect("should have been able to send welcome message");
info!("{user} has been invited to the admin room as the first user.");
Ok(true)
true
}
/// Get the single-use registration token which may be used to create the
@@ -181,7 +186,7 @@ pub fn print_first_run_banner(&self) {
eprintln!(
"Welcome to {} {}!",
"Continuwuity".bold().bright_magenta(),
conduwuit::version::version().bold()
conduwuit::version().bold()
);
eprintln!();
eprintln!(
+4 -2
View File
@@ -92,8 +92,8 @@ pub async fn send<Template: MessageTemplate>(
let message = MessageBuilder::new()
.from(self.sender.clone())
.to(recipient)
.subject(subject)
.to(recipient.clone())
.subject(subject.clone())
.date_now()
.header(ContentType::TEXT_PLAIN)
.body(body)
@@ -104,6 +104,8 @@ pub async fn send<Template: MessageTemplate>(
.await
.map_err(|err: TransportError| err!("Failed to send message: {err}"))?;
info!(recipient = recipient.to_string(), ?subject, "Email sent");
Ok(())
}
}
+1 -1
View File
@@ -27,7 +27,7 @@
pub mod mailer;
pub mod media;
pub mod moderation;
pub mod password_reset;
pub mod oauth;
pub mod presence;
pub mod pusher;
pub mod registration_tokens;
+196
View File
@@ -0,0 +1,196 @@
use std::{collections::BTreeSet, hash::Hash};
use itertools::Itertools;
use serde::{Deserialize, Deserializer, Serialize};
use url::Url;
#[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize)]
#[non_exhaustive]
pub struct ClientMetadata {
#[serde(default)]
pub application_type: ApplicationType,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub client_name: Option<String>,
pub client_uri: Url,
#[serde(default, deserialize_with = "btreeset_skip_err")]
pub grant_types: BTreeSet<GrantType>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub logo_uri: Option<Url>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub policy_uri: Option<Url>,
#[serde(default)]
pub redirect_uris: Vec<Url>,
#[serde(default, deserialize_with = "btreeset_skip_err")]
pub response_types: BTreeSet<ResponseType>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub token_endpoint_auth_method: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tos_uri: Option<Url>,
}
impl ClientMetadata {
pub(super) const ACCEPTABLE_LOCALHOSTS: [&str; 3] = ["localhost", "127.0.0.1", "[::1]"];
pub(super) fn validate(&self) -> Result<(), &'static str> {
let Some(client_domain) = self.client_uri.domain() else {
return Err("Client URI must have a domain.");
};
if self.client_uri.scheme() != "https" {
return Err("Client URI must be HTTPS.");
}
if !self.client_uri.username().is_empty() || self.client_uri.password().is_some() {
return Err("Client URI must not include credentials.");
}
for uri in [&self.logo_uri, &self.policy_uri, &self.tos_uri]
.iter()
.filter_map(|uri| uri.as_ref())
{
if uri.scheme() != "https" {
return Err("All metadata URIs must be HTTPS.");
}
if !uri.username().is_empty() || uri.password().is_some() {
return Err("All metadata URIs must not include credentials.");
}
if !uri
.domain()
.is_some_and(|domain| is_subdomain(domain, client_domain))
{
return Err("All metadata URIs must be subdomains of the client URI.");
}
}
for uri in &self.redirect_uris {
match uri.scheme() {
| "https" => {
// HTTPS URIs are okay for native and web clients
if !uri.username().is_empty() || uri.password().is_some() {
return Err("HTTPS redirect URIs must not contain credentials.");
}
},
| "http" if self.application_type == ApplicationType::Native => {
if uri
.host_str()
.is_none_or(|host| !Self::ACCEPTABLE_LOCALHOSTS.contains(&host))
{
return Err("HTTP redirect URIs for native applications must only \
refer to localhost.");
}
if uri.port().is_some() {
return Err("HTTP redirect URIs for native applications do not need to \
specify a port. All ports will be accepted during \
authorization.");
}
},
| private_scheme if self.application_type == ApplicationType::Native => {
let rdns_client_uri = client_domain.split('.').rev().join(".");
if !private_scheme.starts_with(&rdns_client_uri) {
return Err("Private-use scheme URIs for native applications must \
begin with the application's client URI domain in \
reverse-DNS notation.");
}
if uri.has_authority() {
return Err("Private-use scheme URIs for native applications must not \
have an authority.");
}
},
| _ =>
return Err("A redirect URI's scheme is not valid for this application type."),
}
}
Ok(())
}
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ApplicationType {
#[default]
Web,
Native,
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum GrantType {
AuthorizationCode,
RefreshToken,
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum ResponseType {
Code,
}
/// Deserialize a BTreeSet from a sequence, skipping items which fail to
/// deserialize. This is used as a deserialize helper for ClientMetadata to
/// ignore unknown enum variants in a few fields.
fn btreeset_skip_err<'de, D, V>(de: D) -> Result<BTreeSet<V>, D::Error>
where
D: Deserializer<'de>,
V: Deserialize<'de> + Hash + Eq + Ord,
{
use std::marker::PhantomData;
use serde::de::{SeqAccess, Visitor};
struct BTreeSetVisitor<V> {
item: PhantomData<V>,
}
impl<'de, V> Visitor<'de> for BTreeSetVisitor<V>
where
V: Deserialize<'de> + Hash + Eq + Ord,
{
type Value = BTreeSet<V>;
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(formatter, "a sequence")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let mut set = BTreeSet::new();
while let Some(element) = seq.next_element().transpose() {
if let Ok(element) = element {
set.insert(element);
}
}
Ok(set)
}
}
de.deserialize_seq(BTreeSetVisitor { item: PhantomData })
}
fn is_subdomain(subdomain: &str, domain: &str) -> bool {
if subdomain == domain {
return true;
}
subdomain.ends_with(&format!(".{domain}"))
}
+211
View File
@@ -0,0 +1,211 @@
use std::{
borrow::Cow,
collections::BTreeSet,
error::Error,
fmt::{Debug, Display},
hash::Hash,
mem::discriminant,
};
use regex::Regex;
use ruma::OwnedDeviceId;
use serde::{Deserialize, Serialize};
use url::Url;
use super::client_metadata::ResponseType;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AuthorizationCodeQuery {
pub response_type: ResponseType,
pub client_id: String,
pub redirect_uri: Url,
pub scope: RawScopes,
pub state: String,
#[serde(default)]
pub response_mode: ResponseMode,
pub code_challenge: String,
pub code_challenge_method: CodeChallengeMethod,
#[serde(default)]
pub prompt: Option<Prompt>,
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum ResponseMode {
#[default]
// default for `code` response type, see https://openid.net/specs/oauth-v2-multiple-response-types-1_0.html#:~:text=Client%2E-,For,encoding%2E,-See
Query,
Fragment,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[non_exhaustive]
pub enum CodeChallengeMethod {
S256,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum Prompt {
Create,
#[serde(other)]
Unknown,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialOrd, Ord)]
pub enum Scope {
Device(OwnedDeviceId),
ClientApi,
}
impl PartialEq for Scope {
fn eq(&self, other: &Self) -> bool { discriminant(self) == discriminant(other) }
}
impl Eq for Scope {}
impl Hash for Scope {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) { discriminant(self).hash(state); }
}
impl Display for Scope {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let urn = match self {
| Self::ClientApi => "urn:matrix:client:api:*".to_owned(),
| Self::Device(device_id) => format!("urn:matrix:client:device:{device_id}"),
};
f.write_str(&urn)
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct RawScopes(String);
impl RawScopes {
pub fn to_scopes(&self) -> Result<BTreeSet<Scope>, String> {
let client_api_token_regex =
Regex::new(r"urn:matrix:(client|org.matrix.msc2967.client):api:\*").unwrap();
let device_token_regex = Regex::new(
r"urn:matrix:(client|org.matrix.msc2967.client):device:([a-zA-Z0-9-._~]{5,})",
)
.unwrap();
let mut scopes = BTreeSet::new();
for token in self.0.split(' ') {
let scope_was_new = {
if client_api_token_regex.is_match(token) {
scopes.insert(Scope::ClientApi)
} else if let Some(captures) = device_token_regex.captures(token) {
scopes.insert(Scope::Device(captures.get(2).unwrap().as_str().into()))
} else if token == "openid" {
// TODO(unspecced): Element sets this scope but doesn't use it for anything
true
} else {
return Err(format!("Invalid scope: {token}"));
}
};
if !scope_was_new {
return Err("Scope was specified more than once".to_owned());
}
}
Ok(scopes)
}
}
#[derive(Serialize, Debug, Clone)]
pub struct OAuthError {
pub error: ErrorCode,
pub error_description: Cow<'static, str>,
}
impl OAuthError {
#[must_use]
pub const fn invalid_request(error_description: &'static str) -> Self {
Self {
error: ErrorCode::InvalidRequest,
error_description: Cow::Borrowed(error_description),
}
}
#[must_use]
pub const fn invalid_grant(error_description: &'static str) -> Self {
Self {
error: ErrorCode::InvalidGrant,
error_description: Cow::Borrowed(error_description),
}
}
}
impl Display for OAuthError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "OAuth error {:?}: {}", self.error, self.error_description)
}
}
impl Error for OAuthError {}
#[derive(Serialize, Debug, Clone, Copy, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ErrorCode {
InvalidRequest,
AccessDenied,
InvalidScope,
InvalidGrant,
InvalidClientMetadata,
}
#[derive(Serialize)]
pub struct AuthorizationCodeResponse {
pub state: String,
pub code: String,
}
#[derive(Deserialize)]
#[serde(tag = "grant_type", rename_all = "snake_case")]
pub enum TokenRequest {
AuthorizationCode {
code: String,
redirect_uri: Url,
client_id: String,
code_verifier: String,
},
RefreshToken {
client_id: String,
refresh_token: String,
},
}
impl TokenRequest {
#[must_use]
pub fn client_id(&self) -> &str {
match self {
| Self::AuthorizationCode { client_id, .. }
| Self::RefreshToken { client_id, .. } => client_id,
}
}
}
#[derive(Serialize)]
pub struct TokenResponse {
pub access_token: String,
pub token_type: TokenType,
pub expires_in: u64,
pub refresh_token: String,
pub scope: String,
}
#[derive(Serialize)]
pub enum TokenType {
Bearer,
}
#[derive(Deserialize)]
pub struct RevokeTokenRequest {
pub token: String,
}
+528
View File
@@ -0,0 +1,528 @@
use std::{
collections::{BTreeSet, HashMap},
sync::{Arc, Mutex},
time::{Duration, SystemTime},
};
use base64::Engine;
use conduwuit::{
Result, info,
utils::{self, hash::sha256},
};
use database::{Deserialized, Json, Map};
use itertools::Itertools;
use lru_cache::LruCache;
use ruma::{DeviceId, OwnedDeviceId, OwnedUserId, UserId};
use serde::{Deserialize, Serialize};
use url::Url;
use crate::{
Dep,
oauth::{
client_metadata::{ApplicationType, ClientMetadata, ResponseType},
grant::{
AuthorizationCodeQuery, AuthorizationCodeResponse, CodeChallengeMethod, ErrorCode,
OAuthError, ResponseMode, Scope, TokenRequest, TokenResponse, TokenType,
},
},
users,
};
pub mod client_metadata;
pub mod grant;
pub struct Service {
services: Services,
db: Data,
tickets: Mutex<HashMap<String, HashMap<OAuthTicket, SystemTime>>>,
pending_code_grants: tokio::sync::Mutex<LruCache<String, PendingCodeGrant>>,
}
struct Data {
clientid_clientmetadata: Arc<Map>,
userdeviceid_oauthsessioninfo: Arc<Map>,
refreshtoken_refreshtokeninfo: Arc<Map>,
}
struct Services {
users: Dep<users::Service>,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct SessionInfo {
pub client_id: String,
pub scopes: BTreeSet<Scope>,
current_refresh_token: String,
}
#[derive(Debug, Deserialize, Serialize)]
struct RefreshTokenInfo {
client_id: String,
user_id: OwnedUserId,
device_id: OwnedDeviceId,
}
struct PendingCodeGrant {
authorizing_user: OwnedUserId,
requested_scopes: BTreeSet<Scope>,
client_name: Option<String>,
expected_client_id: String,
expected_redirect_uri: Url,
code_challenge: String,
requested_at: SystemTime,
}
impl PendingCodeGrant {
const MAX_AGE: Duration = Duration::from_mins(1);
const RANDOM_CODE_LENGTH: usize = 32;
#[must_use]
pub(crate) fn generate_code() -> String { utils::random_string(Self::RANDOM_CODE_LENGTH) }
#[must_use]
pub(crate) fn is_valid_for(&self, client_id: &str) -> bool {
let now = SystemTime::now();
self.expected_client_id == client_id
&& now
.duration_since(self.requested_at)
.is_ok_and(|age| age < Self::MAX_AGE)
}
}
/// A time-limited grant for a client to perform some sensitive action.
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum OAuthTicket {
CrossSigningReset,
}
impl OAuthTicket {
const MAX_AGE: Duration = Duration::from_mins(10);
#[must_use]
pub fn ticket_issue_path(&self) -> &'static str {
match self {
| Self::CrossSigningReset => "/account/cross_signing_reset",
}
}
}
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
services: Services {
users: args.depend::<users::Service>("users"),
},
db: Data {
clientid_clientmetadata: args.db["clientid_clientmetadata"].clone(),
userdeviceid_oauthsessioninfo: args.db["userdeviceid_oauthsessioninfo"].clone(),
refreshtoken_refreshtokeninfo: args.db["refreshtoken_refreshtokeninfo"].clone(),
},
tickets: Mutex::default(),
pending_code_grants: tokio::sync::Mutex::new(LruCache::new(
Self::MAX_PENDING_CODE_GRANTS,
)),
}))
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
impl Service {
const ACCESS_TOKEN_MAX_AGE: Duration = Duration::from_hours(1);
// Maximum number of pending code grants which will be held in memory at once,
// to prevent unbounded memory use if someone decides to repeatedly reload the
// grant page.
const MAX_PENDING_CODE_GRANTS: usize = 100;
const RANDOM_TOKEN_LENGTH: usize = 32;
fn generate_token() -> String { utils::random_string(Self::RANDOM_TOKEN_LENGTH) }
pub async fn register_client(&self, metadata: &ClientMetadata) -> Result<String, OAuthError> {
metadata.validate().map_err(|error| OAuthError {
error: ErrorCode::InvalidClientMetadata,
error_description: error.into(),
})?;
let client_id = base64::prelude::BASE64_STANDARD
.encode(sha256::hash(serde_json::to_string(metadata).unwrap().as_bytes()));
if self
.db
.clientid_clientmetadata
.exists(&client_id)
.await
.is_err()
{
self.db
.clientid_clientmetadata
.raw_put(&client_id, Json(metadata.clone()));
}
Ok(client_id)
}
pub async fn get_client_metadata(&self, client_id: &str) -> Option<ClientMetadata> {
self.db
.clientid_clientmetadata
.get(client_id)
.await
.deserialized()
.ok()
}
pub async fn get_session_info_for_device(
&self,
user_id: &UserId,
device_id: &DeviceId,
) -> Option<SessionInfo> {
self.db
.userdeviceid_oauthsessioninfo
.qry(&(user_id, device_id))
.await
.deserialized::<SessionInfo>()
.ok()
}
pub async fn request_authorization_code(
&self,
authorizing_user: OwnedUserId,
query: AuthorizationCodeQuery,
) -> Result<String, String> {
let Some(client_metadata) = self.get_client_metadata(&query.client_id).await else {
return Err("Invalid client ID".to_owned());
};
if !(client_metadata
.response_types
.contains(&query.response_type)
&& matches!(query.response_type, ResponseType::Code))
{
return Err("Invalid response type".to_owned());
}
if !matches!(query.code_challenge_method, CodeChallengeMethod::S256) {
return Err("Invalid code challenge type".to_owned());
}
{
let mut stripped_uri = query.redirect_uri.clone();
if client_metadata.application_type == ApplicationType::Native
&& query
.redirect_uri
.host_str()
.is_some_and(|host| ClientMetadata::ACCEPTABLE_LOCALHOSTS.contains(&host))
{
// Remove the port from localhost redirect URIs for native applications when
// checking if it's valid
stripped_uri.set_port(None).unwrap();
}
if !client_metadata.redirect_uris.contains(&stripped_uri) {
return Err("Invalid redirect URI".to_owned());
}
}
let requested_scopes = query.scope.to_scopes()?;
let redirect_uri_query_separator = match query.response_mode {
| ResponseMode::Fragment => '#',
| ResponseMode::Query => '?',
};
let code = PendingCodeGrant::generate_code();
info!(
client_id = &query.client_id,
client_name = &client_metadata.client_name,
?requested_scopes,
?authorizing_user,
"Issuing oauth authorization code"
);
let redirect_uri = format!(
"{}{}{}",
query.redirect_uri,
redirect_uri_query_separator,
serde_urlencoded::to_string(AuthorizationCodeResponse {
state: query.state,
code: code.clone(),
})
.unwrap(),
);
let pending_grant = PendingCodeGrant {
authorizing_user,
requested_scopes,
client_name: client_metadata.client_name,
expected_client_id: query.client_id,
expected_redirect_uri: query.redirect_uri,
code_challenge: query.code_challenge,
requested_at: SystemTime::now(),
};
self.pending_code_grants
.lock()
.await
.insert(code, pending_grant);
Ok(redirect_uri)
}
pub async fn issue_token(&self, request: TokenRequest) -> Result<TokenResponse, OAuthError> {
match request {
| TokenRequest::AuthorizationCode {
code,
redirect_uri,
client_id,
code_verifier,
} => {
let mut pending_grants = self.pending_code_grants.lock().await;
let Some(pending_grant) = pending_grants
.remove(&code)
.filter(|grant| grant.is_valid_for(&client_id))
else {
return Err(OAuthError::invalid_grant("Invalid authorization code"));
};
if redirect_uri != pending_grant.expected_redirect_uri {
return Err(OAuthError::invalid_grant("Invalid redirect URI"));
}
let expected_code_challenge =
base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(sha256::hash(&code_verifier));
if expected_code_challenge != pending_grant.code_challenge {
return Err(OAuthError::invalid_grant("Invalid code challenge"));
}
self.create_session(
pending_grant.authorizing_user,
pending_grant.requested_scopes,
pending_grant.client_name,
client_id,
)
.await
},
| TokenRequest::RefreshToken { client_id, refresh_token } =>
self.refresh_session(client_id, refresh_token).await,
}
}
pub async fn revoke_token(&self, token: String) -> Result<(), OAuthError> {
let (user_id, device_id) = if let Ok(refresh_token_info) = self
.db
.refreshtoken_refreshtokeninfo
.get(&token)
.await
.deserialized::<RefreshTokenInfo>()
{
(refresh_token_info.user_id, refresh_token_info.device_id)
} else if let Some((user_id, device_id, _)) =
self.services.users.find_from_token(&token).await
{
(user_id, device_id)
} else {
return Err(OAuthError::invalid_grant("Invalid access or refersh token"));
};
// This will also call [`Self::remove_session`]
self.services
.users
.remove_device(&user_id, &device_id)
.await;
Ok(())
}
async fn create_session(
&self,
authorizing_user: OwnedUserId,
requested_scopes: BTreeSet<Scope>,
client_name: Option<String>,
client_id: String,
) -> Result<TokenResponse, OAuthError> {
let access_token = Self::generate_token();
let refresh_token = Self::generate_token();
let device_id = requested_scopes
.iter()
.find_map(|scope| {
if let Scope::Device(device_id) = scope {
Some(device_id)
} else {
None
}
})
.ok_or_else(|| OAuthError::invalid_grant("No device ID scope supplied"))?;
if self
.services
.users
.get_device_metadata(&authorizing_user, device_id)
.await
.is_ok()
{
return Err(OAuthError {
error: ErrorCode::InvalidScope,
error_description: "A device with the supplied ID already exists for this user"
.into(),
});
}
self.services
.users
.create_device(
&authorizing_user,
device_id,
&access_token,
Some(Self::ACCESS_TOKEN_MAX_AGE),
client_name,
None,
)
.await
// This can only panic if the authorizing user suffered a spontaneous existence
// failure during authentication, which should(?) be impossible(?)
.expect("failed to create device");
self.db.userdeviceid_oauthsessioninfo.put(
(&authorizing_user, device_id),
Json(SessionInfo {
client_id: client_id.clone(),
current_refresh_token: refresh_token.clone(),
scopes: requested_scopes.clone(),
}),
);
self.db.refreshtoken_refreshtokeninfo.raw_put(
&refresh_token,
Json(RefreshTokenInfo {
client_id: client_id.clone(),
user_id: authorizing_user.clone(),
device_id: device_id.to_owned(),
}),
);
info!(
?client_id,
?authorizing_user,
?device_id,
?requested_scopes,
"Created new oauth session"
);
Ok(TokenResponse {
access_token,
token_type: TokenType::Bearer,
expires_in: Self::ACCESS_TOKEN_MAX_AGE.as_secs(),
scope: requested_scopes.iter().join(" "),
refresh_token,
})
}
async fn refresh_session(
&self,
client_id: String,
refresh_token: String,
) -> Result<TokenResponse, OAuthError> {
let Some(refresh_token_info) = self
.db
.refreshtoken_refreshtokeninfo
.get(&refresh_token)
.await
.deserialized::<RefreshTokenInfo>()
.ok()
else {
return Err(OAuthError::invalid_grant("Invalid refresh token"));
};
assert_eq!(&client_id, &refresh_token_info.client_id, "refresh token client id mismatch");
let mut session_info = self
.get_session_info_for_device(
&refresh_token_info.user_id,
&refresh_token_info.device_id,
)
.await
.expect("session info should exist");
assert_eq!(&client_id, &session_info.client_id, "session info client id mismatch");
let new_access_token = Self::generate_token();
let new_refresh_token = Self::generate_token();
let scope = session_info.scopes.iter().join(" ");
session_info
.current_refresh_token
.clone_from(&new_refresh_token);
self.services
.users
.set_token(
&refresh_token_info.user_id,
&refresh_token_info.device_id,
&new_access_token,
Some(Self::ACCESS_TOKEN_MAX_AGE),
)
.await
.expect("should be able to set token");
self.db.userdeviceid_oauthsessioninfo.put(
(&refresh_token_info.user_id, &refresh_token_info.device_id),
Json(session_info),
);
self.db.refreshtoken_refreshtokeninfo.remove(&refresh_token);
drop(refresh_token);
self.db
.refreshtoken_refreshtokeninfo
.raw_put(&new_refresh_token, Json(refresh_token_info));
Ok(TokenResponse {
access_token: new_access_token,
token_type: TokenType::Bearer,
expires_in: Self::ACCESS_TOKEN_MAX_AGE.as_secs(),
scope,
refresh_token: new_refresh_token,
})
}
pub async fn remove_session(&self, user_id: &UserId, device_id: &DeviceId) {
let session_info = self.get_session_info_for_device(user_id, device_id).await;
if let Some(session_info) = session_info {
self.db
.refreshtoken_refreshtokeninfo
.remove(&session_info.current_refresh_token);
self.db
.userdeviceid_oauthsessioninfo
.del((user_id, device_id));
info!(?user_id, ?device_id, "Removed OAuth session");
}
}
/// Issue a ticket for `localpart` to perform some action.
pub fn issue_ticket(&self, localpart: String, ticket: OAuthTicket) {
self.tickets
.lock()
.unwrap()
.entry(localpart)
.or_default()
.insert(ticket, SystemTime::now());
}
/// Try to consume an unexpired ticket for `localpart`.
pub fn try_consume_ticket(&self, localpart: &str, ticket: OAuthTicket) -> bool {
let now = SystemTime::now();
self.tickets
.lock()
.unwrap()
.get_mut(localpart)
.and_then(|tickets| tickets.remove(&ticket))
.is_some_and(|issued| {
now.duration_since(issued)
.is_ok_and(|duration| duration < OAuthTicket::MAX_AGE)
})
}
}
-68
View File
@@ -1,68 +0,0 @@
use std::{
sync::Arc,
time::{Duration, SystemTime},
};
use conduwuit::utils::{ReadyExt, stream::TryExpect};
use database::{Database, Deserialized, Json, Map};
use ruma::{OwnedUserId, UserId};
use serde::{Deserialize, Serialize};
pub(super) struct Data {
passwordresettoken_info: Arc<Map>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ResetTokenInfo {
pub user: OwnedUserId,
pub issued_at: SystemTime,
}
impl ResetTokenInfo {
// one hour
const MAX_TOKEN_AGE: Duration = Duration::from_hours(1);
pub fn is_valid(&self) -> bool {
let now = SystemTime::now();
now.duration_since(self.issued_at)
.is_ok_and(|duration| duration < Self::MAX_TOKEN_AGE)
}
}
impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self {
Self {
passwordresettoken_info: db["passwordresettoken_info"].clone(),
}
}
/// Associate a reset token with its info in the database.
pub(super) fn save_token(&self, token: &str, info: &ResetTokenInfo) {
self.passwordresettoken_info.raw_put(token, Json(info));
}
/// Lookup the info for a reset token.
pub(super) async fn lookup_token_info(&self, token: &str) -> Option<ResetTokenInfo> {
self.passwordresettoken_info
.get(token)
.await
.deserialized()
.ok()
}
/// Find a user's existing reset token, if any.
pub(super) async fn find_token_for_user(
&self,
user: &UserId,
) -> Option<(String, ResetTokenInfo)> {
self.passwordresettoken_info
.stream::<'_, String, ResetTokenInfo>()
.expect_ok()
.ready_find(|(_, info)| info.user == user)
.await
}
/// Remove a reset token.
pub(super) fn remove_token(&self, token: &str) { self.passwordresettoken_info.remove(token); }
}
-111
View File
@@ -1,111 +0,0 @@
mod data;
use std::{sync::Arc, time::SystemTime};
use conduwuit::{Err, Result, utils};
use data::{Data, ResetTokenInfo};
use ruma::OwnedUserId;
use crate::{
Dep, globals,
users::{self, HashedPassword},
};
pub const PASSWORD_RESET_PATH: &str = "/_continuwuity/account/reset_password";
pub const RESET_TOKEN_QUERY_PARAM: &str = "token";
const RESET_TOKEN_LENGTH: usize = 32;
pub struct Service {
db: Data,
services: Services,
}
struct Services {
users: Dep<users::Service>,
globals: Dep<globals::Service>,
}
#[derive(Debug)]
pub struct ValidResetToken {
pub token: String,
pub info: ResetTokenInfo,
}
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data::new(args.db),
services: Services {
users: args.depend::<users::Service>("users"),
globals: args.depend::<globals::Service>("globals"),
},
}))
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
impl Service {
/// Generate a random string suitable to be used as a password reset token.
#[must_use]
pub fn generate_token_string() -> String { utils::random_string(RESET_TOKEN_LENGTH) }
/// Issue a password reset token for `user`, who must be a local user with
/// the `password` origin.
pub async fn issue_token(&self, user_id: OwnedUserId) -> Result<ValidResetToken> {
if !self.services.globals.user_is_local(&user_id) {
return Err!("Cannot issue a password reset token for remote user {user_id}");
}
if user_id == self.services.globals.server_user {
return Err!("Cannot issue a password reset token for the server user");
}
if self.services.users.is_deactivated(&user_id).await? {
return Err!("Cannot issue a password reset token for deactivated user {user_id}");
}
if let Some((existing_token, _)) = self.db.find_token_for_user(&user_id).await {
self.db.remove_token(&existing_token);
}
let token = Self::generate_token_string();
let info = ResetTokenInfo {
user: user_id,
issued_at: SystemTime::now(),
};
self.db.save_token(&token, &info);
Ok(ValidResetToken { token, info })
}
/// Check if `token` represents a valid, non-expired password reset token.
pub async fn check_token(&self, token: &str) -> Option<ValidResetToken> {
self.db.lookup_token_info(token).await.and_then(|info| {
if info.is_valid() {
Some(ValidResetToken { token: token.to_owned(), info })
} else {
self.db.remove_token(token);
None
}
})
}
/// Consume the supplied valid token, using it to change its user's password
/// to `new_password`.
pub async fn consume_token(
&self,
ValidResetToken { token, info }: ValidResetToken,
new_password: &str,
) -> Result<()> {
if info.is_valid() {
self.db.remove_token(&token);
self.services
.users
.set_password(&info.user, Some(HashedPassword::new(new_password)?));
}
Ok(())
}
}
+1 -1
View File
@@ -100,7 +100,7 @@ pub async fn get_presence(&self, user_id: &UserId) -> Result<PresenceEvent> {
/// Pings the presence of the given user in the given room, setting the
/// specified state.
pub async fn ping_presence(&self, user_id: &UserId, new_state: &PresenceState) -> Result<()> {
const REFRESH_TIMEOUT: u64 = 60 * 1000;
const REFRESH_TIMEOUT: u64 = 60 * 1000 * 4;
let last_presence = self.db.get_presence(user_id).await;
let state_changed = match last_presence {
+3 -2
View File
@@ -10,6 +10,7 @@
stream::{iter, once},
};
use ruma::OwnedUserId;
use serde::{Deserialize, Serialize};
use crate::{Dep, config, firstrun};
@@ -27,7 +28,7 @@ struct Services {
}
/// A validated registration token which may be used to create an account.
#[derive(Debug)]
#[derive(Debug, Deserialize, Serialize)]
pub struct ValidToken {
pub token: String,
pub source: ValidTokenSource,
@@ -44,7 +45,7 @@ fn eq(&self, other: &str) -> bool { self.token == other }
}
/// The source of a valid database token.
#[derive(Debug)]
#[derive(Debug, Deserialize, Serialize)]
pub enum ValidTokenSource {
/// The static token set in the homeserver's config file.
Config,
@@ -1,233 +1,667 @@
use std::{
collections::{BTreeMap, HashSet, VecDeque, hash_map},
collections::{HashMap, HashSet, VecDeque},
time::Instant,
};
use assign::assign;
#[cfg(debug_assertions)]
use conduwuit::error;
use conduwuit::{
Event, PduEvent, debug, debug_warn, implement, matrix::event::gen_event_id_canonical_json,
trace, utils::continue_exponential_backoff_secs, warn,
Err, Event, PduEvent, debug, debug_error, debug_info, debug_warn, err,
state_res::lexicographical_topological_sort,
trace,
utils::{IterStream, math::Expected, stream::BroadbandExt},
warn,
};
use futures::{StreamExt, future::select_ok};
use ruma::{
CanonicalJsonValue, EventId, OwnedEventId, RoomId, ServerName,
api::federation::event::get_event,
CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId,
OwnedServerName, RoomId, ServerName, UInt,
api::federation::event::{get_event, get_missing_events},
int,
room_version_rules::RoomVersionRules,
};
use super::get_room_version_rules;
use crate::rooms::event_handler::parse_incoming_pdu::expect_event_id_array;
/// Find the event and auth it. Once the event is validated (steps 1 - 8)
/// it is appended to the outliers Tree.
pub const GET_MISSING_EVENTS_MAX_BATCH_SIZE: usize = 50;
/// Attempts to build a localised directed acyclic graph out of the given PDUs,
/// returning them in a topologically sorted order.
///
/// Returns pdu and if we fetched it over federation the raw json.
///
/// a. Look in the main timeline (pduid_pdu tree)
/// b. Look at outlier pdu tree
/// c. Ask origin server over federation
/// d. TODO: Ask other servers over federation?
#[implement(super::Service)]
pub(super) async fn fetch_and_handle_outliers<'a, Pdu, Events>(
&self,
origin: &'a ServerName,
events: Events,
create_event: &'a Pdu,
room_id: &'a RoomId,
) -> Vec<(PduEvent, Option<BTreeMap<String, CanonicalJsonValue>>)>
where
Pdu: Event + Send + Sync,
Events: Iterator<Item = &'a EventId> + Clone + Send,
{
let back_off = |id| match self
.services
.globals
.bad_event_ratelimiter
.write()
.entry(id)
{
| hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1));
},
| hash_map::Entry::Occupied(mut e) => {
*e.get_mut() = (Instant::now(), e.get().1.saturating_add(1));
},
};
/// This is used to attempt to process PDUs in an order that respects their
/// dependencies, however it is ultimately the sender's responsibility to send
/// them in a processable order, so this is just a best effort attempt. It does
/// not account for power levels or other tie breaks.
pub async fn build_local_dag<S: std::hash::BuildHasher + Send + Sync>(
pdu_map: &HashMap<OwnedEventId, &CanonicalJsonObject, S>,
) -> conduwuit::Result<Vec<OwnedEventId>> {
debug_assert!(pdu_map.len() >= 2, "needless call to build_local_dag with less than 2 PDUs");
let mut dag: HashMap<OwnedEventId, HashSet<OwnedEventId>> =
HashMap::with_capacity(pdu_map.len());
let mut id_origin_ts: HashMap<OwnedEventId, _> = HashMap::with_capacity(pdu_map.len());
let mut events_with_auth_events = Vec::with_capacity(events.clone().count());
trace!("Fetching {} outlier pdus", events.clone().count());
for (event_id, value) in pdu_map {
// We already checked that these properties are correct in parse_incoming_pdu,
// so it's safe to unwrap here.
// We also filter to remove any prev_events that are not in this pdu_map, as we
// need to have at least one event with zero out degrees for the lexico-topo
// sort below. If there are multiple events with omitted prevs, they will be
// ordered by timestamp, then event ID. At that point though, it's unlikely to
// matter.
let prev_events = value
.get("prev_events")
.unwrap()
.as_array()
.unwrap()
.iter()
.map(|v| EventId::parse(v.as_str().unwrap()).unwrap())
.filter(|id| pdu_map.contains_key(id))
.collect();
for id in events {
// a. Look in the main timeline (pduid_pdu tree)
// b. Look at outlier pdu tree
// (get_pdu_json checks both)
if let Ok(local_pdu) = self.services.timeline.get_pdu(id).await {
trace!("Found {id} in main timeline or outlier tree");
events_with_auth_events.push((id.to_owned(), Some(local_pdu), vec![]));
continue;
}
dag.insert(event_id.clone(), prev_events);
let origin_server_ts = value
.get("origin_server_ts")
.and_then(CanonicalJsonValue::as_integer)
.unwrap_or_default();
id_origin_ts.insert(event_id.clone(), origin_server_ts);
}
// c. Ask origin server over federation
// We also handle its auth chain here so we don't get a stack overflow in
// handle_outlier_pdu.
let mut todo_auth_events: VecDeque<_> = [id.to_owned()].into();
let mut events_in_reverse_order = Vec::with_capacity(todo_auth_events.len());
debug!(count = dag.len(), "Sorting incoming events with partial graph");
lexicographical_topological_sort(&dag, &async |node_id| {
// Note: we don't bother fetching power levels because that would massively slow
// this function down. This is a best-effort attempt to order events correctly
// for processing, however ultimately that should be the sender's job.
let ts = id_origin_ts
.get(&node_id)
.copied()
.unwrap_or_else(|| int!(0))
.to_string()
.parse::<u64>()
.ok()
.and_then(UInt::new)
.unwrap_or_default();
Ok((int!(0), MilliSecondsSinceUnixEpoch(ts)))
})
.await
.inspect(|sorted| {
debug_assert_eq!(
sorted.len(),
pdu_map.len(),
"Sorted graph was not the same size as the input graph"
);
})
.map_err(|e| err!("failed to resolve local graph: {e}"))
}
let mut events_all = HashSet::with_capacity(todo_auth_events.len());
while let Some(next_id) = todo_auth_events.pop_front() {
if let Some((time, tries)) = self
.services
.globals
.bad_event_ratelimiter
.read()
.get(&*next_id)
{
// Exponential backoff
const MIN_DURATION: u64 = 60 * 2;
const MAX_DURATION: u64 = 60 * 60 * 8;
if continue_exponential_backoff_secs(
MIN_DURATION,
MAX_DURATION,
time.elapsed(),
*tries,
) {
debug_warn!(
tried = ?*tries,
elapsed = ?time.elapsed(),
"Backing off from {next_id}",
);
continue;
}
}
impl super::Service {
/// Uses `POST /_matrix/federation/v1/get_missing_events/{room_id}` to fill
/// gaps in the DAG.
///
/// This function walks backwards from `head`, fetching incrementally (by a
/// factor of 10) more events until the remote we're fetching from either
/// stops returning new events, or the min_depth is reached.
///
/// This function does not persist the events, but does validate them. The
/// caller is responsible for passing them through handle_incoming_pdu or
/// related functions.
///
/// Only the one `via` is asked for missing events, as multiplexing remotes
/// may result in the event tree being walked in a gappy or disordered
/// manner.
///
/// ## Parameters
///
/// - `room_id`: The room's ID.
/// - `head`: The event we are potentially missing prev_events for.
/// - `tail`: The most recently known events in the graph (typically forward
/// extremities).
/// - `via`: The server to ask for missing events.
/// - `min_depth`: Don't process events with a `depth` lower than this
/// value. Not massively useful, but can help short-circuit infinite loops
/// and weird edge paths.
#[tracing::instrument(name = "get_missing_events_bulk", skip_all)]
pub async fn get_missing_events(
&self,
room_id: &RoomId,
head: &PduEvent,
tail: Vec<OwnedEventId>,
via: &ServerName,
min_depth: UInt,
) -> conduwuit::Result<HashMap<OwnedEventId, PduEvent>> {
let start = Instant::now();
#[cfg(debug_assertions)]
{
let missing_count = head
.prev_events()
.stream()
.fold(0_u8, |i, event_id| async move {
if self.services.timeline.pdu_exists(event_id).await {
i.expected_add(1)
} else {
i
}
})
.await;
debug_assert_ne!(
missing_count, 0,
"event passed to get_missing_events is not missing any events (wasteful call)"
);
};
assert!(!tail.is_empty(), "empty tail");
assert_ne!(via, self.services.globals.server_name(), "cannot ask ourselves for events");
if events_all.contains(&next_id) {
continue;
}
// The iteration limit is in place to ensure that if the remote server leaves us
// in a state of infinite recursion (as old versions of continuwuity and
// predecessors would), we give up. However, get_missing_events doesn't return
// that many events per-request. Synapse returns 20, and conduwuit+ return 50.
// This means with a hard iteration limit, we might give up too early, before
// we get a chance to even come close to max_fetch_prev_events. As such, we'll
// calculate the limit based on that config option and the aforementioned
// averages.
let max_fetch = self.services.server.config.max_fetch_prev_events;
let iteration_limit = max_fetch.saturating_div(20).max(10);
if self.services.timeline.pdu_exists(&next_id).await {
trace!("Found {next_id} in db");
continue;
}
debug!("Fetching {next_id} over federation from {origin}.");
match self
let mut discovered = HashMap::with_capacity(head.prev_events.len());
let mut latest_events: Vec<OwnedEventId> = vec![head.event_id().to_owned()];
debug!(elapsed=?start.elapsed(),
%room_id,
event_id=%head.event_id(),
%iteration_limit,
"Fetching any missing events for head event",
);
for iteration in 0..iteration_limit {
let limit = iteration
.expected_add(1)
.saturating_mul(10)
.min(GET_MISSING_EVENTS_MAX_BATCH_SIZE.try_into().expect(
"GET_MISSING_EVENTS_MAX_BATCH_SIZE (usize) should fit in u16 (<=65536)",
))
.max(
// This max call ensures we fetch *at least* all the prev events the
// head has.
u16::try_from(head.prev_events.len())
.expect("cannot have more than 20 prev events, which fits in u16"),
);
debug_info!(elapsed=?start.elapsed(),
%limit,
%via,
%iteration,
%iteration_limit,
discovered=discovered.len(),
%min_depth,
"Attempting to gap fill missing events"
);
let response: get_missing_events::v1::Response = self
.services
.sending
.send_federation_request(
origin,
get_event::v1::Request::new((*next_id).to_owned()),
via,
assign!(
get_missing_events::v1::Request::new(
room_id.to_owned(),
tail.clone(),
latest_events.clone()
),
{limit: limit.into(), min_depth}
),
)
.await
{
| Ok(res) => {
debug!("Got {next_id} over federation from {origin}");
let Ok(room_version_rules) = get_room_version_rules(create_event) else {
back_off((*next_id).to_owned());
continue;
};
.await?;
let Ok((calculated_event_id, value)) =
gen_event_id_canonical_json(&res.pdu, &room_version_rules)
else {
back_off((*next_id).to_owned());
continue;
};
if calculated_event_id != *next_id {
warn!(
"Server didn't return event id we requested: requested: {next_id}, \
we got {calculated_event_id}. Event: {:?}",
&res.pdu
);
}
if let Some(auth_events) = value
.get("auth_events")
.and_then(CanonicalJsonValue::as_array)
{
for auth_event in auth_events {
match serde_json::from_value::<OwnedEventId>(
auth_event.clone().into(),
) {
| Ok(auth_event) => {
trace!(
"Found auth event id {auth_event} for event {next_id}"
);
todo_auth_events.push_back(auth_event);
},
| _ => {
warn!("Auth event id is not valid");
},
}
}
} else {
warn!("Auth event list invalid");
}
events_in_reverse_order.push((next_id.clone(), value));
events_all.insert(next_id);
},
| Err(e) => {
warn!("Failed to fetch auth event {next_id} from {origin}: {e}");
back_off((*next_id).to_owned());
},
if response.events.is_empty() {
debug_info!(
elapsed=?start.elapsed(),
%via,
"Finished gap filling missing events (remote returned no more events)."
);
break;
}
}
debug_info!(
elapsed=?start.elapsed(),
"Got {} events back from remote",
response.events.len()
);
events_with_auth_events.push((id.to_owned(), None, events_in_reverse_order));
}
let mut pdus = Vec::with_capacity(events_with_auth_events.len());
for (id, local_pdu, events_in_reverse_order) in events_with_auth_events {
// a. Look in the main timeline (pduid_pdu tree)
// b. Look at outlier pdu tree
// (get_pdu_json checks both)
if let Some(local_pdu) = local_pdu {
trace!("Found {id} in main timeline or outlier tree");
pdus.push((local_pdu.clone(), None));
}
for (next_id, value) in events_in_reverse_order.into_iter().rev() {
if let Some((time, tries)) = self
.services
.globals
.bad_event_ratelimiter
.read()
.get(&*next_id)
{
// Exponential backoff
const MIN_DURATION: u64 = 5 * 60;
const MAX_DURATION: u64 = 60 * 60 * 24;
if continue_exponential_backoff_secs(
MIN_DURATION,
MAX_DURATION,
time.elapsed(),
*tries,
) {
debug!("Backing off from {next_id}");
latest_events.clear();
for raw_event in response.events {
let (_, event_id, pdu_json) = self.parse_incoming_pdu(&raw_event).await?;
let pdu = PduEvent::from_id_val(&event_id, pdu_json).map_err(|e| {
err!(Request(BadJson("Failed to parse gapfilled event {event_id}: {e}")))
})?;
if discovered.contains_key(&event_id) {
// We already received this event.
trace!("Already received {event_id}");
continue;
}
if self
.services
.timeline
.non_outlier_pdu_exists(&event_id)
.await
{
// NOTE: we explicitly check for *non*-outlier events here, as if we end
// up discovering outlier events, we will be able to upgrade them
// immediately.
trace!("Already have {event_id} as a timeline PDU");
continue;
}
if pdu.depth < min_depth {
debug_warn!(
elapsed=?start.elapsed(),
"Received PDU with depth {} below min_depth {}",
pdu.depth,
min_depth
);
discovered.insert(event_id.clone(), pdu);
continue;
}
for prev_event_id in pdu.prev_events() {
if discovered.contains_key(prev_event_id) {
// We already received this event.
trace!("Already received prev event {prev_event_id}");
continue;
}
if self
.services
.timeline
.non_outlier_pdu_exists(prev_event_id)
.await
{
// NOTE: we explicitly check for *non*-outlier events here, as if we end
// up discovering outlier events, we will be able to upgrade them
// immediately.
trace!("Already have prev event {prev_event_id} as a timeline PDU");
continue;
}
if let Ok(outlier) = self.services.timeline.get_pdu(prev_event_id).await {
// We already have this PDU as an outlier, don't ask for
// it. However, if we are missing any prev events for it, add it to the
// latest events anyway.
let outlier_missing_prevs = outlier
.prev_events()
.stream()
.fold(0_u8, |i, event_id| async move {
if self.services.timeline.pdu_exists(event_id).await {
i.expected_add(1)
} else {
i
}
})
.await;
if outlier_missing_prevs > 0 {
trace!("Missing {outlier_missing_prevs} PDU(s) for prev event");
latest_events.push(prev_event_id.to_owned());
}
trace!("Had {prev_event_id} as an outlier already, skipping discovery");
discovered.insert(prev_event_id.to_owned(), outlier);
continue;
}
trace!("Missing prev {prev_event_id} of {event_id}");
latest_events.push(prev_event_id.to_owned());
}
trace!("Discovered {event_id}");
discovered.insert(event_id.clone(), pdu);
}
trace!("Handling outlier {next_id}");
if latest_events.is_empty() {
debug!(elapsed=?start.elapsed(),
%limit,
%via,
%iteration,
discovered=discovered.len(),
"No more events to fetch."
);
break;
}
if discovered.len() >= self.services.server.config.max_fetch_prev_events.into() {
// Stupid hack, debug_error!() drops the log to a DEBUG when not in debug mode,
// which is bad because this should at least produce a warning. It's an error in
// debug mode because this can be important, but typically not much can be done
// about it as a user.
#[cfg(debug_assertions)]
error!(elapsed=?start.elapsed(),
discovered=discovered.len(),
max_fetch_prev_events=self.services.server.config.max_fetch_prev_events,
%iteration,
%iteration_limit,
%via,
event_id=%head.event_id(),
%room_id,
"Encountered a gap too large to fill, giving up"
);
#[cfg(not(debug_assertions))]
warn!(elapsed=?start.elapsed(),
discovered=discovered.len(),
max_fetch_prev_events=self.services.server.config.max_fetch_prev_events,
%iteration,
%iteration_limit,
%via,
event_id=%head.event_id(),
%room_id,
"Encountered a gap too large to fill"
);
break;
}
}
trace!(elapsed=?start.elapsed(), "Finished get_missing_events");
Ok(discovered)
}
/// Sends a `GET /_matrix/federation/v1/event/{event_id}` request to the
/// target `remote`, parses the resulting PDU, and ensures the remote
/// returned the correct event.
/// Allows `fetch_and_handle_missing_events` to atomically fetch events from
/// multiple remotes in parallel.
async fn fetch_event_via(
&self,
remote: OwnedServerName,
event_id: OwnedEventId,
room_version_rules: &RoomVersionRules,
) -> conduwuit::Result<(OwnedEventId, CanonicalJsonObject)> {
let res = self
.services
.sending
.send_federation_request(&remote, get_event::v1::Request::new(event_id.clone()))
.await?;
let (calculated_event_id, value) = self
.parse_incoming_pdu_with_known_room(&res.pdu, room_version_rules)
.await?;
if calculated_event_id != event_id {
Err!(Request(BadJson(warn!(
expected=%event_id,
received=%calculated_event_id,
"Server didn't return event id we requested",
))))
} else {
Ok((event_id, value))
}
}
async fn fetch_event_vias(
&self,
candidates: impl Iterator<Item = &OwnedServerName>,
event_id: &EventId,
room_version_rules: &RoomVersionRules,
) -> conduwuit::Result<(OwnedEventId, CanonicalJsonObject)> {
if let Ok(pdu_json) = self.services.timeline.get_pdu_json(event_id).await {
return Ok((event_id.to_owned(), pdu_json));
}
let futures = candidates
.map(|remote| {
Box::pin(self.fetch_event_via(
remote.to_owned(),
event_id.to_owned(),
room_version_rules,
))
})
.collect::<Vec<_>>();
select_ok(futures).await.map(|(res, _)| res)
}
/// Asks remote servers for any individual events that are missing, also
/// known as "atomic fetch". Should only be used for fetching missing auth
/// events or resolving missing events from state_ids. For all other uses,
/// use get_missing_events.
///
/// This function manually walks auth_events trees in a breadth-first
/// search, and persists all fetched events as outliers when all the
/// backwards extremities have been resolved.
#[tracing::instrument(name = "get_missing_auth_events_atomic", skip_all)]
pub(super) async fn fetch_and_handle_auth_events<Pdu>(
&self,
origin: &ServerName,
events: Vec<OwnedEventId>,
create_event: &Pdu,
room_id: &RoomId,
) -> HashMap<OwnedEventId, PduEvent>
where
Pdu: Event + Send + Sync,
{
let start = Instant::now();
let room_version_rules =
&get_room_version_rules(create_event).unwrap_or(RoomVersionRules::V1);
let mut candidates = self
.services
.timeline
.candidate_backfill_servers(room_id)
.await;
candidates.insert(origin.to_owned());
assert!(!candidates.is_empty(), "no candidates to fetch missing events from");
let mut discovered_events =
HashMap::with_capacity(events.len().saturating_add(events.len().saturating_mul(3)));
trace!(
elapsed=?start.elapsed(),
"Fetching {} unknown PDUs on demand from {} candidates",
events.len(),
candidates.len()
);
let mut seen: HashMap<OwnedEventId, u8> = HashMap::new();
for apex_event_id in &events {
let mut todo: VecDeque<OwnedEventId> = [apex_event_id.to_owned()].into();
while let Some(target_id) = todo.pop_front() {
if discovered_events.contains_key(&target_id) {
continue;
}
if let Ok(local_pdu) = self.services.timeline.get_pdu(&target_id).await {
trace!(elapsed=?start.elapsed(), "Found {target_id} in db");
let mut obj = local_pdu.into_canonical_object();
obj.remove("event_id");
discovered_events.insert(target_id.clone(), obj);
continue;
}
let attempts = seen.get(&*target_id).copied().unwrap_or_default();
if attempts >= 5 {
debug_error!(
elapsed=?start.elapsed(),
%attempts,
%target_id,
"Could not fetch missing event after 5 attempts, giving up"
);
continue;
}
debug!(elapsed=?start.elapsed(),"Fetching {target_id} over federation");
let value = match self
.fetch_event_vias(candidates.iter(), &target_id, room_version_rules)
.await
{
| Ok((_, x)) => x,
| Err(e) => {
warn!(elapsed=?start.elapsed(),"failed to fetch missing event {target_id} from any candidate: {e}");
continue;
},
};
let auth_events =
match expect_event_id_array(&value, "auth_events").map_err(|e| {
err!(Request(BadJson(warn!(
elapsed=?start.elapsed(),
event_id=%target_id,
"Failed to parse event fetched from remote: {e}"
))))
}) {
| Ok(auth_events) => auth_events,
| Err(e) => {
warn!(
elapsed=?start.elapsed(),
?e,
"event {target_id} is malformed (bad auth_events), skipping"
);
continue;
},
};
let mut have_all_auth = true;
for auth_event_id in auth_events {
if let Ok(local_pdu) = self.services.timeline.get_pdu(&auth_event_id).await {
trace!(elapsed=?start.elapsed(),"Found auth event {auth_event_id} in db");
let mut obj = local_pdu.into_canonical_object();
obj.remove("event_id");
discovered_events.insert(auth_event_id.clone(), obj);
continue;
}
if discovered_events.contains_key(&auth_event_id) {
trace!(elapsed=?start.elapsed(),%auth_event_id, "Already found auth event");
continue;
}
debug!(elapsed=?start.elapsed(),"Missing auth event {auth_event_id} for event {target_id}");
seen.insert(
auth_event_id.clone(),
seen.get(&auth_event_id)
.copied()
.unwrap_or_default()
.saturating_add(1),
);
todo.push_back(auth_event_id);
have_all_auth = false;
}
// Insert this PDU back at the end of the queue so that it will be resolved once
// all of its auth events have been fetched.
if have_all_auth {
debug!(elapsed=?start.elapsed(),%target_id, "Have all auth events");
discovered_events.insert(target_id, value);
} else {
debug_warn!(elapsed=?start.elapsed(),
"Fetched {target_id} but missing some auth events, will have to re-fetch."
);
seen.insert(target_id.clone(), attempts.saturating_add(1));
todo.push_back(target_id);
}
}
}
let refmap: HashMap<OwnedEventId, &CanonicalJsonObject> = discovered_events
.iter()
.map(|(id, data)| (id.clone(), data))
.collect();
let seeded_ordered = build_local_dag(&refmap)
.await
.expect("failed to build local DAG");
let mut pdus = HashMap::with_capacity(seeded_ordered.len());
for discovered_event_id in seeded_ordered {
let pdu_json = discovered_events.remove(&discovered_event_id).unwrap();
debug_info!(
elapsed=?start.elapsed(),
"Handling missing event {discovered_event_id} as outlier"
);
assert_eq!(pdu_json.get("event_id"), None, "pdu_json had event_id");
match Box::pin(self.handle_outlier_pdu(
origin,
create_event,
&next_id,
&discovered_event_id,
room_id,
value.clone(),
true,
pdu_json,
))
.await
{
| Ok((pdu, json)) =>
if next_id == *id {
trace!("Handled outlier {next_id} (original request)");
pdus.push((pdu, Some(json)));
},
| Err(e) => {
warn!("Authentication of event {next_id} failed: {e:?}");
back_off(next_id);
| Ok((pdu, _)) => {
trace!(elapsed=?start.elapsed(), "Persisted {discovered_event_id}");
let _ = pdus.insert(discovered_event_id, pdu);
},
| Err(e) => warn!(
elapsed=?start.elapsed(),
"Authentication of event {discovered_event_id} failed: {e:?}"
),
}
}
trace!(
elapsed=?start.elapsed(),
"Finished fetch_and_handle_missing_events: fetched and handled {} missing PDUs",
pdus.len()
);
pdus.retain(|id, _| events.contains(id)); // Only return state events
trace!(elapsed=?start.elapsed(), "Filtered return value down to {} PDUs", pdus.len());
pdus
}
/// Similar to `fetch_and_handle_missing_events`, but simply walks the
/// prev events tree instead of the auth events tree. Additionally, it does
/// not *handle* fetched PDUs in any capacity.
#[tracing::instrument(name = "get_missing_prev_events_atomic", skip_all)]
pub(super) async fn fetch_prev_events<Pdu>(
&self,
origin: &ServerName,
events: Vec<OwnedEventId>,
create_event: &Pdu,
room_id: &RoomId,
) -> HashMap<OwnedEventId, PduEvent>
where
Pdu: Event + Send + Sync,
{
let room_version_rules =
&get_room_version_rules(create_event).unwrap_or(RoomVersionRules::V1);
let mut candidates = self
.services
.timeline
.candidate_backfill_servers(room_id)
.await;
candidates.insert(origin.to_owned());
let mut todo: VecDeque<OwnedEventId> = VecDeque::from(events);
let mut discovered_events = HashMap::new();
while let Some(next_id) = todo.pop_front() {
if discovered_events.len() >= self.services.server.config.max_fetch_prev_events.into()
{
debug_warn!(
"Encountered a gap too large to fill, giving up (fetched {} events)",
discovered_events.len()
);
break;
}
if discovered_events.contains_key(&next_id) {
continue;
}
let pdu = match self
.fetch_event_vias(candidates.iter(), &next_id, room_version_rules)
.await
{
| Ok((_, data)) => data,
| Err(e) => {
warn!("Failed to fetch prev event {next_id} from any candidate: {e}");
continue;
},
};
let prev_events = match expect_event_id_array(&pdu, "prev_events").map_err(|e| {
err!(Request(BadJson(warn!(
event_id=%next_id,
"Failed to parse event fetched from remote: {e}"
))))
}) {
| Ok(auth_events) => auth_events,
| Err(e) => {
warn!(?e, "event {next_id} is malformed (bad prev_events), skipping");
continue;
},
};
let missing_prev = prev_events
.iter()
.stream()
.broad_filter_map(|event_id| async {
if discovered_events.contains_key(event_id)
|| self.services.timeline.pdu_exists(event_id).await
{
None
} else {
Some(event_id.to_owned())
}
})
.collect::<Vec<_>>()
.await;
todo.extend(missing_prev);
discovered_events.insert(
next_id.clone(),
PduEvent::from_id_val(&next_id, pdu).expect("fetched PDU was already validated"),
);
}
discovered_events
}
trace!("Fetched and handled {} outlier pdus", pdus.len());
pdus
}
+142 -120
View File
@@ -1,128 +1,150 @@
use std::{
collections::{BTreeMap, HashMap, HashSet, VecDeque},
iter::once,
};
use std::{collections::HashMap, time::Instant};
use conduwuit::{
Event, PduEvent, Result, debug_warn, err, implement,
state_res::{self},
};
use futures::{FutureExt, future};
use ruma::{
CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, ServerName,
int, uint,
Event, PduEvent, debug, debug_info, info, trace,
utils::{BoolExt, IterStream, stream::BroadbandExt},
warn,
};
use futures::StreamExt;
use ruma::{CanonicalJsonObject, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, ServerName};
use super::check_room_id;
use crate::rooms::event_handler::build_local_dag;
#[implement(super::Service)]
#[tracing::instrument(
level = "debug",
skip_all,
fields(%origin),
)]
#[allow(clippy::type_complexity)]
pub(super) async fn fetch_prev<'a, Pdu, Events>(
&self,
origin: &ServerName,
create_event: &Pdu,
room_id: &RoomId,
first_ts_in_room: MilliSecondsSinceUnixEpoch,
initial_set: Events,
) -> Result<(
Vec<OwnedEventId>,
HashMap<OwnedEventId, (PduEvent, BTreeMap<String, CanonicalJsonValue>)>,
)>
where
Pdu: Event + Send + Sync,
Events: Iterator<Item = &'a EventId> + Clone + Send,
{
let num_ids = initial_set.clone().count();
let mut eventid_info = HashMap::new();
let mut graph: HashMap<OwnedEventId, _> = HashMap::with_capacity(num_ids);
let mut todo_outlier_stack: VecDeque<OwnedEventId> =
initial_set.map(ToOwned::to_owned).collect();
let mut amount = 0;
while let Some(prev_event_id) = todo_outlier_stack.pop_front() {
self.services.server.check_running()?;
match self
.fetch_and_handle_outliers(
origin,
once(prev_event_id.as_ref()),
create_event,
room_id,
)
.boxed()
.await
.pop()
{
| Some((pdu, mut json_opt)) => {
check_room_id(room_id, &pdu)?;
let limit = self.services.server.config.max_fetch_prev_events;
if amount > limit {
debug_warn!("Max prev event limit reached! Limit: {limit}");
graph.insert(prev_event_id.clone(), HashSet::new());
continue;
}
if json_opt.is_none() {
json_opt = self
.services
.outlier
.get_outlier_pdu_json(&prev_event_id)
.await
.ok();
}
if let Some(json) = json_opt {
if pdu.origin_server_ts() > first_ts_in_room {
amount = amount.saturating_add(1);
for prev_prev in pdu.prev_events() {
if !graph.contains_key(prev_prev) {
todo_outlier_stack.push_back(prev_prev.to_owned());
}
}
graph.insert(
prev_event_id.clone(),
pdu.prev_events().map(ToOwned::to_owned).collect(),
);
} else {
// Time based check failed
graph.insert(prev_event_id.clone(), HashSet::new());
}
eventid_info.insert(prev_event_id.clone(), (pdu, json));
} else {
// Get json failed, so this was not fetched over federation
graph.insert(prev_event_id.clone(), HashSet::new());
}
},
| _ => {
// Fetch and handle failed
graph.insert(prev_event_id.clone(), HashSet::new());
},
impl super::Service {
/// Fetches any missing prev_events for this event and persists them before
/// returning.
pub(super) async fn fetch_prevs(
&self,
room_id: &RoomId,
create_event: &PduEvent,
incoming_pdu: &PduEvent,
origin: &ServerName,
first_ts_in_room: MilliSecondsSinceUnixEpoch,
) -> conduwuit::Result<()> {
let start = Instant::now();
let mut missing = incoming_pdu
.prev_events()
.stream()
.broad_filter_map(|event_id| async move {
self.services
.timeline
.get_non_outlier_pdu_json(event_id)
.await
.is_ok()
.or(|| event_id.to_owned())
})
.collect::<Vec<_>>()
.await;
if missing.is_empty() {
debug!(elapsed=?start.elapsed(), event_id=%incoming_pdu.event_id(), "No missing prev events.");
return Ok(());
}
debug!(elapsed=?start.elapsed(), %room_id, event_id=%incoming_pdu.event_id(), ?missing, "Fetching previous events");
let tail = self
.services
.state
.get_forward_extremities(room_id)
.collect::<Vec<_>>()
.await;
let mut gapfilled = self
.get_missing_events(
room_id,
incoming_pdu,
tail,
origin,
self.services
.metadata
.get_mindepth(room_id)
.await
.saturating_sub(
u8::try_from(incoming_pdu.prev_events.len())
.unwrap()
.saturating_mul(2)
.into(),
),
)
.await?;
debug_info!(elapsed=?start.elapsed(), "Fetched {} missing events", gapfilled.len());
missing.retain(|eid| !gapfilled.contains_key(eid));
if !missing.is_empty() {
warn!(elapsed=?start.elapsed(), "Still missing {} events, falling back to atomic fetch.", missing.len());
gapfilled.extend(
self.fetch_prev_events(origin, missing, create_event, room_id)
.await,
);
}
// Persist all fetched events
let mapped = gapfilled
.iter()
.map(|(eid, evt)| {
let mut obj = evt.to_canonical_object();
obj.remove("event_id"); // event_id is inserted by backfill_missing_events
(eid.clone(), obj)
})
.collect::<HashMap<_, _>>();
let to_persist = if mapped.len() <= 1 {
mapped.keys().map(ToOwned::to_owned).collect()
} else {
let refmap: HashMap<OwnedEventId, &CanonicalJsonObject> =
mapped.iter().map(|(id, data)| (id.clone(), data)).collect();
build_local_dag(&refmap).await?
};
let job_start = Instant::now();
trace!("Starting to persist {} prev events", to_persist.len());
for (i, event_id) in to_persist.iter().enumerate() {
info!(
elapsed=?start.elapsed(),
"Persisting fetched prev event: {event_id} ({}/{})",
i.saturating_add(1),
to_persist.len(),
);
let obj = mapped.get(event_id).cloned().unwrap();
let persist_start = Instant::now();
match self
.handle_outlier_pdu(origin, create_event, event_id, room_id, obj)
.await
{
| Ok((pdu, val)) if pdu.origin_server_ts() >= first_ts_in_room => {
self.upgrade_outlier_to_timeline_pdu(pdu, val, create_event, origin, room_id)
.await
.inspect_err(|e| {
warn!(
total_elapsed=?start.elapsed(),
job_elapsed=?job_start.elapsed(),
task_elapsed=?persist_start.elapsed(),
"Failed to upgrade prev event {event_id}: {e}",
);
})
.inspect(|_| {
info!(
total_elapsed=?start.elapsed(),
job_elapsed=?job_start.elapsed(),
task_elapsed=?persist_start.elapsed(),
"Upgraded prev event {event_id}",
);
})
.ok();
},
| Err(e) => warn!(
total_elapsed=?start.elapsed(),
job_elapsed=?job_start.elapsed(),
task_elapsed=?persist_start.elapsed(),
"Failed to persist prev event {event_id}: {e}",
),
| _ => {},
}
}
// NOTE because i keep forgetting: the caller persists incoming_pdu.
// we only care about its prev events
trace!(
total_elapsed=?start.elapsed(),
persist_elapsed=?job_start.elapsed(),
);
Ok(())
}
let event_fetch = |event_id| {
let origin_server_ts = eventid_info
.get(&event_id)
.map_or_else(|| uint!(0), |info| info.0.origin_server_ts().get());
// This return value is the key used for sorting events,
// events are then sorted by power level, time,
// and lexically by event_id.
future::ok((int!(0), MilliSecondsSinceUnixEpoch(origin_server_ts)))
};
let sorted = state_res::lexicographical_topological_sort(&graph, &event_fetch)
.await
.map_err(|e| err!(Database(error!("Error sorting prev events: {e}"))))?;
Ok((sorted, eventid_info))
}
+379 -73
View File
@@ -1,86 +1,392 @@
use std::collections::{HashMap, hash_map};
use conduwuit::{Err, Event, Result, debug, debug_warn, err, implement};
use futures::FutureExt;
use ruma::{
EventId, OwnedEventId, RoomId, ServerName, api::federation::event::get_room_state_ids,
events::StateEventType,
use std::{
cmp::max,
collections::{HashMap, HashSet, hash_map},
hash::{BuildHasherDefault, DefaultHasher},
time::{Duration, Instant},
};
use crate::rooms::short::ShortStateKey;
use conduwuit::{
Err, Event, PduEvent, Result, debug, debug_warn, err, info, trace,
utils::{BoolExt, IterStream},
warn,
};
use futures::{StreamExt, TryFutureExt, future::select_ok};
use ruma::{
EventId, OwnedEventId, OwnedRoomId, RoomId, ServerName,
api::federation::event::{get_room_state, get_room_state_ids},
};
/// Call /state_ids to find out what the state at this pdu is. We trust the
/// server's response to some extend (sic), but we still do a lot of checks
/// on the events
#[implement(super::Service)]
#[tracing::instrument(
level = "debug",
skip_all,
fields(%origin),
)]
pub(super) async fn fetch_state<Pdu>(
&self,
origin: &ServerName,
create_event: &Pdu,
room_id: &RoomId,
event_id: &EventId,
) -> Result<Option<HashMap<u64, OwnedEventId>>>
where
Pdu: Event + Send + Sync,
{
let res = self
.services
.sending
.send_federation_request(
origin,
get_room_state_ids::v1::Request::new(event_id.to_owned(), room_id.to_owned()),
)
.await
.inspect_err(|e| debug_warn!("Fetching state for event failed: {e}"))?;
use crate::{conduwuit::utils::stream::BroadbandExt, rooms::short::ShortStateKey};
debug!("Fetching state events");
let state_ids = res.pdu_ids.iter().map(AsRef::as_ref);
let state_vec = self
.fetch_and_handle_outliers(origin, state_ids, create_event, room_id)
.boxed()
.await;
let mut state: HashMap<ShortStateKey, OwnedEventId> = HashMap::with_capacity(state_vec.len());
for (pdu, _) in state_vec {
let state_key = pdu
.state_key()
.ok_or_else(|| err!(Database("Found non-state pdu in state events.")))?;
let shortstatekey = self
impl super::Service {
/// Asks a remote server what the state at this event is.
/// It first attempts to call `GET /_matrix/federation/v1/state_ids` (fast).
/// If any events are missing, they are fetched from the remote, and
/// persisted as outliers, before being returned back to this function. If
/// we are missing a lot of events locally (>=50), this function falls back
/// to requesting the full state in PDU format from the remote (`GET
/// /_matrix/federation/v1/state, very slow in large rooms), and persists
/// them directly.
#[tracing::instrument(skip_all)]
pub(super) async fn fetch_state(
&self,
origin: &ServerName,
create_event: &PduEvent,
room_id: &RoomId,
event_id: &EventId,
) -> Result<HashMap<u64, OwnedEventId>> {
let start = Instant::now();
trace!(%origin, "Asking remote for state_ids");
let res: get_room_state_ids::v1::Response = match self
.services
.short
.get_or_create_shortstatekey(&pdu.kind().to_string().into(), state_key)
.sending
.send_federation_request(
origin,
get_room_state_ids::v1::Request::new(event_id.to_owned(), room_id.to_owned()),
)
.await
.inspect_err(
|e| debug_warn!(elapsed=?start.elapsed(), "Fetching state for event failed: {e}"),
) {
| Ok(resp) => Ok(resp),
| Err(e) =>
if e.is_not_found() {
self.fetch_state_ids_from_backfill_servers(
event_id.to_owned(),
room_id.to_owned(),
)
.await
} else {
Err(e)
},
}?;
debug!(elapsed=?start.elapsed(), events = res.pdu_ids.len(), "Fetching state events");
let mut state_events: HashMap<OwnedEventId, PduEvent> =
HashMap::with_capacity(res.pdu_ids.len());
let to_fetch: Vec<OwnedEventId> = res
.pdu_ids
.clone()
.into_iter()
.stream()
.broad_filter_map(|event_id| async move {
self.services
.timeline
.pdu_exists(&event_id)
.await
.or_some(event_id)
})
.collect()
.await;
if to_fetch.is_empty() {
debug!(elapsed=?start.elapsed(), "All required state events are already known.");
state_events = res
.pdu_ids
.iter()
.stream()
.broad_filter_map(|event_id| async move {
Some((
event_id.clone(),
self.services
.timeline
.get_pdu(event_id)
.await
.expect("Event disappeared between filtering and fetching"),
))
})
.collect()
.await;
assert_eq!(
state_events.len(),
res.pdu_ids.len(),
"Failed to load all required state events despite allegedly knowing all of them \
already",
);
} else {
let total_count = res.pdu_ids.len();
let missing_count = to_fetch.len();
let missing_threshold = max(50, total_count >> 2);
if missing_count >= missing_threshold {
// If there's more than 50 events to fetch, or we're missing 25% or more of the
// state, we would need to make a lot of atomic requests, so we'll just try
// to fetch the full state from the remote instead.
// Since this endpoint might fail in huge rooms, we fall back to atomic fetch
// anyway.
warn!(
elapsed=?start.elapsed(),
%missing_count,
%total_count,
%missing_threshold,
"Fetching full state from remote server for event"
);
let state_response = tokio::time::timeout(
Duration::from_secs(30),
self.fetch_full_state(origin, create_event, room_id, event_id),
)
.await;
info!(
elapsed=?start.elapsed(),
%missing_count,
%total_count,
%missing_threshold,
"Fetched full state from remote server for event"
);
let fetched_state = match state_response {
| Ok(Ok(state)) => {
// Filter to ensure we only use the PDUs we were expecting, preventing
// arbitrary state injection.
// Atomic fetch does not have this problem as each PDU is evaluated
// individually.
let expected: &HashSet<OwnedEventId, BuildHasherDefault<DefaultHasher>> =
&HashSet::from_iter(res.pdu_ids.clone());
state
.into_iter()
.stream()
.broad_filter_map(|(event_id, pdu)| async move {
expected.contains(&event_id).then_some((event_id, pdu))
})
.collect()
.await
},
| Ok(Err(e)) => {
warn!(
elapsed=?start.elapsed(),
error=?e,
%origin,
"Failed to fetch full state from remote, falling back to atomic fetch"
);
self.fetch_and_handle_auth_events(
origin,
res.pdu_ids.clone(),
create_event,
room_id,
)
.await
},
| Err(e) => {
warn!(
elapsed=?start.elapsed(),
error=?e,
%origin,
"Remote did not return room state in an acceptable timeframe, falling back to atomic fetch"
);
self.fetch_and_handle_auth_events(
origin,
res.pdu_ids.clone(),
create_event,
room_id,
)
.await
},
};
match state.entry(shortstatekey) {
| hash_map::Entry::Vacant(v) => {
v.insert(pdu.event_id().to_owned());
},
| hash_map::Entry::Occupied(_) => {
return Err!(Database(
"State event's type and state_key combination exists multiple times: {}, {}",
pdu.kind(),
state_key
));
},
assert!(
!fetched_state.is_empty(),
"fetch_full_state or fetch_and_handle_missing_events returned empty state \
map"
);
state_events.extend(fetched_state);
} else {
state_events = res
.pdu_ids
.iter()
.stream()
.broad_filter_map(|event_id| async move {
self.services
.timeline
.get_pdu(event_id)
.await
.map(|p| (event_id.to_owned(), p))
.ok()
})
.collect()
.await;
assert!(
!state_events.is_empty(),
"Only missing {} events but read-ahead state vec was empty",
to_fetch.len()
);
debug!(
elapsed=?start.elapsed(),
to_fetch = to_fetch.len(),
"Fetching missing events for state from remote"
);
let fetched_state = self
.fetch_and_handle_auth_events(origin, to_fetch, create_event, room_id)
.await;
state_events.extend(fetched_state);
}
}
if state_events.is_empty() {
return Ok(HashMap::new());
}
let mut state: HashMap<ShortStateKey, OwnedEventId> =
HashMap::with_capacity(state_events.len());
debug!(elapsed=?start.elapsed(), events = state_events.len(), "Processing state events");
for (event_id, pdu) in state_events {
let state_key = pdu.state_key().ok_or_else(|| {
err!(Request(BadJson("Found non-state pdu in state events: {event_id}")))
})?;
let shortstatekey = self
.services
.short
.get_or_create_shortstatekey(&pdu.kind().to_string().into(), state_key)
.await;
match state.entry(shortstatekey) {
| hash_map::Entry::Vacant(v) => {
v.insert(pdu.event_id().to_owned());
},
| hash_map::Entry::Occupied(existing) => {
return Err!(Request(Forbidden(
"State event's type and state_key combination exists multiple times \
({event_id} + {}): ({}, \"{}\")",
existing.get(),
pdu.kind(),
state_key,
)));
},
}
}
trace!(elapsed=?start.elapsed(), "fetch_state finished");
Ok(state)
}
// The original create event must still be in the state
let create_shortstatekey = self
.services
.short
.get_shortstatekey(&StateEventType::RoomCreate, "")
.await?;
if state.get(&create_shortstatekey).map(AsRef::as_ref) != Some(create_event.event_id()) {
return Err!(Database("Incoming event refers to wrong create event."));
async fn fetch_state_ids_from_backfill_servers(
&self,
event_id: OwnedEventId,
room_id: OwnedRoomId,
) -> Result<get_room_state_ids::v1::Response> {
let candidates = self
.services
.timeline
.candidate_backfill_servers(&room_id)
.await;
if candidates.is_empty() {
return Err!(Request(NotFound(
"Cannot ask any other servers for the state at this event"
)));
}
debug!(%room_id, ?candidates, "Asking backfill servers for state_ids");
let futures = candidates.iter().map(|server_name| {
Box::pin(
self.services
.sending
.send_federation_request(
server_name,
get_room_state_ids::v1::Request::new(event_id.clone(), room_id.clone()),
)
.inspect_err(|e| {
debug_warn!("Fallback fetching state for event failed: {e}");
}),
)
});
Ok(select_ok(futures).await?.0)
}
Ok(Some(state))
/// Fetches the full state via `GET /_matrix/federation/v1/state` from a
/// remote server, and persists all the incoming auth chain events and
/// state events as outliers, for use later.
///
/// Any events that cannot be persisted are dropped with a warning.
pub(super) async fn fetch_full_state(
&self,
origin: &ServerName,
create_event: &PduEvent,
room_id: &RoomId,
event_id: &EventId,
) -> Result<HashMap<OwnedEventId, PduEvent>> {
let start = Instant::now();
trace!("Fetching full state from remote server");
let res: get_room_state::v1::Response = self
.services
.sending
.send_federation_request(
origin,
get_room_state::v1::Request::new(event_id.to_owned(), room_id.to_owned()),
)
.await
.inspect_err(|e| debug_warn!("Fetching state for event failed: {e}"))?;
debug!(elapsed=?start.elapsed(), count = res.auth_chain.len(), "Handling incoming auth chain...");
res.auth_chain
.iter()
.stream()
.broad_filter_map(|raw_event_json| async {
if let Some(parsed) = self.parse_incoming_pdu(raw_event_json).await.ok()
&& parsed.0 == room_id
{
Some(parsed)
} else {
None
}
})
.for_each_concurrent(
None,
|(incoming_room_id, incoming_event_id, incoming_event_json)| async move {
self.handle_outlier_pdu(
origin,
create_event,
&incoming_event_id,
&incoming_room_id,
incoming_event_json,
)
.await
.inspect_err(|e| {
warn!(
%incoming_room_id,
%incoming_event_id,
?e,
"Failed to handle auth chain event from state fetch"
);
})
.ok();
},
)
.await;
debug!(elapsed=?start.elapsed(), count = res.pdus.len(), "Handling incoming state PDUs...");
let r = res
.pdus
.iter()
.stream()
.broad_filter_map(|raw_event_json| async {
if let Some(parsed) = self.parse_incoming_pdu(raw_event_json).await.ok()
&& parsed.0 == room_id
{
Some(parsed)
} else {
None
}
})
.broad_filter_map(
|(incoming_room_id, incoming_event_id, incoming_event_json)| async move {
self.handle_outlier_pdu(
origin,
create_event,
&incoming_event_id,
&incoming_room_id,
incoming_event_json,
)
.await
.inspect_err(|e| {
warn!(
elapsed=?start.elapsed(),
%incoming_room_id,
%incoming_event_id,
?e,
"Failed to handle state event from state fetch"
);
})
.ok()
},
)
.fold(HashMap::new(), |mut acc, (event, _)| async move {
acc.insert(event.event_id().to_owned(), event);
acc
})
.await;
trace!(elapsed=?start.elapsed(), "fetch_full_state finished");
Ok(r)
}
}
@@ -1,14 +1,14 @@
use std::{
collections::{BTreeMap, hash_map},
time::Instant,
collections::BTreeMap,
time::{Duration, Instant},
};
use conduwuit::{
Err, Event, PduEvent, Result, debug::INFO_SPAN_LEVEL, debug_error, debug_info, defer, err,
implement, info, trace, utils::stream::IterStream, warn,
Err, Event, PduEvent, Result, debug, debug_error, debug_info, debug_warn, defer, err, error,
implement, info, matrix::PartialPdu, result::DebugInspect, trace, warn,
};
use futures::{
FutureExt, TryFutureExt, TryStreamExt,
FutureExt, StreamExt,
future::{OptionFuture, try_join4},
};
use ruma::{
@@ -18,7 +18,6 @@
room::member::{MembershipState, RoomMemberEventContent},
},
};
use tracing::debug;
use crate::rooms::timeline::{RawPduId, pdu_fits};
@@ -111,7 +110,6 @@ async fn should_rescind_invite(
#[implement(super::Service)]
#[tracing::instrument(
name = "pdu",
level = INFO_SPAN_LEVEL,
skip_all,
fields(%room_id, %event_id),
)]
@@ -151,7 +149,7 @@ pub async fn handle_incoming_pdu<'a>(
.and_then(|v| v.as_str())
.ok_or_else(|| err!("No sender in object"))
.and_then(|v| Ok(UserId::parse(v)?))
.map_err(|e| err!(Request(InvalidParam("PDU does not have a valid sender key: {e}"))))?;
.map_err(|e| err!(Request(BadJson("PDU does not have a valid sender key: {e}"))))?;
let sender_acl_check: OptionFuture<_> = sender
.server_name()
@@ -224,75 +222,107 @@ pub async fn handle_incoming_pdu<'a>(
self.federation_handletime
.write()
.remove(room_id);
}};
}}
let (incoming_pdu, val) = self
.handle_outlier_pdu(origin, create_event, event_id, room_id, value, false)
.await?;
.handle_outlier_pdu(origin, create_event, event_id, room_id, value)
.await
.inspect_err(|e| error!("Failed to handle outlier PDU: {e:?}"))?;
// 8. if not timeline event: stop
if !is_timeline_event {
return Ok(None);
}
// Skip old events
// Skip events sent before we joined (they need to be persisted as backfilled
// events, not timeline events, which is handled elsewhere).
let first_ts_in_room = self
.services
.timeline
.first_pdu_in_room(room_id)
.await?
.origin_server_ts();
if incoming_pdu.origin_server_ts() < first_ts_in_room {
return Ok(None);
}
// 9. Fetch any missing prev events doing all checks listed here starting at 1.
// These are timeline events
let (sorted_prev_events, mut eventid_info) = self
.fetch_prev(origin, create_event, room_id, first_ts_in_room, incoming_pdu.prev_events())
.await?;
debug!(
events = ?sorted_prev_events,
"Handling previous events"
);
sorted_prev_events
.iter()
.try_stream()
.map_ok(AsRef::as_ref)
.try_for_each(|prev_id| {
self.handle_prev_pdu(
origin,
event_id,
room_id,
eventid_info.remove(prev_id),
create_event,
first_ts_in_room,
prev_id,
)
.inspect_err(move |e| {
warn!("Prev {prev_id} failed: {e}");
match self
.services
.globals
.bad_event_ratelimiter
.write()
.entry(prev_id.into())
{
| hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1));
},
| hash_map::Entry::Occupied(mut e) => {
let tries = e.get().1.saturating_add(1);
*e.get_mut() = (Instant::now(), tries);
},
}
})
.map(|_| self.services.server.check_running())
})
.boxed()
.await?;
debug!("Fetching and persisting any missing prev events");
self.fetch_prevs(room_id, create_event, &incoming_pdu, origin, first_ts_in_room)
.await
.debug_inspect_err(|e| {
error!("Failed to fetch and persist incoming event's prev_events: {e:?}");
})?;
// Done with prev events, now handling the incoming event
self.upgrade_outlier_to_timeline_pdu(incoming_pdu, val, create_event, origin, room_id)
.boxed()
.await
let pdu_id = self
.upgrade_outlier_to_timeline_pdu(incoming_pdu, val, create_event, origin, room_id)
.await?;
let extremities_count = self
.services
.state
.get_forward_extremities(room_id)
.count()
.await;
if extremities_count >= self.services.server.config.dummy_event_threshold.into() {
self.squash_extremities(room_id, extremities_count).await;
}
Ok(pdu_id)
}
#[implement(super::Service)]
async fn squash_extremities(&self, room_id: &RoomId, count: usize) {
let last_squash = {
let squash_timings = self.last_extremity_squash.read();
squash_timings.get(room_id).copied()
};
if last_squash.is_some_and(|s| s.elapsed() < Duration::from_mins(1)) {
// Avoid sending more than one squash per minute to avoid flooding rooms.
return;
}
debug_warn!(
%count,
threshold=%self.services.server.config.dummy_event_threshold,
"Attempting to squash extremities after upgrading pdu"
);
// Try to send a dummy event to squash extremities. See issue #1844
let power_levels = self
.services
.state_accessor
.get_room_power_levels(room_id)
.await;
let mut local_users = self.services.state_cache.local_users_in_room(room_id);
while let Some(user_id) = local_users.next().await {
if !power_levels.user_can_send_message(&user_id, "org.matrix.dummy_event".into()) {
trace!(%user_id, "user does not have power level to send dummy event, skipping");
continue;
}
let state_lock = self.services.state.mutex.lock(room_id).await;
if self
.services
.timeline
.build_and_append_pdu(
PartialPdu {
event_type: "org.matrix.dummy_event".into(),
..PartialPdu::default()
},
&user_id,
Some(room_id),
&state_lock,
)
.await
.inspect(|_| debug!(sender=%user_id, "Successfully sent a dummy event"))
.inspect_err(|e| debug!(sender=%user_id, ?e, "Failed to send a dummy event via user"))
.is_ok()
{
break;
}
}
let mut squash_timings = self.last_extremity_squash.write();
squash_timings.insert(room_id.to_owned(), Instant::now());
}
@@ -1,12 +1,13 @@
use std::collections::{BTreeMap, HashMap, hash_map};
use conduwuit::{
Err, Event, PduEvent, Result, debug, debug_info, debug_warn, err, implement, state_res,
Err, Event, PduEvent, Result, debug, debug_info, debug_warn, err, implement, info, state_res,
trace, warn,
};
use futures::future::ready;
use ruma::{
CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, RoomId, ServerName,
api::federation::authorization::get_event_authorization, canonical_json::redact,
events::StateEventType,
};
@@ -15,6 +16,7 @@
#[implement(super::Service)]
#[allow(clippy::too_many_arguments)]
#[tracing::instrument(name="handle_outlier", skip_all, fields(%event_id))]
pub(super) async fn handle_outlier_pdu<'a, Pdu>(
&self,
origin: &'a ServerName,
@@ -22,7 +24,6 @@ pub(super) async fn handle_outlier_pdu<'a, Pdu>(
event_id: &'a EventId,
room_id: &'a RoomId,
mut value: CanonicalJsonObject,
auth_events_known: bool,
) -> Result<(PduEvent, BTreeMap<String, CanonicalJsonValue>)>
where
Pdu: Event + Send + Sync,
@@ -46,27 +47,38 @@ pub(super) async fn handle_outlier_pdu<'a, Pdu>(
.verify_event(&value, &room_version_rules)
.await
{
| Ok(ruma::signatures::Verified::All) => value,
| Ok(ruma::signatures::Verified::All) => {
if let Ok(pdu_event) = self.services.timeline.get_pdu(event_id).await {
debug!(
"Already have event {event_id} as an outlier or timeline event, not \
re-processing"
);
value.insert(
"event_id".to_owned(),
CanonicalJsonValue::String(event_id.as_str().to_owned()),
);
check_room_id(room_id, &pdu_event)?;
return Ok((pdu_event, value));
}
value
},
| Ok(ruma::signatures::Verified::Signatures) => {
// Redact
debug_info!("Calculated hash does not match (redaction): {event_id}");
let Ok(obj) =
ruma::canonical_json::redact(value, &room_version_rules.redaction, None)
else {
return Err!(Request(InvalidParam("Redaction failed")));
};
// Skip the PDU if it is redacted and we already have it as an outlier event
if self.services.timeline.pdu_exists(event_id).await {
return Err!(Request(InvalidParam(
"Event was redacted and we already knew about it"
)));
if let Ok(pdu_event) = self.services.timeline.get_pdu(event_id).await {
debug!(
"Received a redacted copy of {event_id}, but we already knew about it. \
Re-using known content instead."
);
check_room_id(room_id, &pdu_event)?;
let obj = pdu_event.to_canonical_object();
return Ok((pdu_event, obj));
}
obj
debug_info!("Calculated hash does not match (redaction): {event_id}");
redact(value, &room_version_rules.redaction, None)
.map_err(|e| err!(Request(BadJson("Failed to redact {event_id}: {e}"))))?
},
| Err(e) => {
return Err!(Request(InvalidParam(debug_error!(
return Err!(Request(Forbidden(debug_error!(
"Signature verification failed for {event_id}: {e}"
))));
},
@@ -87,65 +99,78 @@ pub(super) async fn handle_outlier_pdu<'a, Pdu>(
// Fetch all auth events
let mut auth_events: HashMap<OwnedEventId, PduEvent> = HashMap::new();
for aid in pdu_event.auth_events() {
if self.services.pdu_metadata.is_event_rejected(aid).await {
for auth_event_id in pdu_event.auth_events() {
if self
.services
.pdu_metadata
.is_event_rejected(auth_event_id)
.await
{
debug_warn!(
"Rejecting incoming event {} which depends on rejected auth event {aid}",
"Rejecting incoming event {} which depends on rejected auth event \
{auth_event_id}",
event_id,
);
self.services.pdu_metadata.mark_event_rejected(event_id);
return Err!(Request(InvalidParam("Event has rejected auth event: {aid}")));
return Err!(Request(Forbidden("Event has rejected auth event: {auth_event_id}")));
}
if let Ok(auth_event) = self.services.timeline.get_pdu(aid).await {
if let Ok(auth_event) = self.services.timeline.get_pdu(auth_event_id).await {
check_room_id(room_id, &auth_event)?;
trace!("Found auth event {aid} for outlier event {event_id} locally");
auth_events.insert(aid.to_owned(), auth_event);
trace!("Found auth event {auth_event_id} for outlier event {event_id} locally");
auth_events.insert(auth_event_id.to_owned(), auth_event);
} else {
debug_warn!("Could not find auth event {aid} for outlier event {event_id} locally");
debug_warn!(
"Could not find auth event {auth_event_id} for outlier event {event_id} locally"
);
}
}
// Fetch any missing ones & reject invalid ones
let missing_auth_events = if auth_events_known {
pdu_event
.auth_events()
.filter(|id| !auth_events.contains_key(*id))
.collect::<Vec<_>>()
} else {
pdu_event.auth_events().collect::<Vec<_>>()
};
if !missing_auth_events.is_empty() || !auth_events_known {
debug_info!(
"Fetching {} missing auth events for outlier event {event_id}",
missing_auth_events.len()
);
for (pdu, _) in self
.fetch_and_handle_outliers(
if auth_events.len() != pdu_event.auth_events().count() {
info!("Missing some auth events, asking remote for auth chain");
let response: get_event_authorization::v1::Response = self
.services
.sending
.send_federation_request(
origin,
missing_auth_events.iter().copied(),
create_event,
room_id,
get_event_authorization::v1::Request::new(
room_id.to_owned(),
event_id.to_owned(),
),
)
.await
{
auth_events.insert(pdu.event_id().to_owned(), pdu);
.map_err(|e| {
err!(Request(Forbidden(
"Remote server is not divulging incoming event's auth chain: {e}"
)))
})?;
let mut auth_chain_map = HashMap::with_capacity(response.auth_chain.len());
for auth_pdu_json in response.auth_chain {
let (auth_event_room_id, auth_event_id, auth_pdu_json) =
self.parse_incoming_pdu(&auth_pdu_json).await?;
if auth_event_room_id != room_id {
return Err!(Request(Forbidden(
"Auth event {auth_event_id} is in {auth_event_room_id}, not {room_id}."
)));
}
let auth_pdu = PduEvent::from_id_val(&auth_event_id, auth_pdu_json)
.map_err(|e| err!(Request(BadJson("Invalid PDU {auth_event_id}: {e}"))))?;
auth_chain_map.insert(auth_event_id, auth_pdu);
}
for auth_event_id in pdu_event.auth_events() {
if auth_events.contains_key(auth_event_id) {
continue;
}
if let Some(auth_event) = auth_chain_map.get(auth_event_id) {
auth_events.insert(auth_event_id.to_owned(), auth_event.clone());
} else {
return Err!(Request(Forbidden(
"Remote server is not divulging incoming event's auth events (missing: \
{auth_event_id})"
)));
}
}
} else {
debug!("No missing auth events for outlier event {event_id}");
}
// reject if we are still missing some
let still_missing = pdu_event
.auth_events()
.filter(|id| !auth_events.contains_key(*id))
.collect::<Vec<_>>();
if !still_missing.is_empty() {
// Don't reject: this could be a temporary condition
// TODO: use get_missing_events?
return Err!(Request(InvalidParam(
"Could not fetch all auth events for outlier event {event_id}, still missing: \
{still_missing:?}"
)));
}
// 6. Reject "due to auth events" if the event doesn't pass auth based on the
@@ -176,7 +201,7 @@ pub(super) async fn handle_outlier_pdu<'a, Pdu>(
.outlier
.add_pdu_outlier(pdu_event.event_id(), &incoming_pdu);
self.services.pdu_metadata.mark_event_rejected(event_id);
return Err!(Request(InvalidParam(
return Err!(Request(Forbidden(
"Auth event's type and state_key combination exists multiple times: {}, {}",
auth_event.kind,
auth_event.state_key().unwrap_or("")
@@ -185,18 +210,6 @@ pub(super) async fn handle_outlier_pdu<'a, Pdu>(
}
}
// The original create event must be in the auth events
if !matches!(
auth_events_by_key.get(&(StateEventType::RoomCreate, String::new().into())),
Some(_) | None
) {
self.services.pdu_metadata.mark_event_rejected(event_id);
self.services
.outlier
.add_pdu_outlier(pdu_event.event_id(), &incoming_pdu);
return Err!(Request(InvalidParam("Incoming event refers to wrong create event.")));
}
let state_fetch = |ty: &StateEventType, sk: &str| {
let key = (ty.to_owned(), sk.into());
ready(auth_events_by_key.get(&key).map(ToOwned::to_owned))
@@ -46,7 +46,7 @@ pub(super) async fn handle_prev_pdu<'a, Pdu>(
{
// Exponential backoff
const MIN_DURATION: u64 = 5 * 60;
const MAX_DURATION: u64 = 60 * 60 * 24;
const MAX_DURATION: u64 = 60 * 60;
if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) {
debug!(
?tries,
+8 -2
View File
@@ -4,7 +4,6 @@
mod fetch_state;
mod handle_incoming_pdu;
mod handle_outlier_pdu;
mod handle_prev_pdu;
mod parse_incoming_pdu;
mod policy_server;
mod resolve_state;
@@ -15,6 +14,7 @@
use async_trait::async_trait;
use conduwuit::{Err, Event, PduEvent, Result, Server, SyncRwLock, utils::MutexMap};
pub use fetch_and_handle_outliers::{GET_MISSING_EVENTS_MAX_BATCH_SIZE, build_local_dag};
use ruma::{
OwnedEventId, OwnedRoomId, RoomId, events::room::create::RoomCreateEventContent,
room_version_rules::RoomVersionRules,
@@ -22,10 +22,10 @@
use tokio::sync::Notify;
use crate::{Dep, globals, rooms, sending, server_keys};
pub struct Service {
pub mutex_federation: RoomMutexMap,
pub federation_handletime: SyncRwLock<HandleTimeMap>,
pub last_extremity_squash: SyncRwLock<HashMap<OwnedRoomId, Instant>>,
services: Services,
server_shutdown: Notify,
}
@@ -56,6 +56,7 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
mutex_federation: RoomMutexMap::new(),
federation_handletime: HandleTimeMap::new().into(),
last_extremity_squash: SyncRwLock::new(HashMap::new()),
services: Services {
globals: args.depend::<globals::Service>("globals"),
sending: args.depend::<sending::Service>("sending"),
@@ -91,6 +92,11 @@ async fn memory_usage(&self, out: &mut (dyn Write + Send)) -> Result {
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
fn interrupt(&self) { self.server_shutdown.notify_waiters(); }
async fn clear_cache(&self) {
let mut squashes = self.last_extremity_squash.write();
squashes.clear();
}
}
impl Service {
@@ -7,7 +7,7 @@
use itertools::Itertools;
use ruma::{
CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, OwnedRoomId, RoomId,
RoomVersionId,
RoomVersionId, room_version_rules::RoomVersionRules,
};
use serde_json::value::RawValue as RawJsonValue;
@@ -56,7 +56,10 @@ fn extract_room_id(event_type: &str, pdu: &CanonicalJsonObject) -> Result<OwnedR
/// Parses every entry in an array as an event ID, returning an error if any
/// step fails.
fn expect_event_id_array(value: &CanonicalJsonObject, field: &str) -> Result<Vec<OwnedEventId>> {
pub(super) fn expect_event_id_array(
value: &CanonicalJsonObject,
field: &str,
) -> Result<Vec<OwnedEventId>> {
value
.get(field)
.ok_or_else(|| err!(Request(BadJson("missing field `{field}` on PDU"))))?
@@ -101,6 +104,21 @@ pub fn validate_pdu(&self, pdu: &CanonicalJsonObject) -> Result {
}
#[implement(super::Service)]
pub async fn parse_incoming_pdu_with_known_room(
&self,
pdu: &RawJsonValue,
room_version_rules: &RoomVersionRules,
) -> Result<(OwnedEventId, CanonicalJsonObject)> {
let (event_id, value) =
gen_event_id_canonical_json(pdu, room_version_rules).map_err(|e| {
err!(Request(InvalidParam("Could not convert event to canonical json: {e}")))
})?;
self.validate_pdu(&value)?;
Ok((event_id, value))
}
#[implement(super::Service)]
#[tracing::instrument(name = "parse", skip_all)]
pub async fn parse_incoming_pdu(&self, pdu: &RawJsonValue) -> Result<Parsed> {
let value = serde_json::from_str::<CanonicalJsonObject>(pdu.get()).map_err(|e| {
err!(BadServerResponse(debug_warn!("Error parsing incoming event {e:?}")))
@@ -5,7 +5,7 @@
};
use conduwuit::{
Result, debug, err, error, implement,
Result, debug, debug_error, err, error, implement,
matrix::{Event, StateMap},
trace,
utils::stream::{BroadbandExt, IterStream, ReadyExt, TryBroadbandExt, TryWidebandExt},
@@ -37,6 +37,7 @@ pub(super) async fn state_at_incoming_degree_one<Pdu>(
.pdu_shortstatehash(prev_event)
.await
else {
trace!("No shortstatehash for {prev_event}, cannot calculate one-degree state.");
return Ok(None);
};
@@ -99,6 +100,7 @@ pub(super) async fn state_at_incoming_resolved<Pdu>(
.map_ok(move |sstatehash| (sstatehash, prev_event))
})
.try_collect::<HashMap<_, _>>()
.inspect_err(|e| debug_error!("failed to calculate N-degree short state hashes: {e}"))
.await
else {
return Ok(None);
@@ -1,8 +1,9 @@
use std::{borrow::Borrow, sync::Arc, time::Instant};
use conduwuit::{
Err, Result, debug, debug_info, err, implement, info, is_equal_to,
Err, Result, debug, debug_error, debug_info, err, implement, info, is_equal_to,
matrix::{Event, EventTypeExt, PduEvent, StateKey, state_res},
result::DebugInspect,
trace,
utils::{
IterStream,
@@ -23,28 +24,17 @@
};
#[implement(super::Service)]
pub(super) async fn upgrade_outlier_to_timeline_pdu<Pdu>(
#[tracing::instrument(name="upgrade_outlier", skip_all, fields(event_id=%incoming_pdu.event_id()))]
pub(super) async fn upgrade_outlier_to_timeline_pdu(
&self,
incoming_pdu: PduEvent,
mut val: CanonicalJsonObject,
create_event: &Pdu,
create_event: &PduEvent,
origin: &ServerName,
room_id: &RoomId,
) -> Result<Option<RawPduId>>
where
Pdu: Event + Send + Sync,
{
// Skip the PDU if we already have it as a timeline event
if let Ok(pduid) = self
.services
.timeline
.get_pdu_id(incoming_pdu.event_id())
.await
{
return Ok(Some(pduid));
}
let (rejected, soft_failed) = join!(
) -> Result<Option<RawPduId>> {
let (pduid, rejected, soft_failed) = join!(
self.services.timeline.get_pdu_id(incoming_pdu.event_id()),
self.services
.pdu_metadata
.is_event_rejected(incoming_pdu.event_id()),
@@ -52,17 +42,27 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu<Pdu>(
.pdu_metadata
.is_event_soft_failed(incoming_pdu.event_id())
);
if rejected {
return Err!(Request(InvalidParam("Event has been rejected")));
if let Ok(id) = pduid {
trace!(event_id=%incoming_pdu.event_id(), "Skipping upgrade of already upgraded PDU");
return Ok(Some(id));
} else if rejected {
return Err!(Request(Forbidden("Event has been rejected")));
} else if soft_failed {
return Err!(Request(InvalidParam("Event has been soft-failed")));
return Err!(Request(Forbidden("Event has been soft-failed")));
}
assert_eq!(
*create_event.kind(),
StateEventType::RoomCreate.into(),
"tried to upgrade a PDU with a create_event that is not a room create event"
);
debug!(
event_id = %incoming_pdu.event_id,
"Upgrading PDU from outlier to timeline"
);
let timer = Instant::now();
let min_depth = self.services.metadata.get_mindepth(room_id).await;
let room_version_rules = get_room_version_rules(create_event)?;
// 10. Fetch missing state and auth chain events by calling /state_ids at
@@ -73,21 +73,32 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu<Pdu>(
event_id = %incoming_pdu.event_id,
"Resolving state at event"
);
let mut state_at_incoming_event = if incoming_pdu.prev_events().count() == 1 {
let state_at_incoming_event = if incoming_pdu.prev_events().count() == 1 {
self.state_at_incoming_degree_one(&incoming_pdu).await?
} else {
self.state_at_incoming_resolved(&incoming_pdu, room_id, &room_version_rules)
.await?
};
let state_at_incoming_event = match state_at_incoming_event {
| Some(s) => s,
| None => {
trace!("Could not calculate incoming state, asking remote {origin} for it");
self.fetch_state(origin, create_event, room_id, incoming_pdu.event_id())
.await
.debug_inspect_err(|e| debug_error!("Could not fetch state from {origin}: {e}"))?
},
};
if state_at_incoming_event.is_none() {
state_at_incoming_event = self
.fetch_state(origin, create_event, room_id, incoming_pdu.event_id())
.await?;
if state_at_incoming_event.is_empty()
&& *incoming_pdu.event_type() != StateEventType::RoomCreate.into()
{
// This can happen if the remote sends an event but cannot be reached to fetch
// the state at it, and all other servers in the room (which might just be the
// unreachable server) are unable to provide required info.
// returning an error here allows the upgrade to be attempted at another time.
return Err!(Request(Forbidden("Could not resolve incoming state at event")));
}
let state_at_incoming_event =
state_at_incoming_event.expect("we always set this to some above");
trace!(state_events = state_at_incoming_event.len(), "Calculated incoming state");
debug!(
event_id = %incoming_pdu.event_id,
@@ -382,6 +393,12 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu<Pdu>(
// Event has passed all auth/stateres checks
drop(state_lock);
if incoming_pdu.depth > min_depth && incoming_pdu.state_key().is_some() {
self.services
.metadata
.set_mindepth(room_id, incoming_pdu.depth.into());
trace!("Increased room's min depth from {} to {}", min_depth, incoming_pdu.depth);
}
Ok(pdu_id)
}
+43 -4
View File
@@ -1,7 +1,7 @@
use std::{collections::HashMap, sync::Arc};
use conduwuit::{
Err, Pdu, Result, Server, debug, debug_info, debug_warn, err, error, info, is_true,
Err, Event, Pdu, Result, Server, debug, debug_info, debug_warn, err, error, info, is_true,
matrix::{
StateKey,
event::{gen_event_id, gen_event_id_canonical_json},
@@ -34,7 +34,7 @@
use crate::{
Dep, antispam, globals,
rooms::{
metadata, outlier, pdu_metadata, short,
event_handler, metadata, outlier, pdu_metadata, short,
state::{self, RoomMutexGuard},
state_accessor, state_cache,
state_compressor::{self, CompressedState, HashSetCompressStateEvent},
@@ -51,6 +51,7 @@ struct Services {
server: Arc<Server>,
db: Arc<Database>,
antispam: Dep<antispam::Service>,
event_handler: Dep<event_handler::Service>,
globals: Dep<globals::Service>,
metadata: Dep<metadata::Service>,
outlier: Dep<outlier::Service>,
@@ -73,6 +74,7 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
server: args.server.clone(),
db: args.db.clone(),
antispam: args.depend::<antispam::Service>("antispam"),
event_handler: args.depend::<event_handler::Service>("rooms::event_handler"),
globals: args.depend::<globals::Service>("globals"),
metadata: args.depend::<metadata::Service>("rooms::metadata"),
outlier: args.depend::<outlier::Service>("rooms::outlier"),
@@ -381,8 +383,6 @@ pub async fn join_remote_room(
// It has enough fields to be called a proper event now
let mut join_event = join_event_stub;
info!("Asking {remote_server} for send_join in room {room_id}");
let send_join_request = federation::membership::create_join_event::v2::Request::new(
room_id.to_owned(),
event_id.clone(),
@@ -392,6 +392,18 @@ pub async fn join_remote_room(
.await,
);
// NOTE: send_join can take a long time to respond, but from the point of view
// of other servers, we may already have finished joining. This means they
// sometimes end up sending PDUs to us that we aren't yet ready to accept, and
// consequently drop. Holding the mutex over the room while processing mitigates
// this.
let _room_lock = self
.services
.event_handler
.mutex_federation
.lock(room_id.as_str())
.await;
info!("Asking {remote_server} for send_join in room {room_id}");
let send_join_response = match self
.services
.sending
@@ -577,7 +589,13 @@ pub async fn join_remote_room(
if !auth_check {
return Err!(Request(Forbidden("Auth check failed")));
}
let resident_before = self
.services
.state_cache
.server_in_room(self.services.globals.server_name(), room_id)
.await;
let cork = self.services.db.cork_and_flush();
info!("Compressing state from send_join");
let compressed: CompressedState = self
.services
@@ -626,6 +644,10 @@ pub async fn join_remote_room(
room_id,
)
.await?;
self.services
.metadata
.maybe_set_mindepth(room_id, parsed_join_pdu.depth.into())
.await;
info!("Setting final room state for new room");
// We set the room state after inserting the pdu, so that we never have a moment
@@ -633,6 +655,23 @@ pub async fn join_remote_room(
self.services
.state
.set_room_state(room_id, statehash_after_join, &state_lock);
if !resident_before {
// NOTE: We replace local extremities for this room if we were not a resident
// before. We might be doing a remote join to satisfy restricted join rules,
// so we don't want to do this if we're already a resident. Otherwise, we
// want to replace our forward extremities whole-sale in case we were
// desynced.
info!("Replacing local forward extremities");
self.services
.state
.set_forward_extremities(
room_id,
std::iter::once(parsed_join_pdu.event_id()),
&state_lock,
)
.await;
}
drop(cork);
Ok(())
}
+28 -2
View File
@@ -1,9 +1,9 @@
use std::sync::Arc;
use conduwuit::{Result, implement, utils::stream::TryIgnore};
use database::Map;
use database::{Deserialized, Map};
use futures::{Stream, StreamExt};
use ruma::{OwnedRoomId, RoomId};
use ruma::{OwnedRoomId, RoomId, UInt, uint};
use crate::{Dep, rooms};
@@ -17,6 +17,7 @@ struct Data {
bannedroomids: Arc<Map>,
roomid_shortroomid: Arc<Map>,
pduid_pdu: Arc<Map>,
roomid_mindepth: Arc<Map>,
}
struct Services {
@@ -31,6 +32,7 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
bannedroomids: args.db["bannedroomids"].clone(),
roomid_shortroomid: args.db["roomid_shortroomid"].clone(),
pduid_pdu: args.db["pduid_pdu"].clone(),
roomid_mindepth: args.db["roomid_mindepth"].clone(),
},
services: Services {
short: args.depend::<rooms::short::Service>("rooms::short"),
@@ -98,3 +100,27 @@ pub async fn is_disabled(&self, room_id: &RoomId) -> bool {
pub async fn is_banned(&self, room_id: &RoomId) -> bool {
self.db.bannedroomids.get(room_id).await.is_ok()
}
#[implement(Service)]
pub async fn get_mindepth(&self, room_id: &RoomId) -> UInt {
self.db
.roomid_mindepth
.get(room_id)
.await
.deserialized::<UInt>()
.unwrap_or_else(|_| uint!(0))
}
#[implement(Service)]
pub fn set_mindepth(&self, room_id: &RoomId, min_depth: u64) {
self.db
.roomid_mindepth
.put_raw(room_id.as_bytes(), min_depth.to_be_bytes());
}
#[implement(Service)]
pub async fn maybe_set_mindepth(&self, room_id: &RoomId, min_depth: u64) {
if min_depth > self.get_mindepth(room_id).await.into() {
self.set_mindepth(room_id, min_depth);
}
}
+10
View File
@@ -371,6 +371,16 @@ pub async fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<ShortSta
.deserialized()
}
pub fn all_forward_extremities(
&self,
) -> impl Stream<Item = (OwnedRoomId, OwnedEventId)> + Send {
self.db
.roomid_pduleaves
.keys()
.map_ok(|(room_id, event_id): (OwnedRoomId, OwnedEventId)| (room_id, event_id))
.ignore_err()
}
pub fn get_forward_extremities<'a>(
&'a self,
room_id: &'a RoomId,
+4 -1
View File
@@ -221,7 +221,10 @@ pub async fn backfill_pdu(&self, origin: &ServerName, pdu: Box<RawJsonValue>) ->
}
#[implement(super::Service)]
async fn candidate_backfill_servers(&self, room_id: &RoomId) -> HashSet<OwnedServerName> {
pub(crate) async fn candidate_backfill_servers(
&self,
room_id: &RoomId,
) -> HashSet<OwnedServerName> {
let mut candidate_backfill_servers = HashSet::new();
let power_levels = self
+4
View File
@@ -173,6 +173,10 @@ pub async fn get_non_outlier_pdu_json(
self.db.get_non_outlier_pdu_json(event_id).await
}
pub async fn non_outlier_pdu_exists(&self, event_id: &EventId) -> bool {
self.db.non_outlier_pdu_exists(event_id).await.is_ok()
}
/// Returns the pdu's id.
#[inline]
pub async fn get_pdu_id(&self, event_id: &EventId) -> Result<RawPduId> {
+3 -2
View File
@@ -34,8 +34,9 @@ pub(super) async fn batch_notary_request<'a, S, K>(
batch
});
debug_assert!(!server_keys.is_empty(), "empty batch request to notary");
if server_keys.is_empty() {
return Ok(vec![]);
}
let mut results = Vec::new();
while let Some(batch) = server_keys
+4 -4
View File
@@ -11,8 +11,8 @@
account_data, admin, announcements, antispam, appservice, client, config, emergency,
federation, firstrun, globals, key_backups, mailer,
manager::Manager,
media, moderation, password_reset, presence, pusher, registration_tokens, resolver, rooms,
sending, server_keys,
media, moderation, oauth, presence, pusher, registration_tokens, resolver, rooms, sending,
server_keys,
service::{self, Args, Map, Service},
sync, threepid, transactions, uiaa, users,
};
@@ -27,7 +27,7 @@ pub struct Services {
pub globals: Arc<globals::Service>,
pub key_backups: Arc<key_backups::Service>,
pub media: Arc<media::Service>,
pub password_reset: Arc<password_reset::Service>,
pub oauth: Arc<oauth::Service>,
pub mailer: Arc<mailer::Service>,
pub presence: Arc<presence::Service>,
pub pusher: Arc<pusher::Service>,
@@ -84,7 +84,7 @@ macro_rules! build {
globals: build!(globals::Service),
key_backups: build!(key_backups::Service),
media: build!(media::Service),
password_reset: build!(password_reset::Service),
oauth: build!(oauth::Service),
mailer: build!(mailer::Service),
presence: build!(presence::Service),
pusher: build!(pusher::Service),
+29 -7
View File
@@ -9,8 +9,9 @@
ClientSecret, OwnedClientSecret, OwnedSessionId, SessionId,
api::error::{ErrorKind, LimitExceededErrorData},
};
use tokio::sync::MutexGuard;
mod session;
pub mod session;
use crate::{
Args, Dep, config,
@@ -26,6 +27,7 @@ pub struct Service {
ratelimiter: DefaultKeyedRateLimiter<Address>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum EmailRequirement {
/// Users may change their email, but cannot remove it entirely.
Required,
@@ -219,13 +221,12 @@ pub async fn try_validate_session(
Ok(())
}
/// Consume a validated validation session, removing it from the database
/// and returning the newly validated email address.
pub async fn consume_valid_session(
/// Get a validated validation session.
pub async fn get_valid_session(
&self,
session_id: &SessionId,
client_secret: &ClientSecret,
) -> Result<Address, Cow<'static, str>> {
) -> Result<ValidSession<'_>, Cow<'static, str>> {
let mut sessions = self.sessions.lock().await;
let Some(session) = sessions.get_session(session_id) else {
@@ -235,9 +236,13 @@ pub async fn consume_valid_session(
if session.client_secret == client_secret
&& matches!(session.validation_state, ValidationState::Validated)
{
let session = sessions.remove_session(session_id);
let email = session.email.clone();
Ok(session.email)
Ok(ValidSession {
email,
session_id: session_id.to_owned(),
sessions,
})
} else {
Err("This email address has not been validated. Did you use the link that was sent \
to you?"
@@ -313,3 +318,20 @@ pub async fn get_localpart_for_email(&self, email: &Address) -> Option<String> {
.ok()
}
}
pub struct ValidSession<'lock> {
pub email: Address,
session_id: OwnedSessionId,
sessions: MutexGuard<'lock, ValidationSessions>,
}
impl ValidSession<'_> {
/// Consume this session, removing it from the database and releasing the
/// lock it holds.
#[must_use]
pub fn consume(mut self) -> Address {
self.sessions.remove_session(&self.session_id);
self.email
}
}
+5 -5
View File
@@ -8,14 +8,14 @@
use ruma::{ClientSecret, OwnedClientSecret, OwnedSessionId, SessionId};
#[derive(Default)]
pub(super) struct ValidationSessions {
pub struct ValidationSessions {
sessions: HashMap<OwnedSessionId, ValidationSession>,
client_secrets: HashMap<OwnedClientSecret, OwnedSessionId>,
}
/// A pending or completed email validation session.
#[derive(Debug)]
pub(crate) struct ValidationSession {
pub struct ValidationSession {
/// The session's ID
pub session_id: OwnedSessionId,
/// The client's supplied client secret
@@ -28,7 +28,7 @@ pub(crate) struct ValidationSession {
/// The state of an email validation session.
#[derive(Debug)]
pub(crate) enum ValidationState {
pub enum ValidationState {
/// The session is waiting for this validation token to be provided
Pending(ValidationToken),
/// The session has been validated
@@ -36,7 +36,7 @@ pub(crate) enum ValidationState {
}
#[derive(Clone, Debug)]
pub(crate) struct ValidationToken {
pub struct ValidationToken {
pub token: String,
pub issued_at: SystemTime,
}
@@ -69,7 +69,7 @@ impl ValidationSessions {
const RANDOM_SID_LENGTH: usize = 16;
#[must_use]
pub(super) fn generate_session_id() -> OwnedSessionId {
pub fn generate_session_id() -> OwnedSessionId {
SessionId::parse(utils::random_string(Self::RANDOM_SID_LENGTH)).unwrap()
}
+303 -151
View File
@@ -7,7 +7,7 @@
use conduwuit::{Err, Error, Result, error, utils};
use lettre::Address;
use ruma::{
UserId,
DeviceId, UserId,
api::{
client::uiaa::{
AuthData, AuthFlow, AuthType, EmailIdentity, EmailUserIdentifier,
@@ -16,11 +16,19 @@
},
error::{ErrorKind, StandardErrorBody},
},
assign,
};
use serde_json::{
json,
value::{RawValue, to_raw_value},
};
use serde_json::value::RawValue;
use tokio::sync::Mutex;
use crate::{Dep, config, globals, registration_tokens, threepid, users};
use crate::{
Dep, config, globals,
oauth::{self, OAuthTicket},
registration_tokens, threepid, users,
};
pub struct Service {
services: Services,
@@ -33,6 +41,7 @@ struct Services {
config: Dep<config::Service>,
registration_tokens: Dep<registration_tokens::Service>,
threepid: Dep<threepid::Service>,
oauth: Dep<oauth::Service>,
}
impl crate::Service for Service {
@@ -45,6 +54,7 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
registration_tokens: args
.depend::<registration_tokens::Service>("registration_tokens"),
threepid: args.depend::<threepid::Service>("threepid"),
oauth: args.depend::<oauth::Service>("oauth"),
},
uiaa_sessions: Mutex::new(HashMap::new()),
}))
@@ -54,8 +64,56 @@ fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
struct UiaaSession {
session_metadata: UiaaSessionMetadata,
info: UiaaInfo,
identity: Identity,
}
#[derive(Clone)]
enum UiaaSessionMetadata {
Legacy {
identity: Identity,
},
OAuth {
localpart: String,
ticket: OAuthTicket,
},
}
impl UiaaSessionMetadata {
fn into_identity(self) -> Identity {
match self {
| Self::Legacy { identity } => identity,
| Self::OAuth { localpart, .. } =>
assign!(Identity::default(), { localpart: Some(localpart) }),
}
}
}
/// Information about the user which is initiating this UIAA session.
pub struct UiaaInitiator<'a> {
user_id: &'a UserId,
device_id: Option<&'a DeviceId>,
oauth_ticket: Option<OAuthTicket>,
}
impl<'a> UiaaInitiator<'a> {
#[must_use]
pub fn new(user_id: &'a UserId, device_id: Option<&'a DeviceId>) -> Self {
Self { user_id, device_id, oauth_ticket: None }
}
#[must_use]
pub fn with_oauth_ticket(
user_id: &'a UserId,
device_id: Option<&'a DeviceId>,
oauth_ticket: OAuthTicket,
) -> Self {
Self {
user_id,
device_id,
oauth_ticket: Some(oauth_ticket),
}
}
}
/// Information about the authenticated user's identity.
@@ -106,7 +164,7 @@ impl Identity {
/// Create an Identity with the localpart of the provided user ID
/// and all other fields set to None.
#[must_use]
pub fn from_user_id(user_id: &UserId) -> Self {
fn from_user_id(user_id: &UserId) -> Self {
Self {
localpart: Some(user_id.localpart().to_owned()),
..Default::default()
@@ -124,11 +182,11 @@ pub async fn authenticate(
auth: &Option<AuthData>,
flows: Vec<AuthFlow>,
params: Box<RawValue>,
identity: Option<Identity>,
initiator: Option<UiaaInitiator<'_>>,
) -> Result<Identity> {
match auth.as_ref() {
| None => {
let info = self.create_session(flows, params, identity).await;
let info = self.create_session(flows, params, initiator).await?;
Err(Error::Uiaa(info))
},
@@ -140,8 +198,8 @@ pub async fn authenticate(
// session if they want to start the UIAA exchange with existing
// authentication data. If that happens, we create a new session
// here.
self.create_session(flows, params, identity)
.await
self.create_session(flows, params, initiator)
.await?
.session
.unwrap()
.into()
@@ -161,13 +219,15 @@ pub async fn authenticate(
pub async fn authenticate_password(
&self,
auth: &Option<AuthData>,
identity: Option<Identity>,
user_id: &UserId,
device_id: Option<&DeviceId>,
oauth_ticket: Option<OAuthTicket>,
) -> Result<Identity> {
self.authenticate(
auth,
vec![AuthFlow::new(vec![AuthType::Password])],
Box::default(),
identity,
Some(UiaaInitiator { user_id, device_id, oauth_ticket }),
)
.await
}
@@ -183,20 +243,88 @@ async fn create_session(
&self,
flows: Vec<AuthFlow>,
params: Box<RawValue>,
identity: Option<Identity>,
) -> UiaaInfo {
initiator: Option<UiaaInitiator<'_>>,
) -> Result<UiaaInfo> {
let mut uiaa_sessions = self.uiaa_sessions.lock().await;
let session_id = utils::random_string(Self::SESSION_ID_LENGTH);
let mut info = assign::assign!(UiaaInfo::new(flows), {params: Some(params)});
info.session = Some(session_id.clone());
uiaa_sessions.insert(session_id, UiaaSession {
info: info.clone(),
identity: identity.unwrap_or_default(),
});
let mut info = assign!(UiaaInfo::new(flows), { params: Some(params), session: Some(session_id.clone()) });
info
let session_metadata = if let Some(initiator) = initiator {
let is_oauth = if let Some(device_id) = initiator.device_id {
self.services
.oauth
.get_session_info_for_device(initiator.user_id, device_id)
.await
.is_some()
} else {
// Appservices never have oauth sessions
false
};
if is_oauth {
if let Some(oauth_ticket) = initiator.oauth_ticket {
let ticket_url = self
.services
.config
.get_client_domain()
.join(&format!(
"{}{}",
conduwuit_core::ROUTE_PREFIX,
oauth_ticket.ticket_issue_path()
))
.unwrap();
info.flows = vec![AuthFlow::new(vec![AuthType::OAuth])];
info.params = Some(
to_raw_value(&json!({
AuthType::OAuth.as_str(): {
"url": ticket_url,
},
// TODO(compat): This is necessary for older versions of matrix-rust-sdk
"org.matrix.cross_signing_reset": {
"url": ticket_url,
}
}))
.unwrap(),
);
UiaaSessionMetadata::OAuth {
localpart: initiator.user_id.localpart().to_owned(),
ticket: oauth_ticket,
}
} else {
return Err!(Request(Forbidden(
"Clients authorized with OAuth cannot use this route."
)));
}
} else {
UiaaSessionMetadata::Legacy {
identity: Identity::from_user_id(initiator.user_id),
}
}
} else {
UiaaSessionMetadata::Legacy { identity: Identity::default() }
};
// Legacy sessions aren't available if OAuth is required
if matches!(&session_metadata, UiaaSessionMetadata::Legacy { .. })
&& !self
.services
.config
.oauth
.compatibility_mode
.uiaa_available()
{
return Err!(Request(Unrecognized(
"User-interactive authentication is unavailable on this server"
)));
}
uiaa_sessions.insert(session_id, UiaaSession { session_metadata, info: info.clone() });
Ok(info)
}
/// Proceed with UIAA authentication given a client's authorization data.
@@ -225,7 +353,7 @@ async fn continue_session(
}
let completed = {
let UiaaSession { info, identity } = session.get_mut();
let UiaaSession { session_metadata, info } = session.get_mut();
let auth_type = auth.auth_type().expect("auth type should be set");
@@ -258,12 +386,12 @@ async fn continue_session(
// If the provided stage hasn't already been completed, check it for completion
if !completed_stages.contains(auth_type.as_str()) {
match self.check_stage(auth, identity.clone()).await {
| Ok((completed_stage, updated_identity)) => {
match self.check_stage(auth, session_metadata.clone()).await {
| Ok((completed_stage, updated_metadata)) => {
info.auth_error = None;
completed_stages.insert(completed_stage.to_string());
info.completed.push(completed_stage);
*identity = updated_identity;
*session_metadata = updated_metadata;
},
| Err(error) => {
info.auth_error = Some(error);
@@ -279,9 +407,9 @@ async fn continue_session(
if completed {
// This session is complete, remove it and return success
let (_, UiaaSession { identity, .. }) = session.remove_entry();
let (_, UiaaSession { session_metadata, .. }) = session.remove_entry();
Ok(Ok(identity))
Ok(Ok(session_metadata.into_identity()))
} else {
// The client needs to try again, return the updated session
Ok(Err(session.get().info.clone()))
@@ -295,152 +423,176 @@ async fn continue_session(
async fn check_stage(
&self,
auth: &AuthData,
mut identity: Identity,
) -> Result<(AuthType, Identity), StandardErrorBody> {
// Note: This function takes ownership of `identity` because mutations to the
// identity must not be applied unless checking the stage succeeds. The
// updated identity is returned as part of the Ok value, and
// `continue_session` handles saving it to `uiaa_sessions`.
mut session_metadata: UiaaSessionMetadata,
) -> Result<(AuthType, UiaaSessionMetadata), StandardErrorBody> {
// Note: This function takes ownership of `session_metadata` because mutations
// to the identity (if it's a legacy session) must not be applied unless
// checking the stage succeeds. The updated identity is returned as part of
// the Ok value, and `continue_session` handles saving it to `uiaa_sessions`.
//
// This also means it's fine to mutate `identity` at any point in this function,
// because those mutations won't be saved unless the function returns Ok.
match auth {
| AuthData::Dummy(_) => Ok(AuthType::Dummy),
| AuthData::EmailIdentity(EmailIdentity {
thirdparty_id_creds: ThirdpartyIdCredentials { client_secret, sid, .. },
..
}) => {
match self
.services
.threepid
.consume_valid_session(sid, client_secret)
.await
{
| Ok(email) => {
if let Some(localpart) =
self.services.threepid.get_localpart_for_email(&email).await
{
identity.try_set_localpart(localpart)?;
}
let completed_auth_type = match &mut session_metadata {
| UiaaSessionMetadata::OAuth { localpart, ticket } => {
// m.oauth is the only valid stage for oauth sessions
assert!(
matches!(auth, AuthData::OAuth(_)),
"got non-oauth auth data for oauth session"
);
identity.try_set_email(email)?;
Ok(AuthType::EmailIdentity)
},
| Err(message) => Err(StandardErrorBody::new(
ErrorKind::ThreepidAuthFailed,
message.into_owned(),
)),
if self.services.oauth.try_consume_ticket(localpart, *ticket) {
Ok(AuthType::OAuth)
} else {
Err(StandardErrorBody::new(
ErrorKind::Forbidden,
"No OAuth ticket available".to_owned(),
))
}
},
#[allow(clippy::useless_let_if_seq)]
| AuthData::Password(Password { identifier, password, .. }) => {
let user_id_or_localpart = match identifier {
| UserIdentifier::Matrix(MatrixUserIdentifier { user, .. }) =>
user.to_owned(),
| UserIdentifier::Email(EmailUserIdentifier { address, .. }) => {
let Ok(email) = Address::try_from(address.to_owned()) else {
return Err(StandardErrorBody::new(
ErrorKind::InvalidParam,
"Email is malformed".to_owned(),
));
};
| UiaaSessionMetadata::Legacy { identity } => match auth {
| AuthData::Dummy(_) => Ok(AuthType::Dummy),
| AuthData::EmailIdentity(EmailIdentity {
thirdparty_id_creds: ThirdpartyIdCredentials { client_secret, sid, .. },
..
}) => {
match self
.services
.threepid
.get_valid_session(sid, client_secret)
.await
{
| Ok(session) => {
let email = session.consume();
if let Some(localpart) =
self.services.threepid.get_localpart_for_email(&email).await
{
identity.try_set_localpart(localpart)?;
}
if let Some(localpart) =
self.services.threepid.get_localpart_for_email(&email).await
{
identity.try_set_email(email)?;
localpart
} else {
return Err(StandardErrorBody::new(
ErrorKind::Forbidden,
"Invalid identifier or password".to_owned(),
));
}
},
| _ =>
return Err(StandardErrorBody::new(
ErrorKind::Unrecognized,
"Identifier type not recognized".to_owned(),
Ok(AuthType::EmailIdentity)
},
| Err(message) => Err(StandardErrorBody::new(
ErrorKind::ThreepidAuthFailed,
message.into_owned(),
)),
};
}
},
#[allow(clippy::useless_let_if_seq)]
| AuthData::Password(Password { identifier, password, .. }) => {
let user_id_or_localpart = match identifier {
| UserIdentifier::Matrix(MatrixUserIdentifier { user, .. }) =>
user.to_owned(),
| UserIdentifier::Email(EmailUserIdentifier { address, .. }) => {
let Ok(email) = Address::try_from(address.to_owned()) else {
return Err(StandardErrorBody::new(
ErrorKind::InvalidParam,
"Email is malformed".to_owned(),
));
};
let Ok(user_id) = UserId::parse_with_server_name(
user_id_or_localpart,
self.services.globals.server_name(),
) else {
return Err(StandardErrorBody::new(
ErrorKind::InvalidParam,
"User ID is malformed".to_owned(),
));
};
if let Some(localpart) =
self.services.threepid.get_localpart_for_email(&email).await
{
identity.try_set_email(email)?;
if self
.services
.users
.check_password(&user_id, password)
.await
.is_ok()
{
identity.try_set_localpart(user_id.localpart().to_owned())?;
localpart
} else {
return Err(StandardErrorBody::new(
ErrorKind::Forbidden,
"Invalid identifier or password".to_owned(),
));
}
},
| _ =>
return Err(StandardErrorBody::new(
ErrorKind::Unrecognized,
"Identifier type not recognized".to_owned(),
)),
};
Ok(AuthType::Password)
} else {
Err(StandardErrorBody::new(
ErrorKind::Forbidden,
"Invalid identifier or password".to_owned(),
))
}
},
| AuthData::ReCaptcha(ReCaptcha { response, .. }) => {
let Some(ref private_site_key) = self.services.config.recaptcha_private_site_key
else {
return Err(StandardErrorBody::new(
ErrorKind::Forbidden,
"ReCaptcha is not configured".to_owned(),
));
};
let Ok(user_id) = UserId::parse_with_server_name(
user_id_or_localpart,
self.services.globals.server_name(),
) else {
return Err(StandardErrorBody::new(
ErrorKind::InvalidParam,
"User ID is malformed".to_owned(),
));
};
match recaptcha_verify::verify_v3(private_site_key, response, None).await {
| Ok(()) => Ok(AuthType::ReCaptcha),
| Err(e) => {
error!("ReCaptcha verification failed: {e:?}");
if self
.services
.users
.check_password(&user_id, password)
.await
.is_ok()
{
identity.try_set_localpart(user_id.localpart().to_owned())?;
Ok(AuthType::Password)
} else {
Err(StandardErrorBody::new(
ErrorKind::Forbidden,
"ReCaptcha verification failed".to_owned(),
"Invalid identifier or password".to_owned(),
))
},
}
},
| AuthData::RegistrationToken(RegistrationToken { token, .. }) => {
let token = token.trim().to_owned();
}
},
| AuthData::ReCaptcha(ReCaptcha { response, .. }) => {
let Some(ref private_site_key) =
self.services.config.recaptcha_private_site_key
else {
return Err(StandardErrorBody::new(
ErrorKind::Forbidden,
"ReCaptcha is not configured".to_owned(),
));
};
if let Some(valid_token) = self
.services
.registration_tokens
.validate_token(token)
.await
{
self.services
match recaptcha_verify::verify_v3(private_site_key, response, None).await {
| Ok(()) => Ok(AuthType::ReCaptcha),
| Err(e) => {
error!("ReCaptcha verification failed: {e:?}");
Err(StandardErrorBody::new(
ErrorKind::CaptchaInvalid,
"ReCaptcha verification failed".to_owned(),
))
},
}
},
| AuthData::RegistrationToken(RegistrationToken { token, .. }) => {
let token = token.trim().to_owned();
if let Some(valid_token) = self
.services
.registration_tokens
.mark_token_as_used(valid_token);
.validate_token(token)
.await
{
self.services
.registration_tokens
.mark_token_as_used(valid_token);
Ok(AuthType::RegistrationToken)
} else {
Err(StandardErrorBody::new(
ErrorKind::Forbidden,
"Invalid registration token".to_owned(),
))
}
Ok(AuthType::RegistrationToken)
} else {
Err(StandardErrorBody::new(
ErrorKind::Forbidden,
"Invalid registration token".to_owned(),
))
}
},
| AuthData::Terms(_) => Ok(AuthType::Terms),
| unknown => {
// We already checked that the stage type is one that exists in the flow,
// so we can only get here if we ourselves served a flow with a stage that we
// don't understand.
panic!("tried to check an unsupported stage type: {unknown:?}");
},
},
| AuthData::Terms(_) => Ok(AuthType::Terms),
| _ => Err(StandardErrorBody::new(
ErrorKind::Unrecognized,
"Unsupported stage type".into(),
)),
}
.map(|auth_type| (auth_type, identity))
}?;
Ok((completed_auth_type, session_metadata))
}
}
+1 -1
View File
@@ -54,6 +54,7 @@ pub async fn set_dehydrated_device(&self, user_id: &UserId, request: Request) ->
user_id,
&request.device_id,
"",
None,
request.initial_device_display_name.clone(),
None,
)
@@ -138,7 +139,6 @@ pub async fn get_dehydrated_device_id(&self, user_id: &UserId) -> Result<OwnedDe
level = "debug",
skip_all,
fields(%user_id),
ret,
)]
pub async fn get_dehydrated_device(&self, user_id: &UserId) -> Result<DehydratedDevice> {
self.db
+334 -18
View File
@@ -1,13 +1,21 @@
pub(super) mod dehydrated_device;
use std::{collections::BTreeMap, mem, net::IpAddr, sync::Arc};
use std::{
collections::BTreeMap,
mem,
net::IpAddr,
sync::Arc,
time::{Duration, SystemTime},
};
use conduwuit::{
Err, Error, Result, Server, debug_error, debug_warn, err, trace,
Err, Error, Result, debug_error, debug_warn, err, info, trace,
utils::{self, ReadyExt, stream::TryIgnore, string::Unquoted},
warn,
};
use database::{Deserialized, Ignore, Interfix, Json, Map};
use futures::{Stream, StreamExt, TryFutureExt};
use futures::{FutureExt, Stream, StreamExt, TryFutureExt};
use lettre::Address;
use ruma::{
DeviceId, MilliSecondsSinceUnixEpoch, OneTimeKeyAlgorithm, OneTimeKeyId, OneTimeKeyName,
OwnedDeviceId, OwnedKeyId, OwnedMxcUri, OwnedOneTimeKeyId, OwnedUserId, RoomId, UInt, UserId,
@@ -18,15 +26,24 @@
encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
events::{
AnyToDeviceEvent, GlobalAccountDataEventType, ignored_user_list::IgnoredUserListEvent,
push_rules::PushRulesEvent, room::message::RoomMessageEventContent,
},
push::Ruleset,
serde::Raw,
uint,
};
use ruminuwuity::invite_permission_config::{FilterLevel, InvitePermissionConfigEvent};
use serde::{Deserialize, Serialize};
use serde_json::json;
use tracing::error;
use crate::{Dep, account_data, admin, appservice, globals, rooms};
use crate::{
Dep, account_data, admin,
appservice::{self, RegistrationInfo},
config, firstrun, globals, oauth,
rooms::{self, alias, membership},
threepid,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserSuspension {
@@ -41,6 +58,7 @@ pub struct UserSuspension {
/// A password hash. This is only for use when setting a user's password,
/// if the hash needs to be kept around for a while without keeping the password
/// in memory.
#[derive(Serialize, Deserialize)]
pub struct HashedPassword(String);
impl HashedPassword {
@@ -51,19 +69,30 @@ pub fn new(password: &str) -> Result<Self> {
}
}
/// The status of an access token.
pub enum AccessTokenStatus {
Valid,
Expired,
}
pub struct Service {
services: Services,
db: Data,
}
struct Services {
server: Arc<Server>,
account_data: Dep<account_data::Service>,
admin: Dep<admin::Service>,
alias: Dep<alias::Service>,
appservice: Dep<appservice::Service>,
config: Dep<config::Service>,
firstrun: Dep<firstrun::Service>,
globals: Dep<globals::Service>,
membership: Dep<membership::Service>,
oauth: Dep<oauth::Service>,
state_accessor: Dep<rooms::state_accessor::Service>,
state_cache: Dep<rooms::state_cache::Service>,
threepid: Dep<threepid::Service>,
}
struct Data {
@@ -75,6 +104,7 @@ struct Data {
logintoken_expiresatuserid: Arc<Map>,
todeviceid_events: Arc<Map>,
token_userdeviceid: Arc<Map>,
userdeviceid_tokenexpires: Arc<Map>,
userdeviceid_metadata: Arc<Map>,
userdeviceid_token: Arc<Map>,
userfilterid_filter: Arc<Map>,
@@ -97,14 +127,19 @@ impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
services: Services {
server: args.server.clone(),
account_data: args.depend::<account_data::Service>("account_data"),
admin: args.depend::<admin::Service>("admin"),
alias: args.depend::<alias::Service>("alias"),
appservice: args.depend::<appservice::Service>("appservice"),
config: args.depend::<config::Service>("config"),
firstrun: args.depend::<firstrun::Service>("firstrun"),
globals: args.depend::<globals::Service>("globals"),
membership: args.depend::<membership::Service>("membership"),
oauth: args.depend::<oauth::Service>("oauth"),
state_accessor: args
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
threepid: args.depend::<threepid::Service>("threepid"),
},
db: Data {
keychangeid_userid: args.db["keychangeid_userid"].clone(),
@@ -131,6 +166,7 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
userid_selfsigningkeyid: args.db["userid_selfsigningkeyid"].clone(),
userid_usersigningkeyid: args.db["userid_usersigningkeyid"].clone(),
useridprofilekey_value: args.db["useridprofilekey_value"].clone(),
userdeviceid_tokenexpires: args.db["userdeviceid_tokenexpires"].clone(),
},
}))
}
@@ -206,12 +242,239 @@ pub async fn create(&self, user_id: &UserId, password: Option<HashedPassword>) -
Ok(())
}
// /// Create a new account for a local human or bot user.
// pub async fn create_local_account(
// &self,
// username: String,
// password:
// )
/// Create a new account for a local human or bot user.
pub async fn create_local_account(
&self,
user_id: &UserId,
password: HashedPassword,
email: Option<Address>,
) {
self.create(user_id, Some(password))
.await
.expect("should be able to save a new local user. what happened?");
// Set an initial display name
{
let mut displayname = user_id.localpart().to_owned();
let suffix = &self.services.config.new_user_displayname_suffix;
if !suffix.is_empty() {
displayname.push(' ');
displayname.push_str(suffix);
}
self.set_displayname(user_id, Some(displayname));
};
// Set default push rules
self.services
.account_data
.update(
None,
user_id,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(PushRulesEvent::new(
Ruleset::server_default(user_id).into(),
))
.expect("should be able to serialize push rules"),
)
.await
.expect("should be able to update account data");
// If the user registered with an email, associate it with their account.
if let Some(email) = email {
// This may fail if the email is already in use, but we should have already
// checked that when we sent the validation email, so ignoring the error is
// acceptable here in the rare case that an email is sniped by another user
// between the validation email being sent and the account being created.
let _ = self
.services
.threepid
.associate_localpart_email(user_id.localpart(), &email)
.await;
}
// Attempt to empower the first user and disable first-run mode.
let was_first_user = self.services.firstrun.empower_first_user(user_id).await;
// If the registering user was not the first and we're suspending users on
// register, suspend them.
if !was_first_user && self.services.config.suspend_on_register {
// Note that we can still do auto joins for suspended users
self.suspend_account(user_id, &self.services.globals.server_user)
.await;
// And send an @room notice to the admin room, to prompt admins to review the
// new user and ideally unsuspend them if deemed appropriate.
if self.services.config.admin_room_notices {
self.services
.admin
.send_loud_message(RoomMessageEventContent::text_plain(format!(
"User {user_id} has been suspended as they are not the first user on \
this server. Please review and unsuspend them if appropriate."
)))
.await
.ok();
}
}
// Autojoin the user to the configured autojoin rooms
for room in &self.services.config.auto_join_rooms {
let Ok(room_id) = self.services.alias.resolve(room).await else {
error!(
"Failed to resolve room alias to room ID when attempting to auto join \
{room}, skipping"
);
continue;
};
if !self
.services
.state_cache
.server_in_room(self.services.globals.server_name(), &room_id)
.await
{
warn!(
"Skipping room {room} to automatically join as we have never joined before."
);
continue;
}
if let Some(room_server_name) = room.server_name() {
match self
.services
.membership
.join_room(
user_id,
&room_id,
Some("Automatically joining this room upon registration".to_owned()),
&[
self.services.globals.server_name().to_owned(),
room_server_name.to_owned(),
],
)
.boxed()
.await
{
| Err(e) => {
// don't return this error so we don't fail registrations
error!(
"Failed to automatically join room {room} for user {user_id}: {e}"
);
},
| _ => {
info!("Automatically joined room {room} for user {user_id}");
},
}
}
}
info!("Created new user account for {user_id}");
}
pub async fn determine_registration_user_id(
&self,
supplied_username: Option<String>,
email: Option<&Address>,
appservice_info: Option<&RegistrationInfo>,
) -> Result<OwnedUserId> {
const RANDOM_USER_ID_LENGTH: usize = 10;
let emergency_mode_enabled = self.services.config.emergency_password.is_some();
let supplied_username = supplied_username.or_else(|| {
// If the user didn't supply a username but did supply an email, use
// the email's user part to avoid falling back to a random username
email.map(|address| address.user().to_owned())
});
if let Some(supplied_username) = supplied_username {
// The user gets to pick their username. Do some validation to make sure it's
// acceptable.
// Don't allow registration with forbidden usernames.
if self
.services
.globals
.forbidden_usernames()
.is_match(&supplied_username)
&& !emergency_mode_enabled
{
return Err!(Request(Forbidden("Username is forbidden")));
}
// Create and validate the user ID
let user_id = match UserId::parse_with_server_name(
&supplied_username,
self.services.globals.server_name(),
) {
| Ok(user_id) => {
if let Err(e) = user_id.validate_strict() {
// Unless we are in emergency mode, we should follow synapse's behaviour
// on not allowing things like spaces and UTF-8 characters in
// usernames
if !emergency_mode_enabled {
return Err!(Request(InvalidUsername(debug_warn!(
"Username {supplied_username} contains disallowed characters or \
spaces: {e}"
))));
}
}
// Don't allow registration with user IDs that aren't local
if !self.services.globals.user_is_local(&user_id) {
return Err!(Request(InvalidUsername(
"Username {supplied_username} is not local to this server"
)));
}
user_id
},
| Err(e) => {
return Err!(Request(InvalidUsername(debug_warn!(
"Username {supplied_username} is not valid: {e}"
))));
},
};
if self.exists(&user_id).await {
return Err!(Request(UserInUse("User ID is not available.")));
}
// Check that the user ID is/is not in an appservice's namespace
if let Some(appservice_info) = appservice_info {
if !appservice_info.is_user_match(&user_id) && !emergency_mode_enabled {
return Err!(Request(Exclusive(
"Username is not in this appservice's namespace."
)));
}
} else if self
.services
.appservice
.is_exclusive_user_id(&user_id)
.await && !emergency_mode_enabled
{
return Err!(Request(Exclusive("Username is reserved by an appservice.")));
}
Ok(user_id)
} else {
// The user didn't specify a username. Generate a username for
// them.
loop {
let user_id = UserId::parse_with_server_name(
utils::random_string(RANDOM_USER_ID_LENGTH).to_lowercase(),
self.services.globals.server_name(),
)
.unwrap();
if !self.exists(&user_id).await {
break Ok(user_id);
}
}
}
}
/// Deactivate account
pub async fn deactivate_account(&self, user_id: &UserId) -> Result<()> {
@@ -351,9 +614,42 @@ pub async fn is_active_local(&self, user_id: &UserId) -> bool {
pub async fn count(&self) -> usize { self.db.userid_password.count().await }
/// Find out which user an access token belongs to.
pub async fn find_from_token(&self, token: &str) -> Result<(OwnedUserId, OwnedDeviceId)> {
assert!(!token.is_empty(), "Empty access token");
self.db.token_userdeviceid.get(token).await.deserialized()
pub async fn find_from_token(
&self,
token: &str,
) -> Option<(OwnedUserId, OwnedDeviceId, AccessTokenStatus)> {
let user = self
.db
.token_userdeviceid
.get(token)
.await
.deserialized()
.ok();
// Check if the token has expired
if let Some((user_id, device_id)) = user {
if let Some(expires) = self
.db
.userdeviceid_tokenexpires
.qry(&(&user_id, &device_id))
.await
.deserialized::<u64>()
.ok()
.map(Duration::from_secs)
{
let expires_at = SystemTime::UNIX_EPOCH
.checked_add(expires)
.expect("expiry time should not overflow SystemTime");
if SystemTime::now() > expires_at {
return Some((user_id, device_id, AccessTokenStatus::Expired));
}
}
Some((user_id, device_id, AccessTokenStatus::Valid))
} else {
None
}
}
/// Returns an iterator over all users on this homeserver.
@@ -449,6 +745,7 @@ pub async fn create_device(
user_id: &UserId,
device_id: &DeviceId,
token: &str,
token_max_age: Option<Duration>,
initial_device_display_name: Option<String>,
client_ip: Option<String>,
) -> Result<()> {
@@ -466,7 +763,8 @@ pub async fn create_device(
increment(&self.db.userid_devicelistversion, user_id.as_bytes());
self.db.userdeviceid_metadata.put(key, Json(device));
self.set_token(user_id, device_id, token).await
self.set_token(user_id, device_id, token, token_max_age)
.await
}
/// Removes a device from a user.
@@ -482,6 +780,7 @@ pub async fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) {
if let Ok(old_token) = self.db.userdeviceid_token.qry(&userdeviceid).await {
self.db.userdeviceid_token.del(userdeviceid);
self.db.token_userdeviceid.remove(&old_token);
self.db.userdeviceid_tokenexpires.del(userdeviceid);
}
// Remove todevice events
@@ -495,6 +794,9 @@ pub async fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) {
// TODO: Remove onetimekeys
// Remove OAuth session information
self.services.oauth.remove_session(user_id, device_id).await;
increment(&self.db.userid_devicelistversion, user_id.as_bytes());
self.db.userdeviceid_metadata.del(userdeviceid);
@@ -550,6 +852,7 @@ pub async fn set_token(
user_id: &UserId,
device_id: &DeviceId,
token: &str,
token_max_age: Option<Duration>,
) -> Result<()> {
let key = (user_id, device_id);
if self.db.userdeviceid_metadata.qry(&key).await.is_err() {
@@ -576,6 +879,7 @@ pub async fn set_token(
// Remove old token
if let Ok(old_token) = self.db.userdeviceid_token.qry(&key).await {
self.db.token_userdeviceid.remove(&old_token);
self.db.userdeviceid_tokenexpires.remove(&old_token);
// It will be removed from userdeviceid_token by the insert later
}
@@ -583,6 +887,18 @@ pub async fn set_token(
self.db.userdeviceid_token.put_raw(key, token);
self.db.token_userdeviceid.raw_put(token, key);
if let Some(max_age) = token_max_age {
let expires = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.expect("system time should not be before the epoch")
.saturating_add(max_age)
.as_secs();
self.db.userdeviceid_tokenexpires.put(key, expires);
} else {
self.db.userdeviceid_tokenexpires.del(key);
}
Ok(())
}
@@ -1253,7 +1569,7 @@ pub async fn get_filter(
pub fn create_openid_token(&self, user_id: &UserId, token: &str) -> Result<u64> {
use std::num::Saturating as Sat;
let expires_in = self.services.server.config.openid_token_ttl;
let expires_in = self.services.config.openid_token_ttl;
let expires_at = Sat(utils::millis_since_unix_epoch()) + Sat(expires_in) * Sat(1000);
let mut value = expires_at.0.to_be_bytes().to_vec();
@@ -1297,7 +1613,7 @@ pub async fn find_from_openid_token(&self, token: &str) -> Result<OwnedUserId> {
pub fn create_login_token(&self, user_id: &UserId, token: &str) -> u64 {
use std::num::Saturating as Sat;
let expires_in = self.services.server.config.login_token_ttl;
let expires_in = self.services.config.login_token_ttl;
let expires_at = Sat(utils::millis_since_unix_epoch()) + Sat(expires_in);
let value = (expires_at.0, user_id);
+11
View File
@@ -22,6 +22,8 @@ crate-type = [
conduwuit-build-metadata.workspace = true
conduwuit-service.workspace = true
conduwuit-core.workspace = true
conduwuit-database.workspace = true
conduwuit-api.workspace = true
async-trait.workspace = true
askama.workspace = true
axum.workspace = true
@@ -35,9 +37,18 @@ ruma.workspace = true
thiserror.workspace = true
tower-http.workspace = true
serde.workspace = true
serde_json.workspace = true
lettre.workspace = true
memory-serve = "2.1.0"
validator = { version = "0.20.0", features = ["derive"] }
tower-sec-fetch = { version = "0.1.2", features = ["tracing"] }
tower-sessions = { version = "0.15.0", default-features = false, features = ["axum-core"] }
tower-sessions-core = { version = "0.15.0", features = ["deletion-task"] }
serde_urlencoded.workspace = true
url.workspace = true
recaptcha-verify = { version = "0.2.0", default-features = false }
reqwest_recaptcha = { package = "reqwest", version = "0.12.28", default-features = false, features = ["rustls-tls-native-roots-no-provider"] } # As long as recaptcha-verify's reqwest is outdated
form_urlencoded = "1.2.2"
[build-dependencies]
memory-serve = "2.1.0"
+48
View File
@@ -0,0 +1,48 @@
use axum::{
extract::{FromRequest, FromRequestParts, Request},
http::{Method, request::Parts},
};
use serde::de::DeserializeOwned;
use crate::WebError;
/// An extractor which deserializes a struct from a POST request's body.
/// For GET requests the struct will be None.
#[derive(Debug, Clone, Copy, Default)]
#[must_use]
pub(crate) struct PostForm<T>(pub Option<T>);
impl<T, S> FromRequest<S> for PostForm<T>
where
T: DeserializeOwned,
S: Send + Sync,
{
type Rejection = WebError;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
if req.method() == Method::POST {
let axum::Form(data) = axum::Form::from_request(req, state).await?;
Ok(Self(Some(data)))
} else {
Ok(Self(None))
}
}
}
/// An extractor which wraps another extractor and converts its errors into
/// `WebError`s.
pub(crate) struct Expect<E>(pub E);
impl<E, S, R> FromRequestParts<S> for Expect<E>
where
E: FromRequestParts<S, Rejection = R>,
WebError: From<R>,
S: Send + Sync,
{
type Rejection = WebError;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
Ok(Self(E::from_request_parts(parts, state).await?))
}
}
+61 -15
View File
@@ -1,25 +1,34 @@
use std::any::Any;
use std::{any::Any, sync::Once, time::Duration};
use askama::Template;
use axum::{
Router,
extract::rejection::{FormRejection, QueryRejection},
http::{HeaderValue, StatusCode, header},
response::{Html, IntoResponse, Response},
extract::rejection::{FormRejection, PathRejection, QueryRejection},
http::StatusCode,
middleware::from_fn_with_state,
response::{Html, IntoResponse, Redirect, Response},
};
use conduwuit_service::state;
use tower_http::{catch_panic::CatchPanicLayer, set_header::SetResponseHeaderLayer};
use conduwuit_service::{Services, state};
use tower_http::catch_panic::CatchPanicLayer;
use tower_sec_fetch::SecFetchLayer;
use tower_sessions::{ExpiredDeletion, SessionManagerLayer, cookie::SameSite};
use crate::pages::TemplateContext;
use crate::{
pages::TemplateContext,
session::{LoginQuery, store::RocksDbSessionStore},
};
mod extract;
mod pages;
mod session;
type State = state::State;
const CATASTROPHIC_FAILURE: &str = "cat-astrophic failure! we couldn't even render the error template. \
please contact the team @ https://continuwuity.org";
const ROUTE_PREFIX: &str = conduwuit_core::ROUTE_PREFIX;
#[derive(Debug, thiserror::Error)]
enum WebError {
#[error("Failed to validate form body: {0}")]
@@ -29,10 +38,16 @@ enum WebError {
#[error("{0}")]
FormRejection(#[from] FormRejection),
#[error("{0}")]
PathRejection(#[from] PathRejection),
#[error("{0}")]
BadRequest(String),
#[error("This page does not exist.")]
NotFound,
#[error("You are not allowed to request this page: {0}")]
Forbidden(String),
#[error("You must log in to access this page")]
LoginRequired(LoginQuery),
#[error("Failed to render template: {0}")]
Render(#[from] askama::Error),
@@ -52,12 +67,26 @@ struct Error {
context: TemplateContext,
}
if let Self::LoginRequired(query) = self {
return Redirect::to(&format!(
"{}/account/login?{}",
ROUTE_PREFIX,
serde_urlencoded::to_string(query).unwrap()
))
.into_response();
}
let status = match &self {
| Self::ValidationError(_)
| Self::BadRequest(_)
| Self::QueryRejection(_)
| Self::FormRejection(_) => StatusCode::BAD_REQUEST,
| Self::FormRejection(_)
| Self::InternalError(_) => StatusCode::BAD_REQUEST,
| Self::NotFound => StatusCode::NOT_FOUND,
| Self::Forbidden(_) => StatusCode::FORBIDDEN,
| Self::LoginRequired(_) => {
unreachable!("LoginRequired is handled earlier")
},
| _ => StatusCode::INTERNAL_SERVER_ERROR,
};
@@ -67,6 +96,7 @@ struct Error {
context: TemplateContext {
// Statically set false to prevent error pages from being indexed.
allow_indexing: false,
csp_nonce: String::new(),
},
};
@@ -78,21 +108,40 @@ struct Error {
}
}
pub fn build() -> Router<state::State> {
static STORE_CLEANUP_TASK: Once = Once::new();
pub fn build(services: &Services) -> Router<state::State> {
#[allow(clippy::wildcard_imports)]
use pages::*;
let store = RocksDbSessionStore::new(&services.db);
STORE_CLEANUP_TASK.call_once(|| {
services.server.runtime().spawn(
store
.clone()
.continuously_delete_expired(Duration::from_hours(1)),
);
});
Router::new()
.merge(index::build())
.nest(
"/_continuwuity/",
Router::new()
.merge(resources::build())
.merge(password_reset::build())
.nest("/about", about::build())
.nest("/account/", account::build())
.merge(debug::build())
.nest("/oauth2/", oauth::build())
.merge(resources::build())
.merge(threepid::build())
.fallback(async || WebError::NotFound),
)
.layer(
SessionManagerLayer::new(store)
.with_name("_c10y_session")
.with_same_site(SameSite::Lax),
)
.layer(CatchPanicLayer::custom(|panic: Box<dyn Any + Send + 'static>| {
let details = if let Some(s) = panic.downcast_ref::<String>() {
s.clone()
@@ -104,10 +153,7 @@ pub fn build() -> Router<state::State> {
WebError::Panic(details).into_response()
}))
.layer(SetResponseHeaderLayer::if_not_present(
header::CONTENT_SECURITY_POLICY,
HeaderValue::from_static("default-src 'self'; img-src 'self' data:;"),
))
.layer(from_fn_with_state(services.config.clone(), template_context_middleware))
.layer(SecFetchLayer::new(|policy| {
policy.allow_safe_methods().reject_missing_metadata();
}))
+38
View File
@@ -0,0 +1,38 @@
use std::collections::BTreeMap;
use axum::{Extension, Router, extract::State, routing::get};
use conduwuit_core::config::TermsDocument;
use ruma::{
OwnedServerName,
api::client::discovery::discover_support::{Contact, ContactRole},
};
use url::Url;
use crate::{
pages::{Result, TemplateContext},
response, template,
};
pub(crate) fn build() -> Router<crate::State> { Router::new().route("/", get(get_about)) }
template! {
struct About use "about.html.j2" {
server_name: OwnedServerName,
support_page: Option<Url>,
contacts: Vec<Contact>,
terms: BTreeMap<String, TermsDocument>
}
}
async fn get_about(
State(services): State<crate::State>,
Extension(context): Extension<TemplateContext>,
) -> Result {
response!(About::new(
context,
services.globals.server_name().to_owned(),
services.config.well_known.support_page.clone(),
services.admin.get_support_contacts().await,
services.config.registration_terms.documents.clone()
))
}
@@ -0,0 +1,47 @@
use axum::{Extension, Router, extract::State, routing::on};
use conduwuit_service::oauth::OAuthTicket;
use crate::{
extract::PostForm,
pages::{GET_POST, Result, TemplateContext, components::UserCard},
response,
session::{LoginTarget, User},
template,
};
pub(crate) fn build() -> Router<crate::State> {
Router::new().route("/", on(GET_POST, route_cross_signing_reset))
}
template! {
struct CrossSigningReset use "cross_signing_reset.html.j2" {
user_card: UserCard,
body: CrossSigningResetBody
}
}
#[derive(Debug)]
enum CrossSigningResetBody {
Form,
Success,
}
async fn route_cross_signing_reset(
State(services): State<crate::State>,
Extension(context): Extension<TemplateContext>,
user: User,
PostForm(form): PostForm<()>,
) -> Result {
let user_id = user.expect_recent(LoginTarget::CrossSigningReset)?;
let user_card = UserCard::for_local_user(&services, user_id.clone()).await;
if form.is_some() {
services
.oauth
.issue_ticket(user_id.localpart().to_owned(), OAuthTicket::CrossSigningReset);
response!(CrossSigningReset::new(context, user_card, CrossSigningResetBody::Success))
} else {
response!(CrossSigningReset::new(context, user_card, CrossSigningResetBody::Form))
}
}
+129
View File
@@ -0,0 +1,129 @@
use axum::{Extension, Router, extract::State, routing::on};
use conduwuit_api::client::full_user_deactivate;
use futures::StreamExt;
use ruma::{OwnedRoomId, OwnedUserId, UserId};
use tower_sessions::Session;
use validator::{Validate, ValidationError, ValidationErrors};
use crate::{
extract::PostForm,
form,
pages::{
GET_POST, Result, TemplateContext,
components::{UserCard, form::Form},
},
response,
session::{LoginTarget, User},
template,
};
pub(crate) fn build() -> Router<crate::State> {
Router::new().route("/", on(GET_POST, route_deactivate))
}
template! {
struct Deactivate use "deactivate.html.j2" {
body: DeactivateBody
}
}
#[derive(Debug)]
#[allow(clippy::large_enum_variant)]
enum DeactivateBody {
Unavailable,
Form {
user_id: OwnedUserId,
user_card: UserCard,
form: Form<'static>,
},
Success,
}
form! {
struct DeactivateForm {
password: String where {
input_type: "password",
label: "Enter your password to confirm",
autocomplete: "current-password"
},
#[validate(required(message = "This checkbox must be checked"))]
confirm: Option<String> where {
input_type: "checkbox",
label: "I understand that deactivating my account cannot be undone."
}
submit: "Deactivate my account",
slowdown: true
}
}
async fn route_deactivate(
State(services): State<crate::State>,
Extension(context): Extension<TemplateContext>,
user: User,
session: Session,
PostForm(form): PostForm<DeactivateForm>,
) -> Result {
let user_id = user.expect_recent(LoginTarget::Deactivate)?;
let user_card = UserCard::for_local_user(&services, user_id.clone()).await;
let body = {
if !services.config.allow_deactivation {
DeactivateBody::Unavailable
} else if let Some(form) = form {
if let Err(err) = validate_deactivate_form(&services, &user_id, form).await {
DeactivateBody::Form {
user_id,
user_card,
form: DeactivateForm::with_errors(context.clone(), err),
}
} else {
let all_joined_rooms: Vec<OwnedRoomId> = services
.rooms
.state_cache
.rooms_joined(&user_id)
.collect()
.await;
full_user_deactivate(&services, &user_id, &all_joined_rooms).await?;
session.clear().await;
DeactivateBody::Success
}
} else {
DeactivateBody::Form {
user_id,
user_card,
form: DeactivateForm::build(context.clone()),
}
}
};
response!(Deactivate::new(context, body))
}
async fn validate_deactivate_form(
services: &crate::State,
user_id: &UserId,
form: DeactivateForm,
) -> Result<(), ValidationErrors> {
form.validate()?;
if services
.users
.check_password(user_id, &form.password)
.await
.is_err()
{
let mut errors = ValidationErrors::new();
errors.add(
"password",
ValidationError::new("wrong").with_message("Incorrect password".into()),
);
return Err(errors);
}
Ok(())
}
+126
View File
@@ -0,0 +1,126 @@
use axum::{
Extension, Router,
extract::{Path, State},
routing::{get, on},
};
use conduwuit_service::oauth::{SessionInfo, client_metadata::ClientMetadata};
use futures::StreamExt;
use ruma::OwnedDeviceId;
use serde::{Deserialize, Serialize};
use crate::{
WebError,
extract::{Expect, PostForm},
pages::{
GET_POST, Result, TemplateContext,
components::{ClientScopes, DeviceCard, DeviceCardStyle},
},
response,
session::{LoginTarget, User},
template,
};
pub(crate) fn build() -> Router<crate::State> {
Router::new()
.route("/{device}/", get(get_device_info))
.route("/{device}/remove", on(GET_POST, route_remove_device))
}
template! {
struct DeviceInfo use "device_info.html.j2" {
device_card: DeviceCard,
client_metadata: Option<(ClientMetadata, SessionInfo)>
}
}
async fn get_device_info(
State(services): State<crate::State>,
Extension(context): Extension<TemplateContext>,
user: User,
Expect(Path(query)): Expect<Path<DevicePath>>,
) -> Result {
let user_id = user.expect(LoginTarget::RemoveDevice(query.clone()))?;
let Ok(device) = services
.users
.get_device_metadata(&user_id, &query.device)
.await
else {
return response!(WebError::BadRequest("Unknown device".to_owned()));
};
let client_metadata = async {
let session_info = services
.oauth
.get_session_info_for_device(&user_id, &device.device_id)
.await?;
let client_metadata = services
.oauth
.get_client_metadata(&session_info.client_id)
.await?;
Some((client_metadata, session_info))
}
.await;
let device_card =
DeviceCard::for_device(&services, &user_id, device, DeviceCardStyle::Detailed).await;
response!(DeviceInfo::new(context, device_card, client_metadata))
}
template! {
struct RemoveDevice use "remove_device.html.j2" {
body: RemoveDeviceBody
}
}
#[derive(Debug)]
enum RemoveDeviceBody {
Form {
device_card: Box<DeviceCard>,
last_device: bool,
},
Success,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub(crate) struct DevicePath {
pub device: OwnedDeviceId,
}
async fn route_remove_device(
State(services): State<crate::State>,
Extension(context): Extension<TemplateContext>,
user: User,
Expect(Path(query)): Expect<Path<DevicePath>>,
PostForm(form): PostForm<()>,
) -> Result {
let user_id = user.expect_recent(LoginTarget::RemoveDevice(query.clone()))?;
let Ok(device) = services
.users
.get_device_metadata(&user_id, &query.device)
.await
else {
return response!(WebError::BadRequest("Unknown device".to_owned()));
};
if form.is_some() {
services
.users
.remove_device(&user_id, &device.device_id)
.await;
response!(RemoveDevice::new(context, RemoveDeviceBody::Success))
} else {
let device_card =
DeviceCard::for_device(&services, &user_id, device, DeviceCardStyle::Minimal).await;
let last_device = services.users.all_devices_metadata(&user_id).count().await <= 1;
response!(RemoveDevice::new(context, RemoveDeviceBody::Form {
device_card: Box::new(device_card),
last_device
}))
}
}
+210
View File
@@ -0,0 +1,210 @@
use axum::{
Extension, Router,
extract::{Query, State},
routing::{get, on, post},
};
use conduwuit_core::warn;
use conduwuit_service::{mailer::messages, threepid::session::ValidationSessions};
use lettre::{Address, message::Mailbox};
use ruma::{ClientSecret, OwnedClientSecret, OwnedSessionId};
use serde::{Deserialize, Serialize};
use crate::{
WebError,
extract::{Expect, PostForm},
form,
pages::{
GET_POST, Result, TemplateContext,
account::ThreepidQuery,
components::{UserCard, form::Form},
},
response,
session::{LoginTarget, User},
template,
};
pub(crate) fn build() -> Router<crate::State> {
Router::new()
.route("/change/", on(GET_POST, route_change_email))
.route("/change/validate", get(get_change_email_validate))
.route("/change/delete", post(post_delete_email))
}
template! {
struct ChangeEmail use "change_email.html.j2" {
user_card: UserCard,
email: Option<String>,
form: Form<'static>,
may_remove: bool
}
}
form! {
struct ChangeEmailForm {
email: Address where {
input_type: "email",
label: "Email address"
}
submit: "Change email"
}
}
template! {
struct ChangeEmailValidate use "change_email_validate.html.j2" {
user_card: UserCard,
body: ChangeEmailValidateBody
}
}
template! {
struct DeleteEmail use "delete_email.html.j2" {
user_card: UserCard
}
}
#[derive(Debug)]
enum ChangeEmailValidateBody {
ValidationPending {
session_id: OwnedSessionId,
client_secret: OwnedClientSecret,
validation_error: bool,
},
Success,
}
async fn route_change_email(
State(services): State<crate::State>,
Extension(context): Extension<TemplateContext>,
user: User,
PostForm(form): PostForm<ChangeEmailForm>,
) -> Result {
let user_id = user.expect_recent(LoginTarget::ChangeEmail)?;
let Some(form) = form else {
return response!(ChangeEmail::new(
context.clone(),
UserCard::for_local_user(&services, user_id.clone()).await,
services
.threepid
.get_email_for_localpart(user_id.localpart())
.await
.map(|address| address.to_string()),
ChangeEmailForm::build(context),
services.threepid.email_requirement().may_remove(),
));
};
let client_secret = ClientSecret::new();
let session_id = {
let display_name = services.users.displayname(&user_id).await.ok();
match services
.threepid
.send_validation_email(
Mailbox::new(display_name, form.email.clone()),
|verification_link| messages::ChangeEmail {
server_name: services.globals.server_name().as_str(),
user_id: Some(&user_id),
verification_link,
},
&client_secret,
0,
)
.await
{
| Ok(session_id) => session_id,
| Err(err) => {
// If we couldn't send an email, generate a random session ID to not give that
// away
warn!(
"Failed to send email change message for {user_id} to {}: {err}",
form.email
);
ValidationSessions::generate_session_id()
},
}
};
response!(ChangeEmailValidate::new(
context,
UserCard::for_local_user(&services, user_id).await,
ChangeEmailValidateBody::ValidationPending {
session_id,
client_secret,
validation_error: false
}
))
}
#[derive(Deserialize, Serialize)]
struct ChangeEmailQuery {
#[serde(flatten)]
threepid: ThreepidQuery,
}
async fn get_change_email_validate(
State(services): State<crate::State>,
Extension(context): Extension<TemplateContext>,
Expect(Query(ChangeEmailQuery {
threepid: ThreepidQuery { client_secret, session_id },
})): Expect<Query<ChangeEmailQuery>>,
user: User,
) -> Result {
let user_id = user.expect(LoginTarget::ChangeEmail)?;
let user_card = UserCard::for_local_user(&services, user_id.clone()).await;
if !services.threepid.email_requirement().may_change() {
return Err(WebError::Forbidden("You may not change your email address.".to_owned()));
}
let Ok(session) = services
.threepid
.get_valid_session(&session_id, &client_secret)
.await
else {
return response!(ChangeEmailValidate::new(
context,
user_card,
ChangeEmailValidateBody::ValidationPending {
session_id,
client_secret,
validation_error: true
}
));
};
let new_email = session.consume();
if let Err(err) = services
.threepid
.associate_localpart_email(user_id.localpart(), &new_email)
.await
{
return response!(BadRequest(err.message()));
}
response!(ChangeEmailValidate::new(context, user_card, ChangeEmailValidateBody::Success))
}
async fn post_delete_email(
State(services): State<crate::State>,
Extension(context): Extension<TemplateContext>,
user: User,
) -> Result {
let user_id = user.expect(LoginTarget::ChangeEmail)?;
let user_card = UserCard::for_local_user(&services, user_id.clone()).await;
if !services.threepid.email_requirement().may_remove() {
return Err(WebError::Forbidden("You may not remove your email address.".to_owned()));
}
let _ = services
.threepid
.disassociate_localpart_email(user_id.localpart())
.await;
response!(DeleteEmail::new(context, user_card))
}
+155
View File
@@ -0,0 +1,155 @@
use std::time::SystemTime;
use axum::{
Extension, Router,
extract::{Query, RawQuery, State},
response::{IntoResponse, Redirect},
routing::{get, on},
};
use conduwuit_api::client::handle_login;
use ruma::{
OwnedUserId,
api::client::uiaa::{EmailUserIdentifier, MatrixUserIdentifier, UserIdentifier},
};
use serde::Deserialize;
use tower_sessions::Session;
use crate::{
ROUTE_PREFIX, WebError,
extract::{Expect, PostForm},
pages::{
GET_POST, Result, TemplateContext,
account::register::{TrustedFlowStatus, UntrustedFlowStatus, registration_flow_status},
components::UserCard,
},
response,
session::{LoginQuery, LoginTarget, User, UserSession},
template,
};
pub(crate) fn build() -> Router<crate::State> {
Router::new()
.route("/login", on(GET_POST, route_login))
.route("/logout", get(get_logout))
}
template! {
struct Login use "login.html.j2" {
body: LoginBody,
login_error: Option<String>
}
}
#[derive(Debug)]
enum LoginBody {
Unauthenticated {
server_name: String,
registration_available: bool,
next: Option<LoginTarget>,
},
Authenticated {
user_card: UserCard,
},
}
#[derive(Deserialize)]
struct LoginForm {
identifier: Option<String>,
password: String,
}
async fn route_login(
State(services): State<crate::State>,
Extension(context): Extension<TemplateContext>,
Expect(Query(LoginQuery { next, reauthenticate })): Expect<Query<LoginQuery>>,
session_store: Session,
user: User<true>,
PostForm(form): PostForm<LoginForm>,
) -> Result {
let user_id = user.into_session().map(|session| session.user_id);
let body = match &user_id {
| None => {
let (trusted_flow_status, untrusted_flow_status) =
registration_flow_status(&services).await;
let registration_available =
matches!(trusted_flow_status, TrustedFlowStatus::Available)
|| matches!(untrusted_flow_status, UntrustedFlowStatus::Available { .. });
LoginBody::Unauthenticated {
server_name: services.globals.server_name().to_string(),
registration_available,
next: next.clone(),
}
},
| Some(user_id) => {
if !reauthenticate {
return response!(Redirect::to(&next.unwrap_or_default().target_path()));
}
let user_card = UserCard::for_local_user(&services, user_id.to_owned()).await;
LoginBody::Authenticated { user_card }
},
};
let mut template = Login::new(context, body, None);
if let Some(form) = form {
let login_result = match (user_id, form.identifier) {
| (Some(user_id), _) => {
// The user is already authenticated, we need to check their password
services
.users
.check_password(&user_id, &form.password)
.await
},
| (None, Some(identifier)) => {
// The user isn't authenticated, we need to log them in
let identifier = if identifier.parse::<lettre::Address>().is_ok() {
UserIdentifier::Email(EmailUserIdentifier::new(identifier))
} else {
UserIdentifier::Matrix(MatrixUserIdentifier::new(identifier))
};
handle_login(&services, Some(&identifier), &form.password, None).await
},
| (None, None) => {
// The user isn't authenticated and didn't supply an identity
return response!(WebError::BadRequest("No identity provided".to_owned()));
},
};
let user_id = match login_result {
| Ok(user_id) => user_id,
| Err(err) => {
let error_message = if let conduwuit_core::Error::Request(_, message, _) = err {
message.into_owned()
} else {
"Internal login error".to_owned()
};
template.login_error = Some(error_message);
return response!(template);
},
};
let user_session = UserSession { user_id, last_login: SystemTime::now() };
session_store
.insert(User::KEY, user_session)
.await
.expect("should be able to serialize user session");
return response!(Redirect::to(&next.unwrap_or_default().target_path()));
}
response!(template)
}
async fn get_logout(session: Session, RawQuery(query): RawQuery) -> impl IntoResponse {
let _ = session.remove::<OwnedUserId>(User::KEY).await;
Redirect::to(&format!("{}/account/login?{}", ROUTE_PREFIX, query.unwrap_or_default()))
}
+173
View File
@@ -0,0 +1,173 @@
use axum::{
Extension, Router,
extract::{Query, State},
response::Redirect,
routing::get,
};
use conduwuit_core::utils::{IterStream, ReadyExt, stream::TryExpect};
use conduwuit_service::threepid::EmailRequirement;
use futures::StreamExt;
use ruma::{
OwnedClientSecret, OwnedDeviceId, OwnedSessionId,
api::client::discovery::get_authorization_server_metadata::v1::AccountManagementAction,
};
use serde::{Deserialize, Serialize};
use crate::{
WebError,
extract::Expect,
pages::{
Result, TemplateContext,
components::{DeviceCard, DeviceCardStyle, UserCard},
},
response,
session::{LoginTarget, User},
template,
};
pub(crate) mod cross_signing_reset;
pub(crate) mod deactivate;
pub(crate) mod device;
pub(crate) mod email;
pub(crate) mod login;
pub(crate) mod password;
pub(crate) mod register;
pub(crate) fn build() -> Router<crate::State> {
#[allow(clippy::wildcard_imports)]
use self::*;
Router::new()
.route("/", get(get_account))
.route("/deeplink", get(get_account_deeplink))
.merge(login::build())
.nest("/password/", password::build())
.nest("/email/", email::build())
.nest("/cross_signing_reset", cross_signing_reset::build())
.nest("/deactivate", deactivate::build())
.nest("/device/", device::build())
.nest("/register/", register::build())
}
#[derive(Deserialize, Serialize)]
struct ThreepidQuery {
client_secret: OwnedClientSecret,
session_id: OwnedSessionId,
}
template! {
struct Account use "account.html.j2" {
user_card: UserCard,
body: AccountBody
}
}
#[derive(Debug)]
enum AccountBody {
Unlocked {
suspended: bool,
email_requirement: EmailRequirement,
email: Option<String>,
devices: Vec<DeviceCard>,
},
Locked,
}
async fn get_account(
State(services): State<crate::State>,
Extension(context): Extension<TemplateContext>,
user: User<true>,
) -> Result {
let user_id = user.expect(LoginTarget::Account)?;
let user_card = UserCard::for_local_user(&services, user_id.clone()).await;
if services.users.is_locked(&user_id).await.unwrap() {
return response!(Account::new(context, user_card, AccountBody::Locked));
}
let email_requirement = services.threepid.email_requirement();
let email = services
.threepid
.get_email_for_localpart(user_id.localpart())
.await
.map(|address| address.to_string());
let dehydrated_device_id = services.users.get_dehydrated_device_id(&user_id).await.ok();
let mut devices: Vec<_> = services
.users
.all_device_ids(&user_id)
.then(async |device_id| {
services
.users
.get_device_metadata(&user_id, &device_id)
.await
})
.expect_ok()
.ready_filter(|device| {
dehydrated_device_id
.as_ref()
.is_none_or(|id| device.device_id != *id)
})
.collect()
.await;
devices.sort_unstable_by(|a, b| a.last_seen_ts.cmp(&b.last_seen_ts).reverse());
let device_cards = devices
.into_iter()
.stream()
.then(async |device| {
DeviceCard::for_device(&services, &user_id, device, DeviceCardStyle::Minimal).await
})
.collect()
.await;
let suspended = services.users.is_suspended(&user_id).await.unwrap();
response!(Account::new(context, user_card, AccountBody::Unlocked {
suspended,
email_requirement,
email,
devices: device_cards
}))
}
#[derive(Deserialize)]
struct AccountDeeplinkQuery {
action: Option<AccountManagementAction>,
device_id: Option<OwnedDeviceId>,
}
async fn get_account_deeplink(
Expect(Query(query)): Expect<Query<AccountDeeplinkQuery>>,
) -> Result {
let redirect_target = match query.action.unwrap_or(AccountManagementAction::Profile) {
| AccountManagementAction::AccountDeactivate => "deactivate".to_owned(),
| AccountManagementAction::CrossSigningReset => "cross_signing_reset".to_owned(),
| AccountManagementAction::DeviceDelete => {
let Some(device_id) = query.device_id else {
return response!(WebError::BadRequest(
"A device ID is required for this action".to_owned()
));
};
format!("device/{device_id}/delete")
},
| AccountManagementAction::DeviceView => {
let Some(device_id) = query.device_id else {
return response!(WebError::BadRequest(
"A device ID is required for this action".to_owned()
));
};
format!("device/{device_id}/")
},
| AccountManagementAction::DevicesList => "#devices".to_owned(),
| AccountManagementAction::Profile => String::new(),
| _ => return response!(WebError::BadRequest("Unknown action".to_owned())),
};
response!(Redirect::to(&format!("{}/account/{}", crate::ROUTE_PREFIX, redirect_target)))
}
+122
View File
@@ -0,0 +1,122 @@
use axum::{Extension, Router, extract::State, routing::on};
use conduwuit_service::users::HashedPassword;
use ruma::UserId;
use validator::{Validate, ValidationError, ValidationErrors};
use crate::{
extract::PostForm,
form,
pages::{
GET_POST, Result, TemplateContext,
components::{UserCard, form::Form},
},
response,
session::{LoginTarget, User},
template,
};
pub(crate) fn build() -> Router<crate::State> {
Router::new().route("/", on(GET_POST, route_change_password))
}
template! {
struct ChangePassword use "change_password.html.j2" {
user_card: UserCard,
body: ChangePasswordBody
}
}
#[derive(Debug)]
enum ChangePasswordBody {
Form(Form<'static>),
Success,
}
form! {
struct ChangePasswordForm {
#[validate(length(min = 1, message = "Current password cannot be empty"))]
current_password: String where {
input_type: "password",
label: "Current password",
autocomplete: "current-password"
},
#[validate(length(min = 1, message = "New password cannot be empty"))]
new_password: String where {
input_type: "password",
label: "New password",
autocomplete: "new-password"
},
#[validate(must_match(other = "new_password", message = "Passwords must match"))]
confirm_new_password: String where {
input_type: "password",
label: "Confirm new password",
autocomplete: "new-password"
}
submit: "Change password"
}
}
async fn route_change_password(
State(services): State<crate::State>,
Extension(context): Extension<TemplateContext>,
user: User,
PostForm(form): PostForm<ChangePasswordForm>,
) -> Result {
let user_id = user.expect(LoginTarget::ChangePassword)?;
let user_card = UserCard::for_local_user(&services, user_id.clone()).await;
let body = if let Some(form) = form {
match change_password(&services, &user_id, form).await {
| Ok(()) => ChangePasswordBody::Success,
| Err(errors) =>
ChangePasswordBody::Form(ChangePasswordForm::with_errors(context.clone(), errors)),
}
} else {
ChangePasswordBody::Form(ChangePasswordForm::build(context.clone()))
};
response!(ChangePassword::new(context, user_card, body))
}
async fn change_password(
services: &crate::State,
user_id: &UserId,
form: ChangePasswordForm,
) -> Result<(), ValidationErrors> {
form.validate()?;
if services
.users
.check_password(user_id, &form.current_password)
.await
.is_err()
{
let mut errors = ValidationErrors::new();
errors.add(
"current_password",
ValidationError::new("wrong").with_message("Incorrect password".into()),
);
return Err(errors);
}
match HashedPassword::new(&form.new_password) {
| Ok(hash) => {
services.users.set_password(user_id, Some(hash));
},
| Err(err) => {
let mut errors = ValidationErrors::new();
errors.add(
"new_password",
ValidationError::new("malformed").with_message(err.message().into()),
);
return Err(errors);
},
}
Ok(())
}

Some files were not shown because too many files have changed in this diff Show More