Compare commits

..

5 Commits

Author SHA1 Message Date
timedout
dfd89425a1 fix: Update some of the trait bounds to be not broken 2026-02-19 03:38:12 +00:00
timedout
c69e7c7d1b feat: Do more refactoring 2026-02-19 02:40:42 +00:00
timedout
bd404e808c feat: Add invite membership check 2026-02-19 02:40:42 +00:00
timedout
0899985476 feat: Start on membership auth 2026-02-19 02:40:42 +00:00
timedout
b3cf649732 wip: Refactor event auth 2026-02-19 02:40:41 +00:00
87 changed files with 2687 additions and 4637 deletions

1054
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -12,7 +12,7 @@ license = "Apache-2.0"
# See also `rust-toolchain.toml`
readme = "README.md"
repository = "https://forgejo.ellis.link/continuwuation/continuwuity"
version = "0.5.6-alpha"
version = "0.5.5"
[workspace.metadata.crane]
name = "conduwuit"
@@ -68,7 +68,7 @@ default-features = false
version = "0.1.3"
[workspace.dependencies.rand]
version = "0.10.0"
version = "0.8.5"
# Used for the http request / response body type for Ruma endpoints used with reqwest
[workspace.dependencies.bytes]
@@ -97,7 +97,7 @@ features = [
]
[workspace.dependencies.axum-extra]
version = "0.12.0"
version = "0.10.1"
default-features = false
features = ["typed-header", "tracing"]
@@ -144,7 +144,6 @@ features = [
"socks",
"hickory-dns",
"http2",
"stream",
]
[workspace.dependencies.serde]
@@ -254,7 +253,7 @@ features = [
version = "0.4.0"
[workspace.dependencies.libloading]
version = "0.9.0"
version = "0.8.6"
# Validating urls in config, was already a transitive dependency
[workspace.dependencies.url]
@@ -299,7 +298,7 @@ default-features = false
features = ["env", "toml"]
[workspace.dependencies.hickory-resolver]
version = "0.25.2"
version = "0.25.1"
default-features = false
features = [
"serde",
@@ -344,7 +343,7 @@ version = "0.1.2"
[workspace.dependencies.ruma]
git = "https://forgejo.ellis.link/continuwuation/ruwuma"
#branch = "conduwuit-changes"
rev = "bb12ed288a31a23aa11b10ba0fad22b7f985eb88"
rev = "3126cb5eea991ec40590e54d8c9d75637650641a"
features = [
"compat",
"rand",
@@ -364,7 +363,6 @@ features = [
"unstable-msc2870",
"unstable-msc3026",
"unstable-msc3061",
"unstable-msc3814",
"unstable-msc3245",
"unstable-msc3266",
"unstable-msc3381", # polls
@@ -383,7 +381,6 @@ features = [
"unstable-pdu",
"unstable-msc4155",
"unstable-msc4143", # livekit well_known response
"unstable-msc4284"
]
[workspace.dependencies.rust-rocksdb]
@@ -428,7 +425,7 @@ features = ["http", "grpc-tonic", "trace", "logs", "metrics"]
# optional sentry metrics for crash/panic reporting
[workspace.dependencies.sentry]
version = "0.46.0"
version = "0.45.0"
default-features = false
features = [
"backtrace",
@@ -444,9 +441,9 @@ features = [
]
[workspace.dependencies.sentry-tracing]
version = "0.46.0"
version = "0.45.0"
[workspace.dependencies.sentry-tower]
version = "0.46.0"
version = "0.45.0"
# jemalloc usage
[workspace.dependencies.tikv-jemalloc-sys]
@@ -475,7 +472,7 @@ features = ["use_std"]
version = "0.5"
[workspace.dependencies.nix]
version = "0.31.0"
version = "0.30.1"
default-features = false
features = ["resource"]

View File

@@ -57,15 +57,10 @@ ### What are the project's goals?
### Can I try it out?
Check out the [documentation](https://continuwuity.org) for installation instructions.
Check out the [documentation](https://continuwuity.org) for installation instructions, or join one of these vetted public homeservers running Continuwuity to get a feel for things!
If you want to try it out as a user, we have some partnered homeservers you can use:
* You can head over to [https://federated.nexus](https://federated.nexus/) in your browser.
* Hit the `Apply to Join` button. Once your request has been accepted, you will receive an email with your username and password.
* Head over to [https://app.federated.nexus](https://app.federated.nexus/) and you can sign in there, or use any other matrix chat client you wish elsewhere.
* Your username for matrix will be in the form of `@username:federated.nexus`, however you can simply use the `username` part to log in. Your password is your password.
* There's also [https://continuwuity.rocks/](https://continuwuity.rocks/). You can register a new account using Cinny via [this convenient link](https://app.cinny.in/register/continuwuity.rocks), or you can use Element or another matrix client *that supports registration*.
- https://continuwuity.rocks -- A public demo server operated by the Continuwuity Team.
- https://federated.nexus -- Federated Nexus is a community resource hosting multiple FOSS (especially federated) services, including Matrix and Forgejo.
### What are we working on?

View File

@@ -1 +0,0 @@
Fixed a startup crash in the sender service if we can't detect the number of CPU cores, even if the `sender_workers' config option is set correctly. Contributed by @katie.

View File

@@ -1 +0,0 @@
Improved the concurrency handling of federation transactions, vastly improving performance and reliability by more accurately handling inbound transactions and reducing the amount of repeated wasted work. Contributed by @nex and @Jade.

View File

@@ -1 +0,0 @@
Added MSC3202 Device masquerading (not all of MSC3202). This should fix issues with enabling MSC4190 for some Mautrix bridges. Contributed by @Jade

View File

@@ -1 +0,0 @@
Added MSC3814 Dehydrated Devices - you can now decrypt messages sent while all devices were logged out.

View File

@@ -1 +0,0 @@
Removed the `allow_public_room_directory_without_auth` config option. Contributed by @0xnim.

View File

@@ -1 +0,0 @@
Implement MSC4143 MatrixRTC transport discovery endpoint. Move RTC foci configuration from `[global.well_known]` to a new `[global.matrix_rtc]` section with a `foci` field. Contributed by @0xnim

View File

@@ -1 +0,0 @@
Fixed sliding sync v5 list ranges always starting from 0, causing extra rooms to be unnecessarily processed and returned. Contributed by @0xnim

View File

@@ -1 +0,0 @@
Improved URL preview fetching with a more compatible user agent for sites like YouTube Music. Added `!admin media delete-url-preview <url>` command to clear cached URL previews that were stuck and broken.

View File

@@ -15,18 +15,6 @@ disallowed-macros = [
{ path = "log::trace", reason = "use conduwuit_core::trace" },
]
[[disallowed-methods]]
path = "tokio::spawn"
reason = "use and pass conduwuit_core::server::Server::runtime() to spawn from"
[[disallowed-methods]]
path = "reqwest::Response::bytes"
reason = "bytes is unsafe, use limit_read via the conduwuit_core::utils::LimitReadExt trait instead"
[[disallowed-methods]]
path = "reqwest::Response::text"
reason = "text is unsafe, use limit_read_text via the conduwuit_core::utils::LimitReadExt trait instead"
[[disallowed-methods]]
path = "reqwest::Response::json"
reason = "json is unsafe, use limit_read_text via the conduwuit_core::utils::LimitReadExt trait instead"
disallowed-methods = [
{ path = "tokio::spawn", reason = "use and pass conduuwit_core::server::Server::runtime() to spawn from" },
]

View File

@@ -9,6 +9,7 @@ address = "0.0.0.0"
allow_device_name_federation = true
allow_guest_registration = true
allow_public_room_directory_over_federation = true
allow_public_room_directory_without_auth = true
allow_registration = true
database_path = "/database"
log = "trace,h2=debug,hyper=debug"

View File

@@ -290,25 +290,6 @@
#
#max_fetch_prev_events = 192
# How many incoming federation transactions the server is willing to be
# processing at any given time before it becomes overloaded and starts
# rejecting further transactions until some slots become available.
#
# Setting this value too low or too high may result in unstable
# federation, and setting it too high may cause runaway resource usage.
#
#max_concurrent_inbound_transactions = 150
# Maximum age (in seconds) for cached federation transaction responses.
# Entries older than this will be removed during cleanup.
#
#transaction_id_cache_max_age_secs = 7200 (2 hours)
# Maximum number of cached federation transaction responses.
# When the cache exceeds this limit, older entries will be removed.
#
#transaction_id_cache_max_entries = 8192
# Default/base connection timeout (seconds). This is used only by URL
# previews and update/news endpoint checks.
#
@@ -546,6 +527,12 @@
#
#allow_public_room_directory_over_federation = false
# Set this to true to allow your server's public room directory to be
# queried without client authentication (access token) through the Client
# APIs. Set this to false to protect against /publicRooms spiders.
#
#allow_public_room_directory_without_auth = false
# Allow guests/unauthenticated users to access TURN credentials.
#
# This is the equivalent of Synapse's `turn_allow_guests` config option.
@@ -1338,7 +1325,7 @@
# sender user's server name, inbound federation X-Matrix origin, and
# outbound federation handler.
#
# You can set this to [".*"] to block all servers by default, and then
# You can set this to ["*"] to block all servers by default, and then
# use `allowed_remote_server_names` to allow only specific servers.
#
# example: ["badserver\\.tld$", "badphrase", "19dollarfortnitecards"]
@@ -1844,13 +1831,14 @@
#
#support_mxid =
# **DEPRECATED**: Use `[global.matrix_rtc].foci` instead.
#
# A list of MatrixRTC foci URLs which will be served as part of the
# MSC4143 client endpoint at /.well-known/matrix/client.
# MSC4143 client endpoint at /.well-known/matrix/client. If you're
# setting up livekit, you'd want something like:
# rtc_focus_server_urls = [
# { type = "livekit", livekit_service_url = "https://livekit.example.com" },
# ]
#
# This option is deprecated and will be removed in a future release.
# Please migrate to the new `[global.matrix_rtc]` config section.
# To disable, set this to be an empty vector (`[]`).
#
#rtc_focus_server_urls = []
@@ -1872,23 +1860,6 @@
#
#blurhash_max_raw_size = 33554432
[global.matrix_rtc]
# A list of MatrixRTC foci (transports) which will be served via the
# MSC4143 RTC transports endpoint at
# `/_matrix/client/v1/rtc/transports`. If you're setting up livekit,
# you'd want something like:
# ```toml
# [global.matrix_rtc]
# foci = [
# { type = "livekit", livekit_service_url = "https://livekit.example.com" },
# ]
# ```
#
# To disable, set this to an empty list (`[]`).
#
#foci = []
[global.ldap]
# Whether to enable LDAP login.

View File

@@ -52,7 +52,7 @@ ENV BINSTALL_VERSION=1.17.5
# renovate: datasource=github-releases depName=psastras/sbom-rs
ENV CARGO_SBOM_VERSION=0.9.1
# renovate: datasource=crate depName=lddtree
ENV LDDTREE_VERSION=0.5.0
ENV LDDTREE_VERSION=0.4.0
# renovate: datasource=crate depName=timelord-cli
ENV TIMELORD_VERSION=3.0.1
@@ -180,11 +180,6 @@ RUN --mount=type=cache,target=/usr/local/cargo/registry \
export RUSTFLAGS="${RUSTFLAGS}"
fi
RUST_PROFILE_DIR="${RUST_PROFILE}"
if [[ "${RUST_PROFILE}" == "dev" ]]; then
RUST_PROFILE_DIR="debug"
fi
TARGET_DIR=($(cargo metadata --no-deps --format-version 1 | \
jq -r ".target_directory"))
mkdir /out/sbin
@@ -196,8 +191,8 @@ RUN --mount=type=cache,target=/usr/local/cargo/registry \
jq -r ".packages[] | select(.name == \"$PACKAGE\") | .targets[] | select( .kind | map(. == \"bin\") | any ) | .name"))
for BINARY in "${BINARIES[@]}"; do
echo $BINARY
xx-verify $TARGET_DIR/$(xx-cargo --print-target-triple)/${RUST_PROFILE_DIR}/$BINARY
cp $TARGET_DIR/$(xx-cargo --print-target-triple)/${RUST_PROFILE_DIR}/$BINARY /out/sbin/$BINARY
xx-verify $TARGET_DIR/$(xx-cargo --print-target-triple)/${RUST_PROFILE}/$BINARY
cp $TARGET_DIR/$(xx-cargo --print-target-triple)/${RUST_PROFILE}/$BINARY /out/sbin/$BINARY
done
EOF

View File

@@ -22,7 +22,7 @@ ENV BINSTALL_VERSION=1.17.5
# renovate: datasource=github-releases depName=psastras/sbom-rs
ENV CARGO_SBOM_VERSION=0.9.1
# renovate: datasource=crate depName=lddtree
ENV LDDTREE_VERSION=0.5.0
ENV LDDTREE_VERSION=0.4.0
# Install unpackaged tools
RUN <<EOF

View File

@@ -78,19 +78,47 @@ #### Firewall hints
### 3. Telling clients where to find LiveKit
To tell clients where to find LiveKit, you need to add the address of your `lk-jwt-service` to the `[global.matrix_rtc]` config section using the `foci` option.
To tell clients where to find LiveKit, you need to add the address of your `lk-jwt-service` to your client .well-known file. To do so, in the config section `global.well-known`, add (or modify) the option `rtc_focus_server_urls`.
The variable should be a list of servers serving as MatrixRTC endpoints. Clients discover these via the `/_matrix/client/v1/rtc/transports` endpoint (MSC4143).
The variable should be a list of servers serving as MatrixRTC endpoints to serve in the well-known file to the client.
```toml
[global.matrix_rtc]
foci = [
rtc_focus_server_urls = [
{ type = "livekit", livekit_service_url = "https://livekit.example.com" },
]
```
Remember to replace the URL with the address you are deploying your instance of lk-jwt-service to.
#### Serving .well-known manually
If you don't let Continuwuity serve your `.well-known` files, you need to add the following lines to your `.well-known/matrix/client` file, remembering to replace the URL with your own `lk-jwt-service` deployment:
```json
"org.matrix.msc4143.rtc_foci": [
{
"type": "livekit",
"livekit_service_url": "https://livekit.example.com"
}
]
```
The final file should look something like this:
```json
{
"m.homeserver": {
"base_url":"https://matrix.example.com"
},
"org.matrix.msc4143.rtc_foci": [
{
"type": "livekit",
"livekit_service_url": "https://livekit.example.com"
}
]
}
```
### 4. Configure your Reverse Proxy
Reverse proxies can be configured in many different ways - so we can't provide a step by step for this.

View File

@@ -51,13 +51,7 @@ ## Can I try it out?
Check out the [documentation](https://continuwuity.org) for installation instructions.
If you want to try it out as a user, we have some partnered homeservers you can use:
* You can head over to [https://federated.nexus](https://federated.nexus/) in your browser.
* Hit the `Apply to Join` button. Once your request has been accepted, you will receive an email with your username and password.
* Head over to [https://app.federated.nexus](https://app.federated.nexus/) and you can sign in there, or use any other matrix chat client you wish elsewhere.
* Your username for matrix will be in the form of `@username:federated.nexus`, however you can simply use the `username` part to log in. Your password is your password.
* There's also [https://continuwuity.rocks/](https://continuwuity.rocks/). You can register a new account using Cinny via [this convenient link](https://app.cinny.in/register/continuwuity.rocks), or you can use Element or another matrix client *that supports registration*.
There are currently no open registration continuwuity instances available.
## What are we working on?

View File

@@ -36,7 +36,3 @@ ## `!admin media delete-all-from-user`
## `!admin media delete-all-from-server`
Deletes all remote media from the specified remote server. This will always ignore errors by default
## `!admin media delete-url-preview`
Deletes a cached URL preview, forcing it to be re-fetched. Use --all to purge all cached URL previews

View File

@@ -77,12 +77,7 @@ rec {
craneLib.buildDepsOnly (
(commonAttrs commonAttrsArgs)
// {
env = uwuenv.buildDepsOnlyEnv
// (makeRocksDBEnv { inherit rocksdb; })
// {
# required since we started using unstable reqwest apparently ... otherwise the all-features build will fail
RUSTFLAGS = "--cfg reqwest_unstable";
};
env = uwuenv.buildDepsOnlyEnv // (makeRocksDBEnv { inherit rocksdb; });
inherit (features) cargoExtraArgs;
}
@@ -107,13 +102,7 @@ rec {
'';
cargoArtifacts = deps;
doCheck = true;
env =
uwuenv.buildPackageEnv
// rocksdbEnv
// {
# required since we started using unstable reqwest apparently ... otherwise the all-features build will fail
RUSTFLAGS = "--cfg reqwest_unstable";
};
env = uwuenv.buildPackageEnv // rocksdbEnv;
passthru.env = uwuenv.buildPackageEnv // rocksdbEnv;
meta.mainProgram = crateInfo.pname;
inherit (features) cargoExtraArgs;

View File

@@ -1,6 +1,6 @@
use std::fmt::Write;
use conduwuit::{Err, Result, utils::response::LimitReadExt};
use conduwuit::{Err, Result};
use futures::StreamExt;
use ruma::{OwnedRoomId, OwnedServerName, OwnedUserId};
@@ -30,15 +30,12 @@ pub(super) async fn incoming_federation(&self) -> Result {
.federation_handletime
.read();
let mut msg = format!(
"Handling {} incoming PDUs across {} active transactions:\n",
map.len(),
self.services.transactions.txn_active_handle_count()
);
let mut msg = format!("Handling {} incoming pdus:\n", map.len());
for (r, (e, i)) in map.iter() {
let elapsed = i.elapsed();
writeln!(msg, "{} {}: {}m{}s", r, e, elapsed.as_secs() / 60, elapsed.as_secs() % 60)?;
}
msg
};
@@ -55,15 +52,7 @@ pub(super) async fn fetch_support_well_known(&self, server_name: OwnedServerName
.send()
.await?;
let text = response
.limit_read_text(
self.services
.config
.max_request_size
.try_into()
.expect("u64 fits into usize"),
)
.await?;
let text = response.text().await?;
if text.is_empty() {
return Err!("Response text/body is empty.");

View File

@@ -29,9 +29,7 @@ pub(super) async fn delete(
.delete(&mxc.as_str().try_into()?)
.await?;
return self
.write_str("Deleted the MXC from our database and on our filesystem.")
.await;
return Err!("Deleted the MXC from our database and on our filesystem.",);
}
if let Some(event_id) = event_id {
@@ -390,19 +388,3 @@ pub(super) async fn get_remote_thumbnail(
self.write_str(&format!("```\n{result:#?}\nreceived {len} bytes for file content.\n```"))
.await
}
#[admin_command]
pub(super) async fn delete_url_preview(&self, url: Option<String>, all: bool) -> Result {
if all {
self.services.media.clear_url_previews().await;
return self.write_str("Deleted all cached URL previews.").await;
}
let url = url.expect("clap enforces url is required unless --all");
self.services.media.remove_url_preview(&url).await?;
self.write_str(&format!("Deleted cached URL preview for: {url}"))
.await
}

View File

@@ -108,16 +108,4 @@ pub enum MediaCommand {
#[arg(long, default_value("800"))]
height: u32,
},
/// Deletes a cached URL preview, forcing it to be re-fetched.
/// Use --all to purge all cached URL previews.
DeleteUrlPreview {
/// The URL to clear from the saved preview data
#[arg(required_unless_present = "all")]
url: Option<String>,
/// Purge all cached URL previews
#[arg(long, conflicts_with = "url")]
all: bool,
},
}

View File

@@ -209,7 +209,7 @@ pub(super) async fn compact(
let parallelism = parallelism.unwrap_or(1);
let results = maps
.into_iter()
.try_stream::<conduwuit::Error>()
.try_stream()
.paralleln_and_then(runtime, parallelism, move |map| {
map.compact_blocking(options.clone())?;
Ok(map.name().to_owned())

View File

@@ -20,17 +20,7 @@ pub enum ResolverCommand {
name: Option<String>,
},
/// Flush a given server from the resolver caches or flush them completely
///
/// * Examples:
/// * Flush a specific server:
///
/// `!admin query resolver flush-cache matrix.example.com`
///
/// * Flush all resolver caches completely:
///
/// `!admin query resolver flush-cache --all`
#[command(verbatim_doc_comment)]
/// Flush a specific server from the resolver caches or everything
FlushCache {
name: Option<OwnedServerName>,

View File

@@ -252,13 +252,6 @@ pub(crate) async fn register_route(
}
}
// Don't allow registration with user IDs that aren't local
if !services.globals.user_is_local(&user_id) {
return Err!(Request(InvalidUsername(
"Username {body_username} is not local to this server"
)));
}
user_id
},
| Err(e) => {

View File

@@ -9,7 +9,7 @@
},
events::{
AnyGlobalAccountDataEventContent, AnyRoomAccountDataEventContent,
RoomAccountDataEventType,
GlobalAccountDataEventType, RoomAccountDataEventType,
},
serde::Raw,
};
@@ -126,6 +126,12 @@ async fn set_account_data(
)));
}
if event_type_s == GlobalAccountDataEventType::PushRules.to_cow_str() {
return Err!(Request(BadJson(
"This endpoint cannot be used for setting/configuring push rules."
)));
}
let data: serde_json::Value = serde_json::from_str(data.get())
.map_err(|e| err!(Request(BadJson(warn!("Invalid JSON provided: {e}")))))?;

View File

@@ -1,121 +0,0 @@
use axum::extract::State;
use axum_client_ip::InsecureClientIp;
use conduwuit::{Err, Result, at};
use futures::StreamExt;
use ruma::api::client::dehydrated_device::{
delete_dehydrated_device::unstable as delete_dehydrated_device,
get_dehydrated_device::unstable as get_dehydrated_device, get_events::unstable as get_events,
put_dehydrated_device::unstable as put_dehydrated_device,
};
use crate::Ruma;
const MAX_BATCH_EVENTS: usize = 50;
/// # `PUT /_matrix/client/../dehydrated_device`
///
/// Creates or overwrites the user's dehydrated device.
#[tracing::instrument(skip_all, fields(%client))]
pub(crate) async fn put_dehydrated_device_route(
State(services): State<crate::State>,
InsecureClientIp(client): InsecureClientIp,
body: Ruma<put_dehydrated_device::Request>,
) -> Result<put_dehydrated_device::Response> {
let sender_user = body
.sender_user
.as_deref()
.expect("AccessToken authentication required");
let device_id = body.body.device_id.clone();
services
.users
.set_dehydrated_device(sender_user, body.body)
.await?;
Ok(put_dehydrated_device::Response { device_id })
}
/// # `DELETE /_matrix/client/../dehydrated_device`
///
/// Deletes the user's dehydrated device without replacement.
#[tracing::instrument(skip_all, fields(%client))]
pub(crate) async fn delete_dehydrated_device_route(
State(services): State<crate::State>,
InsecureClientIp(client): InsecureClientIp,
body: Ruma<delete_dehydrated_device::Request>,
) -> Result<delete_dehydrated_device::Response> {
let sender_user = body.sender_user();
let device_id = services.users.get_dehydrated_device_id(sender_user).await?;
services.users.remove_device(sender_user, &device_id).await;
Ok(delete_dehydrated_device::Response { device_id })
}
/// # `GET /_matrix/client/../dehydrated_device`
///
/// Gets the user's dehydrated device
#[tracing::instrument(skip_all, fields(%client))]
pub(crate) async fn get_dehydrated_device_route(
State(services): State<crate::State>,
InsecureClientIp(client): InsecureClientIp,
body: Ruma<get_dehydrated_device::Request>,
) -> Result<get_dehydrated_device::Response> {
let sender_user = body.sender_user();
let device = services.users.get_dehydrated_device(sender_user).await?;
Ok(get_dehydrated_device::Response {
device_id: device.device_id,
device_data: device.device_data,
})
}
/// # `GET /_matrix/client/../dehydrated_device/{device_id}/events`
///
/// Paginates the events of the dehydrated device.
#[tracing::instrument(skip_all, fields(%client))]
pub(crate) async fn get_dehydrated_events_route(
State(services): State<crate::State>,
InsecureClientIp(client): InsecureClientIp,
body: Ruma<get_events::Request>,
) -> Result<get_events::Response> {
let sender_user = body.sender_user();
let device_id = &body.body.device_id;
let existing_id = services.users.get_dehydrated_device_id(sender_user).await;
if existing_id.as_ref().is_err()
|| existing_id
.as_ref()
.is_ok_and(|existing_id| existing_id != device_id)
{
return Err!(Request(Forbidden("Not the dehydrated device_id.")));
}
let since: Option<u64> = body
.body
.next_batch
.as_deref()
.map(str::parse)
.transpose()?;
let mut next_batch: Option<u64> = None;
let events = services
.users
.get_to_device_events(sender_user, device_id, since, None)
.take(MAX_BATCH_EVENTS)
.inspect(|&(count, _)| {
next_batch.replace(count);
})
.map(at!(1))
.collect()
.await;
Ok(get_events::Response {
events,
next_batch: next_batch.as_ref().map(ToString::to_string),
})
}

View File

@@ -6,7 +6,6 @@
Err, Result, err,
utils::{self, content_disposition::make_content_disposition, math::ruma_from_usize},
};
use conduwuit_core::error;
use conduwuit_service::{
Services,
media::{CACHE_CONTROL_IMMUTABLE, CORP_CROSS_ORIGIN, Dim, FileMeta, MXC_LENGTH},
@@ -145,22 +144,12 @@ pub(crate) async fn get_content_route(
server_name: &body.server_name,
media_id: &body.media_id,
};
let FileMeta {
content,
content_type,
content_disposition,
} = match fetch_file(&services, &mxc, user, body.timeout_ms, None).await {
| Ok(meta) => meta,
| Err(conduwuit::Error::Io(e)) => match e.kind() {
| std::io::ErrorKind::NotFound => return Err!(Request(NotFound("Media not found."))),
| std::io::ErrorKind::PermissionDenied => {
error!("Permission denied when trying to read file: {e:?}");
return Err!(Request(Unknown("Unknown error when fetching file.")));
},
| _ => return Err!(Request(Unknown("Unknown error when fetching file."))),
},
| Err(_) => return Err!(Request(Unknown("Unknown error when fetching file."))),
};
} = fetch_file(&services, &mxc, user, body.timeout_ms, None).await?;
Ok(get_content::v1::Response {
file: content.expect("entire file contents"),
@@ -196,18 +185,7 @@ pub(crate) async fn get_content_as_filename_route(
content,
content_type,
content_disposition,
} = match fetch_file(&services, &mxc, user, body.timeout_ms, None).await {
| Ok(meta) => meta,
| Err(conduwuit::Error::Io(e)) => match e.kind() {
| std::io::ErrorKind::NotFound => return Err!(Request(NotFound("Media not found."))),
| std::io::ErrorKind::PermissionDenied => {
error!("Permission denied when trying to read file: {e:?}");
return Err!(Request(Unknown("Unknown error when fetching file.")));
},
| _ => return Err!(Request(Unknown("Unknown error when fetching file."))),
},
| Err(_) => return Err!(Request(Unknown("Unknown error when fetching file."))),
};
} = fetch_file(&services, &mxc, user, body.timeout_ms, Some(&body.filename)).await?;
Ok(get_content_as_filename::v1::Response {
file: content.expect("entire file contents"),

View File

@@ -6,7 +6,6 @@
pub(super) mod backup;
pub(super) mod capabilities;
pub(super) mod context;
pub(super) mod dehydrated_device;
pub(super) mod device;
pub(super) mod directory;
pub(super) mod filter;
@@ -50,7 +49,6 @@
pub(super) use backup::*;
pub(super) use capabilities::*;
pub(super) use context::*;
pub(super) use dehydrated_device::*;
pub(super) use device::*;
pub(super) use directory::*;
pub(super) use filter::*;

View File

@@ -4,6 +4,7 @@
use axum_client_ip::InsecureClientIp;
use conduwuit::{Err, Event, Result, debug_info, info, matrix::pdu::PduEvent, utils::ReadyExt};
use conduwuit_service::Services;
use rand::Rng;
use ruma::{
EventId, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, UserId,
api::client::{
@@ -243,7 +244,7 @@ fn build_report(report: Report) -> RoomMessageEventContent {
/// random delay sending a response per spec suggestion regarding
/// enumerating for potential events existing in our server.
async fn delay_response() {
let time_to_wait = rand::random_range(2..5);
let time_to_wait = rand::thread_rng().gen_range(2..5);
debug_info!(
"Got successful /report request, waiting {time_to_wait} seconds before sending \
successful response."

View File

@@ -50,8 +50,8 @@ pub(crate) async fn send_message_event_route(
// Check if this is a new transaction id
if let Ok(response) = services
.transactions
.get_client_txn(sender_user, sender_device, &body.txn_id)
.transaction_ids
.existing_txnid(sender_user, sender_device, &body.txn_id)
.await
{
// The client might have sent a txnid of the /sendToDevice endpoint
@@ -92,7 +92,7 @@ pub(crate) async fn send_message_event_route(
)
.await?;
services.transactions.add_client_txnid(
services.transaction_ids.add_txnid(
sender_user,
sender_device,
&body.txn_id,

View File

@@ -11,7 +11,7 @@
use axum::extract::State;
use axum_client_ip::InsecureClientIp;
use conduwuit::{
Result, at, extract_variant,
Result, extract_variant,
utils::{
ReadyExt, TryFutureExtExt,
stream::{BroadbandExt, Tools, WidebandExt},
@@ -385,7 +385,6 @@ pub(crate) async fn build_sync_events(
last_sync_end_count,
Some(current_count),
)
.map(at!(1))
.collect::<Vec<_>>();
let device_one_time_keys_count = services

View File

@@ -336,9 +336,7 @@ async fn handle_lists<'a, Rooms, AllRooms>(
let ranges = list.ranges.clone();
for mut range in ranges {
range.0 = range
.0
.min(UInt::try_from(active_rooms.len()).unwrap_or(UInt::MAX));
range.0 = uint!(0);
range.1 = range.1.checked_add(uint!(1)).unwrap_or(range.1);
range.1 = range
.1
@@ -1029,7 +1027,6 @@ async fn collect_to_device(
events: services
.users
.get_to_device_events(sender_user, sender_device, None, Some(next_batch))
.map(at!(1))
.collect()
.await,
})

View File

@@ -26,8 +26,8 @@ pub(crate) async fn send_event_to_device_route(
// Check if this is a new transaction id
if services
.transactions
.get_client_txn(sender_user, sender_device, &body.txn_id)
.transaction_ids
.existing_txnid(sender_user, sender_device, &body.txn_id)
.await
.is_ok()
{
@@ -104,8 +104,8 @@ pub(crate) async fn send_event_to_device_route(
// Save transaction id with empty data
services
.transactions
.add_client_txnid(sender_user, sender_device, &body.txn_id, &[]);
.transaction_ids
.add_txnid(sender_user, sender_device, &body.txn_id, &[]);
Ok(send_event_to_device::v3::Response {})
}

View File

@@ -50,7 +50,6 @@ pub(crate) async fn get_supported_versions_route(
("org.matrix.msc2836".to_owned(), true), /* threading/threads (https://github.com/matrix-org/matrix-spec-proposals/pull/2836) */
("org.matrix.msc2946".to_owned(), true), /* spaces/hierarchy summaries (https://github.com/matrix-org/matrix-spec-proposals/pull/2946) */
("org.matrix.msc3026.busy_presence".to_owned(), true), /* busy presence status (https://github.com/matrix-org/matrix-spec-proposals/pull/3026) */
("org.matrix.msc3814".to_owned(), true), /* dehydrated devices */
("org.matrix.msc3827".to_owned(), true), /* filtering of /publicRooms by room type (https://github.com/matrix-org/matrix-spec-proposals/pull/3827) */
("org.matrix.msc3952_intentional_mentions".to_owned(), true), /* intentional mentions (https://github.com/matrix-org/matrix-spec-proposals/pull/3952) */
("org.matrix.msc3916.stable".to_owned(), true), /* authenticated media (https://github.com/matrix-org/matrix-spec-proposals/pull/3916) */

View File

@@ -27,32 +27,10 @@ pub(crate) async fn well_known_client(
identity_server: None,
sliding_sync_proxy: Some(SlidingSyncProxyInfo { url: client_url }),
tile_server: None,
rtc_foci: services
.config
.matrix_rtc
.effective_foci(&services.config.well_known.rtc_focus_server_urls)
.to_vec(),
rtc_foci: services.config.well_known.rtc_focus_server_urls.clone(),
})
}
/// # `GET /_matrix/client/v1/rtc/transports`
/// # `GET /_matrix/client/unstable/org.matrix.msc4143/rtc/transports`
///
/// Returns the list of MatrixRTC foci (transports) configured for this
/// homeserver, implementing MSC4143.
pub(crate) async fn get_rtc_transports(
State(services): State<crate::State>,
_body: Ruma<ruma::api::client::discovery::get_rtc_transports::Request>,
) -> Result<ruma::api::client::discovery::get_rtc_transports::Response> {
Ok(ruma::api::client::discovery::get_rtc_transports::Response::new(
services
.config
.matrix_rtc
.effective_foci(&services.config.well_known.rtc_focus_server_urls)
.to_vec(),
))
}
/// # `GET /.well-known/matrix/support`
///
/// Server support contact and support page of a homeserver's domain.

View File

@@ -160,10 +160,6 @@ pub fn build(router: Router<State>, server: &Server) -> Router<State> {
.ruma_route(&client::update_device_route)
.ruma_route(&client::delete_device_route)
.ruma_route(&client::delete_devices_route)
.ruma_route(&client::put_dehydrated_device_route)
.ruma_route(&client::delete_dehydrated_device_route)
.ruma_route(&client::get_dehydrated_device_route)
.ruma_route(&client::get_dehydrated_events_route)
.ruma_route(&client::get_tags_route)
.ruma_route(&client::update_tag_route)
.ruma_route(&client::delete_tag_route)
@@ -188,7 +184,6 @@ pub fn build(router: Router<State>, server: &Server) -> Router<State> {
.ruma_route(&client::put_suspended_status)
.ruma_route(&client::well_known_support)
.ruma_route(&client::well_known_client)
.ruma_route(&client::get_rtc_transports)
.route("/_conduwuit/server_version", get(client::conduwuit_server_version))
.route("/_continuwuity/server_version", get(client::conduwuit_server_version))
.ruma_route(&client::room_initial_sync_route)

View File

@@ -14,8 +14,7 @@
pin_mut,
};
use ruma::{
CanonicalJsonObject, CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedServerName,
OwnedUserId, UserId,
CanonicalJsonObject, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId,
api::{
AuthScheme, IncomingRequest, Metadata,
client::{
@@ -67,17 +66,23 @@ pub(super) async fn auth(
if metadata.authentication == AuthScheme::None {
match metadata {
| &get_public_rooms::v3::Request::METADATA => {
match token {
| Token::Appservice(_) | Token::User(_) => {
// we should have validated the token above
// already
},
| Token::None | Token::Invalid => {
return Err(Error::BadRequest(
ErrorKind::MissingToken,
"Missing or invalid access token.",
));
},
if !services
.server
.config
.allow_public_room_directory_without_auth
{
match token {
| Token::Appservice(_) | Token::User(_) => {
// we should have validated the token above
// already
},
| Token::None | Token::Invalid => {
return Err(Error::BadRequest(
ErrorKind::MissingToken,
"Missing or invalid access token.",
));
},
}
}
},
| &get_profile::v3::Request::METADATA
@@ -229,33 +234,10 @@ async fn auth_appservice(
return Err!(Request(Exclusive("User is not in namespace.")));
}
// MSC3202/MSC4190: Handle device_id masquerading for appservices.
// The device_id can be provided via `device_id` or
// `org.matrix.msc3202.device_id` query parameter.
let sender_device = if let Some(ref device_id_str) = request.query.device_id {
let device_id: &DeviceId = device_id_str.as_str().into();
// Verify the device exists for this user
if services
.users
.get_device_metadata(&user_id, device_id)
.await
.is_err()
{
return Err!(Request(Forbidden(
"Device does not exist for user or appservice cannot masquerade as this device."
)));
}
Some(device_id.to_owned())
} else {
None
};
Ok(Auth {
origin: None,
sender_user: Some(user_id),
sender_device,
sender_device: None,
appservice_info: Some(*info),
})
}

View File

@@ -11,10 +11,6 @@
pub(super) struct QueryParams {
pub(super) access_token: Option<String>,
pub(super) user_id: Option<String>,
/// Device ID for appservice device masquerading (MSC3202/MSC4190).
/// Can be provided as `device_id` or `org.matrix.msc3202.device_id`.
#[serde(alias = "org.matrix.msc3202.device_id")]
pub(super) device_id: Option<String>,
}
pub(super) struct Request {

View File

@@ -40,7 +40,7 @@ pub(crate) async fn get_room_information_route(
servers.sort_unstable();
servers.dedup();
servers.shuffle(&mut rand::rng());
servers.shuffle(&mut rand::thread_rng());
// insert our server as the very first choice if in list
if let Some(server_index) = servers

View File

@@ -1,33 +1,27 @@
use std::{
collections::{BTreeMap, HashMap, HashSet},
net::IpAddr,
time::{Duration, Instant},
};
use std::{collections::BTreeMap, net::IpAddr, time::Instant};
use axum::extract::State;
use axum_client_ip::InsecureClientIp;
use conduwuit::{
Err, Error, Result, debug, debug_warn, err, error,
result::LogErr,
state_res::lexicographical_topological_sort,
trace,
utils::{
IterStream, ReadyExt, millis_since_unix_epoch,
stream::{BroadbandExt, TryBroadbandExt, automatic_width},
},
warn,
};
use conduwuit_service::{
Services,
sending::{EDU_LIMIT, PDU_LIMIT},
};
use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
use http::StatusCode;
use itertools::Itertools;
use ruma::{
CanonicalJsonObject, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, OwnedUserId,
RoomId, ServerName, UserId,
CanonicalJsonObject, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, ServerName, UserId,
api::{
client::error::{ErrorKind, ErrorKind::LimitExceeded},
client::error::ErrorKind,
federation::transactions::{
edu::{
DeviceListUpdateContent, DirectDeviceContent, Edu, PresenceContent,
@@ -38,16 +32,9 @@
},
},
events::receipt::{ReceiptEvent, ReceiptEventContent, ReceiptType},
int,
serde::Raw,
to_device::DeviceIdOrAllDevices,
uint,
};
use service::transactions::{
FederationTxnState, TransactionError, TxnKey, WrappedTransactionResponse,
};
use tokio::sync::watch::{Receiver, Sender};
use tracing::instrument;
use crate::Ruma;
@@ -57,6 +44,15 @@
/// # `PUT /_matrix/federation/v1/send/{txnId}`
///
/// Push EDUs and PDUs to this server.
#[tracing::instrument(
name = "txn",
level = "debug",
skip_all,
fields(
%client,
origin = body.origin().as_str()
),
)]
pub(crate) async fn send_transaction_message_route(
State(services): State<crate::State>,
InsecureClientIp(client): InsecureClientIp,
@@ -80,73 +76,16 @@ pub(crate) async fn send_transaction_message_route(
)));
}
let txn_key = (body.origin().to_owned(), body.transaction_id.clone());
// Atomically check cache, join active, or start new transaction
match services
.transactions
.get_or_start_federation_txn(txn_key.clone())?
{
| FederationTxnState::Cached(response) => {
// Already responded
Ok(response)
},
| FederationTxnState::Active(receiver) => {
// Another thread is processing
wait_for_result(receiver).await
},
| FederationTxnState::Started { receiver, sender } => {
// We're the first, spawn the processing task
services
.server
.runtime()
.spawn(process_inbound_transaction(services, body, client, txn_key, sender));
// and wait for it
wait_for_result(receiver).await
},
}
}
async fn wait_for_result(
mut recv: Receiver<WrappedTransactionResponse>,
) -> Result<send_transaction_message::v1::Response> {
if tokio::time::timeout(Duration::from_secs(50), recv.changed())
.await
.is_err()
{
// Took too long, return 429 to encourage the sender to try again
return Err(Error::BadRequest(
LimitExceeded { retry_after: None },
"Transaction is being still being processed. Please try again later.",
));
}
let value = recv.borrow_and_update();
match value.clone() {
| Some(Ok(response)) => Ok(response),
| Some(Err(err)) => Err(transaction_error_to_response(&err)),
| None => Err(Error::Request(
ErrorKind::Unknown,
"Transaction processing failed unexpectedly".into(),
StatusCode::INTERNAL_SERVER_ERROR,
)),
}
}
#[instrument(
skip_all,
fields(
id = ?body.transaction_id.as_str(),
origin = ?body.origin()
)
)]
async fn process_inbound_transaction(
services: crate::State,
body: Ruma<send_transaction_message::v1::Request>,
client: IpAddr,
txn_key: TxnKey,
sender: Sender<WrappedTransactionResponse>,
) {
let txn_start_time = Instant::now();
trace!(
pdus = body.pdus.len(),
edus = body.edus.len(),
elapsed = ?txn_start_time.elapsed(),
id = %body.transaction_id,
origin = %body.origin(),
"Starting txn",
);
let pdus = body
.pdus
.iter()
@@ -163,79 +102,40 @@ async fn process_inbound_transaction(
.filter_map(Result::ok)
.stream();
debug!(pdus = body.pdus.len(), edus = body.edus.len(), "Processing transaction",);
let results = match handle(&services, &client, body.origin(), pdus, edus).await {
| Ok(results) => results,
| Err(err) => {
fail_federation_txn(services, &txn_key, &sender, err);
return;
},
};
for (id, result) in &results {
if let Err(e) = result {
if matches!(e, Error::BadRequest(ErrorKind::NotFound, _)) {
debug_warn!("Incoming PDU failed {id}: {e:?}");
}
}
}
let results = handle(&services, &client, body.origin(), txn_start_time, pdus, edus).await?;
debug!(
pdus = body.pdus.len(),
edus = body.edus.len(),
elapsed = ?txn_start_time.elapsed(),
"Finished processing transaction"
id = %body.transaction_id,
origin = %body.origin(),
"Finished txn",
);
for (id, result) in &results {
if let Err(e) = result {
if matches!(e, Error::BadRequest(ErrorKind::NotFound, _)) {
warn!("Incoming PDU failed {id}: {e:?}");
}
}
}
let response = send_transaction_message::v1::Response {
Ok(send_transaction_message::v1::Response {
pdus: results
.into_iter()
.map(|(e, r)| (e, r.map_err(error::sanitized_message)))
.collect(),
};
services
.transactions
.finish_federation_txn(txn_key, sender, response);
})
}
/// Handles a failed federation transaction by sending the error through
/// the channel and cleaning up the transaction state. This allows waiters to
/// receive an appropriate error response.
fn fail_federation_txn(
services: crate::State,
txn_key: &TxnKey,
sender: &Sender<WrappedTransactionResponse>,
err: TransactionError,
) {
debug!("Transaction failed: {err}");
// Remove from active state so the transaction can be retried
services.transactions.remove_federation_txn(txn_key);
// Send the error to any waiters
if let Err(e) = sender.send(Some(Err(err))) {
debug_warn!("Failed to send transaction error to receivers: {e}");
}
}
/// Converts a TransactionError into an appropriate HTTP error response.
fn transaction_error_to_response(err: &TransactionError) -> Error {
match err {
| TransactionError::ShuttingDown => Error::Request(
ErrorKind::Unknown,
"Server is shutting down, please retry later".into(),
StatusCode::SERVICE_UNAVAILABLE,
),
}
}
async fn handle(
services: &Services,
client: &IpAddr,
origin: &ServerName,
started: Instant,
pdus: impl Stream<Item = Pdu> + Send,
edus: impl Stream<Item = Edu> + Send,
) -> std::result::Result<ResolvedMap, TransactionError> {
) -> Result<ResolvedMap> {
// group pdus by room
let pdus = pdus
.collect()
@@ -252,7 +152,7 @@ async fn handle(
.into_iter()
.try_stream()
.broad_and_then(|(room_id, pdus): (_, Vec<_>)| {
handle_room(services, client, origin, room_id, pdus.into_iter())
handle_room(services, client, origin, started, room_id, pdus.into_iter())
.map_ok(Vec::into_iter)
.map_ok(IterStream::try_stream)
})
@@ -269,51 +169,14 @@ async fn handle(
Ok(results)
}
/// Attempts to build a localised directed acyclic graph out of the given PDUs,
/// returning them in a topologically sorted order.
///
/// This is used to attempt to process PDUs in an order that respects their
/// dependencies, however it is ultimately the sender's responsibility to send
/// them in a processable order, so this is just a best effort attempt. It does
/// not account for power levels or other tie breaks.
async fn build_local_dag(
pdu_map: &HashMap<OwnedEventId, CanonicalJsonObject>,
) -> Result<Vec<OwnedEventId>> {
debug_assert!(pdu_map.len() >= 2, "needless call to build_local_dag with less than 2 PDUs");
let mut dag: HashMap<OwnedEventId, HashSet<OwnedEventId>> = HashMap::new();
for (event_id, value) in pdu_map {
let prev_events = value
.get("prev_events")
.expect("pdu must have prev_events")
.as_array()
.expect("prev_events must be an array")
.iter()
.map(|v| {
OwnedEventId::parse(v.as_str().expect("prev_events values must be strings"))
.expect("prev_events must be valid event IDs")
})
.collect::<HashSet<OwnedEventId>>();
dag.insert(event_id.clone(), prev_events);
}
lexicographical_topological_sort(&dag, &|_| async {
// Note: we don't bother fetching power levels because that would massively slow
// this function down. This is a best-effort attempt to order events correctly
// for processing, however ultimately that should be the sender's job.
Ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0))))
})
.await
.map_err(|e| err!("failed to resolve local graph: {e}"))
}
async fn handle_room(
services: &Services,
_client: &IpAddr,
origin: &ServerName,
txn_start_time: Instant,
room_id: OwnedRoomId,
pdus: impl Iterator<Item = Pdu> + Send,
) -> std::result::Result<Vec<(OwnedEventId, Result)>, TransactionError> {
) -> Result<Vec<(OwnedEventId, Result)>> {
let _room_lock = services
.rooms
.event_handler
@@ -322,40 +185,27 @@ async fn handle_room(
.await;
let room_id = &room_id;
let pdu_map: HashMap<OwnedEventId, CanonicalJsonObject> = pdus
.into_iter()
.map(|(_, event_id, value)| (event_id, value))
.collect();
// Try to sort PDUs by their dependencies, but fall back to arbitrary order on
// failure (e.g., cycles). This is best-effort; proper ordering is the sender's
// responsibility.
let sorted_event_ids = if pdu_map.len() >= 2 {
build_local_dag(&pdu_map).await.unwrap_or_else(|e| {
debug_warn!("Failed to build local DAG for room {room_id}: {e}");
pdu_map.keys().cloned().collect()
pdus.try_stream()
.and_then(|(_, event_id, value)| async move {
services.server.check_running()?;
let pdu_start_time = Instant::now();
let result = services
.rooms
.event_handler
.handle_incoming_pdu(origin, room_id, &event_id, value, true)
.await
.map(|_| ());
debug!(
pdu_elapsed = ?pdu_start_time.elapsed(),
txn_elapsed = ?txn_start_time.elapsed(),
"Finished PDU {event_id}",
);
Ok((event_id, result))
})
} else {
pdu_map.keys().cloned().collect()
};
let mut results = Vec::with_capacity(sorted_event_ids.len());
for event_id in sorted_event_ids {
let value = pdu_map
.get(&event_id)
.expect("sorted event IDs must be from the original map")
.clone();
services
.server
.check_running()
.map_err(|_| TransactionError::ShuttingDown)?;
let result = services
.rooms
.event_handler
.handle_incoming_pdu(origin, room_id, &event_id, value, true)
.await
.map(|_| ());
results.push((event_id, result));
}
Ok(results)
.try_collect()
.await
}
async fn handle_edu(services: &Services, client: &IpAddr, origin: &ServerName, edu: Edu) {
@@ -628,8 +478,8 @@ async fn handle_edu_direct_to_device(
// Check if this is a new transaction id
if services
.transactions
.get_client_txn(sender, None, message_id)
.transaction_ids
.existing_txnid(sender, None, message_id)
.await
.is_ok()
{
@@ -648,8 +498,8 @@ async fn handle_edu_direct_to_device(
// Save transaction id with empty data
services
.transactions
.add_client_txnid(sender, None, message_id, &[]);
.transaction_ids
.add_txnid(sender, None, message_id, &[]);
}
async fn handle_edu_direct_to_device_user<Event: Send + Sync>(

View File

@@ -86,7 +86,6 @@ libloading.optional = true
log.workspace = true
num-traits.workspace = true
rand.workspace = true
rand_core = { version = "0.6.4", features = ["getrandom"] }
regex.workspace = true
reqwest.workspace = true
ring.workspace = true

View File

@@ -368,31 +368,6 @@ pub struct Config {
#[serde(default = "default_max_fetch_prev_events")]
pub max_fetch_prev_events: u16,
/// How many incoming federation transactions the server is willing to be
/// processing at any given time before it becomes overloaded and starts
/// rejecting further transactions until some slots become available.
///
/// Setting this value too low or too high may result in unstable
/// federation, and setting it too high may cause runaway resource usage.
///
/// default: 150
#[serde(default = "default_max_concurrent_inbound_transactions")]
pub max_concurrent_inbound_transactions: usize,
/// Maximum age (in seconds) for cached federation transaction responses.
/// Entries older than this will be removed during cleanup.
///
/// default: 7200 (2 hours)
#[serde(default = "default_transaction_id_cache_max_age_secs")]
pub transaction_id_cache_max_age_secs: u64,
/// Maximum number of cached federation transaction responses.
/// When the cache exceeds this limit, older entries will be removed.
///
/// default: 8192
#[serde(default = "default_transaction_id_cache_max_entries")]
pub transaction_id_cache_max_entries: usize,
/// Default/base connection timeout (seconds). This is used only by URL
/// previews and update/news endpoint checks.
///
@@ -678,6 +653,12 @@ pub struct Config {
#[serde(default)]
pub allow_public_room_directory_over_federation: bool,
/// Set this to true to allow your server's public room directory to be
/// queried without client authentication (access token) through the Client
/// APIs. Set this to false to protect against /publicRooms spiders.
#[serde(default)]
pub allow_public_room_directory_without_auth: bool,
/// Allow guests/unauthenticated users to access TURN credentials.
///
/// This is the equivalent of Synapse's `turn_allow_guests` config option.
@@ -1544,7 +1525,7 @@ pub struct Config {
/// sender user's server name, inbound federation X-Matrix origin, and
/// outbound federation handler.
///
/// You can set this to [".*"] to block all servers by default, and then
/// You can set this to ["*"] to block all servers by default, and then
/// use `allowed_remote_server_names` to allow only specific servers.
///
/// example: ["badserver\\.tld$", "badphrase", "19dollarfortnitecards"]
@@ -2080,12 +2061,6 @@ pub struct Config {
/// display: nested
#[serde(default)]
pub blurhashing: BlurhashConfig,
/// Configuration for MatrixRTC (MSC4143) transport discovery.
/// display: nested
#[serde(default)]
pub matrix_rtc: MatrixRtcConfig,
#[serde(flatten)]
#[allow(clippy::zero_sized_map_values)]
// this is a catchall, the map shouldn't be zero at runtime
@@ -2151,16 +2126,17 @@ pub struct WellKnownConfig {
/// listed.
pub support_mxid: Option<OwnedUserId>,
/// **DEPRECATED**: Use `[global.matrix_rtc].foci` instead.
///
/// A list of MatrixRTC foci URLs which will be served as part of the
/// MSC4143 client endpoint at /.well-known/matrix/client.
/// MSC4143 client endpoint at /.well-known/matrix/client. If you're
/// setting up livekit, you'd want something like:
/// rtc_focus_server_urls = [
/// { type = "livekit", livekit_service_url = "https://livekit.example.com" },
/// ]
///
/// This option is deprecated and will be removed in a future release.
/// Please migrate to the new `[global.matrix_rtc]` config section.
/// To disable, set this to be an empty vector (`[]`).
///
/// default: []
#[serde(default)]
#[serde(default = "default_rtc_focus_urls")]
pub rtc_focus_server_urls: Vec<RtcFocusInfo>,
}
@@ -2189,43 +2165,6 @@ pub struct BlurhashConfig {
pub blurhash_max_raw_size: u64,
}
#[derive(Clone, Debug, Deserialize, Default)]
#[config_example_generator(filename = "conduwuit-example.toml", section = "global.matrix_rtc")]
pub struct MatrixRtcConfig {
/// A list of MatrixRTC foci (transports) which will be served via the
/// MSC4143 RTC transports endpoint at
/// `/_matrix/client/v1/rtc/transports`. If you're setting up livekit,
/// you'd want something like:
/// ```toml
/// [global.matrix_rtc]
/// foci = [
/// { type = "livekit", livekit_service_url = "https://livekit.example.com" },
/// ]
/// ```
///
/// To disable, set this to an empty list (`[]`).
///
/// default: []
#[serde(default)]
pub foci: Vec<RtcFocusInfo>,
}
impl MatrixRtcConfig {
/// Returns the effective foci, falling back to the deprecated
/// `rtc_focus_server_urls` if the new config is empty.
#[must_use]
pub fn effective_foci<'a>(
&'a self,
deprecated_foci: &'a [RtcFocusInfo],
) -> &'a [RtcFocusInfo] {
if !self.foci.is_empty() {
&self.foci
} else {
deprecated_foci
}
}
}
#[derive(Clone, Debug, Default, Deserialize)]
#[config_example_generator(filename = "conduwuit-example.toml", section = "global.ldap")]
pub struct LdapConfig {
@@ -2419,7 +2358,6 @@ pub struct DraupnirConfig {
"well_known_support_email",
"well_known_support_mxid",
"registration_token_file",
"well_known.rtc_focus_server_urls",
];
impl Config {
@@ -2602,12 +2540,6 @@ fn default_pusher_idle_timeout() -> u64 { 15 }
fn default_max_fetch_prev_events() -> u16 { 192_u16 }
fn default_max_concurrent_inbound_transactions() -> usize { 150 }
fn default_transaction_id_cache_max_age_secs() -> u64 { 60 * 60 * 2 }
fn default_transaction_id_cache_max_entries() -> usize { 8192 }
fn default_tracing_flame_filter() -> String {
cfg!(debug_assertions)
.then_some("trace,h2=off")
@@ -2703,6 +2635,9 @@ fn default_rocksdb_stats_level() -> u8 { 1 }
#[inline]
pub fn default_default_room_version() -> RoomVersionId { RoomVersionId::V11 }
#[must_use]
pub fn default_rtc_focus_urls() -> Vec<RtcFocusInfo> { vec![] }
fn default_ip_range_denylist() -> Vec<String> {
vec![
"127.0.0.0/8".to_owned(),

View File

@@ -14,7 +14,6 @@
static VERSION: OnceLock<String> = OnceLock::new();
static VERSION_UA: OnceLock<String> = OnceLock::new();
static USER_AGENT: OnceLock<String> = OnceLock::new();
static USER_AGENT_MEDIA: OnceLock<String> = OnceLock::new();
#[inline]
#[must_use]
@@ -22,22 +21,14 @@ pub fn name() -> &'static str { BRANDING }
#[inline]
pub fn version() -> &'static str { VERSION.get_or_init(init_version) }
#[inline]
pub fn version_ua() -> &'static str { VERSION_UA.get_or_init(init_version_ua) }
#[inline]
pub fn user_agent() -> &'static str { USER_AGENT.get_or_init(init_user_agent) }
#[inline]
pub fn user_agent_media() -> &'static str { USER_AGENT_MEDIA.get_or_init(init_user_agent_media) }
fn init_user_agent() -> String { format!("{}/{} (bot; +{WEBSITE})", name(), version_ua()) }
fn init_user_agent_media() -> String {
format!("{}/{} (embedbot; facebookexternalhit/1.1; +{WEBSITE})", name(), version_ua())
}
fn init_version_ua() -> String {
conduwuit_build_metadata::version_tag()
.map_or_else(|| SEMANTIC.to_owned(), |extra| format!("{SEMANTIC}+{extra}"))

View File

@@ -1,552 +0,0 @@
#[cfg(conduwuit_bench)]
extern crate test;
use std::{
borrow::Borrow,
collections::{HashMap, HashSet},
sync::atomic::{AtomicU64, Ordering::SeqCst},
};
use futures::{future, future::ready};
use maplit::{btreemap, hashmap, hashset};
use ruma::{
EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, RoomVersionId, Signatures, UserId,
events::{
StateEventType, TimelineEventType,
room::{
join_rules::{JoinRule, RoomJoinRulesEventContent},
member::{MembershipState, RoomMemberEventContent},
},
},
int, room_id, uint, user_id,
};
use serde_json::{
json,
value::{RawValue as RawJsonValue, to_raw_value as to_raw_json_value},
};
use crate::{
matrix::{Event, Pdu, pdu::EventHash},
state_res::{self as state_res, Error, Result, StateMap},
};
static SERVER_TIMESTAMP: AtomicU64 = AtomicU64::new(0);
#[cfg(conduwuit_bench)]
#[cfg_attr(conduwuit_bench, bench)]
fn lexico_topo_sort(c: &mut test::Bencher) {
let graph = hashmap! {
event_id("l") => hashset![event_id("o")],
event_id("m") => hashset![event_id("n"), event_id("o")],
event_id("n") => hashset![event_id("o")],
event_id("o") => hashset![], // "o" has zero outgoing edges but 4 incoming edges
event_id("p") => hashset![event_id("o")],
};
c.iter(|| {
let _ = state_res::lexicographical_topological_sort(&graph, &|_| {
future::ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0))))
});
});
}
#[cfg(conduwuit_bench)]
#[cfg_attr(conduwuit_bench, bench)]
fn resolution_shallow_auth_chain(c: &mut test::Bencher) {
let mut store = TestStore(hashmap! {});
// build up the DAG
let (state_at_bob, state_at_charlie, _) = store.set_up();
c.iter(|| async {
let ev_map = store.0.clone();
let state_sets = [&state_at_bob, &state_at_charlie];
let fetch = |id: OwnedEventId| ready(ev_map.get(&id).map(ToOwned::to_owned));
let exists = |id: OwnedEventId| ready(ev_map.get(&id).is_some());
let auth_chain_sets: Vec<HashSet<_>> = state_sets
.iter()
.map(|map| {
store
.auth_event_ids(room_id(), map.values().cloned().collect())
.unwrap()
})
.collect();
let _ = match state_res::resolve(
&RoomVersionId::V6,
state_sets.into_iter(),
&auth_chain_sets,
&fetch,
&exists,
)
.await
{
| Ok(state) => state,
| Err(e) => panic!("{e}"),
};
});
}
#[cfg(conduwuit_bench)]
#[cfg_attr(conduwuit_bench, bench)]
fn resolve_deeper_event_set(c: &mut test::Bencher) {
let mut inner = INITIAL_EVENTS();
let ban = BAN_STATE_SET();
inner.extend(ban);
let store = TestStore(inner.clone());
let state_set_a = [
inner.get(&event_id("CREATE")).unwrap(),
inner.get(&event_id("IJR")).unwrap(),
inner.get(&event_id("IMA")).unwrap(),
inner.get(&event_id("IMB")).unwrap(),
inner.get(&event_id("IMC")).unwrap(),
inner.get(&event_id("MB")).unwrap(),
inner.get(&event_id("PA")).unwrap(),
]
.iter()
.map(|ev| {
(
(ev.event_type().clone().into(), ev.state_key().unwrap().into()),
ev.event_id().to_owned(),
)
})
.collect::<StateMap<_>>();
let state_set_b = [
inner.get(&event_id("CREATE")).unwrap(),
inner.get(&event_id("IJR")).unwrap(),
inner.get(&event_id("IMA")).unwrap(),
inner.get(&event_id("IMB")).unwrap(),
inner.get(&event_id("IMC")).unwrap(),
inner.get(&event_id("IME")).unwrap(),
inner.get(&event_id("PA")).unwrap(),
]
.iter()
.map(|ev| {
(
(ev.event_type().clone().into(), ev.state_key().unwrap().into()),
ev.event_id().to_owned(),
)
})
.collect::<StateMap<_>>();
c.iter(|| async {
let state_sets = [&state_set_a, &state_set_b];
let auth_chain_sets: Vec<HashSet<_>> = state_sets
.iter()
.map(|map| {
store
.auth_event_ids(room_id(), map.values().cloned().collect())
.unwrap()
})
.collect();
let fetch = |id: OwnedEventId| ready(inner.get(&id).map(ToOwned::to_owned));
let exists = |id: OwnedEventId| ready(inner.get(&id).is_some());
let _ = match state_res::resolve(
&RoomVersionId::V6,
state_sets.into_iter(),
&auth_chain_sets,
&fetch,
&exists,
)
.await
{
| Ok(state) => state,
| Err(_) => panic!("resolution failed during benchmarking"),
};
});
}
//*/////////////////////////////////////////////////////////////////////
//
// IMPLEMENTATION DETAILS AHEAD
//
/////////////////////////////////////////////////////////////////////*/
struct TestStore<E: Event>(HashMap<OwnedEventId, E>);
#[allow(unused)]
impl<E: Event + Clone> TestStore<E> {
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<E> {
self.0
.get(event_id)
.cloned()
.ok_or_else(|| Error::NotFound(format!("{} not found", event_id)))
}
/// Returns the events that correspond to the `event_ids` sorted in the same
/// order.
fn get_events(&self, room_id: &RoomId, event_ids: &[OwnedEventId]) -> Result<Vec<E>> {
let mut events = vec![];
for id in event_ids {
events.push(self.get_event(room_id, id)?);
}
Ok(events)
}
/// Returns a Vec of the related auth events to the given `event`.
fn auth_event_ids(
&self,
room_id: &RoomId,
event_ids: Vec<OwnedEventId>,
) -> Result<HashSet<OwnedEventId>> {
let mut result = HashSet::new();
let mut stack = event_ids;
// DFS for auth event chain
while !stack.is_empty() {
let ev_id = stack.pop().unwrap();
if result.contains(&ev_id) {
continue;
}
result.insert(ev_id.clone());
let event = self.get_event(room_id, ev_id.borrow())?;
stack.extend(event.auth_events().map(ToOwned::to_owned));
}
Ok(result)
}
/// Returns a vector representing the difference in auth chains of the given
/// `events`.
fn auth_chain_diff(
&self,
room_id: &RoomId,
event_ids: Vec<Vec<OwnedEventId>>,
) -> Result<Vec<OwnedEventId>> {
let mut auth_chain_sets = vec![];
for ids in event_ids {
// TODO state store `auth_event_ids` returns self in the event ids list
// when an event returns `auth_event_ids` self is not contained
let chain = self
.auth_event_ids(room_id, ids)?
.into_iter()
.collect::<HashSet<_>>();
auth_chain_sets.push(chain);
}
if let Some(first) = auth_chain_sets.first().cloned() {
let common = auth_chain_sets
.iter()
.skip(1)
.fold(first, |a, b| a.intersection(b).cloned().collect::<HashSet<_>>());
Ok(auth_chain_sets
.into_iter()
.flatten()
.filter(|id| !common.contains(id))
.collect())
} else {
Ok(vec![])
}
}
}
impl TestStore<Pdu> {
#[allow(clippy::type_complexity)]
fn set_up(
&mut self,
) -> (StateMap<OwnedEventId>, StateMap<OwnedEventId>, StateMap<OwnedEventId>) {
let create_event = to_pdu_event::<&EventId>(
"CREATE",
alice(),
TimelineEventType::RoomCreate,
Some(""),
to_raw_json_value(&json!({ "creator": alice() })).unwrap(),
&[],
&[],
);
let cre = create_event.event_id().to_owned();
self.0.insert(cre.clone(), create_event.clone());
let alice_mem = to_pdu_event(
"IMA",
alice(),
TimelineEventType::RoomMember,
Some(alice().to_string().as_str()),
member_content_join(),
&[cre.clone()],
&[cre.clone()],
);
self.0
.insert(alice_mem.event_id().to_owned(), alice_mem.clone());
let join_rules = to_pdu_event(
"IJR",
alice(),
TimelineEventType::RoomJoinRules,
Some(""),
to_raw_json_value(&RoomJoinRulesEventContent::new(JoinRule::Public)).unwrap(),
&[cre.clone(), alice_mem.event_id().to_owned()],
&[alice_mem.event_id().to_owned()],
);
self.0
.insert(join_rules.event_id().to_owned(), join_rules.clone());
// Bob and Charlie join at the same time, so there is a fork
// this will be represented in the state_sets when we resolve
let bob_mem = to_pdu_event(
"IMB",
bob(),
TimelineEventType::RoomMember,
Some(bob().to_string().as_str()),
member_content_join(),
&[cre.clone(), join_rules.event_id().to_owned()],
&[join_rules.event_id().to_owned()],
);
self.0
.insert(bob_mem.event_id().to_owned(), bob_mem.clone());
let charlie_mem = to_pdu_event(
"IMC",
charlie(),
TimelineEventType::RoomMember,
Some(charlie().to_string().as_str()),
member_content_join(),
&[cre, join_rules.event_id().to_owned()],
&[join_rules.event_id().to_owned()],
);
self.0
.insert(charlie_mem.event_id().to_owned(), charlie_mem.clone());
let state_at_bob = [&create_event, &alice_mem, &join_rules, &bob_mem]
.iter()
.map(|ev| {
(
(ev.event_type().clone().into(), ev.state_key().unwrap().into()),
ev.event_id().to_owned(),
)
})
.collect::<StateMap<_>>();
let state_at_charlie = [&create_event, &alice_mem, &join_rules, &charlie_mem]
.iter()
.map(|ev| {
(
(ev.event_type().clone().into(), ev.state_key().unwrap().into()),
ev.event_id().to_owned(),
)
})
.collect::<StateMap<_>>();
let expected = [&create_event, &alice_mem, &join_rules, &bob_mem, &charlie_mem]
.iter()
.map(|ev| {
(
(ev.event_type().clone().into(), ev.state_key().unwrap().into()),
ev.event_id().to_owned(),
)
})
.collect::<StateMap<_>>();
(state_at_bob, state_at_charlie, expected)
}
}
fn event_id(id: &str) -> OwnedEventId {
if id.contains('$') {
return id.try_into().unwrap();
}
format!("${}:foo", id).try_into().unwrap()
}
fn alice() -> &'static UserId { user_id!("@alice:foo") }
fn bob() -> &'static UserId { user_id!("@bob:foo") }
fn charlie() -> &'static UserId { user_id!("@charlie:foo") }
fn ella() -> &'static UserId { user_id!("@ella:foo") }
fn room_id() -> &'static RoomId { room_id!("!test:foo") }
fn member_content_ban() -> Box<RawJsonValue> {
to_raw_json_value(&RoomMemberEventContent::new(MembershipState::Ban)).unwrap()
}
fn member_content_join() -> Box<RawJsonValue> {
to_raw_json_value(&RoomMemberEventContent::new(MembershipState::Join)).unwrap()
}
fn to_pdu_event<S>(
id: &str,
sender: &UserId,
ev_type: TimelineEventType,
state_key: Option<&str>,
content: Box<RawJsonValue>,
auth_events: &[S],
prev_events: &[S],
) -> Pdu
where
S: AsRef<str>,
{
// We don't care if the addition happens in order just that it is atomic
// (each event has its own value)
let ts = SERVER_TIMESTAMP.fetch_add(1, SeqCst);
let id = if id.contains('$') {
id.to_owned()
} else {
format!("${}:foo", id)
};
let auth_events = auth_events
.iter()
.map(AsRef::as_ref)
.map(event_id)
.collect::<Vec<_>>();
let prev_events = prev_events
.iter()
.map(AsRef::as_ref)
.map(event_id)
.collect::<Vec<_>>();
Pdu {
event_id: id.try_into().unwrap(),
room_id: Some(room_id().to_owned()),
sender: sender.to_owned(),
origin_server_ts: ts.try_into().unwrap(),
state_key: state_key.map(Into::into),
kind: ev_type,
content,
origin: None,
redacts: None,
unsigned: None,
auth_events,
prev_events,
depth: uint!(0),
hashes: EventHash { sha256: String::new() },
signatures: None,
}
}
// all graphs start with these input events
#[allow(non_snake_case)]
fn INITIAL_EVENTS() -> HashMap<OwnedEventId, Pdu> {
vec![
to_pdu_event::<&EventId>(
"CREATE",
alice(),
TimelineEventType::RoomCreate,
Some(""),
to_raw_json_value(&json!({ "creator": alice() })).unwrap(),
&[],
&[],
),
to_pdu_event(
"IMA",
alice(),
TimelineEventType::RoomMember,
Some(alice().as_str()),
member_content_join(),
&["CREATE"],
&["CREATE"],
),
to_pdu_event(
"IPOWER",
alice(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100 } })).unwrap(),
&["CREATE", "IMA"],
&["IMA"],
),
to_pdu_event(
"IJR",
alice(),
TimelineEventType::RoomJoinRules,
Some(""),
to_raw_json_value(&RoomJoinRulesEventContent::new(JoinRule::Public)).unwrap(),
&["CREATE", "IMA", "IPOWER"],
&["IPOWER"],
),
to_pdu_event(
"IMB",
bob(),
TimelineEventType::RoomMember,
Some(bob().to_string().as_str()),
member_content_join(),
&["CREATE", "IJR", "IPOWER"],
&["IJR"],
),
to_pdu_event(
"IMC",
charlie(),
TimelineEventType::RoomMember,
Some(charlie().to_string().as_str()),
member_content_join(),
&["CREATE", "IJR", "IPOWER"],
&["IMB"],
),
to_pdu_event::<&EventId>(
"START",
charlie(),
TimelineEventType::RoomTopic,
Some(""),
to_raw_json_value(&json!({})).unwrap(),
&[],
&[],
),
to_pdu_event::<&EventId>(
"END",
charlie(),
TimelineEventType::RoomTopic,
Some(""),
to_raw_json_value(&json!({})).unwrap(),
&[],
&[],
),
]
.into_iter()
.map(|ev| (ev.event_id().to_owned(), ev))
.collect()
}
// all graphs start with these input events
#[allow(non_snake_case)]
fn BAN_STATE_SET() -> HashMap<OwnedEventId, Pdu> {
vec![
to_pdu_event(
"PA",
alice(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
&["CREATE", "IMA", "IPOWER"], // auth_events
&["START"], // prev_events
),
to_pdu_event(
"PB",
alice(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
&["CREATE", "IMA", "IPOWER"],
&["END"],
),
to_pdu_event(
"MB",
alice(),
TimelineEventType::RoomMember,
Some(ella().as_str()),
member_content_ban(),
&["CREATE", "IMA", "PB"],
&["PA"],
),
to_pdu_event(
"IME",
ella(),
TimelineEventType::RoomMember,
Some(ella().as_str()),
member_content_join(),
&["CREATE", "IJR", "PA"],
&["MB"],
),
]
.into_iter()
.map(|ev| (ev.event_id().to_owned(), ev))
.collect()
}

View File

@@ -1,3 +1,4 @@
use ruma::OwnedEventId;
use serde_json::Error as JsonError;
use thiserror::Error;
@@ -14,10 +15,28 @@ pub enum Error {
Unsupported(String),
/// The given event was not found.
#[error("Not found error: {0}")]
#[error("Event not found: {0}")]
NotFound(String),
/// A required event this event depended on could not be fetched,
/// either as it was missing, or because it was invalid
#[error("Failed to fetch required {0} event: {1}")]
DependencyFailed(OwnedEventId, String),
/// Invalid fields in the given PDU.
#[error("Invalid PDU: {0}")]
InvalidPdu(String),
/// This event failed an authorization condition.
#[error("Auth check failed: {0}")]
AuthConditionFailed(String),
/// This event contained multiple auth events of the same type and state
/// key.
#[error("Duplicate auth events: {0}")]
DuplicateAuthEvents(String),
/// This event contains unnecessary auth events.
#[error("Unknown or unnecessary auth events present: {0}")]
UnselectedAuthEvents(String),
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,238 @@
//! Auth checks relevant to any event's `auth_events`.
//!
//! See: https://spec.matrix.org/v1.16/rooms/v12/#authorization-rules
use std::collections::{HashMap, HashSet};
use ruma::{
EventId, OwnedEventId, RoomId, UserId,
events::{
StateEventType, TimelineEventType,
room::member::{MembershipState, RoomMemberEventContent, ThirdPartyInvite},
},
};
use crate::{Event, EventTypeExt, Pdu, RoomVersion, matrix::StateKey, state_res::Error, warn};
/// For the given event `kind` what are the relevant auth events that are needed
/// to authenticate this `content`.
///
/// # Errors
///
/// This function will return an error if the supplied `content` is not a JSON
/// object.
pub fn auth_types_for_event(
room_version: &RoomVersion,
event_type: &TimelineEventType,
state_key: Option<&StateKey>,
sender: &UserId,
member_content: Option<RoomMemberEventContent>,
) -> serde_json::Result<Vec<(StateEventType, StateKey)>> {
if event_type == &TimelineEventType::RoomCreate {
// Create events never have auth events
return Ok(vec![]);
}
let mut auth_types = if room_version.room_ids_as_hashes {
vec![
StateEventType::RoomMember.with_state_key(sender.as_str()),
StateEventType::RoomPowerLevels.with_state_key(""),
]
} else {
// For room versions that do not use room IDs as hashes, include the
// RoomCreate event as an auth event.
vec![
StateEventType::RoomMember.with_state_key(sender.as_str()),
StateEventType::RoomPowerLevels.with_state_key(""),
StateEventType::RoomCreate.with_state_key(""),
]
};
if event_type == &TimelineEventType::RoomMember {
let member_content =
member_content.expect("member_content must be provided for RoomMember events");
// Include the target's membership (if available)
auth_types.push((
StateEventType::RoomMember,
state_key
.expect("state_key must be provided for RoomMember events")
.to_owned(),
));
if matches!(
member_content.membership,
MembershipState::Join | MembershipState::Invite | MembershipState::Knock
) {
// Include the join rules
auth_types.push(StateEventType::RoomJoinRules.with_state_key(""));
}
if matches!(member_content.membership, MembershipState::Invite) {
// If this is an invite, include the third party invite if it exists
if let Some(ThirdPartyInvite { signed, .. }) = member_content.third_party_invite {
auth_types
.push(StateEventType::RoomThirdPartyInvite.with_state_key(signed.token));
}
}
if matches!(member_content.membership, MembershipState::Join)
&& room_version.restricted_join_rules
{
// If this is a restricted join, include the authorizing user's membership
if let Some(authorizing_user) = member_content.join_authorized_via_users_server {
auth_types
.push(StateEventType::RoomMember.with_state_key(authorizing_user.as_str()));
}
}
}
Ok(auth_types)
}
/// Checks for duplicate auth events in the `auth_events` field of an event.
/// Note: the caller should already have all of the auth events fetched.
///
/// If there are multiple auth events of the same type and state key, this
/// returns an error. Otherwise, it returns a map of (type, state_key) to the
/// corresponding auth event.
pub async fn check_duplicate_auth_events<FE>(
auth_events: &[OwnedEventId],
fetch_event: FE,
) -> Result<HashMap<(StateEventType, StateKey), Pdu>, Error>
where
FE: AsyncFn(&EventId) -> Result<Option<Pdu>, Error>,
{
let mut seen: HashMap<(StateEventType, StateKey), Pdu> = HashMap::new();
// Considering all of the event's auth events:
for auth_event_id in auth_events {
if let Ok(Some(auth_event)) = fetch_event(auth_event_id).await {
let event_type = auth_event.kind();
// If this is not a state event, reject it.
let Some(state_key) = &auth_event.state_key() else {
return Err(Error::InvalidPdu(format!(
"Auth event {:?} is not a state event",
auth_event_id
)));
};
let type_key_pair: (StateEventType, StateKey) =
event_type.clone().with_state_key(state_key.clone());
// If there are duplicate entries for a given type and state_key pair, reject.
if seen.contains_key(&type_key_pair) {
return Err(Error::DuplicateAuthEvents(format!(
"({:?},\"{:?}\")",
event_type, state_key
)));
}
seen.insert(type_key_pair, auth_event);
} else {
return Err(Error::NotFound(auth_event_id.as_str().to_owned()));
}
}
Ok(seen)
}
// Checks that the event does not refer to any auth events that it does not need
// to.
pub fn check_unnecessary_auth_events(
auth_events: &HashSet<(StateEventType, StateKey)>,
expected: &Vec<(StateEventType, StateKey)>,
) -> Result<(), Error> {
// If there are entries whose type and state_key don't match those specified by
// the auth events selection algorithm described in the server specification,
// reject.
let remaining = auth_events
.iter()
.filter(|key| !expected.contains(key))
.collect::<HashSet<_>>();
if !remaining.is_empty() {
return Err(Error::UnselectedAuthEvents(format!("{:?}", remaining)));
}
Ok(())
}
// Checks that all provided auth events were not rejected previously.
//
// TODO: this is currently a no-op and always returns Ok(()).
pub fn check_all_auth_events_accepted(
_auth_events: &HashMap<(StateEventType, StateKey), Pdu>,
) -> Result<(), Error> {
Ok(())
}
// Checks that all auth events are from the same room as the event being
// validated.
pub fn check_auth_same_room(auth_events: &Vec<Pdu>, room_id: &RoomId) -> bool {
for auth_event in auth_events {
if let Some(auth_room_id) = &auth_event.room_id() {
if auth_room_id.as_str() != room_id.as_str() {
warn!(
auth_event_id=%auth_event.event_id(),
"Auth event room id {} does not match expected room id {}",
auth_room_id,
room_id
);
return false;
}
} else {
warn!(auth_event_id=%auth_event.event_id(), "Auth event has no room_id");
return false;
}
}
true
}
/// Performs all auth event checks for the given event.
pub async fn check_auth_events<FE>(
event: &Pdu,
room_id: &RoomId,
room_version: &RoomVersion,
fetch_event: &FE,
) -> Result<HashMap<(StateEventType, StateKey), Pdu>, Error>
where
FE: AsyncFn(&EventId) -> Result<Option<Pdu>, Error>,
{
// If there are duplicate entries for a given type and state_key pair, reject.
let auth_events_map = check_duplicate_auth_events(&event.auth_events, fetch_event).await?;
let auth_events_set: HashSet<(StateEventType, StateKey)> =
auth_events_map.keys().cloned().collect();
// If there are entries whose type and state_key dont match those specified by
// the auth events selection algorithm described in the server specification,
// reject.
let member_event_content = match event.kind() {
| TimelineEventType::RoomMember =>
Some(event.get_content::<RoomMemberEventContent>().map_err(|e| {
Error::InvalidPdu(format!("Failed to parse m.room.member content: {}", e))
})?),
| _ => None,
};
let expected_auth_events = auth_types_for_event(
room_version,
event.kind(),
event.state_key.as_ref(),
event.sender(),
member_event_content,
)?;
if let Err(e) = check_unnecessary_auth_events(&auth_events_set, &expected_auth_events) {
return Err(e);
}
// If there are entries which were themselves rejected under the checks
// performed on receipt of a PDU, reject.
if let Err(e) = check_all_auth_events_accepted(&auth_events_map) {
return Err(e);
}
// If any event in auth_events has a room_id which does not match that of the
// event being authorised, reject.
let auth_event_refs: Vec<Pdu> = auth_events_map.values().cloned().collect();
if !check_auth_same_room(&auth_event_refs, room_id) {
return Err(Error::InvalidPdu(
"One or more auth events are from a different room".to_owned(),
));
}
Ok(auth_events_map)
}

View File

@@ -0,0 +1,113 @@
//! Context for event authorisation checks
use ruma::{
Int, OwnedUserId, UserId,
events::{
StateEventType,
room::{create::RoomCreateEventContent, power_levels::RoomPowerLevelsEventContent},
},
};
use crate::{Event, EventTypeExt, Pdu, RoomVersion, matrix::StateKey, state_res::Error};
pub enum UserPower {
/// Creator indicates this user should be granted a power level above all.
Creator,
/// Standard indicates power levels should be used to determine rank.
Standard,
}
impl PartialEq for UserPower {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
| (UserPower::Creator, UserPower::Creator) => true,
| (UserPower::Standard, UserPower::Standard) => true,
| _ => false,
}
}
}
/// Get the creators of the room.
/// If this room only supports one creator, a vec of one will be returned.
/// If multiple creators are supported, all will be returned, with the
/// m.room.create sender first.
pub async fn calculate_creators<FS>(
room_version: &RoomVersion,
fetch_state: FS,
) -> Result<Vec<OwnedUserId>, Error>
where
FS: AsyncFn((StateEventType, StateKey)) -> Result<Option<Pdu>, Error>,
{
let create_event = fetch_state(StateEventType::RoomCreate.with_state_key(""))
.await?
.ok_or_else(|| Error::InvalidPdu("Room create event not found".to_owned()))?;
let content = create_event
.get_content::<RoomCreateEventContent>()
.map_err(|e| {
Error::InvalidPdu(format!("Room create event has invalid content: {}", e))
})?;
if room_version.explicitly_privilege_room_creators {
let mut creators = vec![create_event.sender().to_owned()];
if let Some(additional) = content.additional_creators {
for user_id in additional {
if !creators.contains(&user_id) {
creators.push(user_id);
}
}
}
Ok(creators)
} else if room_version.use_room_create_sender {
Ok(vec![create_event.sender().to_owned()])
} else {
// Have to check the event content
#[allow(deprecated)]
if let Some(creator) = content.creator {
Ok(vec![creator])
} else {
Err(Error::InvalidPdu("Room create event missing creator field".to_owned()))
}
}
}
/// Rank fetches the creatorship and power level of the target user
///
/// Returns (UserPower, power_level, Option<RoomPowerLevelsEventContent>)
/// If UserPower::Creator is returned, the power_level and
/// RoomPowerLevelsEventContent will be meaningless and can be ignored.
pub async fn get_rank<FS>(
room_version: &RoomVersion,
fetch_state: &FS,
user_id: &UserId,
) -> Result<(UserPower, Int, Option<RoomPowerLevelsEventContent>), Error>
where
FS: AsyncFn((StateEventType, StateKey)) -> Result<Option<Pdu>, Error>,
{
let creators = calculate_creators(room_version, &fetch_state).await?;
if creators.contains(&user_id.to_owned()) && room_version.explicitly_privilege_room_creators {
return Ok((UserPower::Creator, Int::MAX, None));
}
let power_levels = fetch_state(StateEventType::RoomPowerLevels.with_state_key("")).await?;
if let Some(power_levels) = power_levels {
let power_levels = power_levels
.get_content::<RoomPowerLevelsEventContent>()
.map_err(|e| {
Error::InvalidPdu(format!("m.room.power_levels event has invalid content: {}", e))
})?;
Ok((
UserPower::Standard,
*power_levels
.users
.get(user_id)
.unwrap_or(&power_levels.users_default),
Some(power_levels),
))
} else {
// No power levels event, use defaults
if creators[0] == user_id {
return Ok((UserPower::Creator, Int::MAX, None));
}
Ok((UserPower::Standard, Int::from(0), None))
}
}

View File

@@ -0,0 +1,97 @@
//! Auth checks relevant to the `m.room.create` event specifically.
//!
//! See: https://spec.matrix.org/v1.16/rooms/v12/#authorization-rules
use ruma::{OwnedUserId, RoomVersionId, events::room::create::RoomCreateEventContent};
use serde::Deserialize;
use serde_json::from_str;
use crate::{Event, Pdu, RoomVersion, state_res::Error, trace};
// A raw representation of the create event content, for initial parsing.
// This allows us to extract fields without fully validating the event first.
#[derive(Deserialize)]
struct RawCreateContent {
creator: Option<String>,
room_version: Option<String>,
additional_creators: Option<Vec<String>>,
}
// Check whether an `m.room.create` event is valid.
// This ensures that:
//
// 1. The event has no `prev_events`
// 2. If the version disallows it, the event has no `room_id` present.
// 3. If the room version is present and recognised, otherwise assume invalid.
// 4. If the room version supports it, `additional_creators` is populated with
// valid user IDs.
// 5. If the room version supports it, `creator` is populated AND is a valid
// user ID.
// 6. Otherwise, this event is valid.
//
// The fully deserialized `RoomCreateEventContent` is returned for further calls
// to other checks.
pub fn check_room_create(event: &Pdu) -> Result<RoomCreateEventContent, Error> {
// Check 1: The event has no `prev_events`
if !event.prev_events.is_empty() {
return Err(Error::InvalidPdu("m.room.create event has prev_events".to_owned()));
}
let create_content = from_str::<RawCreateContent>(event.content().get())?;
// Note: Here we attempt to both load the raw room version string and validate
// it, and then cast it to the room features. If either step fails, we return
// an unsupported error. If the room version is missing, it defaults to "1",
// which we also do not support.
//
// This performs check 3, which then allows us to perform check 2.
let room_version = if let Some(raw_room_version) = create_content.room_version {
trace!("Parsing and interpreting room version: {}", raw_room_version);
let room_version_id = RoomVersionId::try_from(raw_room_version.as_str())
.map_err(|_| Error::Unsupported(raw_room_version))?;
RoomVersion::new(&room_version_id)
.map_err(|_| Error::Unsupported(room_version_id.as_str().to_owned()))?
} else {
return Err(Error::Unsupported("1".to_owned()));
};
// Check 2: If the version disallows it, the event has no `room_id` present.
if room_version.room_ids_as_hashes && event.room_id.is_some() {
return Err(Error::InvalidPdu(
"m.room.create event has room_id but room version disallows it".to_owned(),
));
}
// Check 4: If the room version supports it, `additional_creators` is populated
// with valid user IDs.
if room_version.explicitly_privilege_room_creators {
if let Some(additional_creators) = create_content.additional_creators {
for creator in additional_creators {
trace!("Validating additional creator user ID: {}", creator);
if OwnedUserId::parse(&creator).is_err() {
return Err(Error::InvalidPdu(format!(
"Invalid user ID in additional_creators: {creator}"
)));
}
}
}
}
// Check 5: If the room version supports it, `creator` is populated AND is a
// valid user ID.
if !room_version.use_room_create_sender {
if let Some(creator) = create_content.creator {
trace!("Validating creator user ID: {}", creator);
if OwnedUserId::parse(&creator).is_err() {
return Err(Error::InvalidPdu(format!("Invalid user ID in creator: {creator}")));
}
} else {
return Err(Error::InvalidPdu(
"m.room.create event missing creator field".to_owned(),
));
}
}
// Deserialise into the full create event for future checks.
Ok(from_str::<RoomCreateEventContent>(event.content().get())?)
}

View File

@@ -0,0 +1,650 @@
use ruma::{
EventId, OwnedUserId, RoomVersionId,
events::{
StateEventType, TimelineEventType,
room::{create::RoomCreateEventContent, member::MembershipState},
},
int,
serde::Raw,
};
use serde::{Deserialize, de::IgnoredAny};
use serde_json::from_str as from_json_str;
use crate::{
Event, EventTypeExt, Pdu, RoomVersion, debug, error,
matrix::StateKey,
state_res::{
error::Error,
event_auth::{
auth_events::check_auth_events,
context::{UserPower, calculate_creators, get_rank},
create_event::check_room_create,
member_event::check_member_event,
power_levels::check_power_levels,
},
},
trace, warn,
};
// FIXME: field extracting could be bundled for `content`
#[derive(Deserialize)]
struct GetMembership {
membership: MembershipState,
}
#[derive(Deserialize, Debug)]
struct RoomMemberContentFields {
membership: Option<Raw<MembershipState>>,
join_authorised_via_users_server: Option<Raw<OwnedUserId>>,
}
#[derive(Deserialize)]
struct RoomCreateContentFields {
room_version: Option<Raw<RoomVersionId>>,
creator: Option<Raw<IgnoredAny>>,
additional_creators: Option<Vec<Raw<OwnedUserId>>>,
#[serde(rename = "m.federate", default = "ruma::serde::default_true")]
federate: bool,
}
/// Authenticate the incoming `event`.
///
/// The steps of authentication are:
///
/// * check that the event is being authenticated for the correct room
/// * then there are checks for specific event types
///
/// The `fetch_state` closure should gather state from a state snapshot. We need
/// to know if the event passes auth against some state not a recursive
/// collection of auth_events fields.
#[tracing::instrument(
skip_all,
fields(
event_id = incoming_event.event_id().as_str(),
event_type = ?incoming_event.event_type().to_string()
)
)]
#[allow(clippy::suspicious_operation_groupings)]
pub async fn auth_check<FE, FS>(
room_version: &RoomVersion,
incoming_event: &Pdu,
fetch_event: &FE,
fetch_state: &FS,
create_event: Option<&Pdu>,
) -> Result<bool, Error>
where
FE: AsyncFn(&EventId) -> Result<Option<Pdu>, Error>,
FS: AsyncFn((StateEventType, StateKey)) -> Result<Option<Pdu>, Error>,
{
debug!("auth_check beginning");
let sender = incoming_event.sender();
// Since v1, If type is m.room.create:
if *incoming_event.event_type() == TimelineEventType::RoomCreate {
debug!("start m.room.create check");
if let Err(e) = check_room_create(incoming_event) {
warn!("m.room.create event has been rejected: {}", e);
return Ok(false);
}
debug!("m.room.create event was allowed");
return Ok(true);
}
let Some(create_event) = create_event else {
error!("no create event provided for auth check");
return Err(Error::InvalidPdu("missing create event".to_owned()));
};
// TODO: we need to know if events have previously been rejected or soft failed
// For now, we'll just assume the create_event is valid.
let create_content = from_json_str::<RoomCreateEventContent>(create_event.content().get())
.expect("provided create event must be valid");
// Since v12, If the events room_id is not an event ID for an accepted (not
// rejected) m.room.create event, with the sigil ! instead of $, reject.
if room_version.room_ids_as_hashes {
let calculated_room_id = create_event.event_id().as_str().replace('$', "!");
if let Some(claimed_room_id) = create_event.room_id() {
if claimed_room_id.as_str() != calculated_room_id {
warn!(
expected = %calculated_room_id,
received = %claimed_room_id,
"event's room ID does not match the hash of the m.room.create event ID"
);
return Ok(false);
}
} else {
warn!("event is missing a room ID");
return Ok(false);
}
}
let room_id = incoming_event.room_id().expect("event must have a room ID");
let auth_map =
match check_auth_events(incoming_event, room_id, &room_version, fetch_event).await {
| Ok(map) => map,
| Err(e) => {
warn!("event's auth events are invalid: {}", e);
return Ok(false);
},
};
// Considering the event's auth_events
// Since v1, If the content of the m.room.create event in the room state has the
// property m.federate set to false, and the sender domain of the event does
// not match the sender domain of the create event, reject.
if !create_content.federate {
if create_event.sender().server_name() != incoming_event.sender().server_name() {
warn!(
sender = %incoming_event.sender(),
create_sender = %create_event.sender(),
"room is not federated and event's sender domain does not match create event's sender domain"
);
return Ok(false);
}
}
// From v1 to v5, If type is m.room.aliases
if room_version.special_case_aliases_auth
&& *incoming_event.event_type() == TimelineEventType::RoomAliases
{
if let Some(state_key) = incoming_event.state_key() {
// If sender's domain doesn't matches state_key, reject
if state_key != sender.server_name().as_str() {
warn!("state_key does not match sender");
return Ok(false);
}
// Otherwise, allow
return Ok(true);
}
// If event has no state_key, reject.
warn!("m.room.alias event has no state key");
return Ok(false);
}
// From v1, If type is m.room.member
if *incoming_event.event_type() == TimelineEventType::RoomMember {
if let Err(e) =
check_member_event(&room_version, incoming_event, fetch_event, fetch_state).await
{
warn!("m.room.member event has been rejected: {}", e);
return Ok(false);
}
}
// From v1, If the sender's current membership state is not join, reject
let sender_member_event =
match auth_map.get(&StateEventType::RoomMember.with_state_key(sender.as_str())) {
| Some(ev) => ev,
| None => {
warn!(
%sender,
"sender is not joined - no membership event found for sender in auth events"
);
return Ok(false);
},
};
let sender_membership_event_content: RoomMemberContentFields =
from_json_str(sender_member_event.content().get())?;
let Some(membership_state) = sender_membership_event_content.membership else {
warn!(
?sender_membership_event_content,
"Sender membership event content missing membership field"
);
return Err(Error::InvalidPdu("Missing membership field".to_owned()));
};
let membership_state = membership_state.deserialize()?;
if membership_state != MembershipState::Join {
warn!(
%sender,
?membership_state,
"sender cannot send events without being joined to the room"
);
return Ok(false);
}
// From v1, If type is m.room.third_party_invite
let (rank, sender_pl, pl_evt) = get_rank(&room_version, fetch_state, sender).await?;
// Allow if and only if sender's current power level is greater than
// or equal to the invite level
if *incoming_event.event_type() == TimelineEventType::RoomThirdPartyInvite {
if rank == UserPower::Creator {
trace!("sender is room creator, allowing m.room.third_party_invite");
return Ok(true);
}
let invite_level = match &pl_evt {
| Some(power_levels) => power_levels.invite,
| None => int!(0),
};
if sender_pl < invite_level {
warn!(
%sender,
has=%sender_pl,
required=%invite_level,
"sender cannot send invites in this room"
);
return Ok(false);
}
debug!("m.room.third_party_invite event was allowed");
return Ok(true);
}
// Since v1, if the event types required power level is greater than the
// senders power level, reject.
let required_level = match &pl_evt {
| Some(power_levels) => power_levels
.events
.get(incoming_event.kind())
.unwrap_or_else(|| {
if incoming_event.state_key.is_some() {
&power_levels.state_default
} else {
&power_levels.events_default
}
}),
| None => &int!(0),
};
if rank != UserPower::Creator && sender_pl < *required_level {
warn!(
%sender,
has=%sender_pl,
required=%required_level,
"sender does not have enough power level to send this event"
);
return Ok(false);
}
// Since v1, If the event has a state_key that starts with an @ and does not
// match the sender, reject.
if let Some(state_key) = incoming_event.state_key() {
if state_key.starts_with('@') && state_key != sender.as_str() {
warn!(
%sender,
%state_key,
"event's state key starts with @ and does not match sender"
);
return Ok(false);
}
}
// Since v1, If type is m.room.power_levels
if *incoming_event.event_type() == TimelineEventType::RoomPowerLevels {
let creators = calculate_creators(&room_version, fetch_state).await?;
if let Err(e) =
check_power_levels(&room_version, incoming_event, pl_evt.as_ref(), creators).await
{
warn!(
%sender,
"m.room.power_levels event has been rejected: {}", e
);
return Ok(false);
}
}
// From v1 to v2: If type is m.room.redaction:
// If the senders power level is greater than or equal to the redact level,
// allow.
// If the domain of the event_id of the event being redacted is the same as the
// domain of the event_id of the m.room.redaction, allow.
// Otherwise, reject.
if room_version.extra_redaction_checks {
// We'll panic here, since while we don't theoretically support the room
// versions that require this, we don't want to incorrectly permit an event
// that should be rejected in this theoretically impossible scenario.
unreachable!(
"continuwuity does not support room versions that require extra redaction checks"
);
}
debug!("allowing event passed all checks");
Ok(true)
}
#[cfg(test)]
mod tests {
use ruma::events::{
StateEventType, TimelineEventType,
room::{
join_rules::{
AllowRule, JoinRule, Restricted, RoomJoinRulesEventContent, RoomMembership,
},
member::{MembershipState, RoomMemberEventContent},
},
};
use serde_json::value::to_raw_value as to_raw_json_value;
use crate::{
matrix::{Event, EventTypeExt, Pdu as PduEvent},
state_res::{
RoomVersion, StateMap,
event_auth::{
iterative_auth_checks::valid_membership_change, valid_membership_change,
},
test_utils::{
INITIAL_EVENTS, INITIAL_EVENTS_CREATE_ROOM, alice, charlie, ella, event_id,
member_content_ban, member_content_join, room_id, to_pdu_event,
},
},
};
#[test]
fn test_ban_pass() {
let _ = tracing::subscriber::set_default(
tracing_subscriber::fmt().with_test_writer().finish(),
);
let events = INITIAL_EVENTS();
let auth_events = events
.values()
.map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.clone()))
.collect::<StateMap<_>>();
let requester = to_pdu_event(
"HELLO",
alice(),
TimelineEventType::RoomMember,
Some(charlie().as_str()),
member_content_ban(),
&[],
&["IMC"],
);
let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned();
let target_user = charlie();
let sender = alice();
assert!(
valid_membership_change(
&RoomVersion::V6,
target_user,
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender,
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester,
None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
None,
&MembershipState::Leave,
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
)
.unwrap()
);
}
#[test]
fn test_join_non_creator() {
let _ = tracing::subscriber::set_default(
tracing_subscriber::fmt().with_test_writer().finish(),
);
let events = INITIAL_EVENTS_CREATE_ROOM();
let auth_events = events
.values()
.map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.clone()))
.collect::<StateMap<_>>();
let requester = to_pdu_event(
"HELLO",
charlie(),
TimelineEventType::RoomMember,
Some(charlie().as_str()),
member_content_join(),
&["CREATE"],
&["CREATE"],
);
let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned();
let target_user = charlie();
let sender = charlie();
assert!(
!valid_membership_change(
&RoomVersion::V6,
target_user,
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender,
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester,
None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
None,
&MembershipState::Leave,
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
)
.unwrap()
);
}
#[test]
fn test_join_creator() {
let _ = tracing::subscriber::set_default(
tracing_subscriber::fmt().with_test_writer().finish(),
);
let events = INITIAL_EVENTS_CREATE_ROOM();
let auth_events = events
.values()
.map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.clone()))
.collect::<StateMap<_>>();
let requester = to_pdu_event(
"HELLO",
alice(),
TimelineEventType::RoomMember,
Some(alice().as_str()),
member_content_join(),
&["CREATE"],
&["CREATE"],
);
let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned();
let target_user = alice();
let sender = alice();
assert!(
valid_membership_change(
&RoomVersion::V6,
target_user,
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender,
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester,
None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
None,
&MembershipState::Leave,
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
)
.unwrap()
);
}
#[test]
fn test_ban_fail() {
let _ = tracing::subscriber::set_default(
tracing_subscriber::fmt().with_test_writer().finish(),
);
let events = INITIAL_EVENTS();
let auth_events = events
.values()
.map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.clone()))
.collect::<StateMap<_>>();
let requester = to_pdu_event(
"HELLO",
charlie(),
TimelineEventType::RoomMember,
Some(alice().as_str()),
member_content_ban(),
&[],
&["IMC"],
);
let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned();
let target_user = alice();
let sender = charlie();
assert!(
!valid_membership_change(
&RoomVersion::V6,
target_user,
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender,
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester,
None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
None,
&MembershipState::Leave,
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
)
.unwrap()
);
}
#[test]
fn test_restricted_join_rule() {
let _ = tracing::subscriber::set_default(
tracing_subscriber::fmt().with_test_writer().finish(),
);
let mut events = INITIAL_EVENTS();
*events.get_mut(&event_id("IJR")).unwrap() = to_pdu_event(
"IJR",
alice(),
TimelineEventType::RoomJoinRules,
Some(""),
to_raw_json_value(&RoomJoinRulesEventContent::new(JoinRule::Restricted(
Restricted::new(vec![AllowRule::RoomMembership(RoomMembership::new(
room_id().to_owned(),
))]),
)))
.unwrap(),
&["CREATE", "IMA", "IPOWER"],
&["IPOWER"],
);
let mut member = RoomMemberEventContent::new(MembershipState::Join);
member.join_authorized_via_users_server = Some(alice().to_owned());
let auth_events = events
.values()
.map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.clone()))
.collect::<StateMap<_>>();
let requester = to_pdu_event(
"HELLO",
ella(),
TimelineEventType::RoomMember,
Some(ella().as_str()),
to_raw_json_value(&RoomMemberEventContent::new(MembershipState::Join)).unwrap(),
&["CREATE", "IJR", "IPOWER", "new"],
&["new"],
);
let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned();
let target_user = ella();
let sender = ella();
assert!(
valid_membership_change(
&RoomVersion::V9,
target_user,
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender,
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester,
None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
Some(alice()),
&MembershipState::Join,
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
)
.unwrap()
);
assert!(
!valid_membership_change(
&RoomVersion::V9,
target_user,
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender,
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester,
None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
Some(ella()),
&MembershipState::Leave,
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
)
.unwrap()
);
}
#[test]
fn test_knock() {
let _ = tracing::subscriber::set_default(
tracing_subscriber::fmt().with_test_writer().finish(),
);
let mut events = INITIAL_EVENTS();
*events.get_mut(&event_id("IJR")).unwrap() = to_pdu_event(
"IJR",
alice(),
TimelineEventType::RoomJoinRules,
Some(""),
to_raw_json_value(&RoomJoinRulesEventContent::new(JoinRule::Knock)).unwrap(),
&["CREATE", "IMA", "IPOWER"],
&["IPOWER"],
);
let auth_events = events
.values()
.map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.clone()))
.collect::<StateMap<_>>();
let requester = to_pdu_event(
"HELLO",
ella(),
TimelineEventType::RoomMember,
Some(ella().as_str()),
to_raw_json_value(&RoomMemberEventContent::new(MembershipState::Knock)).unwrap(),
&[],
&["IMC"],
);
let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned();
let target_user = ella();
let sender = ella();
assert!(
valid_membership_change(
&RoomVersion::V7,
target_user,
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender,
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester,
None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
None,
&MembershipState::Leave,
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
)
.unwrap()
);
}
}

View File

@@ -0,0 +1,422 @@
//! Auth checks relevant to the `m.room.member` event specifically.
//!
//! See: https://spec.matrix.org/v1.16/rooms/v12/#authorization-rules
use ruma::{
EventId, OwnedUserId, UserId,
events::{
StateEventType,
room::{
join_rules::{JoinRule, RoomJoinRulesEventContent},
third_party_invite::{PublicKey, RoomThirdPartyInviteEventContent},
},
},
serde::Base64,
signatures::{PublicKeyMap, PublicKeySet, verify_json},
};
use crate::{
Event, EventTypeExt, Pdu, RoomVersion,
matrix::StateKey,
state_res::{
Error,
event_auth::context::{UserPower, get_rank},
},
utils::to_canonical_object,
};
#[derive(serde::Deserialize, Default)]
struct PartialMembershipObject {
membership: Option<String>,
join_authorized_via_users_server: Option<OwnedUserId>,
third_party_invite: Option<serde_json::Value>,
}
/// Fetches the membership *content* of the target.
/// If there is not one, an empty leave membership is returned.
async fn fetch_membership<FS>(
fetch_state: &FS,
target: &UserId,
) -> Result<PartialMembershipObject, Error>
where
FS: AsyncFn((StateEventType, StateKey)) -> Result<Option<Pdu>, Error>,
{
fetch_state(StateEventType::RoomMember.with_state_key(target.as_str()))
.await
.map(|pdu| {
if let Some(ev) = pdu {
ev.get_content::<PartialMembershipObject>().map_err(|e| {
Error::InvalidPdu(format!("m.room.member event has invalid content: {}", e))
})
} else {
Ok(PartialMembershipObject {
membership: Some("leave".to_owned()),
..Default::default()
})
}
})?
}
async fn check_join_event<FE, FS>(
room_version: &RoomVersion,
event: &Pdu,
membership: &PartialMembershipObject,
target: &UserId,
fetch_event: &FE,
fetch_state: &FS,
) -> Result<(), Error>
where
FE: AsyncFn(&EventId) -> Result<Option<Pdu>, Error>,
FS: AsyncFn((StateEventType, StateKey)) -> Result<Option<Pdu>, Error>,
{
// 3.1: If the only previous event is an m.room.create and the state_key is the
// sender of the m.room.create, allow.
if event.prev_events.len() == 1 {
let only_prev = fetch_event(&event.prev_events[0]).await?;
if let Some(prev_event) = only_prev {
let k = prev_event.event_type().with_state_key("");
if k.0 == StateEventType::RoomCreate && k.1.as_str() == event.sender().as_str() {
return Ok(());
}
} else {
return Err(Error::DependencyFailed(
event.prev_events[0].to_owned(),
"Previous event not found when checking join event".to_owned(),
));
}
}
// 3.2: If the sender does not match state_key, reject.
if event.sender() != target {
return Err(Error::AuthConditionFailed(
"m.room.member join event sender does not match state_key".to_owned(),
));
}
let prev_membership = if let Some(ev) =
fetch_state(StateEventType::RoomMember.with_state_key(target.as_str())).await?
{
Some(ev.get_content::<PartialMembershipObject>().map_err(|e| {
Error::InvalidPdu(format!("Previous m.room.member event has invalid content: {}", e))
})?)
} else {
None
};
let join_rule_content =
if let Some(jr) = fetch_state(StateEventType::RoomJoinRules.with_state_key("")).await? {
jr.get_content::<RoomJoinRulesEventContent>().map_err(|e| {
Error::InvalidPdu(format!("m.room.join_rules event has invalid content: {}", e))
})?
} else {
// Default to invite if no join rules event is present.
RoomJoinRulesEventContent { join_rule: JoinRule::Private }
};
// 3.3: If the sender is banned, reject.
let prev_member = if let Some(prev_content) = &prev_membership {
if let Some(membership) = &prev_content.membership {
if membership == "ban" {
return Err(Error::AuthConditionFailed(
"m.room.member join event sender is banned".to_owned(),
));
}
membership
} else {
"leave"
}
} else {
"leave"
};
// 3.4: If the join_rule is invite or knock then allow if membership
// state is invite or join.
// 3.5: If the join_rule is restricted or knock_restricted:
// 3.5.1: If membership state is join or invite, allow.
match join_rule_content.join_rule {
| JoinRule::Invite | JoinRule::Knock => {
if prev_member == "invite" || prev_member == "join" {
return Ok(());
}
Err(Error::AuthConditionFailed(
"m.room.member join event not invited under invite/knock join rule".to_owned(),
))
},
| JoinRule::Restricted(_) | JoinRule::KnockRestricted(_) => {
// 3.5.2: If the join_authorised_via_users_server key in content is not a user
// with sufficient permission to invite other users or is not a joined
// member of the room, reject.
if prev_member == "invite" || prev_member == "join" {
return Ok(());
}
let join_authed_by = membership.join_authorized_via_users_server.as_ref();
if let Some(user_id) = join_authed_by {
let rank = get_rank(&room_version, fetch_state, user_id).await?;
if rank.0 == UserPower::Standard {
// This user is not a creator, check that they have
// sufficient power level
if rank.1 < rank.2.unwrap().invite {
return Err(Error::InvalidPdu(
"m.room.member join event join_authorised_via_users_server does not \
have sufficient power level to invite"
.to_owned(),
));
}
}
// Check that the user is a joined member of the room
if let Some(state_event) =
fetch_state(StateEventType::RoomMember.with_state_key(user_id.as_str()))
.await?
{
let state_content = state_event
.get_content::<PartialMembershipObject>()
.map_err(|e| {
Error::InvalidPdu(format!(
"m.room.member event has invalid content: {}",
e
))
})?;
if let Some(state_membership) = &state_content.membership {
if state_membership == "join" {
return Ok(());
}
}
}
} else {
return Err(Error::AuthConditionFailed(
"m.room.member join event missing join_authorised_via_users_server"
.to_owned(),
));
}
// 3.5.3: Otherwise, allow
return Ok(());
},
| JoinRule::Public => return Ok(()),
| _ => Err(Error::AuthConditionFailed(format!(
"unknown join rule: {:?}",
join_rule_content.join_rule
)))?,
}
}
/// Checks a third-party invite is valid.
async fn check_third_party_invite(
target_current_membership: PartialMembershipObject,
raw_third_party_invite: &serde_json::Value,
target: &UserId,
event: &Pdu,
fetch_state: impl AsyncFn((StateEventType, StateKey)) -> Result<Option<Pdu>, Error>,
) -> Result<(), Error> {
// 4.1.1: If target user is banned, reject.
if target_current_membership
.membership
.is_some_and(|m| m == "ban")
{
return Err(Error::AuthConditionFailed("invite target is banned".to_owned()));
}
// 4.1.2: If content.third_party_invite does not have a signed property, reject.
let signed = raw_third_party_invite.get("signed").ok_or_else(|| {
Error::AuthConditionFailed(
"invite event third_party_invite missing signed property".to_owned(),
)
})?;
// 4.2.3: If signed does not have mxid and token properties, reject.
let mxid = signed.get("mxid").and_then(|v| v.as_str()).ok_or_else(|| {
Error::AuthConditionFailed(
"invite event third_party_invite signed missing/invalid mxid property".to_owned(),
)
})?;
let token = signed
.get("token")
.and_then(|v| v.as_str())
.ok_or_else(|| {
Error::AuthConditionFailed(
"invite event third_party_invite signed missing token property".to_owned(),
)
})?;
// 4.2.4: If mxid does not match state_key, reject.
if mxid != target.as_str() {
return Err(Error::AuthConditionFailed(
"invite event third_party_invite signed mxid does not match state_key".to_owned(),
));
}
// 4.2.5: If there is no m.room.third_party_invite event in the room
// state matching the token, reject.
let Some(third_party_invite_event) =
fetch_state(StateEventType::RoomThirdPartyInvite.with_state_key(token)).await?
else {
return Err(Error::AuthConditionFailed(
"invite event third_party_invite token has no matching m.room.third_party_invite"
.to_owned(),
));
};
// 4.2.6: If sender does not match sender of the m.room.third_party_invite,
// reject.
if third_party_invite_event.sender() != event.sender() {
return Err(Error::AuthConditionFailed(
"invite event sender does not match m.room.third_party_invite sender".to_owned(),
));
}
// 4.2.7: If any signature in signed matches any public key in the
// m.room.third_party_invite event, allow. The public keys are in
// content of m.room.third_party_invite as:
// 1. A single public key in the public_key property.
// 2. A list of public keys in the public_keys property.
let tpi_content = third_party_invite_event
.get_content::<RoomThirdPartyInviteEventContent>()
.or_else(|_| {
Err(Error::InvalidPdu(
"m.room.third_party_invite event has invalid content".to_owned(),
))
})?;
let mut public_keys = tpi_content.public_keys.unwrap_or_default();
public_keys.push(PublicKey {
public_key: tpi_content.public_key,
key_validity_url: None,
});
let signatures = signed
.get("signatures")
.and_then(|v| v.as_object())
.ok_or_else(|| {
Error::InvalidPdu(
"invite event third_party_invite signed missing/invalid signatures".to_owned(),
)
})?;
let mut public_key_map = PublicKeyMap::new();
for (server_name, sig_map) in signatures {
let mut pk_set = PublicKeySet::new();
if let Some(sig_map) = sig_map.as_object() {
for (key_id, sig) in sig_map {
let sig_b64 = Base64::parse(sig.as_str().ok_or(Error::InvalidPdu(
"invite event third_party_invite signature is not a string".to_owned(),
))?)
.map_err(|_| {
Error::InvalidPdu(
"invite event third_party_invite signature is not valid Base64"
.to_owned(),
)
})?;
pk_set.insert(key_id.clone(), sig_b64);
}
}
public_key_map.insert(server_name.clone(), pk_set);
}
verify_json(
&public_key_map,
to_canonical_object(signed).expect("signed was already validated"),
)
.map_err(|e| {
Error::AuthConditionFailed(format!(
"invite event third_party_invite signature verification failed: {e}"
))
})?;
// If there was no error, there was a valid signature, so allow.
Ok(())
}
async fn check_invite_event<FS>(
room_version: &RoomVersion,
event: &Pdu,
membership: &PartialMembershipObject,
target: &UserId,
fetch_state: &FS,
) -> Result<(), Error>
where
FS: AsyncFn((StateEventType, StateKey)) -> Result<Option<Pdu>, Error>,
{
let target_current_membership = fetch_membership(fetch_state, target).await?;
// 4.1: If content has a third_party_invite property:
if let Some(raw_third_party_invite) = &membership.third_party_invite {
return check_third_party_invite(
target_current_membership,
raw_third_party_invite,
target,
event,
fetch_state,
)
.await;
}
// 4.2: If the senders current membership state is not join, reject.
let sender_membership = fetch_membership(fetch_state, event.sender()).await?;
if sender_membership.membership.is_none_or(|m| m != "join") {
return Err(Error::AuthConditionFailed("invite sender is not joined".to_owned()));
}
// 4.3: If target users current membership state is join or ban, reject.
if target_current_membership
.membership
.is_some_and(|m| m == "join" || m == "ban")
{
return Err(Error::AuthConditionFailed(
"invite target is already joined or banned".to_owned(),
));
}
// 4.4: If the senders power level is greater than or equal to the invite
// level, allow.
let (rank, pl, pl_evt) = get_rank(&room_version, fetch_state, event.sender()).await?;
if rank == UserPower::Creator || pl >= pl_evt.unwrap_or_default().invite {
return Ok(());
}
// 4.5: Otherwise, reject.
Err(Error::AuthConditionFailed(
"invite sender does not have sufficient power level to invite".to_owned(),
))
}
pub async fn check_member_event<FE, FS>(
room_version: &RoomVersion,
event: &Pdu,
fetch_event: FE,
fetch_state: FS,
) -> Result<(), Error>
where
FE: AsyncFn(&EventId) -> Result<Option<Pdu>, Error>,
FS: AsyncFn((StateEventType, StateKey)) -> Result<Option<Pdu>, Error>,
{
// 1. If there is no state_key property, or no membership property in content,
// reject.
if event.state_key.is_none() {
return Err(Error::InvalidPdu("m.room.member event missing state_key".to_owned()));
}
let target = UserId::parse(event.state_key().unwrap())
.map_err(|_| Error::InvalidPdu("m.room.member event has invalid state_key".to_owned()))?
.to_owned();
let content = event
.get_content::<PartialMembershipObject>()
.map_err(|e| {
Error::InvalidPdu(format!("m.room.member event has invalid content: {}", e))
})?;
if content.membership.is_none() {
return Err(Error::InvalidPdu(
"m.room.member event missing membership in content".to_owned(),
));
}
// 2: If content has a join_authorised_via_users_server key
//
// 2.1: If the event is not validly signed by the homeserver of the user ID
// denoted by the key, reject.
if let Some(_join_auth) = &content.join_authorized_via_users_server {
// We need to check the signature here, but don't have the means to do so yet.
todo!("Implement join_authorised_via_users_server check");
}
match content.membership.as_deref().unwrap() {
| "join" =>
check_join_event(room_version, event, &content, &target, &fetch_event, &fetch_state)
.await?,
| "invite" =>
check_invite_event(room_version, event, &content, &target, &fetch_state).await?,
| _ => {
todo!()
},
};
Ok(())
}

View File

@@ -0,0 +1,6 @@
pub mod auth_events;
mod context;
pub mod create_event;
pub mod iterative_auth_checks;
pub mod member_event;
mod power_levels;

View File

@@ -0,0 +1,157 @@
use ruma::{OwnedUserId, events::room::power_levels::RoomPowerLevelsEventContent};
use crate::{
Event, Pdu, RoomVersion,
state_res::{Error, event_auth::context::UserPower},
};
/// Verifies that a m.room.power_levels event is well-formed according to the
/// Matrix specification.
///
/// Creators must contain the m.room.create sender and any additional creators.
pub async fn check_power_levels(
room_version: &RoomVersion,
event: &Pdu,
current_power_levels: Option<&RoomPowerLevelsEventContent>,
creators: Vec<OwnedUserId>,
) -> Result<(), Error> {
let content = event
.get_content::<RoomPowerLevelsEventContent>()
.map_err(|e| {
Error::InvalidPdu(format!("m.room.power_levels event has invalid content: {}", e))
})?;
// If any of the properties users_default, events_default, state_default, ban,
// redact, kick, or invite in content are present and not an integer, reject.
//
// If either of the properties events or notifications in content are present
// and not an object with values that are integers, reject.
//
// NOTE: Deserialisation fails if this is not the case, so we don't need to
// check these here.
// If the users property in content is not an object with keys that are valid
// user IDs with values that are integers (or a string that is an integer),
// reject.
while let Some(user_id) = content.users.keys().next() {
// NOTE: Deserialisation fails if the power level is not an integer, so we don't
// need to check that here.
if let Err(e) = user_id.validate_historical() {
return Err(Error::InvalidPdu(format!(
"m.room.power_levels event has invalid user ID in users map: {}",
e
)));
}
// Since v12, If the users property in content contains the sender of the
// m.room.create event or any of the additional_creators array (if present)
// from the content of the m.room.create event, reject.
if room_version.explicitly_privilege_room_creators && creators.contains(user_id) {
return Err(Error::InvalidPdu(
"m.room.power_levels event users map contains a room creator".to_string(),
));
}
}
// If there is no previous m.room.power_levels event in the room, allow.
if current_power_levels.is_none() {
return Ok(());
}
let current_power_levels = current_power_levels.unwrap();
// For the properties users_default, events_default, state_default, ban, redact,
// kick, invite check if they were added, changed or removed. For each found
// alteration:
// If the current value is higher than the senders current power level, reject.
// If the new value is higher than the senders current power level, reject.
let sender = event.sender();
let rank = if room_version.explicitly_privilege_room_creators {
if creators.contains(&sender.to_owned()) {
UserPower::Creator
} else {
UserPower::Standard
}
} else {
UserPower::Standard
};
let sender_pl = current_power_levels
.users
.get(sender)
.unwrap_or(&current_power_levels.users_default);
if rank != UserPower::Creator {
let checks = [
("users_default", current_power_levels.users_default, content.users_default),
("events_default", current_power_levels.events_default, content.events_default),
("state_default", current_power_levels.state_default, content.state_default),
("ban", current_power_levels.ban, content.ban),
("redact", current_power_levels.redact, content.redact),
("kick", current_power_levels.kick, content.kick),
("invite", current_power_levels.invite, content.invite),
];
for (name, old_value, new_value) in checks.iter() {
if old_value != new_value {
if *old_value > *sender_pl {
return Err(Error::AuthConditionFailed(format!(
"sender cannot change level for {}",
name
)));
}
if *new_value > *sender_pl {
return Err(Error::AuthConditionFailed(format!(
"sender cannot raise level for {} to {}",
name, new_value
)));
}
}
}
// For each entry being changed in, or removed from, the events
// property:
// If the current value is greater than the senders current power level,
// reject.
for (event_type, new_value) in content.events.iter() {
let old_value = current_power_levels.events.get(event_type);
if old_value != Some(new_value) {
let old_pl = old_value.unwrap_or(&current_power_levels.events_default);
if *old_pl > *sender_pl {
return Err(Error::AuthConditionFailed(format!(
"sender cannot change event level for {}",
event_type
)));
}
if *new_value > *sender_pl {
return Err(Error::AuthConditionFailed(format!(
"sender cannot raise event level for {} to {}",
event_type, new_value
)));
}
}
}
// For each entry being changed in, or removed from, the events or
// notifications properties:
// If the current value is greater than the senders current power
// level, reject.
// If the new value is greater than the senders current power level,
// reject.
// TODO after making ruwuma's notifications value a BTreeMap
// For each entry being added to, or changed in, the users property:
// If the new value is greater than the senders current power level, reject.
for (user_id, new_value) in content.users.iter() {
let old_value = current_power_levels.users.get(user_id);
if old_value != Some(new_value) {
if *new_value > *sender_pl {
return Err(Error::AuthConditionFailed(format!(
"sender cannot raise user level for {} to {}",
user_id, new_value
)));
}
}
}
}
Ok(())
}

View File

@@ -8,9 +8,6 @@
#[cfg(test)]
mod test_utils;
#[cfg(test)]
mod benches;
use std::{
borrow::Borrow,
cmp::{Ordering, Reverse},
@@ -18,30 +15,31 @@
hash::{BuildHasher, Hash},
};
use futures::{Future, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, future};
use futures::{Future, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
use itertools::Itertools;
use ruma::{
EventId, Int, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomVersionId,
events::{
StateEventType, TimelineEventType,
room::member::{MembershipState, RoomMemberEventContent},
},
int,
room::member::{MembershipState, RoomMemberEventContent}, StateEventType,
TimelineEventType,
}, int, EventId, Int, MilliSecondsSinceUnixEpoch,
OwnedEventId,
RoomVersionId,
};
use serde_json::from_str as from_json_str;
pub(crate) use self::error::Error;
use self::power_levels::PowerLevelsContentFields;
pub use self::{
event_auth::{auth_check, auth_types_for_event},
room_version::RoomVersion,
};
pub use self::{event_auth::iterative_auth_checks::auth_check, room_version::RoomVersion};
use crate::utils::TryFutureExtExt;
use crate::{
debug, debug_error, err,
matrix::{Event, StateKey},
state_res::room_version::StateResolutionVersion,
debug, err, error as log_error, matrix::{Event, StateKey},
state_res::{
event_auth::auth_events::auth_types_for_event, room_version::StateResolutionVersion,
},
trace,
utils::stream::{BroadbandExt, IterStream, ReadyExt, TryBroadbandExt, WidebandExt},
warn,
Pdu,
};
/// A mapping of event type and state_key to some value `T`, usually an
@@ -75,23 +73,20 @@
/// event is part of the same room.
//#[tracing::instrument(level = "debug", skip(state_sets, auth_chain_sets,
//#[tracing::instrument(level event_fetch))]
pub async fn resolve<'a, Pdu, Sets, SetIter, Hasher, Fetch, FetchFut, Exists, ExistsFut>(
pub async fn resolve<'a, Sets, SetIter, Hasher, FE, FR, Exists>(
room_version: &RoomVersionId,
state_sets: Sets,
auth_chain_sets: &'a [HashSet<OwnedEventId, Hasher>],
event_fetch: &Fetch,
event_fetch: &FE,
event_exists: &Exists,
) -> Result<StateMap<OwnedEventId>>
where
Fetch: Fn(OwnedEventId) -> FetchFut + Sync,
FetchFut: Future<Output = Option<Pdu>> + Send,
Exists: Fn(OwnedEventId) -> ExistsFut + Sync,
ExistsFut: Future<Output = bool> + Send,
FE: Fn(&EventId) -> FR + Sync,
FR: Future<Output = Result<Option<Pdu>, Error>> + Send,
Exists: AsyncFn(OwnedEventId) -> bool + Sync,
Sets: IntoIterator<IntoIter = SetIter> + Send,
SetIter: Iterator<Item = &'a StateMap<OwnedEventId>> + Clone + Send,
Hasher: BuildHasher + Send + Sync,
Pdu: Event + Clone + Send + Sync,
for<'b> &'b Pdu: Event + Send,
{
use RoomVersionId::*;
let stateres_version = match room_version {
@@ -169,7 +164,7 @@ pub async fn resolve<'a, Pdu, Sets, SetIter, Hasher, Fetch, FetchFut, Exists, Ex
// Sequentially auth check each control event.
let resolved_control = iterative_auth_check(
&room_version,
sorted_control_levels.iter().stream().map(AsRef::as_ref),
sorted_control_levels.iter().stream().map(ToOwned::to_owned),
initial_state,
&event_fetch,
)
@@ -209,7 +204,7 @@ pub async fn resolve<'a, Pdu, Sets, SetIter, Hasher, Fetch, FetchFut, Exists, Ex
let mut resolved_state = iterative_auth_check(
&room_version,
sorted_left_events.iter().stream().map(AsRef::as_ref),
sorted_left_events.iter().stream(),
resolved_control, // The control events are added to the final resolved state
&event_fetch,
)
@@ -273,14 +268,12 @@ fn separate<'a, Id>(
}
/// Calculate the conflicted subgraph
async fn calculate_conflicted_subgraph<F, Fut, E>(
async fn calculate_conflicted_subgraph<FE>(
conflicted: &StateMap<Vec<OwnedEventId>>,
fetch_event: &F,
fetch_event: &FE,
) -> Option<HashSet<OwnedEventId>>
where
F: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send + Sync,
FE: AsyncFn(OwnedEventId) -> Result<Option<Pdu>> + Sync,
{
let conflicted_events: HashSet<_> = conflicted.values().flatten().cloned().collect();
let mut subgraph: HashSet<OwnedEventId> = HashSet::new();
@@ -312,7 +305,17 @@ async fn calculate_conflicted_subgraph<F, Fut, E>(
continue;
}
trace!(event_id = event_id.as_str(), "fetching event for its auth events");
let evt = fetch_event(event_id.clone()).await;
let evt = fetch_event(event_id.clone())
.await
.inspect_err(|e| {
log_error!(
"error fetching event {} for conflicted state subgraph: {}",
event_id,
e
)
})
.ok()
.flatten();
if evt.is_none() {
err!("could not fetch event {} to calculate conflicted subgraph", event_id);
path.pop();
@@ -359,15 +362,14 @@ fn get_auth_chain_diff<Id, Hasher>(
/// The power level is negative because a higher power level is equated to an
/// earlier (further back in time) origin server timestamp.
#[tracing::instrument(level = "debug", skip_all)]
async fn reverse_topological_power_sort<E, F, Fut>(
async fn reverse_topological_power_sort<FE, FR>(
events_to_sort: Vec<OwnedEventId>,
auth_diff: &HashSet<OwnedEventId>,
fetch_event: &F,
fetch_event: &FE,
) -> Result<Vec<OwnedEventId>>
where
F: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send + Sync,
FE: Fn(&EventId) -> FR + Sync,
FR: Future<Output = Result<Option<Pdu>, Error>> + Send,
{
debug!("reverse topological sort of power events");
@@ -404,8 +406,8 @@ async fn reverse_topological_power_sort<E, F, Fut>(
.get(&event_id)
.ok_or_else(|| Error::NotFound(String::new()))?;
let ev = fetch_event(event_id)
.await
let ev = fetch_event(&event_id)
.await?
.ok_or_else(|| Error::NotFound(String::new()))?;
Ok((pl, ev.origin_server_ts()))
@@ -544,18 +546,14 @@ fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(self.cmp(other))
/// Do NOT use this any where but topological sort, we find the power level for
/// the eventId at the eventId's generation (we walk backwards to `EventId`s
/// most recent previous power level event).
async fn get_power_level_for_sender<E, F, Fut>(
event_id: &EventId,
fetch_event: &F,
) -> serde_json::Result<Int>
async fn get_power_level_for_sender<FE, FR>(event_id: &EventId, fetch_event: &FE) -> Result<Int>
where
F: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send,
FE: Fn(&EventId) -> FR + Sync,
FR: Future<Output = Result<Option<Pdu>, Error>> + Send,
{
debug!("fetch event ({event_id}) senders power level");
let event = fetch_event(event_id.to_owned()).await;
let event = fetch_event(event_id).await?;
let auth_events = event.as_ref().map(Event::auth_events);
@@ -563,7 +561,7 @@ async fn get_power_level_for_sender<E, F, Fut>(
.into_iter()
.flatten()
.stream()
.broadn_filter_map(5, |aid| fetch_event(aid.to_owned()))
.broad_filter_map(|aid| fetch_event(aid).unwrap_or_default())
.ready_find(|aev| is_type_and_key(aev, &TimelineEventType::RoomPowerLevels, ""))
.await;
@@ -594,27 +592,24 @@ async fn get_power_level_for_sender<E, F, Fut>(
/// the the `fetch_event` closure and verify each event using the
/// `event_auth::auth_check` function.
#[tracing::instrument(level = "trace", skip_all)]
async fn iterative_auth_check<'a, E, F, Fut, S>(
async fn iterative_auth_check<FE, FR, S>(
room_version: &RoomVersion,
events_to_check: S,
unconflicted_state: StateMap<OwnedEventId>,
fetch_event: &F,
fetch_event: &FE,
) -> Result<StateMap<OwnedEventId>>
where
F: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
S: Stream<Item = &'a EventId> + Send + 'a,
E: Event + Clone + Send + Sync,
for<'b> &'b E: Event + Send,
FE: Fn(&EventId) -> FR,
FR: Future<Output = Result<Option<Pdu>, Error>> + Send + Sync,
S: Stream<Item = OwnedEventId> + Send,
{
debug!("starting iterative auth check");
let events_to_check: Vec<_> = events_to_check
.map(Result::Ok)
.broad_and_then(async |event_id| {
fetch_event(event_id.to_owned())
.await
.ok_or_else(|| Error::NotFound(format!("Failed to find {event_id}")))
.map(Ok::<OwnedEventId, Error>)
.broad_and_then(async |event_id| match fetch_event(&event_id).await {
| Ok(Some(e)) => Ok(e),
| _ => Err(Error::NotFound(format!("could not find {event_id}")))?,
})
.try_collect()
.boxed()
@@ -627,16 +622,20 @@ async fn iterative_auth_check<'a, E, F, Fut, S>(
let auth_event_ids: HashSet<OwnedEventId> = events_to_check
.iter()
.flat_map(|event: &E| event.auth_events().map(ToOwned::to_owned))
.flat_map(|event: &Pdu| event.auth_events().map(ToOwned::to_owned))
.collect();
trace!(set = ?auth_event_ids, "auth event IDs to fetch");
let auth_events: HashMap<OwnedEventId, E> = auth_event_ids
let auth_events: HashMap<OwnedEventId, Pdu> = auth_event_ids
.into_iter()
.stream()
.broad_filter_map(fetch_event)
.map(|auth_event| (auth_event.event_id().to_owned(), auth_event))
.broad_filter_map(async |event_id| {
fetch_event(&event_id)
.await
.map(|ev_opt| ev_opt.map(|ev| (event_id.clone(), ev)))
.unwrap_or_default()
})
.collect()
.boxed()
.await;
@@ -655,29 +654,23 @@ async fn iterative_auth_check<'a, E, F, Fut, S>(
.state_key()
.ok_or_else(|| Error::InvalidPdu("State event had no state key".to_owned()))?;
let member_event_content = match event.kind() {
| TimelineEventType::RoomMember =>
Some(event.get_content::<RoomMemberEventContent>().map_err(|e| {
Error::InvalidPdu(format!("Failed to parse m.room.member content: {}", e))
})?),
| _ => None,
};
let auth_types = auth_types_for_event(
event.event_type(),
event.sender(),
Some(state_key),
event.content(),
room_version,
event.kind(),
event.state_key().map(StateKey::from_str).as_ref(),
event.sender(),
member_event_content,
)?;
trace!(list = ?auth_types, event_id = event.event_id().as_str(), "auth types for event");
let mut auth_state = StateMap::new();
if room_version.room_ids_as_hashes {
trace!("room version uses hashed IDs, manually fetching create event");
let create_event_id_raw = event.room_id_or_hash().as_str().replace('!', "$");
let create_event_id = EventId::parse(&create_event_id_raw).map_err(|e| {
Error::InvalidPdu(format!(
"Failed to parse create event ID from room ID/hash: {e}"
))
})?;
let create_event = fetch_event(create_event_id.into())
.await
.ok_or_else(|| Error::NotFound("Failed to find create event".into()))?;
auth_state.insert(create_event.event_type().with_state_key(""), create_event);
}
let mut auth_state = StateMap::with_capacity(event.auth_events.len());
for aid in event.auth_events() {
if let Some(ev) = auth_events.get(aid) {
//TODO: synapse checks "rejected_reason" which is most likely related to
@@ -703,7 +696,13 @@ async fn iterative_auth_check<'a, E, F, Fut, S>(
if let Some(event) = auth_events.get(ev_id) {
Some((key, event.clone()))
} else {
Some((key, fetch_event(ev_id.clone()).await?))
match fetch_event(ev_id).await {
| Ok(Some(event)) => Some((key, event)),
| _ => {
warn!(event_id = ev_id.as_str(), "unable to fetch auth event");
None
},
}
}
})
.ready_for_each(|(key, event)| {
@@ -715,30 +714,16 @@ async fn iterative_auth_check<'a, E, F, Fut, S>(
debug!(event_id = event.event_id().as_str(), "Running auth checks");
// The key for this is (eventType + a state_key of the signed token not sender)
// so search for it
let current_third_party = auth_state.iter().find_map(|(_, pdu)| {
(*pdu.event_type() == TimelineEventType::RoomThirdPartyInvite).then_some(pdu)
});
let fetch_state = |ty: &StateEventType, key: &str| {
future::ready(
auth_state
.get(&ty.with_state_key(key))
.map(ToOwned::to_owned),
)
let fetch_state = async |t: (StateEventType, StateKey)| {
Ok(auth_state
.get(&t.0.with_state_key(t.1.as_str()))
.map(ToOwned::to_owned))
};
let auth_result = auth_check(
room_version,
&event,
current_third_party,
fetch_state,
&fetch_state(&StateEventType::RoomCreate, "")
.await
.expect("create event must exist"),
)
.await;
let create_event = fetch_state((StateEventType::RoomCreate, StateKey::new())).await?;
let auth_result =
auth_check(room_version, &event, fetch_event, &fetch_state, create_event.as_ref())
.await;
match auth_result {
| Ok(true) => {
@@ -758,7 +743,7 @@ async fn iterative_auth_check<'a, E, F, Fut, S>(
warn!("event {} failed the authentication check", event.event_id());
},
| Err(e) => {
debug_error!("event {} failed the authentication check: {e}", event.event_id());
log_error!("event {} failed the authentication check: {e}", event.event_id());
return Err(e);
},
}
@@ -777,15 +762,14 @@ async fn iterative_auth_check<'a, E, F, Fut, S>(
/// after the most recent are depth 0, the events before (with the first power
/// level as a parent) will be marked as depth 1. depth 1 is "older" than depth
/// 0.
async fn mainline_sort<E, F, Fut>(
async fn mainline_sort<FE, FR>(
to_sort: &[OwnedEventId],
resolved_power_level: Option<OwnedEventId>,
fetch_event: &F,
fetch_event: &FE,
) -> Result<Vec<OwnedEventId>>
where
F: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
E: Event + Clone + Send + Sync,
FE: Fn(&EventId) -> FR + Sync,
FR: Future<Output = Result<Option<Pdu>, Error>> + Send,
{
debug!("mainline sort of events");
@@ -799,14 +783,14 @@ async fn mainline_sort<E, F, Fut>(
while let Some(p) = pl {
mainline.push(p.clone());
let event = fetch_event(p.clone())
.await
let event = fetch_event(&p)
.await?
.ok_or_else(|| Error::NotFound(format!("Failed to find {p}")))?;
pl = None;
for aid in event.auth_events() {
let ev = fetch_event(aid.to_owned())
.await
let ev = fetch_event(aid)
.await?
.ok_or_else(|| Error::NotFound(format!("Failed to find {aid}")))?;
if is_type_and_key(&ev, &TimelineEventType::RoomPowerLevels, "") {
@@ -827,7 +811,11 @@ async fn mainline_sort<E, F, Fut>(
.iter()
.stream()
.broad_filter_map(async |ev_id| {
fetch_event(ev_id.clone()).await.map(|event| (event, ev_id))
fetch_event(ev_id)
.await
.ok()
.flatten()
.map(|event| (event, ev_id))
})
.broad_filter_map(|(event, ev_id)| {
get_mainline_depth(Some(event.clone()), &mainline_map, fetch_event)
@@ -849,15 +837,14 @@ async fn mainline_sort<E, F, Fut>(
/// Get the mainline depth from the `mainline_map` or finds a power_level event
/// that has an associated mainline depth.
async fn get_mainline_depth<E, F, Fut>(
mut event: Option<E>,
async fn get_mainline_depth<FE, FR>(
mut event: Option<Pdu>,
mainline_map: &HashMap<OwnedEventId, usize>,
fetch_event: &F,
fetch_event: &FE,
) -> Result<usize>
where
F: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send + Sync,
FE: Fn(&EventId) -> FR + Sync,
FR: Future<Output = Result<Option<Pdu>, Error>> + Send,
{
while let Some(sort_ev) = event {
debug!(event_id = sort_ev.event_id().as_str(), "mainline");
@@ -869,8 +856,8 @@ async fn get_mainline_depth<E, F, Fut>(
event = None;
for aid in sort_ev.auth_events() {
let aev = fetch_event(aid.to_owned())
.await
let aev = fetch_event(aid)
.await?
.ok_or_else(|| Error::NotFound(format!("Failed to find {aid}")))?;
if is_type_and_key(&aev, &TimelineEventType::RoomPowerLevels, "") {
@@ -883,20 +870,19 @@ async fn get_mainline_depth<E, F, Fut>(
Ok(0)
}
async fn add_event_and_auth_chain_to_graph<E, F, Fut>(
async fn add_event_and_auth_chain_to_graph<FE, FR>(
graph: &mut HashMap<OwnedEventId, HashSet<OwnedEventId>>,
event_id: OwnedEventId,
auth_diff: &HashSet<OwnedEventId>,
fetch_event: &F,
fetch_event: &FE,
) where
F: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send + Sync,
FE: Fn(&EventId) -> FR + Sync,
FR: Future<Output = Result<Option<Pdu>, Error>> + Send,
{
let mut state = vec![event_id];
while let Some(eid) = state.pop() {
graph.entry(eid.clone()).or_default();
let event = fetch_event(eid.clone()).await;
let event = fetch_event(&eid).await.ok().flatten();
let auth_events = event.as_ref().map(Event::auth_events).into_iter().flatten();
// Prefer the store to event as the store filters dedups the events
@@ -915,14 +901,13 @@ async fn add_event_and_auth_chain_to_graph<E, F, Fut>(
}
}
async fn is_power_event_id<E, F, Fut>(event_id: &EventId, fetch: &F) -> bool
async fn is_power_event_id<FE, FR>(event_id: &EventId, fetch: &FE) -> bool
where
F: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send,
FE: Fn(&EventId) -> FR + Sync,
FR: Future<Output = Result<Option<Pdu>, Error>> + Send,
{
match fetch(event_id.to_owned()).await.as_ref() {
| Some(state) => is_power_event(state),
match fetch(event_id).await.as_ref() {
| Ok(Some(state)) => is_power_event(state),
| _ => false,
}
}
@@ -979,26 +964,27 @@ fn with_state_key(self, state_key: impl Into<StateKey>) -> (StateEventType, Stat
mod tests {
use std::collections::{HashMap, HashSet};
use itertools::Itertools;
use maplit::{hashmap, hashset};
use rand::seq::SliceRandom;
use ruma::{
MilliSecondsSinceUnixEpoch, OwnedEventId, RoomVersionId,
events::{
StateEventType, TimelineEventType,
room::join_rules::{JoinRule, RoomJoinRulesEventContent},
},
int, uint,
room::join_rules::{JoinRule, RoomJoinRulesEventContent}, StateEventType,
TimelineEventType,
}, int, uint,
MilliSecondsSinceUnixEpoch,
OwnedEventId, RoomVersionId,
};
use serde_json::{json, value::to_raw_value as to_raw_json_value};
use super::{
StateMap, is_power_event,
room_version::RoomVersion,
is_power_event, room_version::RoomVersion,
test_utils::{
INITIAL_EVENTS, TestStore, alice, bob, charlie, do_check, ella, event_id,
member_content_ban, member_content_join, room_id, to_init_pdu_event, to_pdu_event,
zara,
alice, bob, charlie, do_check, ella, event_id, member_content_ban, member_content_join,
room_id, to_init_pdu_event, to_pdu_event, zara, TestStore,
INITIAL_EVENTS,
},
StateMap,
};
use crate::{
debug,
@@ -1028,13 +1014,13 @@ async fn test_event_sort() {
.map(|pdu| pdu.event_id.clone())
.collect::<Vec<_>>();
let fetcher = |id| ready(events.get(&id).cloned());
let fetcher = |id| ready(Ok(events.get(id).cloned()));
let sorted_power_events =
super::reverse_topological_power_sort(power_events, &auth_chain, &fetcher)
.await
.unwrap();
let resolved_power = super::iterative_auth_check(
let resolved_power = super::auth_check(
&RoomVersion::V6,
sorted_power_events.iter().map(AsRef::as_ref).stream(),
HashMap::new(), // unconflicted events
@@ -1046,7 +1032,7 @@ async fn test_event_sort() {
// don't remove any events so we know it sorts them all correctly
let mut events_to_sort = events.keys().cloned().collect::<Vec<_>>();
events_to_sort.shuffle(&mut rand::rng());
events_to_sort.shuffle(&mut rand::thread_rng());
let power_level = resolved_power
.get(&(StateEventType::RoomPowerLevels, "".into()))

View File

@@ -28,7 +28,7 @@ fn init_argon() -> Argon2<'static> {
}
pub(super) fn password(password: &str) -> Result<String> {
let salt = SaltString::generate(rand_core::OsRng);
let salt = SaltString::generate(rand::thread_rng());
ARGON
.get_or_init(init_argon)
.hash_password(password.as_bytes(), &salt)

View File

@@ -11,7 +11,6 @@
pub mod math;
pub mod mutex_map;
pub mod rand;
pub mod response;
pub mod result;
pub mod set;
pub mod stream;

View File

@@ -4,16 +4,16 @@
};
use arrayvec::ArrayString;
use rand::{RngExt, seq::SliceRandom};
use rand::{Rng, seq::SliceRandom, thread_rng};
pub fn shuffle<T>(vec: &mut [T]) {
let mut rng = rand::rng();
let mut rng = thread_rng();
vec.shuffle(&mut rng);
}
pub fn string(length: usize) -> String {
rand::rng()
.sample_iter(&rand::distr::Alphanumeric)
thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.take(length)
.map(char::from)
.collect()
@@ -22,8 +22,8 @@ pub fn string(length: usize) -> String {
#[inline]
pub fn string_array<const LENGTH: usize>() -> ArrayString<LENGTH> {
let mut ret = ArrayString::<LENGTH>::new();
rand::rng()
.sample_iter(&rand::distr::Alphanumeric)
thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.take(LENGTH)
.map(char::from)
.for_each(|c| ret.push(c));
@@ -40,4 +40,7 @@ pub fn time_from_now_secs(range: Range<u64>) -> SystemTime {
}
#[must_use]
pub fn secs(range: Range<u64>) -> Duration { Duration::from_secs(rand::random_range(range)) }
pub fn secs(range: Range<u64>) -> Duration {
let mut rng = thread_rng();
Duration::from_secs(rng.gen_range(range))
}

View File

@@ -1,51 +0,0 @@
use futures::StreamExt;
use num_traits::ToPrimitive;
use crate::Err;
/// Reads the response body while enforcing a maximum size limit to prevent
/// memory exhaustion.
pub async fn limit_read(response: reqwest::Response, max_size: u64) -> crate::Result<Vec<u8>> {
if response.content_length().is_some_and(|len| len > max_size) {
return Err!(BadServerResponse("Response too large"));
}
let mut data = Vec::new();
let mut reader = response.bytes_stream();
while let Some(chunk) = reader.next().await {
let chunk = chunk?;
data.extend_from_slice(&chunk);
if data.len() > max_size.to_usize().expect("max_size must fit in usize") {
return Err!(BadServerResponse("Response too large"));
}
}
Ok(data)
}
/// Reads the response body as text while enforcing a maximum size limit to
/// prevent memory exhaustion.
pub async fn limit_read_text(
response: reqwest::Response,
max_size: u64,
) -> crate::Result<String> {
let text = String::from_utf8(limit_read(response, max_size).await?)?;
Ok(text)
}
#[allow(async_fn_in_trait)]
pub trait LimitReadExt {
async fn limit_read(self, max_size: u64) -> crate::Result<Vec<u8>>;
async fn limit_read_text(self, max_size: u64) -> crate::Result<String>;
}
impl LimitReadExt for reqwest::Response {
async fn limit_read(self, max_size: u64) -> crate::Result<Vec<u8>> {
limit_read(self, max_size).await
}
async fn limit_read_text(self, max_size: u64) -> crate::Result<String> {
limit_read_text(self, max_size).await
}
}

View File

@@ -3,17 +3,19 @@
stream::{Stream, TryStream},
};
use crate::{Error, Result};
pub trait IterStream<I: IntoIterator + Send> {
/// Convert an Iterator into a Stream
fn stream(self) -> impl Stream<Item = <I as IntoIterator>::Item> + Send;
/// Convert an Iterator into a TryStream with a generic error type
fn try_stream<E>(
/// Convert an Iterator into a TryStream
fn try_stream(
self,
) -> impl TryStream<
Ok = <I as IntoIterator>::Item,
Error = E,
Item = Result<<I as IntoIterator>::Item, E>,
Error = Error,
Item = Result<<I as IntoIterator>::Item, Error>,
> + Send;
}
@@ -26,12 +28,12 @@ impl<I> IterStream<I> for I
fn stream(self) -> impl Stream<Item = <I as IntoIterator>::Item> + Send { stream::iter(self) }
#[inline]
fn try_stream<E>(
fn try_stream(
self,
) -> impl TryStream<
Ok = <I as IntoIterator>::Item,
Error = E,
Item = Result<<I as IntoIterator>::Item, E>,
Error = Error,
Item = Result<<I as IntoIterator>::Item, Error>,
> + Send {
self.stream().map(Ok)
}

View File

@@ -1,10 +1,9 @@
//! Synchronous combinator extensions to futures::TryStream
use std::result::Result;
use futures::{TryFuture, TryStream, TryStreamExt};
use super::automatic_width;
use crate::Result;
/// Concurrency extensions to augment futures::TryStreamExt. broad_ combinators
/// produce out-of-order

View File

@@ -362,10 +362,6 @@ pub(super) fn open_list(db: &Arc<Engine>, maps: &[Descriptor]) -> Result<Maps> {
name: "userid_blurhash",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "userid_dehydrateddevice",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "userid_devicelistversion",
..descriptor::RANDOM_SMALL

View File

@@ -530,12 +530,7 @@ async fn handle_response_error(
Ok(())
}
pub async fn is_admin_command<E>(
&self,
event: &E,
body: &str,
sent_locally: bool,
) -> Option<InvocationSource>
pub async fn is_admin_command<E>(&self, event: &E, body: &str) -> Option<InvocationSource>
where
E: Event + Send + Sync,
{
@@ -585,15 +580,6 @@ pub async fn is_admin_command<E>(
return None;
}
// Escaped commands must be sent locally (via client API), not via federation
if !sent_locally {
conduwuit::warn!(
"Ignoring escaped admin command from {} that arrived via federation",
event.sender()
);
return None;
}
// Looks good
Some(InvocationSource::EscapedCommand)
}

View File

@@ -18,8 +18,9 @@
use std::{sync::Arc, time::Duration};
use async_trait::async_trait;
use conduwuit::{Result, Server, debug, error, utils::response::LimitReadExt, warn};
use conduwuit::{Result, Server, debug, error, warn};
use database::{Deserialized, Map};
use rand::Rng;
use ruma::events::{Mentions, room::message::RoomMessageEventContent};
use serde::Deserialize;
use tokio::{
@@ -99,7 +100,8 @@ async fn worker(self: Arc<Self>) -> Result<()> {
}
let first_check_jitter = {
let jitter_percent = rand::random_range(-50.0..=10.0);
let mut rng = rand::thread_rng();
let jitter_percent = rng.gen_range(-50.0..=10.0);
self.interval.mul_f64(1.0 + jitter_percent / 100.0)
};
@@ -137,7 +139,7 @@ async fn check(&self) -> Result<()> {
.get(CHECK_FOR_ANNOUNCEMENTS_URL)
.send()
.await?
.limit_read_text(1024 * 1024)
.text()
.await?;
let response = serde_json::from_str::<CheckForAnnouncementsResponse>(&response)?;

View File

@@ -39,7 +39,7 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
let url_preview_user_agent = config
.url_preview_user_agent
.clone()
.unwrap_or_else(|| conduwuit::version::user_agent_media().to_owned());
.unwrap_or_else(|| conduwuit::version::user_agent().to_owned());
Ok(Arc::new(Self {
default: base(config)?

View File

@@ -2,8 +2,8 @@
use bytes::Bytes;
use conduwuit::{
Err, Error, Result, debug, debug::INFO_SPAN_LEVEL, debug_error, debug_warn, err, implement,
trace, utils::response::LimitReadExt,
Err, Error, Result, debug, debug::INFO_SPAN_LEVEL, debug_error, debug_warn, err,
error::inspect_debug_log, implement, trace,
};
use http::{HeaderValue, header::AUTHORIZATION};
use ipaddress::IPAddress;
@@ -133,22 +133,7 @@ async fn handle_response<T>(
where
T: OutgoingRequest + Send,
{
const HUGE_ENDPOINTS: [&str; 2] =
["/_matrix/federation/v2/send_join/", "/_matrix/federation/v2/state/"];
let size_limit: u64 = if HUGE_ENDPOINTS.iter().any(|e| url.path().starts_with(e)) {
// Some federation endpoints can return huge response bodies, so we'll bump the
// limit for those endpoints specifically.
self.services
.server
.config
.max_request_size
.saturating_mul(10)
} else {
self.services.server.config.max_request_size
}
.try_into()
.expect("size_limit (usize) should fit within a u64");
let response = into_http_response(dest, actual, method, url, response, size_limit).await?;
let response = into_http_response(dest, actual, method, url, response).await?;
T::IncomingResponse::try_from_http_response(response)
.map_err(|e| err!(BadServerResponse("Server returned bad 200 response: {e:?}")))
@@ -160,7 +145,6 @@ async fn into_http_response(
method: &Method,
url: &Url,
mut response: Response,
max_size: u64,
) -> Result<http::Response<Bytes>> {
let status = response.status();
trace!(
@@ -183,14 +167,14 @@ async fn into_http_response(
);
trace!("Waiting for response body...");
let body = response
.bytes()
.await
.inspect_err(inspect_debug_log)
.unwrap_or_else(|_| Vec::new().into());
let http_response = http_response_builder
.body(
response
.limit_read(max_size)
.await
.unwrap_or_default()
.into(),
)
.body(body)
.expect("reqwest body is valid http body");
debug!("Got {status:?} for {method} {url}");

View File

@@ -170,8 +170,6 @@ pub(super) fn remove_url_preview(&self, url: &str) -> Result<()> {
Ok(())
}
pub(super) async fn clear_url_previews(&self) { self.url_previews.clear().await; }
pub(super) fn set_url_preview(
&self,
url: &str,

View File

@@ -7,7 +7,7 @@
use std::time::SystemTime;
use conduwuit::{Err, Result, debug, err, utils::response::LimitReadExt};
use conduwuit::{Err, Result, debug, err};
use conduwuit_core::implement;
use ipaddress::IPAddress;
use serde::Serialize;
@@ -37,9 +37,6 @@ pub async fn remove_url_preview(&self, url: &str) -> Result<()> {
self.db.remove_url_preview(url)
}
#[implement(Service)]
pub async fn clear_url_previews(&self) { self.db.clear_url_previews().await; }
#[implement(Service)]
pub async fn set_url_preview(&self, url: &str, data: &UrlPreviewData) -> Result<()> {
let now = SystemTime::now()
@@ -112,22 +109,8 @@ pub async fn download_image(&self, url: &str) -> Result<UrlPreviewData> {
use image::ImageReader;
use ruma::Mxc;
let image = self
.services
.client
.url_preview
.get(url)
.send()
.await?
.limit_read(
self.services
.server
.config
.max_request_size
.try_into()
.expect("u64 should fit in usize"),
)
.await?;
let image = self.services.client.url_preview.get(url).send().await?;
let image = image.bytes().await?;
let mxc = Mxc {
server_name: self.services.globals.server_name(),
media_id: &random_string(super::MXC_LENGTH),
@@ -165,20 +148,24 @@ async fn download_html(&self, url: &str) -> Result<UrlPreviewData> {
use webpage::HTML;
let client = &self.services.client.url_preview;
let body = client
.get(url)
.send()
.await?
.limit_read_text(
self.services
.server
.config
.max_request_size
.try_into()
.expect("u64 should fit in usize"),
)
.await?;
let Ok(html) = HTML::from_string(body.clone(), Some(url.to_owned())) else {
let mut response = client.get(url).send().await?;
let mut bytes: Vec<u8> = Vec::new();
while let Some(chunk) = response.chunk().await? {
bytes.extend_from_slice(&chunk);
if bytes.len() > self.services.globals.url_preview_max_spider_size() {
debug!(
"Response body from URL {} exceeds url_preview_max_spider_size ({}), not \
processing the rest of the response body and assuming our necessary data is in \
this range.",
url,
self.services.globals.url_preview_max_spider_size()
);
break;
}
}
let body = String::from_utf8_lossy(&bytes);
let Ok(html) = HTML::from_string(body.to_string(), Some(url.to_owned())) else {
return Err!(Request(Unknown("Failed to parse HTML")));
};

View File

@@ -2,7 +2,7 @@
use conduwuit::{
Err, Error, Result, debug_warn, err, implement,
utils::{content_disposition::make_content_disposition, response::LimitReadExt},
utils::content_disposition::make_content_disposition,
};
use http::header::{CONTENT_DISPOSITION, CONTENT_TYPE, HeaderValue};
use ruma::{
@@ -286,15 +286,10 @@ async fn location_request(&self, location: &str) -> Result<FileMeta> {
.and_then(Result::ok);
response
.limit_read(
self.services
.server
.config
.max_request_size
.try_into()
.expect("u64 should fit in usize"),
)
.bytes()
.await
.map(Vec::from)
.map_err(Into::into)
.map(|content| FileMeta {
content: Some(content),
content_type: content_type.clone(),

View File

@@ -31,7 +31,7 @@
pub mod sending;
pub mod server_keys;
pub mod sync;
pub mod transactions;
pub mod transaction_ids;
pub mod uiaa;
pub mod users;

View File

@@ -1,7 +1,6 @@
use std::{fmt::Debug, mem, sync::Arc};
use bytes::BytesMut;
use conduwuit::utils::response::LimitReadExt;
use conduwuit_core::{
Err, Event, Result, debug_warn, err, trace,
utils::{stream::TryIgnore, string_from_bytes},
@@ -31,7 +30,7 @@
uint,
};
use crate::{Dep, client, config, globals, rooms, sending, users};
use crate::{Dep, client, globals, rooms, sending, users};
pub struct Service {
db: Data,
@@ -40,7 +39,6 @@ pub struct Service {
struct Services {
globals: Dep<globals::Service>,
config: Dep<config::Service>,
client: Dep<client::Service>,
state_accessor: Dep<rooms::state_accessor::Service>,
state_cache: Dep<rooms::state_cache::Service>,
@@ -63,7 +61,6 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
services: Services {
globals: args.depend::<globals::Service>("globals"),
client: args.depend::<client::Service>("client"),
config: args.depend::<config::Service>("config"),
state_accessor: args
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
@@ -248,15 +245,7 @@ pub async fn send_request<T>(&self, dest: &str, request: T) -> Result<T::Incomin
.expect("http::response::Builder is usable"),
);
let body = response
.limit_read(
self.services
.config
.max_request_size
.try_into()
.expect("usize fits into u64"),
)
.await?;
let body = response.bytes().await?;
if !status.is_success() {
debug_warn!("Push gateway response body: {:?}", string_from_bytes(&body));

View File

@@ -1,6 +1,4 @@
use conduwuit::{
Result, debug, debug_error, debug_info, implement, trace, utils::response::LimitReadExt,
};
use conduwuit::{Result, debug, debug_error, debug_info, debug_warn, implement, trace};
#[implement(super::Service)]
#[tracing::instrument(name = "well-known", level = "debug", skip(self, dest))]
@@ -26,8 +24,12 @@ pub(super) async fn request_well_known(&self, dest: &str) -> Result<Option<Strin
return Ok(None);
}
let text = response.limit_read_text(8192).await?;
let text = response.text().await?;
trace!("response text: {text:?}");
if text.len() >= 12288 {
debug_warn!("response contains junk");
return Ok(None);
}
let body: serde_json::Value = serde_json::from_str(&text).unwrap_or_default();

View File

@@ -142,7 +142,7 @@ async fn get_auth_chain_outer(
let chunk_cache: Vec<_> = chunk
.into_iter()
.try_stream::<conduwuit::Error>()
.try_stream()
.broad_and_then(|(shortid, event_id)| async move {
if let Ok(cached) = self.get_cached_eventid_authchain(&[shortid]).await {
return Ok(cached.to_vec());

View File

@@ -63,9 +63,7 @@ pub(super) async fn fetch_state<Pdu>(
},
| hash_map::Entry::Occupied(_) => {
return Err!(Database(
"State event's type and state_key combination exists multiple times: {}, {}",
pdu.kind(),
state_key
"State event's type and state_key combination exists multiple times.",
));
},
}

View File

@@ -162,9 +162,7 @@ pub(super) async fn handle_outlier_pdu<'a, Pdu>(
},
| hash_map::Entry::Occupied(_) => {
return Err!(Request(InvalidParam(
"Auth event's type and state_key combination exists multiple times: {}, {}",
auth_event.kind,
auth_event.state_key().unwrap_or("")
"Auth event's type and state_key combination exists multiple times.",
)));
},
}

View File

@@ -72,26 +72,6 @@ pub async fn append_incoming_pdu<'a, Leaves>(
.append_pdu(pdu, pdu_json, new_room_leaves, state_lock, room_id)
.await?;
// Process admin commands for federation events
if *pdu.kind() == TimelineEventType::RoomMessage {
let content: ExtractBody = pdu.get_content()?;
if let Some(body) = content.body {
if let Some(source) = self
.services
.admin
.is_admin_command(pdu, &body, false)
.await
{
self.services.admin.command_with_sender(
body,
Some(pdu.event_id().into()),
source,
pdu.sender.clone().into(),
)?;
}
}
}
Ok(Some(pdu_id))
}
@@ -354,6 +334,15 @@ pub async fn append_pdu<'a, Leaves>(
let content: ExtractBody = pdu.get_content()?;
if let Some(body) = content.body {
self.services.search.index_pdu(shortroomid, &pdu_id, &body);
if let Some(source) = self.services.admin.is_admin_command(pdu, &body).await {
self.services.admin.command_with_sender(
body,
Some((pdu.event_id()).into()),
source,
pdu.sender.clone().into(),
)?;
}
}
},
| _ => {},

View File

@@ -18,7 +18,7 @@
},
};
use super::{ExtractBody, RoomMutexGuard};
use super::RoomMutexGuard;
/// Creates a new persisted data unit and adds it to a room. This function
/// takes a roomid_mutex_state, meaning that only this function is able to
@@ -126,26 +126,6 @@ pub async fn build_and_append_pdu(
.boxed()
.await?;
// Process admin commands for locally sent events
if *pdu.kind() == TimelineEventType::RoomMessage {
let content: ExtractBody = pdu.get_content()?;
if let Some(body) = content.body {
if let Some(source) = self
.services
.admin
.is_admin_command(&pdu, &body, true)
.await
{
self.services.admin.command_with_sender(
body,
Some(pdu.event_id().into()),
source,
pdu.sender.clone().into(),
)?;
}
}
}
// We set the room state after inserting the pdu, so that we never have a moment
// in time where events in the current room state do not exist
trace!("Setting room state for room {room_id}");
@@ -187,8 +167,6 @@ pub async fn build_and_append_pdu(
Ok(pdu.event_id().to_owned())
}
/// Assert invariants about the admin room, to prevent (for example) all admins
/// from leaving or being banned from the room
#[implement(super::Service)]
#[tracing::instrument(skip_all, level = "debug")]
async fn check_pdu_for_admin_room<Pdu>(&self, pdu: &Pdu, sender: &UserId) -> Result

View File

@@ -1,7 +1,7 @@
use std::{fmt::Debug, mem};
use bytes::BytesMut;
use conduwuit::{Err, Result, debug_error, err, utils, utils::response::LimitReadExt, warn};
use conduwuit::{Err, Result, debug_error, err, utils, warn};
use reqwest::Client;
use ruma::api::{IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken};
@@ -38,7 +38,7 @@ pub(crate) async fn send_antispam_request<T>(
.expect("http::response::Builder is usable"),
);
let body = response.limit_read(65535).await?; // TODO: handle timeout
let body = response.bytes().await?; // TODO: handle timeout
if !status.is_success() {
debug_error!("Antispam response bytes: {:?}", utils::string_from_bytes(&body));

View File

@@ -1,9 +1,7 @@
use std::{fmt::Debug, mem};
use bytes::BytesMut;
use conduwuit::{
Err, Result, debug_error, err, implement, trace, utils, utils::response::LimitReadExt, warn,
};
use conduwuit::{Err, Result, debug_error, err, implement, trace, utils, warn};
use ruma::api::{
IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken, appservice::Registration,
};
@@ -79,15 +77,7 @@ pub async fn send_appservice_request<T>(
.expect("http::response::Builder is usable"),
);
let body = response
.limit_read(
self.server
.config
.max_request_size
.try_into()
.expect("usize fits into u64"),
)
.await?;
let body = response.bytes().await?;
if !status.is_success() {
debug_error!("Appservice response bytes: {:?}", utils::string_from_bytes(&body));

View File

@@ -385,13 +385,11 @@ fn num_senders(args: &crate::Args<'_>) -> usize {
const MIN_SENDERS: usize = 1;
// Limit the number of senders to the number of workers threads or number of
// cores, conservatively.
let mut max_senders = args.server.metrics.num_workers();
// Work around some platforms not returning the number of cores.
let num_cores = available_parallelism();
if num_cores > 0 {
max_senders = max_senders.min(num_cores);
}
let max_senders = args
.server
.metrics
.num_workers()
.min(available_parallelism());
// If the user doesn't override the default 0, this is intended to then default
// to 1 for now as multiple senders is experimental.

View File

@@ -10,7 +10,7 @@
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use conduwuit_core::{
Error, Event, Result, at, debug, err, error,
Error, Event, Result, debug, err, error,
result::LogErr,
trace,
utils::{
@@ -175,7 +175,7 @@ async fn handle_response_ok<'a>(
if !new_events.is_empty() {
self.db.mark_as_active(new_events.iter());
let new_events_vec = new_events.into_iter().map(at!(1)).collect();
let new_events_vec = new_events.into_iter().map(|(_, event)| event).collect();
futures.push(self.send_events(dest.clone(), new_events_vec));
} else {
statuses.remove(dest);

View File

@@ -14,7 +14,7 @@
media, moderation, presence, pusher, registration_tokens, resolver, rooms, sending,
server_keys,
service::{self, Args, Map, Service},
sync, transactions, uiaa, users,
sync, transaction_ids, uiaa, users,
};
pub struct Services {
@@ -37,7 +37,7 @@ pub struct Services {
pub sending: Arc<sending::Service>,
pub server_keys: Arc<server_keys::Service>,
pub sync: Arc<sync::Service>,
pub transactions: Arc<transactions::Service>,
pub transaction_ids: Arc<transaction_ids::Service>,
pub uiaa: Arc<uiaa::Service>,
pub users: Arc<users::Service>,
pub moderation: Arc<moderation::Service>,
@@ -110,7 +110,7 @@ macro_rules! build {
sending: build!(sending::Service),
server_keys: build!(server_keys::Service),
sync: build!(sync::Service),
transactions: build!(transactions::Service),
transaction_ids: build!(transaction_ids::Service),
uiaa: build!(uiaa::Service),
users: build!(users::Service),
moderation: build!(moderation::Service),

View File

@@ -0,0 +1,54 @@
use std::sync::Arc;
use conduwuit::{Result, implement};
use database::{Handle, Map};
use ruma::{DeviceId, TransactionId, UserId};
pub struct Service {
db: Data,
}
struct Data {
userdevicetxnid_response: Arc<Map>,
}
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data {
userdevicetxnid_response: args.db["userdevicetxnid_response"].clone(),
},
}))
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
#[implement(Service)]
pub fn add_txnid(
&self,
user_id: &UserId,
device_id: Option<&DeviceId>,
txn_id: &TransactionId,
data: &[u8],
) {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
key.push(0xFF);
key.extend_from_slice(txn_id.as_bytes());
self.db.userdevicetxnid_response.insert(&key, data);
}
// If there's no entry, this is a new transaction
#[implement(Service)]
pub async fn existing_txnid(
&self,
user_id: &UserId,
device_id: Option<&DeviceId>,
txn_id: &TransactionId,
) -> Result<Handle<'_>> {
let key = (user_id, device_id, txn_id);
self.db.userdevicetxnid_response.qry(&key).await
}

View File

@@ -1,326 +0,0 @@
use std::{
collections::HashMap,
fmt,
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
time::{Duration, SystemTime},
};
use async_trait::async_trait;
use conduwuit::{Error, Result, SyncRwLock, debug_warn, warn};
use database::{Handle, Map};
use ruma::{
DeviceId, OwnedServerName, OwnedTransactionId, TransactionId, UserId,
api::{
client::error::ErrorKind::LimitExceeded,
federation::transactions::send_transaction_message,
},
};
use tokio::sync::watch::{Receiver, Sender};
use crate::{Dep, config};
pub type TxnKey = (OwnedServerName, OwnedTransactionId);
pub type WrappedTransactionResponse =
Option<Result<send_transaction_message::v1::Response, TransactionError>>;
/// Errors that can occur during federation transaction processing.
#[derive(Debug, Clone)]
pub enum TransactionError {
/// Server is shutting down - the sender should retry the entire
/// transaction.
ShuttingDown,
}
impl fmt::Display for TransactionError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
| Self::ShuttingDown => write!(f, "Server is shutting down"),
}
}
}
impl std::error::Error for TransactionError {}
/// Minimum interval between cache cleanup runs.
/// Exists to prevent thrashing when the cache is full of things that can't be
/// cleared
const CLEANUP_INTERVAL_SECS: u64 = 30;
#[derive(Clone, Debug)]
pub struct CachedTxnResponse {
pub response: send_transaction_message::v1::Response,
pub created: SystemTime,
}
/// Internal state for a federation transaction.
/// Either actively being processed or completed and cached.
#[derive(Clone)]
enum TxnState {
/// Transaction is currently being processed.
Active(Receiver<WrappedTransactionResponse>),
/// Transaction completed and response is cached.
Cached(CachedTxnResponse),
}
/// Result of atomically checking or starting a federation transaction.
pub enum FederationTxnState {
/// Transaction already completed and cached
Cached(send_transaction_message::v1::Response),
/// Transaction is currently being processed by another request.
/// Wait on this receiver for the result.
Active(Receiver<WrappedTransactionResponse>),
/// This caller should process the transaction (first to request it).
Started {
receiver: Receiver<WrappedTransactionResponse>,
sender: Sender<WrappedTransactionResponse>,
},
}
pub struct Service {
services: Services,
db: Data,
federation_txn_state: Arc<SyncRwLock<HashMap<TxnKey, TxnState>>>,
last_cleanup: AtomicU64,
}
struct Services {
config: Dep<config::Service>,
}
struct Data {
userdevicetxnid_response: Arc<Map>,
}
#[async_trait]
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
services: Services {
config: args.depend::<config::Service>("config"),
},
db: Data {
userdevicetxnid_response: args.db["userdevicetxnid_response"].clone(),
},
federation_txn_state: Arc::new(SyncRwLock::new(HashMap::new())),
last_cleanup: AtomicU64::new(0),
}))
}
async fn clear_cache(&self) {
let mut state = self.federation_txn_state.write();
// Only clear cached entries, preserve active transactions
state.retain(|_, v| matches!(v, TxnState::Active(_)));
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
impl Service {
/// Returns the count of currently active (in-progress) transactions.
#[must_use]
pub fn txn_active_handle_count(&self) -> usize {
let state = self.federation_txn_state.read();
state
.values()
.filter(|v| matches!(v, TxnState::Active(_)))
.count()
}
pub fn add_client_txnid(
&self,
user_id: &UserId,
device_id: Option<&DeviceId>,
txn_id: &TransactionId,
data: &[u8],
) {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
key.push(0xFF);
key.extend_from_slice(txn_id.as_bytes());
self.db.userdevicetxnid_response.insert(&key, data);
}
pub async fn get_client_txn(
&self,
user_id: &UserId,
device_id: Option<&DeviceId>,
txn_id: &TransactionId,
) -> Result<Handle<'_>> {
let key = (user_id, device_id, txn_id);
self.db.userdevicetxnid_response.qry(&key).await
}
/// Atomically gets a cached response, joins an active transaction, or
/// starts a new one.
pub fn get_or_start_federation_txn(&self, key: TxnKey) -> Result<FederationTxnState> {
// Only one upgradable lock can be held at a time, and there aren't any
// read-only locks, so no point being upgradable
let mut state = self.federation_txn_state.write();
// Check existing state for this key
if let Some(txn_state) = state.get(&key) {
return Ok(match txn_state {
| TxnState::Cached(cached) => FederationTxnState::Cached(cached.response.clone()),
| TxnState::Active(receiver) => FederationTxnState::Active(receiver.clone()),
});
}
// Check if another transaction from this origin is already running
let has_active_from_origin = state
.iter()
.any(|(k, v)| k.0 == key.0 && matches!(v, TxnState::Active(_)));
if has_active_from_origin {
debug_warn!(
origin = ?key.0,
"Got concurrent transaction request from an origin with an active transaction"
);
return Err(Error::BadRequest(
LimitExceeded { retry_after: None },
"Still processing another transaction from this origin",
));
}
let max_active_txns = self.services.config.max_concurrent_inbound_transactions;
// Check if we're at capacity
if state.len() >= max_active_txns
&& let active_count = state
.values()
.filter(|v| matches!(v, TxnState::Active(_)))
.count() && active_count >= max_active_txns
{
warn!(
active = active_count,
max = max_active_txns,
"Server is overloaded, dropping incoming transaction"
);
return Err(Error::BadRequest(
LimitExceeded { retry_after: None },
"Server is overloaded, try again later",
));
}
// Start new transaction
let (sender, receiver) = tokio::sync::watch::channel(None);
state.insert(key, TxnState::Active(receiver.clone()));
Ok(FederationTxnState::Started { receiver, sender })
}
/// Finishes a transaction by transitioning it from active to cached state.
/// Additionally may trigger cleanup of old entries.
pub fn finish_federation_txn(
&self,
key: TxnKey,
sender: Sender<WrappedTransactionResponse>,
response: send_transaction_message::v1::Response,
) {
// Check if cleanup might be needed before acquiring the lock
let should_try_cleanup = self.should_try_cleanup();
let mut state = self.federation_txn_state.write();
// Explicitly set cached first so there is no gap where receivers get a closed
// channel
state.insert(
key,
TxnState::Cached(CachedTxnResponse {
response: response.clone(),
created: SystemTime::now(),
}),
);
if let Err(e) = sender.send(Some(Ok(response))) {
debug_warn!("Failed to send transaction response to waiting receivers: {e}");
}
// Explicitly close
drop(sender);
// This task is dangling, we can try clean caches now
if should_try_cleanup {
self.cleanup_entries_locked(&mut state);
}
}
pub fn remove_federation_txn(&self, key: &TxnKey) {
let mut state = self.federation_txn_state.write();
state.remove(key);
}
/// Checks if enough time has passed since the last cleanup to consider
/// running another. Updates the last cleanup time if returning true.
fn should_try_cleanup(&self) -> bool {
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.expect("SystemTime before UNIX_EPOCH")
.as_secs();
let last = self.last_cleanup.load(Ordering::Relaxed);
if now.saturating_sub(last) >= CLEANUP_INTERVAL_SECS {
// CAS: only update if no one else has updated it since we read
self.last_cleanup
.compare_exchange(last, now, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
} else {
false
}
}
/// Cleans up cached entries based on age and count limits.
///
/// First removes all cached entries older than the configured max age.
/// Then, if the cache still exceeds the max entry count, removes the oldest
/// cached entries until the count is within limits.
///
/// Must be called with write lock held on the state map.
fn cleanup_entries_locked(&self, state: &mut HashMap<TxnKey, TxnState>) {
let max_age_secs = self.services.config.transaction_id_cache_max_age_secs;
let max_entries = self.services.config.transaction_id_cache_max_entries;
// First pass: remove all cached entries older than max age
let cutoff = SystemTime::now()
.checked_sub(Duration::from_secs(max_age_secs))
.unwrap_or(SystemTime::UNIX_EPOCH);
state.retain(|_, v| match v {
| TxnState::Active(_) => true, // Never remove active transactions
| TxnState::Cached(cached) => cached.created > cutoff,
});
// Count cached entries
let cached_count = state
.values()
.filter(|v| matches!(v, TxnState::Cached(_)))
.count();
// Second pass: if still over max entries, remove oldest cached entries
if cached_count > max_entries {
let excess = cached_count.saturating_sub(max_entries);
// Collect cached entries sorted by age (oldest first)
let mut cached_entries: Vec<_> = state
.iter()
.filter_map(|(k, v)| match v {
| TxnState::Cached(cached) => Some((k.clone(), cached.created)),
| TxnState::Active(_) => None,
})
.collect();
cached_entries.sort_by(|a, b| a.1.cmp(&b.1));
// Remove the oldest cached entries to get under the limit
for (key, _) in cached_entries.into_iter().take(excess) {
state.remove(&key);
}
}
}
}

View File

@@ -1,149 +0,0 @@
use conduwuit::{Err, Result, implement, trace};
use conduwuit_database::{Deserialized, Json};
use ruma::{
DeviceId, OwnedDeviceId, UserId,
api::client::dehydrated_device::{
DehydratedDeviceData, put_dehydrated_device::unstable::Request,
},
serde::Raw,
};
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct DehydratedDevice {
/// Unique ID of the device.
pub device_id: OwnedDeviceId,
/// Contains serialized and encrypted private data.
pub device_data: Raw<DehydratedDeviceData>,
}
/// Creates or recreates the user's dehydrated device.
#[implement(super::Service)]
#[tracing::instrument(
level = "info",
skip_all,
fields(
%user_id,
device_id = %request.device_id,
display_name = ?request.initial_device_display_name,
)
)]
pub async fn set_dehydrated_device(&self, user_id: &UserId, request: Request) -> Result {
assert!(
self.exists(user_id).await,
"Tried to create dehydrated device for non-existent user"
);
let existing_id = self.get_dehydrated_device_id(user_id).await;
if existing_id.is_err()
&& self
.get_device_metadata(user_id, &request.device_id)
.await
.is_ok()
{
return Err!("A hydrated device already exists with that ID.");
}
if let Ok(existing_id) = existing_id {
self.remove_device(user_id, &existing_id).await;
}
self.create_device(
user_id,
&request.device_id,
"",
request.initial_device_display_name.clone(),
None,
)
.await?;
trace!(device_data = ?request.device_data);
self.db.userid_dehydrateddevice.raw_put(
user_id,
Json(&DehydratedDevice {
device_id: request.device_id.clone(),
device_data: request.device_data,
}),
);
trace!(device_keys = ?request.device_keys);
self.add_device_keys(user_id, &request.device_id, &request.device_keys)
.await;
trace!(one_time_keys = ?request.one_time_keys);
for (one_time_key_key, one_time_key_value) in &request.one_time_keys {
self.add_one_time_key(user_id, &request.device_id, one_time_key_key, one_time_key_value)
.await?;
}
Ok(())
}
/// Removes a user's dehydrated device.
///
/// Calling this directly will remove the dehydrated data but leak the frontage
/// device. Thus this is called by the regular device interface such that the
/// dehydrated data will not leak instead.
///
/// If device_id is given, the user's dehydrated device must match or this is a
/// no-op, but an Err is still returned to indicate that. Otherwise returns the
/// removed dehydrated device_id.
#[implement(super::Service)]
#[tracing::instrument(
level = "debug",
skip_all,
fields(
%user_id,
device_id = ?maybe_device_id,
)
)]
pub(super) async fn remove_dehydrated_device(
&self,
user_id: &UserId,
maybe_device_id: Option<&DeviceId>,
) -> Result<OwnedDeviceId> {
let Ok(device_id) = self.get_dehydrated_device_id(user_id).await else {
return Err!(Request(NotFound("No dehydrated device for this user.")));
};
if let Some(maybe_device_id) = maybe_device_id {
if maybe_device_id != device_id {
return Err!(Request(NotFound("Not the user's dehydrated device.")));
}
}
self.db.userid_dehydrateddevice.remove(user_id);
Ok(device_id)
}
/// Get the device_id of the user's dehydrated device.
#[implement(super::Service)]
#[tracing::instrument(
level = "debug",
skip_all,
fields(%user_id)
)]
pub async fn get_dehydrated_device_id(&self, user_id: &UserId) -> Result<OwnedDeviceId> {
self.get_dehydrated_device(user_id)
.await
.map(|device| device.device_id)
}
/// Get the dehydrated device private data
#[implement(super::Service)]
#[tracing::instrument(
level = "debug",
skip_all,
fields(%user_id),
ret,
)]
pub async fn get_dehydrated_device(&self, user_id: &UserId) -> Result<DehydratedDevice> {
self.db
.userid_dehydrateddevice
.get(user_id)
.await
.deserialized()
}

View File

@@ -1,5 +1,3 @@
pub(super) mod dehydrated_device;
#[cfg(feature = "ldap")]
use std::collections::HashMap;
use std::{collections::BTreeMap, mem, net::IpAddr, sync::Arc};
@@ -7,7 +5,7 @@
#[cfg(feature = "ldap")]
use conduwuit::result::LogErr;
use conduwuit::{
Err, Error, Result, Server, debug_warn, err, is_equal_to, trace,
Err, Error, Result, Server, at, debug_warn, err, is_equal_to, trace,
utils::{self, ReadyExt, stream::TryIgnore, string::Unquoted},
};
#[cfg(feature = "ldap")]
@@ -72,7 +70,6 @@ struct Data {
userfilterid_filter: Arc<Map>,
userid_avatarurl: Arc<Map>,
userid_blurhash: Arc<Map>,
userid_dehydrateddevice: Arc<Map>,
userid_devicelistversion: Arc<Map>,
userid_displayname: Arc<Map>,
userid_lastonetimekeyupdate: Arc<Map>,
@@ -113,7 +110,6 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
userfilterid_filter: args.db["userfilterid_filter"].clone(),
userid_avatarurl: args.db["userid_avatarurl"].clone(),
userid_blurhash: args.db["userid_blurhash"].clone(),
userid_dehydrateddevice: args.db["userid_dehydrateddevice"].clone(),
userid_devicelistversion: args.db["userid_devicelistversion"].clone(),
userid_displayname: args.db["userid_displayname"].clone(),
userid_lastonetimekeyupdate: args.db["userid_lastonetimekeyupdate"].clone(),
@@ -188,12 +184,6 @@ pub async fn create(
password: Option<&str>,
origin: Option<&str>,
) -> Result<()> {
if !self.services.globals.user_is_local(user_id)
&& (password.is_some() || origin.is_some())
{
return Err!("Cannot create a nonlocal user with a set password or origin");
}
self.db
.userid_origin
.insert(user_id, origin.unwrap_or("password"));
@@ -484,11 +474,6 @@ pub async fn create_device(
/// Removes a device from a user.
pub async fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) {
// Remove dehydrated device if this is the dehydrated device
let _: Result<_> = self
.remove_dehydrated_device(user_id, Some(device_id))
.await;
let userdeviceid = (user_id, device_id);
// Remove tokens
@@ -1012,7 +997,7 @@ pub fn get_to_device_events<'a>(
device_id: &'a DeviceId,
since: Option<u64>,
to: Option<u64>,
) -> impl Stream<Item = (u64, Raw<AnyToDeviceEvent>)> + Send + 'a {
) -> impl Stream<Item = Raw<AnyToDeviceEvent>> + Send + 'a {
type Key<'a> = (&'a UserId, &'a DeviceId, u64);
let from = (user_id, device_id, since.map_or(0, |since| since.saturating_add(1)));
@@ -1026,7 +1011,7 @@ pub fn get_to_device_events<'a>(
&& device_id == *device_id_
&& to.is_none_or(|to| *count <= to)
})
.map(|((_, _, count), event)| (count, event))
.map(at!(1))
}
pub async fn remove_to_device_events<Until>(