diff --git a/changelog.d/19701.misc b/changelog.d/19701.misc new file mode 100644 index 0000000000..4663e8b961 --- /dev/null +++ b/changelog.d/19701.misc @@ -0,0 +1 @@ +Port the python Event classes to Rust. diff --git a/rust/src/duration.rs b/rust/src/duration.rs index a3dbe919b2..6c2e2653d1 100644 --- a/rust/src/duration.rs +++ b/rust/src/duration.rs @@ -29,19 +29,41 @@ fn duration_module(py: Python<'_>) -> PyResult<&Bound<'_, PyAny>> { } /// Mirrors the `synapse.util.duration.Duration` Python class. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub struct SynapseDuration { - microseconds: u64, + milliseconds: u64, } impl SynapseDuration { - /// For now we only need to create durations from milliseconds. - pub fn from_milliseconds(milliseconds: u64) -> Self { + /// Creates a `SynapseDuration` from a number of milliseconds. + pub const fn from_milliseconds(milliseconds: u64) -> Self { + Self { milliseconds } + } + + /// Creates a `SynapseDuration` from a number of hours. + pub const fn from_hours(hours: u32) -> Self { + // We take a u32 here so that we know the multiplication won't overflow. + // We could instead panic, but that is unstable in a const context (for + // the current MSRV 1.82). Self { - microseconds: milliseconds * 1_000, + milliseconds: (hours as u64) * 3_600_000, } } } +impl<'py> IntoPyObject<'py> for SynapseDuration { + type Target = PyAny; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> Result { + let duration_module = duration_module(py)?; + let kwargs = [("milliseconds", self.milliseconds)].into_py_dict(py)?; + let duration_instance = duration_module.call_method("Duration", (), Some(&kwargs))?; + Ok(duration_instance.into_bound()) + } +} + impl<'py> IntoPyObject<'py> for &SynapseDuration { type Target = PyAny; type Output = Bound<'py, Self::Target>; @@ -49,7 +71,7 @@ impl<'py> IntoPyObject<'py> for &SynapseDuration { fn into_pyobject(self, py: Python<'py>) -> Result { let duration_module = duration_module(py)?; - let kwargs = [("microseconds", self.microseconds)].into_py_dict(py)?; + let kwargs = [("milliseconds", self.milliseconds)].into_py_dict(py)?; let duration_instance = duration_module.call_method("Duration", (), Some(&kwargs))?; Ok(duration_instance.into_bound()) } diff --git a/rust/src/events/constants.rs b/rust/src/events/constants.rs new file mode 100644 index 0000000000..811794d48c --- /dev/null +++ b/rust/src/events/constants.rs @@ -0,0 +1,143 @@ +//! Matrix Events +//! +//! Contains types and utilities for working with Matrix events. + +/// Maximum size of a PDU +pub const MAX_PDU_SIZE_BYTES: usize = 65_536; + +/// Event Types +pub mod event_type { + /// Event type: m.room.member + pub const M_ROOM_MEMBER: &str = "m.room.member"; + /// Event type: m.room.create + pub const M_ROOM_CREATE: &str = "m.room.create"; + /// Event type: m.room.join_rules + pub const M_ROOM_JOIN_RULES: &str = "m.room.join_rules"; + /// Event type: m.room.power_levels + pub const M_ROOM_POWER_LEVELS: &str = "m.room.power_levels"; + /// Event type: m.room.aliases + pub const M_ROOM_ALIASES: &str = "m.room.aliases"; + /// Event type: m.room.history_visibility + pub const M_ROOM_HISTORY_VISIBILITY: &str = "m.room.history_visibility"; + /// Event type: m.room.redaction + pub const M_ROOM_REDACTION: &str = "m.room.redaction"; +} + +/// Event Fields +pub mod event_field { + /// Event field: auth_events + pub const AUTH_EVENTS: &str = "auth_events"; + /// Event field: content + pub const CONTENT: &str = "content"; + /// Event field: depth + pub const DEPTH: &str = "depth"; + /// Event field: hashes + pub const HASHES: &str = "hashes"; + /// Event field: origin_server_ts + pub const ORIGIN_SERVER_TS: &str = "origin_server_ts"; + /// Event field: prev_events + pub const PREV_EVENTS: &str = "prev_events"; + /// Event field: room_id + pub const ROOM_ID: &str = "room_id"; + /// Event field: sender + pub const SENDER: &str = "sender"; + /// Event field: signatures + pub const SIGNATURES: &str = "signatures"; + /// Event field: state_key + pub const STATE_KEY: &str = "state_key"; + /// Event field: type + pub const TYPE: &str = "type"; + /// Event field: unsigned + pub const UNSIGNED: &str = "unsigned"; + /// Event field: event_id + pub const EVENT_ID: &str = "event_id"; + /// Event field: origin + pub const ORIGIN: &str = "origin"; + /// Event field: prev_state + pub const PREV_STATE: &str = "prev_state"; + /// Event field: membership + pub const MEMBERSHIP: &str = "membership"; + /// Event field: replaces_state + pub const REPLACES_STATE: &str = "replaces_state"; + /// Event field: msc4354_sticky + pub const MSC4354_STICKY: &str = "msc4354_sticky"; + // Event field: prev_state_events + pub const PREV_STATE_EVENTS: &str = "prev_state_events"; + // Event field: m.relates_to + pub const M_RELATES_TO: &str = "m.relates_to"; +} + +pub mod unsigned_field { + /// Unsigned field: age + pub const AGE: &str = "age"; + /// Unsigned field: age_ts + pub const AGE_TS: &str = "age_ts"; + /// Unsigned field: redacted_because + pub const REDACTED_BECAUSE: &str = "redacted_because"; +} + +/// Membership Event Fields +pub mod membership_field { + /// Membership event field: membership + pub const MEMBERSHIP: &str = "membership"; + /// Membership event field: join_authorised_via_users_server + pub const JOIN_AUTHORISED_VIA_USERS_SERVER: &str = "join_authorised_via_users_server"; + /// Membership event field: third_party_invite + pub const THIRD_PARTY_INVITE: &str = "third_party_invite"; + /// Membership event field: signed + pub const SIGNED: &str = "signed"; +} + +/// Create Event Fields +pub mod create_field { + /// Create event field: creator + pub const CREATOR: &str = "creator"; +} + +/// Join Rules Event Fields +pub mod join_rules_field { + /// Join Rules event field: join_rule + pub const JOIN_RULE: &str = "join_rule"; + /// Join Rules event field: allow + pub const ALLOW: &str = "allow"; +} + +/// Power Levels Event Fields +pub mod power_levels_field { + /// Power Levels event field: users + pub const USERS: &str = "users"; + /// Power Levels event field: users_default + pub const USERS_DEFAULT: &str = "users_default"; + /// Power Levels event field: events + pub const EVENTS: &str = "events"; + /// Power Levels event field: events_default + pub const EVENTS_DEFAULT: &str = "events_default"; + /// Power Levels event field: state_default + pub const STATE_DEFAULT: &str = "state_default"; + /// Power Levels event field: ban + pub const BAN: &str = "ban"; + /// Power Levels event field: kick + pub const KICK: &str = "kick"; + /// Power Levels event field: redact + pub const REDACT: &str = "redact"; + /// Power Levels event field: invite + pub const INVITE: &str = "invite"; +} + +/// Aliases Event Fields +pub mod aliases_field { + /// Aliases event field: aliases + pub const ALIASES: &str = "aliases"; +} + +/// History Visibility Event Fields +pub mod history_visibility_field { + /// History Visibility event field: history_visibility + pub const HISTORY_VISIBILITY: &str = "history_visibility"; +} + +/// Redaction Event Fields +pub mod redaction_field { + /// Redacts event field: redacts + pub const REDACTS: &str = "redacts"; +} diff --git a/rust/src/events/formats/mod.rs b/rust/src/events/formats/mod.rs new file mode 100644 index 0000000000..10b16a5436 --- /dev/null +++ b/rust/src/events/formats/mod.rs @@ -0,0 +1,268 @@ +/* + * This file is licensed under the Affero General Public License (AGPL) version 3. + * + * Copyright (C) 2026 Element Creations Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * See the GNU Affero General Public License for more details: + * . + * + */ + +//! Over-the-wire representations of Matrix events, parameterised by event +//! format version (i.e. the structure we parse the event JSON into). +//! +//! # Design +//! +//! The shape of an event's JSON varies with the room version, but there are +//! many common fields across all of them — `type`, `sender`, `content`, etc. +//! +//! We model this with a single [`FormattedEvent`] container that is generic +//! over the format-specific tail `E`. Serde `#[serde(flatten)]` merges the +//! common and specific halves into a single JSON object over the wire, while +//! keeping them as distinct structs in Rust. This lets version-agnostic code +//! (field getters, the `unsigned` accessor, …) read [`EventCommonFields`] +//! directly, and only the small amount of version-aware logic (auth-event +//! derivation, room-ID lookup, validation) needs to match on the format. +//! +//! The default `E` parameter is the type-erased [`EventFormatEnum`], which can +//! contain any known format. The [`FormattedEvent::into_general`] method allows +//! converting from a specific format to the general enum. +//! +//! The `signatures` and `unsigned` fields are kept distinct from the +//! common/specific as they allow mutation. When copying an event they need to +//! be deep-copied, but the common/specific fields (which are immutable) can be +//! shared. +//! +//! # Format variants +//! +//! Different room versions have different over-the-wire formats, which is +//! tracked by [`crate::room_versions::RoomVersion::event_format`] field. +//! +//! Each format struct owns only its version-specific fields and any +//! validation/derivation logic; the rest lives in [`EventCommonFields`]. The +//! [`EventFormatEnum`] sum type erases the generic parameter when an `Event` +//! needs to be stored alongside others of unknown room version. +//! +//! Note that any fields not recognised by the format-specific struct or by the +//! common fields are captured into [`EventCommonFields::other_fields`] and +//! round-tripped losslessly. This is useful for capturing optional fields that +//! don't need to be parsed up front. Generally, optional fields should be +//! handled via `other_fields`, as this saves space when they are not present. +//! +//! # Serialization and deserialization +//! +//! Deserializing a Matrix Event from JSON is done by specifying the expected +//! format struct (e.g. [`FormattedEvent`]), which enforces the +//! invariants of that format at parse time. This can then be converted into the +//! version-agnostic [`FormattedEvent`] with the +//! [`FormattedEvent::into_general`] method, which erases the format-specific +//! type but keeps the parsed fields intact. +//! +//! Serializing a [`FormattedEvent`] produces the correct Matrix Event JSON +//! shape for the format variant it contains. + +use std::{collections::HashMap, sync::Arc}; + +use anyhow::Error; +use serde::{Deserialize, Serialize}; + +use crate::{ + events::{json_object::JsonObject, signatures::Signatures, unsigned::Unsigned}, + json::AllowMissing, +}; + +mod v1; +mod v2v3; +mod v4; +mod vmsc4242; + +pub use v1::EventFormatV1; +pub use v2v3::EventFormatV2V3; +pub use v4::EventFormatV4; +pub use vmsc4242::EventFormatVMSC4242; + +/// A parsed Matrix event in its over-the-wire layout. +/// +/// `E` is the format-specific tail. Code that deserialises a known +/// room version picks a concrete `E` (e.g. `FormattedEvent`); +/// the default `Arc` is used once the event has been +/// boxed into the version-agnostic [`Event`](crate::events::Event) +/// pyclass. +/// +/// The `signatures` and `unsigned` fields are kept separate from the other +/// fields as they are mutable (and must be deep-copied if the event is cloned). +/// `common_fields` and `specific_fields` are both `#[serde(flatten)]`ed so that +/// the serialised JSON is a single flat object matching the Matrix spec. +#[derive(Serialize, Deserialize)] +pub struct FormattedEvent> { + /// The event's signatures. + /// + /// Kept separate from common/specific fields as this this is a mutable + /// field. + #[serde(default)] + pub signatures: Signatures, + + /// The event's unsigned data. + /// + /// Kept separate from common/specific fields as this this is a mutable + /// field. + #[serde(default)] + pub unsigned: Unsigned, + + /// The format-specific fields of the event. This is an immutable field. + #[serde(flatten)] + pub specific_fields: E, + + /// The fields common to all event formats. This is an immutable field. + #[serde(flatten)] + pub common_fields: Arc, +} + +impl FormattedEvent { + /// Creates a deep copy of this event, allowing the signatures and unsigned + /// to be mutated without affecting the original. + /// + /// The common and specific fields are shared between the copy and the + /// original, as they are immutable. + pub fn deep_copy(&self) -> FormattedEvent { + FormattedEvent { + signatures: self.signatures.deep_copy(), + unsigned: self.unsigned.deep_copy(), + // These fields can safely be shared among all of the copies as they + // are immutable (they're behind an Arc and so you can't get a + // mutable reference and they have no interior mutability) and these + // write protections extend into Python land as well (i.e. you can't + // accidentally do the wrong thing and mutate) + specific_fields: Arc::clone(&self.specific_fields), + common_fields: Arc::clone(&self.common_fields), + } + } + + pub fn validate(&self) -> Result<(), Error> { + match &*self.specific_fields { + EventFormatEnum::V1(format) => format.validate(&self.common_fields), + EventFormatEnum::V2V3(format) => format.validate(&self.common_fields), + EventFormatEnum::V4(format) => format.validate(&self.common_fields), + EventFormatEnum::VMSC4242(format) => format.validate(&self.common_fields), + } + } +} + +impl FormattedEvent +where + E: Into, +{ + /// Transforms a container of a specific event format into a container of + /// the enum type. + pub fn into_general(self) -> FormattedEvent { + let format: Arc = Arc::new(self.specific_fields.into()); + FormattedEvent { + signatures: self.signatures, + unsigned: self.unsigned, + specific_fields: format, + common_fields: self.common_fields, + } + } +} + +impl From> for FormattedEvent +where + E: Into, +{ + fn from(container: FormattedEvent) -> Self { + container.into_general() + } +} + +/// Fields that appear in every supported event format. +/// +/// Anything not recognised by the format-specific tail or by the fields +/// named here is captured into `other_fields` so events round-trip +/// losslessly even when they carry experimental or future-version +/// keys. +#[derive(Serialize, Deserialize)] +pub struct EventCommonFields { + pub content: JsonObject, + pub depth: i64, + pub hashes: HashMap, Box>, + pub origin_server_ts: i64, + pub sender: Box, + #[serde( + default, + with = "crate::json::allow_missing", + skip_serializing_if = "AllowMissing::is_absent" + )] + pub state_key: AllowMissing>, + + /// The `type` field of the event (we use `type_` in Rust to avoid the + /// reserved keyword). + #[serde(rename = "type")] + pub type_: Box, + + /// All other fields that are not required/parsed by the specific/common + /// fields. This allows us to round-trip events that contain extra fields. + /// + /// Generally, optional fields should be handled via `other_fields`, as this + /// saves space when they are not present. However, that does mean we don't + /// do any type-checking until they get used. + #[serde(flatten)] + pub other_fields: HashMap, serde_json::Value>, +} + +impl EventCommonFields { + /// Helper method to check if the event is a state event and return the + /// tuple of `(type, state_key)` if so. + fn type_state_key_tuple(&self) -> Option<(&str, &str)> { + if let AllowMissing::Some(state_key) = &self.state_key { + Some((&self.type_, state_key)) + } else { + None + } + } +} + +/// Type-erased version-specific tail. +/// +/// Used as the default `E` parameter on [`FormattedEvent`] so the +/// pyclass [`Event`](crate::events::Event) can hold any room version +/// behind a single type. The enum is `#[serde(untagged)]` because the +/// discriminator (the room version) lives outside the JSON; in +/// practice the only direction this is serialised in is `Event -> +/// JSON`, where the chosen variant alone determines the shape. +#[derive(Serialize)] +#[serde(untagged)] +pub enum EventFormatEnum { + V1(EventFormatV1), + V2V3(EventFormatV2V3), + V4(EventFormatV4), + VMSC4242(EventFormatVMSC4242), +} + +impl From for EventFormatEnum { + fn from(format: EventFormatV1) -> Self { + EventFormatEnum::V1(format) + } +} + +impl From for EventFormatEnum { + fn from(format: EventFormatV2V3) -> Self { + EventFormatEnum::V2V3(format) + } +} + +impl From for EventFormatEnum { + fn from(format: EventFormatV4) -> Self { + EventFormatEnum::V4(format) + } +} + +impl From for EventFormatEnum { + fn from(format: EventFormatVMSC4242) -> Self { + EventFormatEnum::VMSC4242(format) + } +} diff --git a/rust/src/events/formats/v1.rs b/rust/src/events/formats/v1.rs new file mode 100644 index 0000000000..09caabc597 --- /dev/null +++ b/rust/src/events/formats/v1.rs @@ -0,0 +1,54 @@ +/* + * This file is licensed under the Affero General Public License (AGPL) version 3. + * + * Copyright (C) 2026 Element Creations Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * See the GNU Affero General Public License for more details: + * . + * + */ + +//! Event format v1 (room versions 1 and 2). +//! +//! Distinguishing features compared to later formats: +//! +//! - `auth_events` and `prev_events` are `[event_id, hashes]` pairs +//! rather than flat lists of IDs. +//! - `event_id` is carried explicitly in the event JSON, rather than +//! being derived from the canonical-JSON hash. +//! - `room_id` is always present. + +use std::{collections::HashMap, sync::Arc}; + +use anyhow::Error; +use serde::{Deserialize, Serialize}; + +use crate::events::formats::EventCommonFields; + +/// Version-specific fields for room versions 1 and 2. +#[derive(Serialize, Deserialize)] +pub struct EventFormatV1 { + pub auth_events: Vec<(String, HashMap)>, + pub prev_events: Vec<(String, HashMap)>, + pub room_id: Arc, + pub event_id: Arc, +} + +impl EventFormatV1 { + pub fn validate(&self, _common_fields: &EventCommonFields) -> Result<(), Error> { + Ok(()) + } + + pub fn auth_event_ids(&self) -> Vec { + self.auth_events.iter().map(|(id, _)| id.clone()).collect() + } + + pub fn prev_event_ids(&self) -> Vec { + self.prev_events.iter().map(|(id, _)| id.clone()).collect() + } +} diff --git a/rust/src/events/formats/v2v3.rs b/rust/src/events/formats/v2v3.rs new file mode 100644 index 0000000000..a62c68a4e4 --- /dev/null +++ b/rust/src/events/formats/v2v3.rs @@ -0,0 +1,60 @@ +/* + * This file is licensed under the Affero General Public License (AGPL) version 3. + * + * Copyright (C) 2026 Element Creations Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * See the GNU Affero General Public License for more details: + * . + * + */ + +//! Event format v2/v3 (room versions 3 through 10). +//! +//! Differences from v1: +//! +//! - `auth_events` and `prev_events` are flat `Vec` lists of event IDs +//! rather than `[id, hashes]` pairs. +//! - `event_id` is no longer in the event JSON; it is derived from the +//! canonical-JSON hash at parse time. +//! +//! Note that the difference between event format v2 and v3 is purely in the +//! base64 encoding of the event ID, so the same struct can be used for both +//! formats. +//! +//! [`SimpleAuthPrevEvents`] is shared with [`v4`](super::v4) since the +//! flat-list encoding carries forward unchanged. + +use std::sync::Arc; + +use anyhow::{bail, Error}; +use serde::{Deserialize, Serialize}; + +use crate::events::formats::EventCommonFields; + +/// Version-specific fields for room versions 3-10. +#[derive(Serialize, Deserialize)] +pub struct EventFormatV2V3 { + pub room_id: Arc, + pub auth_events: Vec, + pub prev_events: Vec, +} + +impl EventFormatV2V3 { + pub fn validate(&self, common_fields: &EventCommonFields) -> Result<(), Error> { + // Ensure that we don't have an event_id set. + if common_fields.other_fields.contains_key("event_id") { + bail!("v2/v3 events must not have an explicit event_id"); + } + + Ok(()) + } + + pub fn auth_event_ids(&self) -> Vec { + self.auth_events.clone() + } +} diff --git a/rust/src/events/formats/v4.rs b/rust/src/events/formats/v4.rs new file mode 100644 index 0000000000..22a1a677f2 --- /dev/null +++ b/rust/src/events/formats/v4.rs @@ -0,0 +1,156 @@ +/* + * This file is licensed under the Affero General Public License (AGPL) version 3. + * + * Copyright (C) 2026 Element Creations Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * See the GNU Affero General Public License for more details: + * . + * + */ + +//! Event format v4 (room version 11). +//! +//! The main change from v2/v3 is that `room_id` becomes optional: an +//! `m.room.create` event no longer carries an explicit room ID, and the +//! room ID is instead *derived* from the create event's ID by replacing +//! the leading `$` with `!`. Conversely, every non-create event still +//! has an explicit `room_id`, and the create event is implicitly +//! included in the auth chain of every non-create event (so it does not +//! need to appear in `auth_events`). +//! +//! [`EventFormatV4::validate`] enforces these invariants at parse time; +//! [`EventFormatV4::room_id`] and [`EventFormatV4::auth_event_ids`] +//! expose the derived values to callers. + +use std::sync::Arc; + +use anyhow::{bail, ensure, Error}; +use serde::{Deserialize, Serialize}; + +use crate::{ + events::{constants::event_type::M_ROOM_CREATE, formats::EventCommonFields}, + json::AllowMissing, +}; + +/// Version-specific fields for room version 11. +#[derive(Serialize, Deserialize)] +pub struct EventFormatV4 { + #[serde( + default, + with = "crate::json::allow_missing", + skip_serializing_if = "AllowMissing::is_absent" + )] + pub room_id: AllowMissing>, + pub auth_events: Vec, + pub prev_events: Vec, +} + +impl EventFormatV4 { + pub fn validate(&self, common_fields: &EventCommonFields) -> Result<(), Error> { + // Ensure that we don't have an event_id set. + if common_fields.other_fields.contains_key("event_id") { + bail!("v4 events must not have an explicit event_id"); + } + + Ok(()) + } + + pub fn room_id( + &self, + event_id: &str, + common_fields: &EventCommonFields, + ) -> Result, Error> { + get_room_id_for_optional_room_id(self.room_id.as_ref_opt(), event_id, common_fields) + } + + pub fn auth_event_ids(&self, common_fields: &EventCommonFields) -> Result, Error> { + let is_create = common_fields.type_state_key_tuple() == Some((M_ROOM_CREATE, "")); + + if is_create { + // The create event itself has no implicit auth events. + return Ok(self.auth_events.clone()); + } + + // For non-create events, the create event is implicitly part of + // the auth chain. Derive its event ID from the room ID by + // replacing the leading '!' with '$'. + let room_id = self + .room_id + .as_deref_opt() + .ok_or_else(|| anyhow::anyhow!("non-create event has no room_id"))?; + + let mut create_event_id = String::with_capacity(room_id.len()); + create_event_id.push('$'); + create_event_id.push_str(&room_id[1..]); + + ensure!( + !self.auth_events.contains(&create_event_id), + "The create event ID is implicitly part of the auth chain and should not be explicitly be in the auth_events" + ); + + let mut auth_events = self.auth_events.clone(); + auth_events.push(create_event_id); + Ok(auth_events) + } +} + +/// Validation helper for v4+ events that can have an optional room ID. +/// +/// Returns the validated room ID (which will be `None` for create events). +pub fn validate_optional_room_id( + room_id: Option<&Arc>, + common_fields: &'_ EventCommonFields, +) -> Result>, Error> { + let is_create_event = common_fields.type_state_key_tuple() == Some((M_ROOM_CREATE, "")); + + match (is_create_event, room_id) { + // For non-create events, room_id must be present. + (false, None) => bail!("non-create event must have a room ID"), + (false, Some(room_id)) => { + // We later derive the create event ID from the room ID by replacing + // the leading '!' with '$', so we require the room ID to start with + // '!'. + ensure!( + room_id.starts_with("!"), + "room_id must start with '!': {}", + room_id + ); + + Ok(Some(Arc::clone(room_id))) + } + + // For create events, room_id must be absent. + (true, Some(_)) => bail!("create event must not have a room ID"), + (true, None) => Ok(None), + } +} + +/// Room ID derivation helper for v4+ events, which can have an optional room +/// ID. +pub fn get_room_id_for_optional_room_id( + room_id: Option<&Arc>, + event_id: &str, + common_fields: &EventCommonFields, +) -> Result, Error> { + match validate_optional_room_id(room_id, common_fields)? { + Some(room_id) => Ok(room_id), + None => { + // This is the create event, where the room ID is derived from the + // event ID by replacing the leading '$' with '!'. + if !event_id.starts_with('$') { + bail!("Create event ID does not start with '$': {}", event_id); + } + + let mut room_id = String::with_capacity(event_id.len()); + room_id.push('!'); + room_id.push_str(&event_id[1..]); + + Ok(room_id.into()) + } + } +} diff --git a/rust/src/events/formats/vmsc4242.rs b/rust/src/events/formats/vmsc4242.rs new file mode 100644 index 0000000000..45d9a25068 --- /dev/null +++ b/rust/src/events/formats/vmsc4242.rs @@ -0,0 +1,94 @@ +/* + * This file is licensed under the Affero General Public License (AGPL) version 3. + * + * Copyright (C) 2026 Element Creations Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * See the GNU Affero General Public License for more details: + * . + * + */ + +//! Event format for [MSC4242] (prev-state events). +//! +//! Adds `prev_state_events` and removes `auth_events` from the v4 layout +//! — auth chains are derived implicitly from the state DAG rather than +//! carried on each event. `room_id`, `prev_events` and the create-event +//! derivation rules carry over unchanged from v4 and are delegated to +//! [`EventFormatV4::validate`] via a shim that supplies an empty +//! explicit auth list. +//! +//! [MSC4242]: https://github.com/matrix-org/matrix-spec-proposals/pull/4242 + +use std::sync::Arc; + +use anyhow::bail; +use anyhow::Error; +use pyo3::exceptions::PyAssertionError; +use pyo3::PyResult; +use serde::{Deserialize, Serialize}; + +use crate::events::constants::event_type::M_ROOM_CREATE; +use crate::events::formats::v4::get_room_id_for_optional_room_id; +use crate::events::formats::EventCommonFields; +use crate::events::Event; +use crate::json::AllowMissing; + +/// Version-specific fields for the MSC4242 event format. +#[derive(Serialize, Deserialize)] +pub struct EventFormatVMSC4242 { + pub prev_state_events: Vec, + pub prev_events: Vec, + #[serde( + default, + with = "crate::json::allow_missing", + skip_serializing_if = "AllowMissing::is_absent" + )] + pub room_id: AllowMissing>, +} + +impl EventFormatVMSC4242 { + pub fn validate(&self, common_fields: &EventCommonFields) -> Result<(), Error> { + // Ensure that we don't have any `auth_events` or `event_id` fields + // set. + if common_fields.other_fields.contains_key("auth_events") { + bail!("MSC4242 events must not have explicit auth_events"); + } + if common_fields.other_fields.contains_key("event_id") { + bail!("MSC4242 events must not have an explicit event_id"); + } + + Ok(()) + } + + pub fn room_id( + &self, + event_id: &str, + common_fields: &EventCommonFields, + ) -> Result, Error> { + get_room_id_for_optional_room_id(self.room_id.as_ref_opt(), event_id, common_fields) + } + + pub fn auth_event_ids(&self, event: &Event) -> PyResult> { + // In the MSC4242 format, the auth events are calculated and stored in + // internal metadata. + let auth_event_ids = event.internal_metadata.get_calculated_auth_event_ids()?; + + // Catches cases where we accidentally call auth_event_ids() prior to calculating what they + // actually are. The exception being the m.room.create event which has no auth events. + if event.parsed_event.common_fields.type_state_key_tuple() != Some((M_ROOM_CREATE, "")) + && auth_event_ids.is_empty() + { + return Err(PyAssertionError::new_err(format!( + "auth_event_ids has not been calculated for event_id='{}'. This is most likely a Synapse programming error.", + event.event_id + ))); + } + + Ok(auth_event_ids) + } +} diff --git a/rust/src/events/internal_metadata.rs b/rust/src/events/internal_metadata.rs index 4084b8442d..93f13f655e 100644 --- a/rust/src/events/internal_metadata.rs +++ b/rust/src/events/internal_metadata.rs @@ -510,7 +510,7 @@ fn attr_err(val: Option, name: &str) -> PyResult { #[pymethods] impl EventInternalMetadata { #[new] - fn new(dict: &Bound<'_, PyDict>) -> PyResult { + pub fn new(dict: &Bound<'_, PyDict>) -> PyResult { let mut data = Vec::with_capacity(dict.len()); for (key, value) in dict.iter() { @@ -536,7 +536,10 @@ impl EventInternalMetadata { }) } - fn copy(&self) -> PyResult { + /// Create a deep copy of this `EventInternalMetadata` to allow modification + /// without affecting other references to the same metadata. This is needed + /// when we clone an event. + pub fn deep_copy(&self) -> PyResult { let guard = self.read_inner()?; Ok(EventInternalMetadata { inner: Arc::new(RwLock::new(guard.clone())), @@ -723,7 +726,7 @@ impl EventInternalMetadata { attr_err(self.read_inner()?.get_redacted(), "redacted") } #[setter] - fn set_redacted(&self, obj: bool) -> PyResult<()> { + pub fn set_redacted(&self, obj: bool) -> PyResult<()> { self.write_inner()?.set_redacted(obj); Ok(()) } @@ -742,7 +745,7 @@ impl EventInternalMetadata { /// The calculated auth event IDs, if it was set when the event was created. #[getter] - fn get_calculated_auth_event_ids(&self) -> PyResult> { + pub fn get_calculated_auth_event_ids(&self) -> PyResult> { let guard = self.read_inner()?; attr_err( guard.get_calculated_auth_event_ids().cloned(), diff --git a/rust/src/events/json_object.rs b/rust/src/events/json_object.rs index 0ab54e8dc5..bb4877d482 100644 --- a/rust/src/events/json_object.rs +++ b/rust/src/events/json_object.rs @@ -193,6 +193,12 @@ impl JsonObject { } } +impl JsonObject { + pub fn get_field(&self, key: &str) -> Option<&serde_json::Value> { + self.object.get(key) + } +} + /// Helper class returned by `JsonObject.keys()` to act as a view into the keys /// of the object. /// diff --git a/rust/src/events/mod.rs b/rust/src/events/mod.rs index e60cdb7078..ad3f61e2fd 100644 --- a/rust/src/events/mod.rs +++ b/rust/src/events/mod.rs @@ -18,18 +18,77 @@ * */ -//! Classes for representing Events. +//! Classes for representing Matrix events. +//! +//! # Overview +//! +//! A Matrix event has a JSON shape that varies by *room version*. The +//! per-room-version shape is captured in the [`formats`] module, where +//! [`FormattedEvent`] is a generic container parametrised by the +//! room-version-specific portion (`EventFormatV1`, `EventFormatV2V3`, +//! `EventFormatV4`, `EventFormatVMSC4242`). See [`formats`] for the layout +//! of the over-the-wire JSON and how the room-version-agnostic fields are +//! split from the version-specific ones. +//! +//! [`Event`] is the `pyclass` exposed to Python. It bundles a fully parsed +//! [`FormattedEvent`] (with the version-specific part type-erased as +//! [`formats::EventFormatEnum`]) together with the pieces of state that +//! live alongside the event JSON in Synapse: +//! +//! - `event_id` — either taken from the event JSON (format v1) or derived +//! from the canonical-JSON hash (v2+); computed once at construction +//! time and cached. +//! - `room_version` — a `'static` reference into the global room-version +//! table, used to drive format-dependent behaviour (e.g. where the +//! `redacts` field lives, which redaction rules apply). +//! - `internal_metadata` — Synapse-internal flags that are *not* part of +//! the federated event (outlier status, soft-failure, stream positions, +//! …). These come from a separate dict at construction time. +//! - `rejected_reason` — `None` for accepted events; otherwise a short +//! string describing why auth rejected the event. +//! +use std::sync::Arc; + +use anyhow::Error; use pyo3::{ - types::{PyAnyMethods, PyMapping, PyModule, PyModuleMethods}, - wrap_pyfunction, Bound, PyResult, Python, + exceptions::{PyAttributeError, PyKeyError, PyValueError}, + pyclass, pyfunction, pymethods, + types::{PyAnyMethods, PyDict, PyDictMethods, PyList, PyMapping, PyModule, PyModuleMethods}, + wrap_pyfunction, Bound, IntoPyObject, PyAny, PyResult, Python, +}; +use pythonize::{depythonize, pythonize}; + +use crate::events::{ + formats::{ + EventFormatEnum, EventFormatV1, EventFormatV2V3, EventFormatV4, EventFormatVMSC4242, + FormattedEvent, + }, + signatures::Signatures, + unsigned::Unsigned, + utils::redact, +}; +use crate::{ + duration::SynapseDuration, + events::{ + constants::event_field::{HASHES, MSC4354_STICKY, SIGNATURES, UNSIGNED}, + constants::membership_field::MEMBERSHIP, + constants::redaction_field::REDACTS, + constants::unsigned_field::{AGE, AGE_TS, REDACTED_BECAUSE}, + internal_metadata::EventInternalMetadata, + utils::calculate_event_id, + }, + room_versions::{EventFormatVersions, RoomVersion}, }; +pub mod constants; pub mod filter; -mod internal_metadata; -mod json_object; +pub mod formats; +pub mod internal_metadata; +pub mod json_object; pub mod signatures; pub mod unsigned; +pub mod utils; use json_object::JsonObject; @@ -39,14 +98,17 @@ pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> PyMapping::register::(py)?; let child_module = PyModule::new(py, "events")?; - child_module.add_class::()?; - child_module.add_class::()?; - child_module.add_class::()?; + child_module.add_class::()?; + child_module.add_class::()?; + child_module.add_class::()?; child_module.add_class::()?; child_module.add_class::()?; child_module.add_class::()?; child_module.add_class::()?; + child_module.add_class::()?; child_module.add_function(wrap_pyfunction!(filter::event_visible_to_server_py, m)?)?; + child_module.add_function(wrap_pyfunction!(redact_event_py, m)?)?; + child_module.add_function(wrap_pyfunction!(redact_event_dict, m)?)?; m.add_submodule(&child_module)?; @@ -58,3 +120,813 @@ pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> Ok(()) } + +/// The Rust-side representation of a Matrix event, exposed to Python. +/// +/// Wraps a parsed [`FormattedEvent`] together with the per-event state +/// that Synapse tracks outside the event JSON (event ID, internal +/// metadata, rejection reason, and a reference to the room version that +/// produced this event). See the module-level docs for the high-level +/// design. +#[pyclass(frozen, weakref)] +pub struct Event { + /// The parsed event JSON. + parsed_event: FormattedEvent, + + /// The event ID. For format v1 this is read directly from the JSON; + /// for v2+ it is computed from the canonical-JSON hash at + /// construction time and cached here. + event_id: Arc, + + /// The calculated room ID. + /// + /// For some room versions, this may be derived, e.g. for create events in + /// v4. + room_id: Arc, + + /// Synapse-internal per-event state that lives outside the federated + /// JSON (e.g. outlier flag, soft-failure, stream positions). + #[pyo3(get)] + internal_metadata: EventInternalMetadata, + + /// The room version this event was parsed for. + #[pyo3(get)] + room_version: &'static RoomVersion, + + /// `None` for accepted events; otherwise a short reason set by auth + /// when the event was rejected. + rejected_reason: Option>, +} + +#[pymethods] +impl Event { + #[new] + fn new_from_py<'a, 'py>( + py: Python<'py>, + event_dict: &'a Bound<'py, PyAny>, + room_version: &'a Bound<'py, PyAny>, + internal_metadata_dict: &'a Bound<'py, PyDict>, + rejected_reason: Option, + ) -> PyResult { + let room_version: &RoomVersion = { + let r = room_version.getattr("identifier")?; + let room_version_str = r.extract::<&str>()?; + room_version_str + .parse() + .map_err(|e| PyValueError::new_err(format!("Unsupported room version: {}", e)))? + }; + + let rejected_reason = rejected_reason.map(String::into_boxed_str); + + // Parse the event dict into a FormattedEvent, converting any failures to + // a `ValueError`. + let parsed_event = depythonize_event_dict(room_version, event_dict).map_err(|err| { + let new_err = PyValueError::new_err(format!( + "Failed to parse event for room version {}", + room_version + )); + new_err.set_cause(py, Some(err)); + new_err + })?; + + let internal_metadata = EventInternalMetadata::new(internal_metadata_dict)?; + + let event_id = match &*parsed_event.specific_fields { + EventFormatEnum::V1(format) => { + // V1/V2 events have the event_id in the event dict. + Arc::clone(&format.event_id) + } + _ => { + // Calculate the event ID by hashing the event JSON. This can + // fail if the event can't be serialized to canonical JSON (e.g. + // having out-of-range integers), which we report as + // `ValueError` as it indicates the event is invalid. + let event_value = serde_json::to_value(&parsed_event).map_err(|err| { + PyValueError::new_err(format!("Failed to serialize event: {}", err)) + })?; + calculate_event_id(&event_value, room_version) + .map_err(|err| { + PyValueError::new_err(format!("Failed to calculate event_id: {}", err)) + })? + .into() + } + }; + + let room_id = match &*parsed_event.specific_fields { + EventFormatEnum::V1(format) => Arc::clone(&format.room_id), + EventFormatEnum::V2V3(format) => Arc::clone(&format.room_id), + EventFormatEnum::V4(format) => format + .room_id(&event_id, &parsed_event.common_fields) + .map_err(|err| { + PyValueError::new_err(format!( + "Failed to calculate room_id for event {}: {}", + event_id, err + )) + })?, + EventFormatEnum::VMSC4242(format) => format + .room_id(&event_id, &parsed_event.common_fields) + .map_err(|err| { + PyValueError::new_err(format!( + "Failed to calculate room_id for event {}: {}", + event_id, err + )) + })?, + }; + + Ok(Self { + parsed_event, + + event_id, + room_id, + room_version, + rejected_reason, + internal_metadata, + }) + } + + /// Convert the event to a dictionary suitable for serialisation. + fn get_dict<'py>(&self, py: Python<'py>) -> PyResult> { + Ok(pythonize(py, &self.parsed_event)?) + } + + /// Like `get_dict`, but serializes `unsigned` in a form suitable for + /// persistence. + fn get_dict_for_persistence<'py>(&self, py: Python<'py>) -> PyResult> { + let binding = self.get_dict(py)?; + let dict = binding.cast::()?; + + dict.set_item("unsigned", self.parsed_event.unsigned.for_persistence(py)?)?; + + Ok(binding) + } + + /// Like [`Event::get_dict`], but serializes `unsigned` in a form suitable + /// for sending over federation. + #[pyo3(signature = (time_now = None))] + fn get_pdu_json<'py>( + &self, + py: Python<'py>, + time_now: Option, + ) -> PyResult> { + let obj = self.get_dict(py)?; + let dict = obj.cast::()?; + + // Get or create the unsigned dict + if let Ok(Some(unsigned)) = dict.get_item(UNSIGNED) { + let unsigned = unsigned.cast::()?; + + if let Some(time_now) = time_now { + if let Ok(Some(age_ts)) = unsigned.get_item(AGE_TS) { + let age = time_now - age_ts.extract::()?; + unsigned.set_item(AGE, age)?; + unsigned.del_item(AGE_TS)?; + } + } + + // This may be a frozen event + unsigned.del_item(REDACTED_BECAUSE).ok(); + } + + Ok(obj) + } + + /// Like [`Event::get_dict`], except strips fields like `signatures`, + /// `hashes` and `unsigned` so that the result is suitable as a template for + /// creating new events. Used in make_{join,leave,knock} flows. + fn get_templated_pdu_json<'py>(&self, py: Python<'py>) -> PyResult> { + // Use get_dict but strip signatures, unsigned, and hashes — the + // joining/leaving/knocking server will re-sign and recalculate hashes. + let obj = self.get_dict(py)?; + let dict = obj.cast::()?; + dict.del_item(SIGNATURES).ok(); + dict.del_item(UNSIGNED).ok(); + dict.del_item(HASHES).ok(); + + Ok(obj) + } + + #[getter] + fn rejected_reason(&self) -> Option<&str> { + self.rejected_reason.as_deref() + } + + /// Returns the list of prev event IDs. The order matches the order + /// specified in the event, though there is no meaning to it. + fn prev_event_ids(&self) -> Vec { + match &*self.parsed_event.specific_fields { + EventFormatEnum::V1(format) => format.prev_event_ids(), + EventFormatEnum::V2V3(format) => format.prev_events.clone(), + EventFormatEnum::V4(format) => format.prev_events.clone(), + EventFormatEnum::VMSC4242(format) => format.prev_events.clone(), + } + } + + /// Returns the list of auth event IDs. The order matches the order + /// specified in the event, though there is no meaning to it. + fn auth_event_ids(&self) -> PyResult> { + match &*self.parsed_event.specific_fields { + EventFormatEnum::V1(format) => Ok(format.auth_event_ids()), + EventFormatEnum::V2V3(format) => Ok(format.auth_event_ids()), + EventFormatEnum::V4(format) => { + Ok(format.auth_event_ids(&self.parsed_event.common_fields)?) + } + EventFormatEnum::VMSC4242(format) => Ok(format.auth_event_ids(self)?), + } + } + + #[getter] + fn membership<'py>(&self, py: Python<'py>) -> PyResult> { + let content = self.content(); + let value = content.get_field(MEMBERSHIP); + match value { + Some(value) => Ok(pythonize(py, value)?), + None => Err(PyKeyError::new_err(MEMBERSHIP)), + } + } + + fn is_state(&self) -> bool { + self.parsed_event.common_fields.state_key.is_some() + } + + /// Get the state key of this event, or None if it's not a state event. + fn get_state_key(&self) -> Option<&str> { + self.parsed_event.common_fields.state_key.as_deref_opt() + } + + /// The EventFormatVersion implemented by this event. + #[getter] + fn format_version(&self) -> i32 { + self.room_version.event_format + } + + /// Returns a deep copy of this object, such that modifying the copy will + /// not affect the original. + fn deep_copy(&self) -> PyResult { + let internal_metadata = self.internal_metadata.deep_copy()?; + + let new_event = Event { + parsed_event: self.parsed_event.deep_copy(), + internal_metadata, + room_version: self.room_version, + rejected_reason: self.rejected_reason.clone(), + event_id: self.event_id.clone(), + room_id: self.room_id.clone(), + }; + Ok(new_event) + } + + /// If this event has the `msc4354_sticky` top-level field, returns a + /// `SynapseDuration` representing the sticky duration. Otherwise returns + /// `None`. + fn sticky_duration(&self) -> Option { + const MAX_DURATION: SynapseDuration = SynapseDuration::from_hours(1); + + let sticky_obj = self + .parsed_event + .common_fields + .other_fields + .get(MSC4354_STICKY); + + let sticky_obj = match sticky_obj { + Some(serde_json::Value::Object(obj)) => obj, + _ => return None, + }; + + // Check for a valid duration field. The MSC requires `duration_ms` to + // be a non-negative integer. If it's missing or invalid, we treat the + // event as non-sticky by returning `None`. + let duration_ms = sticky_obj.get("duration_ms")?.as_u64()?; + + let duration = SynapseDuration::from_milliseconds(duration_ms); + + let duration = std::cmp::min(duration, MAX_DURATION); + + Some(duration) + } + + // Below are the methods for interacting with the event as a mapping. + // + // These are rarely used, so we take the easy approach of re-serializing the + // event to a Python dict and then delegating to the standard dict methods. + // We can't remove these functions as third-party modules may rely on them. + + fn __contains__<'py>(&self, py: Python<'py>, key: &str) -> PyResult { + let dict = self.get_dict(py)?; + dict.contains(key) + } + + /// This is deprecated in favor of `get`, but we still need to support it + /// for backwards compatibility with modules. This is therefore not exposed + /// in the type stubs. + fn __getitem__<'py>(&self, py: Python<'py>, key: &str) -> PyResult> { + let dict = self.get_dict(py)?; + if dict.contains(key)? { + dict.get_item(key) + } else { + Err(PyKeyError::new_err(key.to_owned())) + } + } + + #[pyo3(signature = (key, default=None))] + fn get<'py>( + &self, + py: Python<'py>, + key: &str, + default: Option>, + ) -> PyResult> { + let dict = self.get_dict(py)?; + if dict.contains(key)? { + dict.get_item(key) + } else { + Ok(default.into_pyobject(py)?) + } + } + + fn items<'py>(&self, py: Python<'py>) -> PyResult> { + let dict = self.get_dict(py)?; + let dict = dict.cast::()?; + Ok(dict.items()) + } + + fn keys<'py>(&self, py: Python<'py>) -> PyResult> { + let dict = self.get_dict(py)?; + let dict = dict.cast::()?; + Ok(dict.keys()) + } + + // Below are the getters for the top-level fields on Matrix events. + + #[getter] + fn event_id(&self) -> &str { + &self.event_id + } + + #[getter] + fn room_id(&self) -> &str { + &self.room_id + } + + #[getter] + fn signatures(&self) -> Signatures { + self.parsed_event.signatures.clone() + } + + #[getter] + fn content(&self) -> JsonObject { + self.parsed_event.common_fields.content.clone() + } + + #[getter] + fn depth(&self) -> i64 { + self.parsed_event.common_fields.depth + } + + #[getter] + fn hashes<'py>(&self, py: Python<'py>) -> PyResult> { + let dict = PyDict::new(py); + for (key, value) in &self.parsed_event.common_fields.hashes { + dict.set_item(&**key, &**value)?; + } + Ok(dict) + } + + #[getter] + fn origin_server_ts(&self) -> i64 { + self.parsed_event.common_fields.origin_server_ts + } + + #[getter] + fn sender(&self) -> &str { + &self.parsed_event.common_fields.sender + } + + /// Deprecated alias for `sender`. Kept for backwards compatibility with + /// modules and tests that still read `event.user_id`. This is therefore not + /// exposed in the type stubs. + #[getter] + fn user_id(&self) -> &str { + &self.parsed_event.common_fields.sender + } + + #[getter(state_key)] + // We can't call this `state_key` because that would generate a + // `get_state_key` method which already exists. + fn state_key_attr(&self) -> PyResult<&str> { + let Some(state_key) = self.parsed_event.common_fields.state_key.as_deref_opt() else { + return Err(PyAttributeError::new_err("state_key")); + }; + Ok(state_key) + } + + #[getter] + fn r#type(&self) -> &str { + &self.parsed_event.common_fields.type_ + } + + #[getter] + fn unsigned(&self) -> Unsigned { + self.parsed_event.unsigned.clone() + } + + #[getter] + fn prev_state_events(&self) -> PyResult> { + // `prev_state_events` should only be called after validating the event + // is of a format that supports MSC4242, so we return an AttributeError + // for formats that don't support it. + match &*self.parsed_event.specific_fields { + EventFormatEnum::V1(_) | EventFormatEnum::V2V3(_) | EventFormatEnum::V4(_) => { + Err(PyAttributeError::new_err("prev_state_events")) + } + EventFormatEnum::VMSC4242(format) => Ok(format.prev_state_events.clone()), + } + } + + #[getter] + fn redacts<'py>(&self, py: Python<'py>) -> PyResult>> { + let common = &self.parsed_event.common_fields; + let value = if self.room_version.updated_redaction_rules { + common.content.get_field(REDACTS) + } else { + common.other_fields.get(REDACTS) + }; + value + .map(|v| pythonize(py, v).map_err(Into::into)) + .transpose() + } +} + +fn depythonize_event_dict( + room_version: &RoomVersion, + event_dict: &Bound<'_, PyAny>, +) -> PyResult { + let formatted_event: FormattedEvent = match room_version.event_format { + EventFormatVersions::ROOM_V1_V2 => { + let event_format: FormattedEvent = depythonize(event_dict)?; + + event_format.into() + } + EventFormatVersions::ROOM_V3 | EventFormatVersions::ROOM_V4_PLUS => { + let event_format: FormattedEvent = depythonize(event_dict)?; + event_format.into() + } + EventFormatVersions::ROOM_V11_HYDRA_PLUS => { + let event_format: FormattedEvent = depythonize(event_dict)?; + event_format.into() + } + EventFormatVersions::ROOM_VMSC4242 => { + let event_format: FormattedEvent = depythonize(event_dict)?; + event_format.into() + } + _ => { + return Err(PyValueError::new_err(format!( + "Unsupported room version: {}", + room_version + ))) + } + }; + + formatted_event.validate()?; + + Ok(formatted_event) +} + +/// Converts an event dict as [`serde_json::Value`] into a [`FormattedEvent`]. +fn event_dict_from_json_value( + room_version: &RoomVersion, + event_dict: serde_json::Value, +) -> Result { + let formatted_event: FormattedEvent = match room_version.event_format { + EventFormatVersions::ROOM_V1_V2 => { + let event_format: FormattedEvent = serde_json::from_value(event_dict)?; + + event_format.into() + } + EventFormatVersions::ROOM_V3 | EventFormatVersions::ROOM_V4_PLUS => { + let event_format: FormattedEvent = serde_json::from_value(event_dict)?; + event_format.into() + } + EventFormatVersions::ROOM_V11_HYDRA_PLUS => { + let event_format: FormattedEvent = serde_json::from_value(event_dict)?; + event_format.into() + } + EventFormatVersions::ROOM_VMSC4242 => { + let event_format: FormattedEvent = + serde_json::from_value(event_dict)?; + event_format.into() + } + _ => { + return Err(anyhow::anyhow!( + "Unsupported room version: {}", + room_version + )); + } + }; + + formatted_event.validate()?; + + Ok(formatted_event) +} + +/// Returns a pruned version of the given event, which removes all keys we don't +/// know about or think could potentially be dodgy. +/// +/// Returns the redacted event as a dict. +#[pyfunction(name = "redact_event")] +fn redact_event_py(event: &Event) -> PyResult { + let event_value = serde_json::to_value(&event.parsed_event).map_err(|err| { + PyValueError::new_err(format!("Failed to serialize event for redaction: {}", err)) + })?; + + let redacted_value = redact(&event_value, event.room_version)?; + let redacted_formatted_event = event_dict_from_json_value(event.room_version, redacted_value) + .map_err(|err| { + PyValueError::new_err(format!("Failed to deserialize redacted event: {}", err)) + })?; + + let redacted_event = Event { + parsed_event: redacted_formatted_event, + event_id: Arc::clone(&event.event_id), + room_id: Arc::clone(&event.room_id), + room_version: event.room_version, + rejected_reason: event.rejected_reason.clone(), + internal_metadata: event.internal_metadata.deep_copy()?, + }; + + // Mark event as redacted + redacted_event.internal_metadata.set_redacted(true)?; + + Ok(redacted_event) +} + +/// Returns a pruned version of the given event dict, which removes all keys we +/// don't know about or think could potentially be dodgy. +/// +/// Returns the redacted event as a dict. +#[pyfunction(name = "redact_event_dict")] +fn redact_event_dict<'py>( + py: Python<'py>, + room_version: &RoomVersion, + event_dict: &'py Bound<'py, PyAny>, +) -> PyResult> { + let event_value = depythonize(event_dict)?; + + let redacted = redact(&event_value, room_version)?; + + let redacted_py = pythonize(py, &redacted)?; + + Ok(redacted_py) +} + +#[cfg(test)] +mod tests { + + use super::*; + use crate::events::{ + constants::event_type::M_ROOM_CREATE, + formats::{EventFormatV1, EventFormatV2V3, EventFormatV4, EventFormatVMSC4242}, + }; + + #[test] + fn test_basic_v3_roundtrip() { + let json = r#"{"auth_events":[],"prev_events":[],"type":"m.room.create","sender":"@anon-20260225_142731-20:localhost:8800","content":{"room_version":"10","creator":"@anon-20260225_142731-20:localhost:8800"},"depth":1,"room_id":"!qVoJSympOqdUQRUfiC:localhost:8800","state_key":"","origin_server_ts":1772029657149,"hashes":{"sha256":"RIDkn4CrExGMOfRZlHl//1weAro5QC/q2D76YcyAUqk"},"signatures":{"localhost:8800":{"ed25519:a_GMSl":"GU7WmvI2Kd5kLrXKrWpRbUfEiVKGgH0sxQNEpBMMvgF3QhHN25AubVMmIClht5r/c+Iihb1xsq1j5Sw+RGfiDg"}},"unsigned":{"age_ts":1772029657149}}"#; + let event_value: serde_json::Value = serde_json::from_str(json).unwrap(); + + let event: FormattedEvent = serde_json::from_str(json).unwrap(); + let parsed_value = serde_json::to_value(&event).unwrap(); + + // Check a couple of fields are as expected as a sanity check. + assert_eq!(&*event.common_fields.type_, M_ROOM_CREATE); + assert_eq!( + &*event.specific_fields.room_id, + "!qVoJSympOqdUQRUfiC:localhost:8800" + ); + + assert_eq!(event_value, parsed_value); + } + + #[test] + fn test_room_id_for_create_event_format_v4() { + let json = r#"{"auth_events":[],"prev_events":[],"type":"m.room.create","sender":"@erikj:jki.re","content":{"room_version":"12","predecessor":{"room_id":"!VuNGkDTdbMOOxSmuDa:jki.re"}},"depth":1,"state_key":"","origin_server_ts":1775568141481,"hashes":{"sha256":"qBX+glsKvogXFrvsEN0eh13pO2kpuE6o/b4yREPtOqw"},"signatures":{"jki.re":{"ed25519:auto":"n/4gHQRagk3r1r24L/7a+oaMMf9cysVfQRYdjpDZcf4ppkVym33rhTW18Vy4zMa1L5nsWLkxsBvbrRRDYUOhBQ"}},"unsigned":{"age_ts":1775568141481}}"#; + let event_value: serde_json::Value = serde_json::from_str(json).unwrap(); + + let event: FormattedEvent = serde_json::from_str(json).unwrap(); + + let event_id = calculate_event_id(&event_value, &RoomVersion::V12).unwrap(); + + assert_eq!( + &*event + .specific_fields + .room_id(&event_id, &event.common_fields) + .unwrap(), + "!BeXKh925K_M46DwsuJFR0EyBpE1P7CFUDGuWW4xw55Y" + ); + } + + #[test] + fn test_basic_v1_roundtrip() { + let json = r#"{"auth_events":[["$auth1:localhost",{"sha256":"abc"}],["$auth2:localhost",{"sha256":"def"}]],"prev_events":[["$prev1:localhost",{"sha256":"ghi"}]],"type":"m.room.message","sender":"@user:localhost","content":{"body":"hello","msgtype":"m.text"},"depth":5,"room_id":"!room:localhost","event_id":"$event1:localhost","origin_server_ts":1234567890,"hashes":{"sha256":"base64hash"},"signatures":{"localhost":{"ed25519:key":"sig"}},"unsigned":{}}"#; + let event_value: serde_json::Value = serde_json::from_str(json).unwrap(); + + let event: FormattedEvent = serde_json::from_str(json).unwrap(); + let parsed_value = serde_json::to_value(&event).unwrap(); + + // Check a few fields are as expected as a sanity check. + assert_eq!(&*event.common_fields.type_, "m.room.message"); + assert!(event.common_fields.state_key.is_absent()); + assert_eq!(&*event.specific_fields.room_id, "!room:localhost"); + assert_eq!(&*event.specific_fields.event_id, "$event1:localhost"); + + // Check auth/prev event extraction + let auth_ids = event.specific_fields.auth_event_ids(); + assert_eq!(auth_ids, vec!["$auth1:localhost", "$auth2:localhost"]); + + let prev_ids = event.specific_fields.prev_event_ids(); + assert_eq!(prev_ids, vec!["$prev1:localhost"]); + + assert_eq!(event_value, parsed_value); + } + + #[test] + fn test_basic_v4_roundtrip_with_room_id() { + // A regular (non-create) V4 event has an explicit room_id. + let json = r#"{"auth_events":["$auth1","$auth2"],"prev_events":["$prev1"],"type":"m.room.message","sender":"@user:localhost","content":{"body":"hello","msgtype":"m.text"},"depth":5,"room_id":"!room:localhost","origin_server_ts":1234567890,"hashes":{"sha256":"base64hash"},"signatures":{"localhost":{"ed25519:key":"sig"}},"unsigned":{}}"#; + let event_value: serde_json::Value = serde_json::from_str(json).unwrap(); + + let event: FormattedEvent = serde_json::from_str(json).unwrap(); + let parsed_value = serde_json::to_value(&event).unwrap(); + + // Check a few fields are as expected as a sanity check. + assert_eq!(&*event.common_fields.type_, "m.room.message"); + assert_eq!( + event.specific_fields.room_id.as_deref_opt(), + Some("!room:localhost") + ); + assert_eq!( + event.specific_fields.auth_events, + vec!["$auth1".to_string(), "$auth2".to_string()] + ); + assert_eq!( + event.specific_fields.prev_events, + vec!["$prev1".to_string()] + ); + + assert_eq!(event_value, parsed_value); + } + + #[test] + fn test_basic_v4_roundtrip_create_event() { + // A V4 create event for a v12 room has no room_id field. + let json = r#"{"auth_events":[],"prev_events":[],"type":"m.room.create","sender":"@erikj:jki.re","content":{"room_version":"12"},"depth":1,"state_key":"","origin_server_ts":1775568141481,"hashes":{"sha256":"qBX+glsKvogXFrvsEN0eh13pO2kpuE6o/b4yREPtOqw"},"signatures":{"jki.re":{"ed25519:auto":"sig"}},"unsigned":{}}"#; + let event_value: serde_json::Value = serde_json::from_str(json).unwrap(); + + let event: FormattedEvent = serde_json::from_str(json).unwrap(); + let parsed_value = serde_json::to_value(&event).unwrap(); + + // Check a few fields are as expected as a sanity check. + assert!(event.specific_fields.room_id.is_absent()); + assert_eq!(&*event.common_fields.type_, M_ROOM_CREATE); + + // Create events have no implicit auth events. + assert!(event + .specific_fields + .auth_event_ids(&event.common_fields) + .unwrap() + .is_empty()); + + assert_eq!(event_value, parsed_value); + } + + #[test] + fn test_v4_auth_event_ids_implicit_create() { + // Non-create events implicitly include the create event (derived from + // the room ID) in their auth chain. + let json = r#"{"auth_events":["$auth1"],"prev_events":["$prev1"],"type":"m.room.message","sender":"@user:localhost","content":{"body":"hi","msgtype":"m.text"},"depth":5,"room_id":"!BeXKh925K_M46DwsuJFR0EyBpE1P7CFUDGuWW4xw55Y","origin_server_ts":1234567890,"hashes":{"sha256":"h"},"signatures":{"localhost":{"ed25519:k":"s"}},"unsigned":{}}"#; + + let event: FormattedEvent = serde_json::from_str(json).unwrap(); + + let auth_ids = event + .specific_fields + .auth_event_ids(&event.common_fields) + .unwrap(); + assert_eq!( + auth_ids, + vec![ + "$auth1".to_string(), + "$BeXKh925K_M46DwsuJFR0EyBpE1P7CFUDGuWW4xw55Y".to_string(), + ] + ); + } + + #[test] + fn test_v4_validate_rejects_missing_room_id_for_non_create() { + // A v12 non-create event without a room_id must fail validation. + let json = r#"{"auth_events":[],"prev_events":[],"type":"m.room.message","sender":"@u:l","content":{},"depth":2,"state_key":"","origin_server_ts":1,"hashes":{"sha256":"h"},"signatures":{"l":{"ed25519:k":"s"}},"unsigned":{}}"#; + let event: FormattedEvent = serde_json::from_str(json).unwrap(); + assert!(event + .specific_fields + .validate(&event.common_fields) + .is_err()); + } + + #[test] + fn test_v4_validate_accepts_create_without_room_id() { + let json = r#"{"auth_events":[],"prev_events":[],"type":"m.room.create","sender":"@u:l","content":{"room_version":"12"},"depth":1,"state_key":"","origin_server_ts":1,"hashes":{"sha256":"h"},"signatures":{"l":{"ed25519:k":"s"}},"unsigned":{}}"#; + let event: FormattedEvent = serde_json::from_str(json).unwrap(); + event + .specific_fields + .validate(&event.common_fields) + .unwrap(); + } + + #[test] + fn test_basic_vmsc4242_roundtrip() { + // VMSC4242 introduces a `prev_state_events` field on top of V4. + let json = r#"{"auth_events":["$auth1"],"prev_events":["$prev1"],"prev_state_events":["$pstate1","$pstate2"],"type":"m.room.member","sender":"@user:localhost","content":{"membership":"join"},"depth":5,"room_id":"!room:localhost","state_key":"@user:localhost","origin_server_ts":1234567890,"hashes":{"sha256":"h"},"signatures":{"localhost":{"ed25519:k":"s"}},"unsigned":{}}"#; + let event_value: serde_json::Value = serde_json::from_str(json).unwrap(); + + let event: FormattedEvent = serde_json::from_str(json).unwrap(); + let parsed_value = serde_json::to_value(&event).unwrap(); + + assert_eq!( + event.specific_fields.prev_state_events, + vec!["$pstate1".to_string(), "$pstate2".to_string()] + ); + assert_eq!( + event.specific_fields.room_id.as_deref_opt(), + Some("!room:localhost") + ); + assert_eq!( + event.common_fields.state_key.as_deref_opt(), + Some("@user:localhost") + ); + + assert_eq!(event_value, parsed_value); + } + + #[test] + fn test_vmsc4242_room_id_for_create_event() { + let json = r#"{"auth_events":[],"prev_events":[],"prev_state_events":[],"type":"m.room.create","sender":"@erikj:jki.re","content":{"room_version":"12","predecessor":{"room_id":"!VuNGkDTdbMOOxSmuDa:jki.re"}},"depth":1,"state_key":"","origin_server_ts":1775568141481,"hashes":{"sha256":"qBX+glsKvogXFrvsEN0eh13pO2kpuE6o/b4yREPtOqw"},"signatures":{"jki.re":{"ed25519:auto":"n/4gHQRagk3r1r24L/7a+oaMMf9cysVfQRYdjpDZcf4ppkVym33rhTW18Vy4zMa1L5nsWLkxsBvbrRRDYUOhBQ"}},"unsigned":{"age_ts":1775568141481}}"#; + let event_value: serde_json::Value = serde_json::from_str(json).unwrap(); + + let event: FormattedEvent = serde_json::from_str(json).unwrap(); + + // The event_id calculation is independent of the `prev_state_events` + // field not being present in V4, so the same event_id derivation works. + let event_id = calculate_event_id(&event_value, &RoomVersion::V12).unwrap(); + + assert_eq!( + &*event + .specific_fields + .room_id(&event_id, &event.common_fields) + .unwrap(), + "!BeXKh925K_M46DwsuJFR0EyBpE1P7CFUDGuWW4xw55Y" + ); + } + + #[test] + fn test_event_format_enum_untagged_roundtrip() { + // The untagged EventFormatEnum serialization/deserialization is + // driven by fields, so serializing any variant must match the + // original JSON exactly. + let v2v3_json = r#"{"auth_events":[],"prev_events":[],"type":"m.room.create","sender":"@a:b","content":{},"depth":1,"room_id":"!r:b","state_key":"","origin_server_ts":1,"hashes":{"sha256":"h"},"signatures":{"b":{"ed25519:k":"s"}},"unsigned":{}}"#; + let v2v3_value: serde_json::Value = serde_json::from_str(v2v3_json).unwrap(); + let v2v3_container: FormattedEvent = + serde_json::from_str(v2v3_json).unwrap(); + assert_eq!(serde_json::to_value(&v2v3_container).unwrap(), v2v3_value); + assert_eq!( + serde_json::to_value(v2v3_container.into_general()).unwrap(), + v2v3_value + ); + + let v4_json = r#"{"auth_events":[],"prev_events":[],"type":"m.room.create","sender":"@a:b","content":{"room_version":"12"},"depth":1,"state_key":"","origin_server_ts":1,"hashes":{"sha256":"h"},"signatures":{"b":{"ed25519:k":"s"}},"unsigned":{}}"#; + let v4_value: serde_json::Value = serde_json::from_str(v4_json).unwrap(); + let v4_container: FormattedEvent = serde_json::from_str(v4_json).unwrap(); + assert_eq!(serde_json::to_value(&v4_container).unwrap(), v4_value); + assert_eq!( + serde_json::to_value(v4_container.into_general()).unwrap(), + v4_value + ); + } + + #[test] + fn test_unknown_top_level_fields_preserved_roundtrip() { + // Extra top-level fields (e.g. unknown or experimental) are captured + // via `other_fields` and must round-trip losslessly. + let json = r#"{"auth_events":[],"prev_events":[],"type":"m.room.message","sender":"@a:b","content":{"body":"hi","msgtype":"m.text"},"depth":1,"room_id":"!r:b","origin_server_ts":1,"hashes":{"sha256":"h"},"signatures":{"b":{"ed25519:k":"s"}},"unsigned":{},"msc4354_sticky":{"duration_ms":5000},"some_unknown_field":"some_value"}"#; + let event_value: serde_json::Value = serde_json::from_str(json).unwrap(); + + let event: FormattedEvent = serde_json::from_str(json).unwrap(); + let parsed_value = serde_json::to_value(&event).unwrap(); + + assert!(event + .common_fields + .other_fields + .contains_key(MSC4354_STICKY)); + assert!(event + .common_fields + .other_fields + .contains_key("some_unknown_field")); + + assert_eq!(event_value, parsed_value); + } +} diff --git a/rust/src/events/signatures.rs b/rust/src/events/signatures.rs index 0f2acd5c9b..fa07dd056b 100644 --- a/rust/src/events/signatures.rs +++ b/rust/src/events/signatures.rs @@ -30,12 +30,25 @@ use serde::{Deserialize, Serialize}; /// A class representing the signatures on an event. #[pyclass(frozen, skip_from_py_object)] -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, Default)] #[serde(transparent)] pub struct Signatures { inner: Arc>>>, } +impl Signatures { + /// Create a deep copy of this `Signatures` to allow modification without + /// affecting other references to the same signatures. This is needed when + /// we clone an event. + pub fn deep_copy(&self) -> Self { + let signatures = self.inner.read().expect("lock poisoned").clone(); // Deep copy the inner map + + Self { + inner: Arc::new(RwLock::new(signatures)), + } + } +} + #[pymethods] impl Signatures { #[new] diff --git a/rust/src/events/unsigned.rs b/rust/src/events/unsigned.rs index 5aa56812c0..0bb644ee90 100644 --- a/rust/src/events/unsigned.rs +++ b/rust/src/events/unsigned.rs @@ -101,6 +101,15 @@ impl Unsigned { .write() .map_err(|_| PyRuntimeError::new_err("Unsigned lock poisoned")) } + + /// Create a deep copy of this `Unsigned` to allow modification without + /// affecting other references to the same unsigned data. This is needed + /// when we clone an event. + pub fn deep_copy(&self) -> Self { + Self { + inner: Arc::new(RwLock::new(self.py_read().expect("lock poisoned").clone())), + } + } } #[pymethods] @@ -268,11 +277,11 @@ impl Unsigned { } } - fn for_persistence<'py>(&self, py: Python<'py>) -> PyResult> { + pub fn for_persistence<'py>(&self, py: Python<'py>) -> PyResult> { Ok(pythonize(py, &self.py_read()?.persisted_fields)?) } - fn for_event<'py>(&self, py: Python<'py>) -> PyResult> { + pub fn for_event<'py>(&self, py: Python<'py>) -> PyResult> { Ok(pythonize(py, &*self.py_read()?)?) } } diff --git a/rust/src/events/utils.rs b/rust/src/events/utils.rs new file mode 100644 index 0000000000..7d33c52b44 --- /dev/null +++ b/rust/src/events/utils.rs @@ -0,0 +1,1191 @@ +//! JSON +//! +//! Matrix event JSON utility functions. + +use std::collections::BTreeSet; + +use anyhow::Context as _; +use base64::Engine as _; +use serde_json::Value; +use sha2::{Digest, Sha256}; + +use super::constants::{ + aliases_field, create_field, + event_field::{ + AUTH_EVENTS, CONTENT, DEPTH, EVENT_ID, HASHES, MEMBERSHIP, ORIGIN, ORIGIN_SERVER_TS, + PREV_EVENTS, PREV_STATE, REPLACES_STATE, ROOM_ID, SENDER, SIGNATURES, STATE_KEY, TYPE, + UNSIGNED, + }, + event_type::{ + M_ROOM_ALIASES, M_ROOM_CREATE, M_ROOM_HISTORY_VISIBILITY, M_ROOM_JOIN_RULES, M_ROOM_MEMBER, + M_ROOM_POWER_LEVELS, M_ROOM_REDACTION, + }, + history_visibility_field, join_rules_field, membership_field, power_levels_field, + redaction_field, + unsigned_field::AGE_TS, +}; +use crate::{ + canonical_json::CanonicalizationOptions, events::constants::event_field::PREV_STATE_EVENTS, +}; +use crate::{ + events::constants::event_field::M_RELATES_TO, + room_versions::{EventFormatVersions, RoomVersion}, +}; + +/// Calculates the event_id of an event. +/// +/// The event_id is the `reference_hash` of the redacted event json, preceded by a `$`. +/// `calculate_event_id` can be used to determine the `event_id` for events in room versions V3+. +pub fn calculate_event_id(event: &Value, room_version: &RoomVersion) -> anyhow::Result> { + match room_version.event_format { + EventFormatVersions::ROOM_V1_V2 => { + anyhow::bail!( + "Attempted to calculate event_id using reference hash for room version v1/v2" + ); + } + EventFormatVersions::ROOM_V3 + | EventFormatVersions::ROOM_V4_PLUS + | EventFormatVersions::ROOM_V11_HYDRA_PLUS + | EventFormatVersions::ROOM_VMSC4242 => { + let reference_hash = compute_event_reference_hash(event, room_version)?; + + Ok(format!("${reference_hash}").into_boxed_str()) + } + _ => { + unimplemented!( + "Unknown event format version {}. This is a Synapse Programming error.", + room_version.event_format + ); + } + } +} + +/// Computes the event reference hash. This is the hash of the redacted event. +pub fn compute_event_reference_hash( + event: &Value, + room_version: &RoomVersion, +) -> anyhow::Result { + let mut redacted_value = redact(event, room_version)?; + + let redacted_value_mut = redacted_value + .as_object_mut() + .context("Failed getting `redacted_value` as mutable object")?; + + redacted_value_mut.remove(SIGNATURES); + redacted_value_mut.remove(UNSIGNED); + redacted_value_mut.remove(AGE_TS); + + let canonicalization_options = if room_version.strict_canonicaljson { + CanonicalizationOptions::strict() + } else { + CanonicalizationOptions::relaxed() + }; + + let json = + crate::canonical_json::to_string_canonical(&redacted_value_mut, canonicalization_options) + .map_err(|err| anyhow::anyhow!(err))?; + + let hash = Sha256::digest(json.as_bytes()); + + let base64_alphabet = if room_version.event_format == EventFormatVersions::ROOM_V3 { + base64::alphabet::STANDARD + } else { + base64::alphabet::URL_SAFE + }; + let base64_engine = base64::engine::GeneralPurpose::new( + &base64_alphabet, + base64::engine::general_purpose::NO_PAD, + ); + + Ok(base64_engine.encode(hash)) +} + +/// Attempts to redact the provided event, returning a copy of the redacted +/// event if successful. +/// +/// Events redacted with this function are meant to be sent over federation. +pub fn redact(event: &Value, room_version: &RoomVersion) -> anyhow::Result { + let mut allowed_keys = BTreeSet::from([ + (EVENT_ID), + (SENDER), + (ROOM_ID), + (HASHES), + (SIGNATURES), + (CONTENT), + (TYPE), + (STATE_KEY), + (DEPTH), + (PREV_EVENTS), + (ORIGIN_SERVER_TS), + ]); + + // Earlier room versions had additional allowed keys + if !room_version.updated_redaction_rules { + allowed_keys.extend([PREV_STATE, MEMBERSHIP, ORIGIN]); + } + + // Room versions with MSC4242 have `prev_state_events` instead of + // `auth_events`. + if room_version.msc4242_state_dags { + allowed_keys.insert(PREV_STATE_EVENTS); + } else { + allowed_keys.insert(AUTH_EVENTS); + } + + let event_type = event + .get(TYPE) + .with_context(|| format!("Missing {TYPE} field in json"))? + .as_str() + .with_context(|| format!("{TYPE} field is not a string"))?; + + let event_content = event + .get(CONTENT) + .with_context(|| format!("Missing {CONTENT} field in json"))?; + + let mut new_content = serde_json::json!({}); + let new_content_mut = new_content + .as_object_mut() + .context("Failed getting `new_content` as mutable object")?; + + let mut add_content_field = |field: &str| { + if let Some(existing_field) = event_content.get(field) { + new_content_mut.insert(field.to_string(), existing_field.clone()); + } + }; + + match event_type { + M_ROOM_MEMBER => { + add_content_field(membership_field::MEMBERSHIP); + if room_version.restricted_join_rule_fix { + add_content_field(membership_field::JOIN_AUTHORISED_VIA_USERS_SERVER); + } + if room_version.updated_redaction_rules { + // Preserve the signed field under third_party_invite. + if let Some(third_party_invite) = + event_content.get(membership_field::THIRD_PARTY_INVITE) + { + if third_party_invite.as_object().is_some() { + let mut new_third_party_invite = serde_json::json!({}); + if let Some(signed) = third_party_invite.get(membership_field::SIGNED) { + new_third_party_invite = + serde_json::json!({membership_field::SIGNED: signed.clone()}); + } + new_content_mut.insert( + membership_field::THIRD_PARTY_INVITE.to_string(), + new_third_party_invite, + ); + } + } + } + } + M_ROOM_CREATE => { + if room_version.updated_redaction_rules { + // MSC2176 rules state that create events cannot have their `content` redacted. + if let Some(event_content_object) = event_content.as_object() { + for (field, _value) in event_content_object { + add_content_field(field); + } + } + } + if !room_version.implicit_room_creator { + // Some room versions give meaning to `creator` + add_content_field(create_field::CREATOR); + } + if room_version.msc4291_room_ids_as_hashes { + // room_id is not allowed on the create event as it's derived from the event ID + allowed_keys.remove(ROOM_ID); + } + } + M_ROOM_JOIN_RULES => { + add_content_field(join_rules_field::JOIN_RULE); + if room_version.restricted_join_rule { + add_content_field(join_rules_field::ALLOW); + } + } + M_ROOM_POWER_LEVELS => { + add_content_field(power_levels_field::USERS); + add_content_field(power_levels_field::USERS_DEFAULT); + add_content_field(power_levels_field::EVENTS); + add_content_field(power_levels_field::EVENTS_DEFAULT); + add_content_field(power_levels_field::STATE_DEFAULT); + add_content_field(power_levels_field::BAN); + add_content_field(power_levels_field::KICK); + add_content_field(power_levels_field::REDACT); + if room_version.updated_redaction_rules { + add_content_field(power_levels_field::INVITE); + } + } + M_ROOM_ALIASES if room_version.special_case_aliases_auth => { + add_content_field(aliases_field::ALIASES); + } + M_ROOM_HISTORY_VISIBILITY => { + add_content_field(history_visibility_field::HISTORY_VISIBILITY) + } + M_ROOM_REDACTION if room_version.updated_redaction_rules => { + add_content_field(redaction_field::REDACTS); + } + _ => (), + }; + + let mut allowed_fields = serde_json::json!({}); + let allowed_fields_mut = allowed_fields + .as_object_mut() + .context("Failed getting `allowed_fields` as mutable object")?; + + for (k, v) in event + .as_object() + .context("Event is not a JSON object")? + .iter() + { + if allowed_keys.contains(&k.as_str()) { + allowed_fields_mut.insert(k.clone(), v.clone()); + } + } + + if room_version.msc3389_relation_redactions { + if let Some(relates_to) = event + .get(CONTENT) + .and_then(|content| content.get(M_RELATES_TO)) + { + if relates_to.is_object() { + let mut new_relates_to = serde_json::json!({}); + let new_relates_to_mut = new_relates_to + .as_object_mut() + .context("Failed getting `new_relates_to` as mutable object")?; + + for field in ["rel_type", "event_id"] { + if let Some(value) = relates_to.get(field) { + new_relates_to_mut.insert(field.to_string(), value.clone()); + } + } + + if !new_relates_to_mut.is_empty() { + new_content_mut.insert(M_RELATES_TO.to_string(), new_relates_to); + } + } + } + } + + allowed_fields_mut.insert(CONTENT.to_string(), new_content); + + // Copy over known good unsigned keys + let allowed_unsigned_keys = [AGE_TS, REPLACES_STATE]; + if let Some(unsigned) = event.get(UNSIGNED) { + let mut new_unsigned = serde_json::Map::new(); + for key in allowed_unsigned_keys { + if let Some(value) = unsigned.get(key) { + new_unsigned.insert(key.to_string(), value.clone()); + } + } + allowed_fields_mut.insert(UNSIGNED.to_string(), Value::Object(new_unsigned)); + } + + Ok(allowed_fields) +} + +#[cfg(test)] +mod tests { + use serde_json::json; + + use super::{calculate_event_id, redact}; + use crate::room_versions::RoomVersion; + + #[test] + fn test_calculate_event_id() { + let original = json!( + { + "auth_events":[ + "$gbHO7IPUHybc7ULFnT7P0r3iWlZFHGmr6zBfEYCUyKw", + "$hy1eZFYcgNxMFNNBgCD5fyOzyWRcBkxfNcrUI5ZpZlE", + "$Nt_z68EwFfPqBeHjzEHGsp461Z4EfNEzR-KH5bOYdOY" + ], + "prev_events":[ + "$4FbpZrgPQTwoLD9H5y7jcikucCypUOn78mXhQX7WliY" + ], + "type":"m.room.message", + "room_id":"!DHmIIVvxFASSFgDGzr:localhost:8008", + "sender":"@tester_b:localhost:8008", + "content":{ + "msgtype":"m.text", + "body":"invited people can see history", + "m.mentions":{} + }, + "depth":24, + "origin":"localhost:8008", + "origin_server_ts":1731769874137_i64, + "hashes":{ + "sha256":"FoYV1w3TW/B2mVT0gX2/BZKpCwrrvGXqXFdUhN9LZYU" + }, + "signatures":{ + "localhost:8008":{ + "ed25519:a_phSE":"G8cfk/m97sndxMNrEZ2nMMSXkVeJE05G7if4JiVzAwGfD3TwnF/jfSHt2acWrpNqv/aEhZug3WLofc2id+rVBw" + } + }, + "unsigned":{ + "age_ts":1731769874137_i64 + } + } + ); + let redacted = json!( + { + "auth_events":[ + "$gbHO7IPUHybc7ULFnT7P0r3iWlZFHGmr6zBfEYCUyKw", + "$hy1eZFYcgNxMFNNBgCD5fyOzyWRcBkxfNcrUI5ZpZlE", + "$Nt_z68EwFfPqBeHjzEHGsp461Z4EfNEzR-KH5bOYdOY" + ], + "prev_events":[ + "$4FbpZrgPQTwoLD9H5y7jcikucCypUOn78mXhQX7WliY" + ], + "type":"m.room.message", + "room_id":"!DHmIIVvxFASSFgDGzr:localhost:8008", + "sender":"@tester_b:localhost:8008", + "content":{}, + "depth":24, + "origin":"localhost:8008", + "origin_server_ts":1731769874137_i64, + "hashes":{ + "sha256":"FoYV1w3TW/B2mVT0gX2/BZKpCwrrvGXqXFdUhN9LZYU" + }, + "signatures":{ + "localhost:8008":{ + "ed25519:a_phSE":"G8cfk/m97sndxMNrEZ2nMMSXkVeJE05G7if4JiVzAwGfD3TwnF/jfSHt2acWrpNqv/aEhZug3WLofc2id+rVBw" + } + } + } + ); + + let expected = "$zRz9jjiT9wZc3Hl9ij_74aCmTjqV3YMlj9sj3Uqxg6o"; + let expected_v3 = "$zRz9jjiT9wZc3Hl9ij/74aCmTjqV3YMlj9sj3Uqxg6o"; + + let original_event_id = calculate_event_id(&original, &RoomVersion::V10).unwrap(); + let redacted_event_id = calculate_event_id(&redacted, &RoomVersion::V10).unwrap(); + let _ = calculate_event_id(&original, &RoomVersion::V2).unwrap_err(); + let v3_event_id = calculate_event_id(&original, &RoomVersion::V3).unwrap(); + + assert_eq!(expected, &*original_event_id); + assert_eq!(expected, &*redacted_event_id); + assert_eq!(expected_v3, &*v3_event_id); + } + + #[test] + /// Tests to ensure events with overly large values for `depth` are handled appropriately. + /// This was added in room version 6 . + fn test_calculate_event_id_big_int_old_rooms() { + let original = json!( + { + "auth_events":[ + "$gbHO7IPUHybc7ULFnT7P0r3iWlZFHGmr6zBfEYCUyKw", + "$hy1eZFYcgNxMFNNBgCD5fyOzyWRcBkxfNcrUI5ZpZlE", + "$Nt_z68EwFfPqBeHjzEHGsp461Z4EfNEzR-KH5bOYdOY" + ], + "prev_events":[ + "$4FbpZrgPQTwoLD9H5y7jcikucCypUOn78mXhQX7WliY" + ], + "type":"m.room.message", + "room_id":"!DHmIIVvxFASSFgDGzr:localhost:8008", + "sender":"@tester_b:localhost:8008", + "content":{ + "msgtype":"m.text", + "body":"invited people can see history", + "m.mentions":{} + }, + // NOTE: use the biggest acceptable number + "depth":u64::MAX, + "origin":"localhost:8008", + "origin_server_ts":1731769874137_i64, + "hashes":{ + "sha256":"FoYV1w3TW/B2mVT0gX2/BZKpCwrrvGXqXFdUhN9LZYU" + }, + "signatures":{ + "localhost:8008":{ + "ed25519:a_phSE":"G8cfk/m97sndxMNrEZ2nMMSXkVeJE05G7if4JiVzAwGfD3TwnF/jfSHt2acWrpNqv/aEhZug3WLofc2id+rVBw" + } + }, + "unsigned":{ + "age_ts":1731769874137_i64 + } + } + ); + + // These should succeed. + let _event_id = calculate_event_id(&original, &RoomVersion::V3).unwrap(); + let _event_id = calculate_event_id(&original, &RoomVersion::V4).unwrap(); + let _event_id = calculate_event_id(&original, &RoomVersion::V5).unwrap(); + + // These should not succeed. + let versions = [ + RoomVersion::V6, + RoomVersion::V7, + RoomVersion::V8, + RoomVersion::V9, + RoomVersion::V10, + RoomVersion::V11, + RoomVersion::V12, + ]; + for version in versions { + let _event_id = calculate_event_id(&original, &version).unwrap_err(); + } + } + + #[test] + /// Tests that calling `redact` on invalid event json that is missing the `type` property + /// fails. + /// The `type` field is required by all versions of the spec. Any json encountered where + /// this field is missing must be considered invalid. + fn test_redact_missing_type() { + let original = json!( + { + // "type": "missing_type" + "unknown_key":"unknown_value", + "auth_events":[ + "$gbHO7IPUHybc7ULFnT7P0r3iWlZFHGmr6zBfEYCUyKw", + "$hy1eZFYcgNxMFNNBgCD5fyOzyWRcBkxfNcrUI5ZpZlE", + "$Nt_z68EwFfPqBeHjzEHGsp461Z4EfNEzR-KH5bOYdOY" + ], + "prev_events":[ + "$4FbpZrgPQTwoLD9H5y7jcikucCypUOn78mXhQX7WliY" + ], + "room_id":"!DHmIIVvxFASSFgDGzr:localhost:8008", + "sender":"@tester_b:localhost:8008", + "content":{ + "msgtype":"m.text", + "body":"invited people can see history", + "m.mentions":{} + }, + "depth":24, + "origin":"localhost:8008", + "origin_server_ts":1731769874137_i64, + "hashes":{ + "sha256":"FoYV1w3TW/B2mVT0gX2/BZKpCwrrvGXqXFdUhN9LZYU" + }, + "signatures":{ + "localhost:8008":{ + "ed25519:a_phSE":"G8cfk/m97sndxMNrEZ2nMMSXkVeJE05G7if4JiVzAwGfD3TwnF/jfSHt2acWrpNqv/aEhZug3WLofc2id+rVBw" + } + }, + "unsigned":{ + "age_ts":1731769874137_i64 + } + } + ); + + let versions = [ + RoomVersion::V1, + RoomVersion::V2, + RoomVersion::V3, + RoomVersion::V4, + RoomVersion::V5, + RoomVersion::V6, + RoomVersion::V7, + RoomVersion::V8, + RoomVersion::V9, + RoomVersion::V10, + RoomVersion::V11, + RoomVersion::V12, + ]; + for version in versions { + let _ = redact(&original, &version).unwrap_err(); + } + } + + #[test] + /// Tests that calling `redact` on invalid event json that is missing the `content` property + /// fails. + /// The `content` field is required by all versions of the spec. Any json encountered where + /// this field is missing must be considered invalid. + fn test_redact_missing_content() { + let original = json!( + { + "unknown_key":"unknown_value", + "auth_events":[ + "$gbHO7IPUHybc7ULFnT7P0r3iWlZFHGmr6zBfEYCUyKw", + "$hy1eZFYcgNxMFNNBgCD5fyOzyWRcBkxfNcrUI5ZpZlE", + "$Nt_z68EwFfPqBeHjzEHGsp461Z4EfNEzR-KH5bOYdOY" + ], + "prev_events":[ + "$4FbpZrgPQTwoLD9H5y7jcikucCypUOn78mXhQX7WliY" + ], + "type":"m.room.message", + "room_id":"!DHmIIVvxFASSFgDGzr:localhost:8008", + "sender":"@tester_b:localhost:8008", + "depth":24, + "origin":"localhost:8008", + "origin_server_ts":1731769874137_i64, + "hashes":{ + "sha256":"FoYV1w3TW/B2mVT0gX2/BZKpCwrrvGXqXFdUhN9LZYU" + }, + "signatures":{ + "localhost:8008":{ + "ed25519:a_phSE":"G8cfk/m97sndxMNrEZ2nMMSXkVeJE05G7if4JiVzAwGfD3TwnF/jfSHt2acWrpNqv/aEhZug3WLofc2id+rVBw" + } + }, + "unsigned":{ + "age_ts":1731769874137_i64 + } + } + ); + + let versions = [ + RoomVersion::V1, + RoomVersion::V2, + RoomVersion::V3, + RoomVersion::V4, + RoomVersion::V5, + RoomVersion::V6, + RoomVersion::V7, + RoomVersion::V8, + RoomVersion::V9, + RoomVersion::V10, + RoomVersion::V11, + RoomVersion::V12, + ]; + for version in versions { + let _ = redact(&original, &version).unwrap_err(); + } + } + + #[test] + /// Tests redaction logic for `m.room.message` events against latest room versions. + /// This is only testing v10+ as it is rather cumbersome to create the proper json for all + /// rooms versions. + fn test_redact_m_room_message() { + let original = json!( + { + "unknown_key":"unknown_value", + "auth_events":[ + "$gbHO7IPUHybc7ULFnT7P0r3iWlZFHGmr6zBfEYCUyKw", + "$hy1eZFYcgNxMFNNBgCD5fyOzyWRcBkxfNcrUI5ZpZlE", + "$Nt_z68EwFfPqBeHjzEHGsp461Z4EfNEzR-KH5bOYdOY" + ], + "prev_events":[ + "$4FbpZrgPQTwoLD9H5y7jcikucCypUOn78mXhQX7WliY" + ], + "type":"m.room.message", + "room_id":"!DHmIIVvxFASSFgDGzr:localhost:8008", + "sender":"@tester_b:localhost:8008", + "content":{ + "msgtype":"m.text", + "body":"invited people can see history", + "m.mentions":{} + }, + "depth":24, + "origin":"localhost:8008", + "origin_server_ts":1731769874137_i64, + "hashes":{ + "sha256":"FoYV1w3TW/B2mVT0gX2/BZKpCwrrvGXqXFdUhN9LZYU" + }, + "signatures":{ + "localhost:8008":{ + "ed25519:a_phSE":"G8cfk/m97sndxMNrEZ2nMMSXkVeJE05G7if4JiVzAwGfD3TwnF/jfSHt2acWrpNqv/aEhZug3WLofc2id+rVBw" + } + }, + "unsigned":{ + "age_ts":1731769874137_i64 + } + } + ); + let expected = json!( + { + "auth_events":[ + "$gbHO7IPUHybc7ULFnT7P0r3iWlZFHGmr6zBfEYCUyKw", + "$hy1eZFYcgNxMFNNBgCD5fyOzyWRcBkxfNcrUI5ZpZlE", + "$Nt_z68EwFfPqBeHjzEHGsp461Z4EfNEzR-KH5bOYdOY" + ], + "prev_events":[ + "$4FbpZrgPQTwoLD9H5y7jcikucCypUOn78mXhQX7WliY" + ], + "type":"m.room.message", + "room_id":"!DHmIIVvxFASSFgDGzr:localhost:8008", + "sender":"@tester_b:localhost:8008", + "content":{}, + "depth":24, + "origin":"localhost:8008", + "origin_server_ts":1731769874137_i64, + "hashes":{ + "sha256":"FoYV1w3TW/B2mVT0gX2/BZKpCwrrvGXqXFdUhN9LZYU" + }, + "signatures":{ + "localhost:8008":{ + "ed25519:a_phSE":"G8cfk/m97sndxMNrEZ2nMMSXkVeJE05G7if4JiVzAwGfD3TwnF/jfSHt2acWrpNqv/aEhZug3WLofc2id+rVBw" + } + }, + "unsigned":{ + "age_ts":1731769874137_i64 + } + } + ); + let expected_v11 = json!( + { + "auth_events":[ + "$gbHO7IPUHybc7ULFnT7P0r3iWlZFHGmr6zBfEYCUyKw", + "$hy1eZFYcgNxMFNNBgCD5fyOzyWRcBkxfNcrUI5ZpZlE", + "$Nt_z68EwFfPqBeHjzEHGsp461Z4EfNEzR-KH5bOYdOY" + ], + "prev_events":[ + "$4FbpZrgPQTwoLD9H5y7jcikucCypUOn78mXhQX7WliY" + ], + "type":"m.room.message", + "room_id":"!DHmIIVvxFASSFgDGzr:localhost:8008", + "sender":"@tester_b:localhost:8008", + "content":{}, + "depth":24, + "origin_server_ts":1731769874137_i64, + "hashes":{ + "sha256":"FoYV1w3TW/B2mVT0gX2/BZKpCwrrvGXqXFdUhN9LZYU" + }, + "signatures":{ + "localhost:8008":{ + "ed25519:a_phSE":"G8cfk/m97sndxMNrEZ2nMMSXkVeJE05G7if4JiVzAwGfD3TwnF/jfSHt2acWrpNqv/aEhZug3WLofc2id+rVBw" + } + }, + "unsigned":{ + "age_ts":1731769874137_i64 + } + } + ); + + let redacted = redact(&original, &RoomVersion::V10).unwrap(); + + assert_eq!(expected, redacted); + + let versions = [RoomVersion::V11, RoomVersion::V12]; + for version in versions { + let redacted = redact(&original, &version).unwrap(); + + assert_eq!(expected_v11, redacted, "Room Version {version}"); + } + } + + #[test] + /// Tests redaction logic adheres to changes due to [super::use_updated_redaction_rules]. + /// Tests against all known room versions. + fn test_redact_updated_redaction_rules() { + let original = json!( + { + "type":"m.room.message", + "prev_state":{}, + "membership":{}, + "origin":"some_place", + "content":{}, + } + ); + + let expected_updated_redaction_rules = json!( + { + "type":"m.room.message", + "content":{}, + } + ); + let expected_pre_updated_redaction_rules = json!( + { + "type":"m.room.message", + "prev_state":{}, + "membership":{}, + "origin":"some_place", + "content":{}, + } + ); + + let versions = [ + RoomVersion::V1, + RoomVersion::V2, + RoomVersion::V3, + RoomVersion::V4, + RoomVersion::V5, + RoomVersion::V6, + RoomVersion::V7, + RoomVersion::V8, + RoomVersion::V9, + RoomVersion::V10, + ]; + for version in versions { + let redacted_1 = redact(&original, &version).unwrap(); + assert_eq!( + expected_pre_updated_redaction_rules, redacted_1, + "Room Version {version}" + ); + } + + let versions = [RoomVersion::V11, RoomVersion::V12]; + for version in versions { + let redacted_2 = redact(&original, &version).unwrap(); + assert_eq!( + expected_updated_redaction_rules, redacted_2, + "Room Version {version}" + ); + } + } + + #[test] + /// Tests redaction logic for `m.room.member` events against all known room versions. + fn test_redact_m_room_member() { + let original = json!( + { + "type":"m.room.member", + "content":{ + "unknown_key":"unknown_value", + "membership":"join", + "join_authorised_via_users_server":"server", + "third_party_invite":{ + "signed":{}, + }, + }, + } + ); + + let expected_pre_unrestricted_join_rule_fix = json!( + { + "type":"m.room.member", + "content":{ + "membership":"join", + }, + } + ); + + let expected_updated_redaction_rules = json!( + { + "type":"m.room.member", + "content":{ + "membership":"join", + "join_authorised_via_users_server":"server", + "third_party_invite":{ + "signed":{}, + }, + }, + } + ); + let expected_pre_updated_redaction_rules = json!( + { + "type":"m.room.member", + "content":{ + "membership":"join", + "join_authorised_via_users_server":"server", + }, + } + ); + + let versions = [ + RoomVersion::V1, + RoomVersion::V2, + RoomVersion::V3, + RoomVersion::V4, + RoomVersion::V5, + RoomVersion::V6, + RoomVersion::V7, + RoomVersion::V8, + ]; + for version in versions { + let redacted = redact(&original, &version).unwrap(); + assert_eq!( + expected_pre_unrestricted_join_rule_fix, redacted, + "Room Version {version}" + ); + } + + let versions = [RoomVersion::V9, RoomVersion::V10]; + for version in versions { + let redacted = redact(&original, &version).unwrap(); + assert_eq!( + expected_pre_updated_redaction_rules, redacted, + "Room Version {version}" + ); + } + + let versions = [RoomVersion::V11, RoomVersion::V12]; + for version in versions { + let redacted = redact(&original, &version).unwrap(); + assert_eq!( + expected_updated_redaction_rules, redacted, + "Room Version {version}" + ); + } + } + + /// Tests redaction logic for `m.room.create` events against all known room versions. + #[test] + fn test_redact_m_room_create() { + let original = json!( + { + "type":"m.room.create", + "room_id": "!roomid", + "content":{ + "unknown_key":"unknown_value", + "other_key":"value", + "creator":"user", + }, + } + ); + + let expected_implicit_room_creator = json!( + { + "type":"m.room.create", + "room_id": "!roomid", + "content":{ + "unknown_key":"unknown_value", + "other_key":"value", + "creator":"user", + }, + } + ); + let expected_room_ids_as_hashes = json!( + { + "type":"m.room.create", + "content":{ + "unknown_key":"unknown_value", + "other_key":"value", + "creator":"user", + }, + } + ); + let expected_pre_implicit_room_creator = json!( + { + "type":"m.room.create", + "room_id": "!roomid", + "content":{ + "creator":"user", + }, + } + ); + + let versions = [ + RoomVersion::V1, + RoomVersion::V2, + RoomVersion::V3, + RoomVersion::V4, + RoomVersion::V5, + RoomVersion::V6, + RoomVersion::V7, + RoomVersion::V8, + RoomVersion::V9, + RoomVersion::V10, + ]; + for version in versions { + let redacted = redact(&original, &version).unwrap(); + assert_eq!( + expected_pre_implicit_room_creator, redacted, + "Room Version {version}" + ); + } + + let redacted = redact(&original, &RoomVersion::V11).unwrap(); + assert_eq!(expected_implicit_room_creator, redacted); + + let versions = [RoomVersion::HYDRA_V11, RoomVersion::V12]; + for version in versions { + let redacted = redact(&original, &version).unwrap(); + assert_eq!( + expected_room_ids_as_hashes, redacted, + "Room Version {version}" + ); + } + } + + #[test] + /// Tests redaction logic for `m.room.join_rules` events against all known room versions. + fn test_redact_m_room_join_rules() { + let original = json!( + { + "type":"m.room.join_rules", + "content":{ + "unknown_key":"unknown_value", + "join_rule":"invite", + "allow":"user", + }, + } + ); + + let expected_restricted_join_rule = json!( + { + "type":"m.room.join_rules", + "content":{ + "join_rule":"invite", + "allow":"user", + }, + } + ); + let expected_pre_restricted_join_rule = json!( + { + "type":"m.room.join_rules", + "content":{ + "join_rule":"invite", + }, + } + ); + + let versions = [ + RoomVersion::V1, + RoomVersion::V2, + RoomVersion::V3, + RoomVersion::V4, + RoomVersion::V5, + RoomVersion::V6, + RoomVersion::V7, + ]; + for version in versions { + let redacted = redact(&original, &version).unwrap(); + assert_eq!( + expected_pre_restricted_join_rule, redacted, + "Room Version {version}" + ); + } + + let versions = [ + RoomVersion::V8, + RoomVersion::V9, + RoomVersion::V10, + RoomVersion::V11, + RoomVersion::V12, + ]; + for version in versions { + let redacted = redact(&original, &version).unwrap(); + assert_eq!( + expected_restricted_join_rule, redacted, + "Room Version {version}" + ); + } + } + + #[test] + /// Tests redaction logic for `m.room.power_levels` events against all known room versions. + fn test_redact_m_room_power_levels() { + let original = json!( + { + "type":"m.room.power_levels", + "content":{ + "unknown_key":"unknown_value", + "users":{}, + "users_default":{}, + "events":{}, + "events_default":{}, + "state_default":{}, + "ban":{}, + "kick":{}, + "redact":{}, + "invite":{}, + }, + } + ); + + let expected_updated_redaction_rules = json!( + { + "type":"m.room.power_levels", + "content":{ + "users":{}, + "users_default":{}, + "events":{}, + "events_default":{}, + "state_default":{}, + "ban":{}, + "kick":{}, + "redact":{}, + "invite":{}, + }, + } + ); + let expected_pre_updated_redaction_rules = json!( + { + "type":"m.room.power_levels", + "content":{ + "users":{}, + "users_default":{}, + "events":{}, + "events_default":{}, + "state_default":{}, + "ban":{}, + "kick":{}, + "redact":{}, + }, + } + ); + + let versions = [ + RoomVersion::V1, + RoomVersion::V2, + RoomVersion::V3, + RoomVersion::V4, + RoomVersion::V5, + RoomVersion::V6, + RoomVersion::V7, + RoomVersion::V8, + RoomVersion::V9, + RoomVersion::V10, + ]; + for version in versions { + let redacted = redact(&original, &version).unwrap(); + assert_eq!( + expected_pre_updated_redaction_rules, redacted, + "Room Version {version}" + ); + } + + let versions = [RoomVersion::V11, RoomVersion::V12]; + for version in versions { + let redacted = redact(&original, &version).unwrap(); + assert_eq!( + expected_updated_redaction_rules, redacted, + "Room Version {version}" + ); + } + } + + #[test] + /// Tests redaction logic for `m.room.aliases` events against all known room versions. + fn test_redact_m_room_aliases() { + let original = json!( + { + "type":"m.room.aliases", + "content":{ + "unknown_key":"unknown_value", + "aliases":{}, + }, + } + ); + + let expected_special_case_aliases = json!( + { + "type":"m.room.aliases", + "content":{ + "aliases":{}, + }, + } + ); + let expected_post_special_case_aliases = json!( + { + "type":"m.room.aliases", + "content":{}, + } + ); + + let versions = [ + RoomVersion::V1, + RoomVersion::V2, + RoomVersion::V3, + RoomVersion::V4, + RoomVersion::V5, + ]; + for version in versions { + let redacted = redact(&original, &version).unwrap(); + assert_eq!( + expected_special_case_aliases, redacted, + "Room Version {version}" + ); + } + + let versions = [ + RoomVersion::V6, + RoomVersion::V7, + RoomVersion::V8, + RoomVersion::V9, + RoomVersion::V10, + RoomVersion::V11, + RoomVersion::V12, + ]; + for version in versions { + let redacted = redact(&original, &version).unwrap(); + assert_eq!( + expected_post_special_case_aliases, redacted, + "Room Version {version}" + ); + } + } + + #[test] + /// Tests redaction logic for `m.room.history_visibility` events against all known room versions. + fn test_redact_m_room_history_visibility() { + let original = json!( + { + "type":"m.room.history_visibility", + "content":{ + "unknown_key":"unknown_value", + "history_visibility":"visibility", + }, + } + ); + + let expected = json!( + { + "type":"m.room.history_visibility", + "content":{ + "history_visibility":"visibility", + }, + } + ); + + let versions = [ + RoomVersion::V1, + RoomVersion::V2, + RoomVersion::V3, + RoomVersion::V4, + RoomVersion::V5, + RoomVersion::V6, + RoomVersion::V7, + RoomVersion::V8, + RoomVersion::V9, + RoomVersion::V10, + RoomVersion::V11, + RoomVersion::V12, + ]; + for version in versions { + let redacted = redact(&original, &version).unwrap(); + assert_eq!(expected, redacted, "Room Version {version}"); + } + } + + #[test] + /// Tests redaction logic for `m.room.redaction` events against all known room versions. + fn test_redact_m_room_redaction() { + let original = json!( + { + "type":"m.room.redaction", + "content":{ + "unknown_key":"unknown_value", + "redacts":"event", + }, + } + ); + + let expected_updated_redaction_rules = json!( + { + "type":"m.room.redaction", + "content":{ + "redacts":"event", + }, + } + ); + let expected_pre_updated_redaction_rules = json!( + { + "type":"m.room.redaction", + "content":{}, + } + ); + + let versions = [ + RoomVersion::V1, + RoomVersion::V2, + RoomVersion::V3, + RoomVersion::V4, + RoomVersion::V5, + RoomVersion::V6, + RoomVersion::V7, + RoomVersion::V8, + RoomVersion::V9, + RoomVersion::V10, + ]; + for version in versions { + let redacted = redact(&original, &version).unwrap(); + assert_eq!( + expected_pre_updated_redaction_rules, redacted, + "Room Version {version}" + ); + } + + let versions = [RoomVersion::V11, RoomVersion::V12]; + for version in versions { + let redacted = redact(&original, &version).unwrap(); + assert_eq!( + expected_updated_redaction_rules, redacted, + "Room Version {version}" + ); + } + } +} diff --git a/rust/src/json.rs b/rust/src/json.rs new file mode 100644 index 0000000000..3e833c6707 --- /dev/null +++ b/rust/src/json.rs @@ -0,0 +1,218 @@ +/* + * This file is licensed under the Affero General Public License (AGPL) version 3. + * + * Copyright (C) 2026 Element Creations Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * See the GNU Affero General Public License for more details: + * . + * + */ + +/// A wrapper type that represents a value that may be missing. +/// +/// We can't necessarily use `Option` for this, as we want to distinguish +/// between a missing value and a value that is present but null (e.g. +/// `{"field": null}` vs `{}`). Serde by default treats missing fields as +/// `None`, so we need a custom type to capture this distinction. +/// +/// A plain `AllowMissing` is used for fields that are either present and of +/// type `T`, or absent. An `AllowMissing>` is used for fields that +/// are of type `T`, null, or absent. +/// +/// Note, to use this type correctly, the field **MUST** be annotated with: +/// +/// ```rust +/// #[serde( +/// default, +/// with = "crate::json::allow_missing", +/// skip_serializing_if = "AllowMissing::is_absent" +/// )] +/// ``` +/// +#[derive(Default, Debug, Clone)] +pub enum AllowMissing { + Some(T), + #[default] + Absent, +} + +impl AllowMissing { + /// Returns `true` if the value is present, even if it is null. + pub fn is_some(&self) -> bool { + matches!(self, AllowMissing::Some(_)) + } + + /// Returns `true` if the value is absent. + pub fn is_absent(&self) -> bool { + matches!(self, AllowMissing::Absent) + } + + /// Converts to `Option`. + /// + /// Useful for converting e.g. `AllowMissing` to `Option<&str>`. + pub fn as_deref_opt(&self) -> Option<&T::Target> + where + T: std::ops::Deref, + { + match self { + AllowMissing::Some(inner) => Some(inner.deref()), + AllowMissing::Absent => None, + } + } + + /// Converts to `Option<&T>`. + pub fn as_ref_opt(&self) -> Option<&T> { + match self { + AllowMissing::Some(inner) => Some(inner), + AllowMissing::Absent => None, + } + } +} + +/// A module that provides the serialization and deserialization logic for +/// `AllowMissing`. +pub mod allow_missing { + use serde::ser::Error as _; + + use super::AllowMissing; + + pub fn deserialize<'de, T, D>(deserializer: D) -> Result, D::Error> + where + T: serde::Deserialize<'de>, + D: serde::Deserializer<'de>, + { + Ok(AllowMissing::Some(T::deserialize(deserializer)?)) + } + + pub fn serialize(value: &AllowMissing, serializer: S) -> Result + where + T: serde::Serialize, + S: serde::Serializer, + { + match value { + AllowMissing::Some(inner) => inner.serialize(serializer), + // We should never attempt to serialize an `AllowMissing::Absent`, as we + // should have skipped it with `skip_serializing_if`. + AllowMissing::Absent => Err(S::Error::custom("cannot serialize AllowMissing::Absent")), + } + } +} + +#[cfg(test)] +mod tests { + use std::assert_matches; + + use serde::{Deserialize, Serialize}; + + use super::*; + + #[derive(Serialize, Deserialize)] + struct TestStruct { + #[serde( + default, + with = "crate::json::allow_missing", + skip_serializing_if = "AllowMissing::is_absent" + )] + value: AllowMissing, + } + + #[test] + fn test_deserialize() { + let json = r#"{"value":42}"#; + let deserialized: TestStruct = serde_json::from_str(json).unwrap(); + assert!(deserialized.value.is_some()); + assert_matches!(deserialized.value, AllowMissing::Some(42)); + + let json = r#"{}"#; + let deserialized: TestStruct = serde_json::from_str(json).unwrap(); + assert!(deserialized.value.is_absent()); + assert_matches!(deserialized.value, AllowMissing::Absent); + } + + #[test] + fn test_serialize() { + let value = TestStruct { + value: AllowMissing::Some(42), + }; + let serialized = serde_json::to_string(&value).unwrap(); + assert_eq!(serialized, r#"{"value":42}"#); + + let value = TestStruct { + value: AllowMissing::Absent, + }; + let serialized = serde_json::to_string(&value).unwrap(); + assert_eq!(serialized, r#"{}"#); + } + + /// Test that we get an error if we attempt to serialize an + /// `AllowMissing::Absent` without the skip_serializing_if annotation. + #[test] + fn test_serialize_absent_error() { + #[derive(Serialize)] + struct TestStructWithoutSkip { + #[serde(default, with = "crate::json::allow_missing")] + value: AllowMissing, + } + + let value = TestStructWithoutSkip { + value: AllowMissing::Absent, + }; + + let err = serde_json::to_string(&value).unwrap_err(); + assert_eq!(err.to_string(), "cannot serialize AllowMissing::Absent"); + } + + #[derive(Serialize, Deserialize)] + struct TestStructOption { + #[serde( + default, + with = "crate::json::allow_missing", + skip_serializing_if = "AllowMissing::is_absent" + )] + value: AllowMissing>, + } + + #[test] + fn test_serialize_option() { + let value = TestStructOption { + value: AllowMissing::Some(Some(42)), + }; + let serialized = serde_json::to_string(&value).unwrap(); + assert_eq!(serialized, r#"{"value":42}"#); + + let value = TestStructOption { + value: AllowMissing::Some(None), + }; + let serialized = serde_json::to_string(&value).unwrap(); + assert_eq!(serialized, r#"{"value":null}"#); + + let value = TestStructOption { + value: AllowMissing::Absent, + }; + let serialized = serde_json::to_string(&value).unwrap(); + assert_eq!(serialized, r#"{}"#); + } + + #[test] + fn test_deserialize_option() { + let json = r#"{"value":42}"#; + let deserialized: TestStructOption = serde_json::from_str(json).unwrap(); + assert!(deserialized.value.is_some()); + assert_matches!(deserialized.value, AllowMissing::Some(Some(42))); + + let json = r#"{"value":null}"#; + let deserialized: TestStructOption = serde_json::from_str(json).unwrap(); + assert!(deserialized.value.is_some()); + assert_matches!(deserialized.value, AllowMissing::Some(None)); + + let json = r#"{}"#; + let deserialized: TestStructOption = serde_json::from_str(json).unwrap(); + assert!(deserialized.value.is_absent()); + assert_matches!(deserialized.value, AllowMissing::Absent); + } +} diff --git a/rust/src/lib.rs b/rust/src/lib.rs index cc89862e4e..8ed4e24b81 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -12,6 +12,7 @@ pub mod events; pub mod http; pub mod http_client; pub mod identifier; +pub mod json; pub mod matrix_const; pub mod msc4388_rendezvous; pub mod push; diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py index 823b6288e8..4f36740808 100644 --- a/synapse/crypto/event_signing.py +++ b/synapse/crypto/event_signing.py @@ -33,8 +33,9 @@ from unpaddedbase64 import decode_base64, encode_base64 from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersion from synapse.events import EventBase -from synapse.events.utils import prune_event, prune_event_dict +from synapse.events.utils import prune_event from synapse.logging.opentracing import trace +from synapse.synapse_rust.events import redact_event_dict from synapse.types import JsonDict, UserID logger = logging.getLogger(__name__) @@ -157,7 +158,7 @@ def compute_event_signature( Returns: a dictionary in the same format of an event's signatures field. """ - redact_json = prune_event_dict(room_version, event_dict) + redact_json = redact_event_dict(room_version, event_dict) redact_json.pop("age_ts", None) redact_json.pop("unsigned", None) if logger.isEnabledFor(logging.DEBUG): diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 36736b4559..a1633f881f 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -46,9 +46,9 @@ from synapse.api.errors import ( ) from synapse.config.key import TrustedKeyServer from synapse.events import EventBase -from synapse.events.utils import prune_event_dict from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage.keys import FetchKeyResult +from synapse.synapse_rust.events import redact_event from synapse.types import JsonDict from synapse.util import unwrapFirstError from synapse.util.async_helpers import yieldable_gather_results @@ -136,7 +136,7 @@ class VerifyJsonRequest: server_name, # We defer creating the redacted json object, as it uses a lot more # memory than the Event object itself. - lambda: prune_event_dict(event.room_version, event.get_pdu_json()), + lambda: redact_event(event).get_pdu_json(), minimum_valid_until_ms, key_ids=key_ids, ) diff --git a/synapse/event_auth.py b/synapse/event_auth.py index d9750651fa..81a454b544 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -66,6 +66,7 @@ from synapse.events.py_protocol import supports_msc4242_state_dag from synapse.state import CREATE_KEY from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import ( + JsonMapping, MutableStateMap, StateKey, StateMap, @@ -856,6 +857,7 @@ def get_send_level( power level required to send this event. """ + power_levels_content: JsonMapping if power_levels_event: power_levels_content = power_levels_event.content else: diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index fcb169e079..7abd91f072 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -20,49 +20,37 @@ # # -import abc import collections.abc from typing import ( TYPE_CHECKING, Any, - Generic, - Iterable, - Literal, - TypeVar, + TypeAlias, Union, - overload, ) import attr -from typing_extensions import deprecated -from unpaddedbase64 import encode_base64 from synapse.api.constants import ( EventContentFields, EventTypes, RelationTypes, - StickyEvent, -) -from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions -from synapse.synapse_rust.events import ( - EventInternalMetadata, - JsonObject, - Signatures, - Unsigned, ) +from synapse.api.errors import Codes, SynapseError +from synapse.api.room_versions import RoomVersion, RoomVersions +from synapse.synapse_rust.events import Event from synapse.types import ( JsonDict, - JsonMapping, StateKey, - StrCollection, ) -from synapse.util.caches import intern_dict -from synapse.util.duration import Duration -from synapse.util.frozenutils import freeze if TYPE_CHECKING: from synapse.events.builder import EventBuilder +# The base class for events used to be called EventBase, but it was renamed to +# Event when we switched to using the Rust implementation. We keep the old name +# around for backwards compatibility. +EventBase: TypeAlias = Event + USE_FROZEN_DICTS = False """ @@ -70,639 +58,29 @@ Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents bugs where we accidentally share e.g. signature dicts. However, converting a dict to frozen_dicts is expensive. -NOTE: This is overridden by the configuration by the Synapse worker apps, but -for the sake of tests, it is set here because it cannot be configured on the -homeserver object itself. - -FIXME: Because of how this option works (changing the underlying types), it causes -subtle downstream bugs that makes type comparisons brittle, tracked by -https://github.com/element-hq/synapse/issues/18117 +FIXME: Remove `USE_FROZEN_DICTS` and `use_frozen_dicts` config as this is no +longer used since we switched to using the Rust implementation, all events are +immutable already (and so don't benefit from freezing). """ -T = TypeVar("T") - - -# DictProperty (and DefaultDictProperty) require the classes they're used with to -# have a _dict property to pull properties from. -# -# TODO _DictPropertyInstance should not include EventBuilder but due to -# https://github.com/python/mypy/issues/5570 it thinks the DictProperty and -# DefaultDictProperty get applied to EventBuilder when it is in a Union with -# EventBase. This is the least invasive hack to get mypy to comply. -# -# Note that DictProperty/DefaultDictProperty cannot actually be used with -# EventBuilder as it lacks a _dict property. -_DictPropertyInstance = Union["EventBase", "EventBuilder"] - - -class DictProperty(Generic[T]): - """An object property which delegates to the `_dict` within its parent object.""" - - __slots__ = ["key"] - - def __init__(self, key: str): - self.key = key - - @overload - def __get__( - self, - instance: Literal[None], - owner: type[_DictPropertyInstance] | None = None, - ) -> "DictProperty": ... - - @overload - def __get__( - self, - instance: _DictPropertyInstance, - owner: type[_DictPropertyInstance] | None = None, - ) -> T: ... - - def __get__( - self, - instance: _DictPropertyInstance | None, - owner: type[_DictPropertyInstance] | None = None, - ) -> T | "DictProperty": - # if the property is accessed as a class property rather than an instance - # property, return the property itself rather than the value - if instance is None: - return self - try: - assert isinstance(instance, EventBase) - return instance._dict[self.key] - except KeyError as e1: - # We want this to look like a regular attribute error (mostly so that - # hasattr() works correctly), so we convert the KeyError into an - # AttributeError. - # - # To exclude the KeyError from the traceback, we explicitly - # 'raise from e1.__context__' (which is better than 'raise from None', - # because that would omit any *earlier* exceptions). - # - raise AttributeError( - "'%s' has no '%s' property" % (type(instance), self.key) - ) from e1.__context__ - - def __set__(self, instance: _DictPropertyInstance, v: T) -> None: - assert isinstance(instance, EventBase) - instance._dict[self.key] = v - - def __delete__(self, instance: _DictPropertyInstance) -> None: - assert isinstance(instance, EventBase) - try: - del instance._dict[self.key] - except KeyError as e1: - raise AttributeError( - "'%s' has no '%s' property" % (type(instance), self.key) - ) from e1.__context__ - - -class DefaultDictProperty(DictProperty, Generic[T]): - """An extension of DictProperty which provides a default if the property is - not present in the parent's _dict. - - Note that this means that hasattr() on the property always returns True. - """ - - __slots__ = ["default"] - - def __init__(self, key: str, default: T): - super().__init__(key) - self.default = default - - @overload - def __get__( - self, - instance: Literal[None], - owner: type[_DictPropertyInstance] | None = None, - ) -> "DefaultDictProperty": ... - - @overload - def __get__( - self, - instance: _DictPropertyInstance, - owner: type[_DictPropertyInstance] | None = None, - ) -> T: ... - - def __get__( - self, - instance: _DictPropertyInstance | None, - owner: type[_DictPropertyInstance] | None = None, - ) -> T | "DefaultDictProperty": - if instance is None: - return self - assert isinstance(instance, EventBase) - return instance._dict.get(self.key, self.default) - - -class EventBase(metaclass=abc.ABCMeta): - @property - @abc.abstractmethod - def format_version(self) -> int: - """The EventFormatVersion implemented by this event""" - ... - - def __init__( - self, - event_dict: JsonDict, - room_version: RoomVersion, - signatures: dict[str, dict[str, str]], - unsigned: JsonDict, - internal_metadata_dict: JsonDict, - rejected_reason: str | None, - ): - assert room_version.event_format == self.format_version - - if "content" in event_dict: - event_dict["content"] = JsonObject(event_dict["content"]) - - # We intern these strings because they turn up a lot (especially when - # caching). - event_dict = intern_dict(event_dict) - - if USE_FROZEN_DICTS: - frozen_dict = freeze(event_dict) - else: - frozen_dict = event_dict - - self.room_version = room_version - self.signatures = Signatures(signatures) - self.unsigned = Unsigned(unsigned) - self.rejected_reason = rejected_reason - - self._dict = frozen_dict - - self.internal_metadata = EventInternalMetadata(internal_metadata_dict) - - depth: DictProperty[int] = DictProperty("depth") - content: DictProperty[JsonMapping] = DictProperty("content") - hashes: DictProperty[dict[str, str]] = DictProperty("hashes") - origin_server_ts: DictProperty[int] = DictProperty("origin_server_ts") - sender: DictProperty[str] = DictProperty("sender") - # TODO state_key should be str | None. This is generally asserted in Synapse - # by calling is_state() first (which ensures it is not None), but it is hard (not possible?) - # to properly annotate that calling is_state() asserts that state_key exists - # and is non-None. It would be better to replace such direct references with - # get_state_key() (and a check for None). - state_key: DictProperty[str] = DictProperty("state_key") - type: DictProperty[str] = DictProperty("type") - - # This is a deprecated property, use `sender` instead. Only used by modules. - user_id: DictProperty[str] = DictProperty("sender") - - @property - def event_id(self) -> str: - raise NotImplementedError() - - @property - def room_id(self) -> str: - raise NotImplementedError() - - @property - def membership(self) -> str: - return self.content["membership"] - - @property - def redacts(self) -> str | None: - """MSC2176 moved the redacts field into the content.""" - if self.room_version.updated_redaction_rules: - return self.content.get("redacts") - return self.get("redacts") - - def is_state(self) -> bool: - return self.get_state_key() is not None - - def get_state_key(self) -> str | None: - """Get the state key of this event, or None if it's not a state event""" - return self._dict.get("state_key") - - def get_dict(self) -> JsonDict: - """Convert the event to a dictionary suitable for serialisation.""" - - d = dict(self._dict) - if "content" in d: - # Convert the content (which is a JsonObject) back to a dict. Json - # serialization should handle JsonObjects fine, but for sanities - # sake we want `get_dict()` and `get_pdu_json()` to return plain - # dicts. - d["content"] = dict(self.content) - d.update( - { - "signatures": self.signatures.as_dict(), - "unsigned": self.unsigned.for_event(), - } - ) - - return d - - def get_dict_for_persistence(self) -> JsonDict: - """Convert the event to a dictionary suitable for persistence.""" - d = dict(self._dict) - d.update( - { - "signatures": self.signatures.as_dict(), - "unsigned": self.unsigned.for_persistence(), - } - ) - - return d - - def get(self, key: str, default: Any | None = None) -> Any: - return self._dict.get(key, default) - - def get_internal_metadata_dict(self) -> JsonDict: - return self.internal_metadata.get_dict() - - def get_pdu_json(self, time_now: int | None = None) -> JsonDict: - pdu_json = self.get_dict() - - if time_now is not None and "age_ts" in pdu_json["unsigned"]: - age = time_now - pdu_json["unsigned"]["age_ts"] - pdu_json.setdefault("unsigned", {})["age"] = int(age) - del pdu_json["unsigned"]["age_ts"] - - # This may be a frozen event - pdu_json["unsigned"].pop("redacted_because", None) - - return pdu_json - - def get_templated_pdu_json(self) -> JsonDict: - """ - Return a JSON object suitable for a templated event, as used in the - make_{join,leave,knock} workflow. - """ - # By using _dict directly we don't pull in signatures/unsigned. - template_json = dict(self._dict) - # The hashes (similar to the signature) need to be recalculated by the - # joining/leaving/knocking server after (potentially) modifying the - # event. - template_json.pop("hashes") - - return template_json - - def __contains__(self, field: str) -> bool: - return field in self._dict - - def items(self) -> list[tuple[str, Any | None]]: - return list(self._dict.items()) - - def keys(self) -> Iterable[str]: - return self._dict.keys() - - def prev_event_ids(self) -> list[str]: - """Returns the list of prev event IDs. The order matches the order - specified in the event, though there is no meaning to it. - - Returns: - The list of event IDs of this event's prev_events - """ - return [e for e, _ in self._dict["prev_events"]] - - def auth_event_ids(self) -> StrCollection: - """Returns the list of auth event IDs. The order matches the order - specified in the event, though there is no meaning to it. - - Returns: - The list of event IDs of this event's auth_events - """ - return [e for e, _ in self._dict["auth_events"]] - - def freeze(self) -> None: - """'Freeze' the event dict, so it cannot be modified by accident""" - - # this will be a no-op if the event dict is already frozen. - self._dict = freeze(self._dict) - - def sticky_duration(self) -> Duration | None: - """ - Returns the effective sticky duration of this event, or None - if the event does not have a sticky duration. - (Sticky Events are a MSC4354 feature.) - - Clamps the sticky duration to the maximum allowed duration. - """ - sticky_obj = self.get_dict().get(StickyEvent.EVENT_FIELD_NAME, None) - if type(sticky_obj) is not dict: - return None - sticky_duration_ms = sticky_obj.get("duration_ms", None) - # MSC: Clamp to 0 and MAX_DURATION (1 hour) - # We use `type(...) is int` to avoid accepting bools as `isinstance(True, int)` - # (bool is a subclass of int) - if type(sticky_duration_ms) is int and sticky_duration_ms >= 0: - return min( - Duration(milliseconds=sticky_duration_ms), - StickyEvent.MAX_DURATION, - ) - return None - - def __str__(self) -> str: - return self.__repr__() - - def __repr__(self) -> str: - rejection = f"REJECTED={self.rejected_reason}, " if self.rejected_reason else "" - - conditional_membership_string = "" - if self.get("type") == EventTypes.Member: - conditional_membership_string = f"membership={self.membership}, " - - return ( - f"<{self.__class__.__name__} " - f"{rejection}" - f"event_id={self.event_id}, " - f"type={self.get('type')}, " - f"state_key={self.get('state_key')}, " - f"{conditional_membership_string}" - f"outlier={self.internal_metadata.is_outlier()}" - ">" - ) - - # Using `__getitem__` is deprecated. Only used by modules. - @deprecated("Use attribute access instead") - def __getitem__(self, field: str) -> Any | None: - return self._dict[field] - - -class FrozenEvent(EventBase): - format_version = EventFormatVersions.ROOM_V1_V2 # All events of this type are V1 - - def __init__( - self, - event_dict: JsonDict, - room_version: RoomVersion, - internal_metadata_dict: JsonDict | None = None, - rejected_reason: str | None = None, - ): - internal_metadata_dict = internal_metadata_dict or {} - - event_dict = dict(event_dict) - - # Signatures is a dict of dicts, and this is faster than doing a - # copy.deepcopy - signatures = { - name: dict(sigs.items()) - for name, sigs in event_dict.pop("signatures", {}).items() - } - - unsigned = event_dict.pop("unsigned", {}) - - self._event_id = event_dict["event_id"] - - super().__init__( - event_dict, - room_version=room_version, - signatures=signatures, - unsigned=unsigned, - internal_metadata_dict=internal_metadata_dict, - rejected_reason=rejected_reason, - ) - - @property - def event_id(self) -> str: - return self._event_id - - @property - def room_id(self) -> str: - return self._dict["room_id"] - - -class FrozenEventV2(EventBase): - format_version = EventFormatVersions.ROOM_V3 # All events of this type are V2 - - def __init__( - self, - event_dict: JsonDict, - room_version: RoomVersion, - internal_metadata_dict: JsonDict | None = None, - rejected_reason: str | None = None, - ): - internal_metadata_dict = internal_metadata_dict or {} - - event_dict = dict(event_dict) - - # Signatures is a dict of dicts, and this is faster than doing a - # copy.deepcopy - signatures = { - name: dict(sigs.items()) - for name, sigs in event_dict.pop("signatures", {}).items() - } - - assert "event_id" not in event_dict - - unsigned = event_dict.pop("unsigned", {}) - - self._event_id: str | None = None - - super().__init__( - event_dict, - room_version=room_version, - signatures=signatures, - unsigned=unsigned, - internal_metadata_dict=internal_metadata_dict, - rejected_reason=rejected_reason, - ) - - @property - def event_id(self) -> str: - # We have to import this here as otherwise we get an import loop which - # is hard to break. - from synapse.crypto.event_signing import compute_event_reference_hash - - if self._event_id: - return self._event_id - self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1]) - return self._event_id - - @property - def room_id(self) -> str: - return self._dict["room_id"] - - def prev_event_ids(self) -> list[str]: - """Returns the list of prev event IDs. The order matches the order - specified in the event, though there is no meaning to it. - - Returns: - The list of event IDs of this event's prev_events - """ - return self._dict["prev_events"] - - def auth_event_ids(self) -> StrCollection: - """Returns the list of auth event IDs. The order matches the order - specified in the event, though there is no meaning to it. - - Returns: - The list of event IDs of this event's auth_events - """ - return self._dict["auth_events"] - - -class FrozenEventV3(FrozenEventV2): - """FrozenEventV3, which differs from FrozenEventV2 only in the event_id format""" - - format_version = EventFormatVersions.ROOM_V4_PLUS # All events of this type are V3 - - @property - def event_id(self) -> str: - # We have to import this here as otherwise we get an import loop which - # is hard to break. - from synapse.crypto.event_signing import compute_event_reference_hash - - if self._event_id: - return self._event_id - self._event_id = "$" + encode_base64( - compute_event_reference_hash(self)[1], urlsafe=True - ) - return self._event_id - - -class FrozenEventV4(FrozenEventV3): - """FrozenEventV4 for MSC4291 room IDs are hashes""" - - format_version = EventFormatVersions.ROOM_V11_HYDRA_PLUS - - """Override the room_id for m.room.create events""" - - def __init__( - self, - event_dict: JsonDict, - room_version: RoomVersion, - internal_metadata_dict: JsonDict | None = None, - rejected_reason: str | None = None, - ): - super().__init__( - event_dict=event_dict, - room_version=room_version, - internal_metadata_dict=internal_metadata_dict, - rejected_reason=rejected_reason, - ) - self._room_id: str | None = None - - @property - def room_id(self) -> str: - # if we have calculated the room ID already, don't do it again. - if self._room_id: - return self._room_id - - is_create_event = self.type == EventTypes.Create and self.get_state_key() == "" - - # for non-create events: use the supplied value from the JSON, as per FrozenEventV3 - if not is_create_event: - self._room_id = self._dict["room_id"] - assert self._room_id is not None - return self._room_id - - # for create events: calculate the room ID - from synapse.crypto.event_signing import compute_event_reference_hash - - self._room_id = "!" + encode_base64( - compute_event_reference_hash(self)[1], urlsafe=True - ) - return self._room_id - - def auth_event_ids(self) -> StrCollection: - """Returns the list of auth event IDs. The order matches the order - specified in the event, though there is no meaning to it. - Returns: - The list of event IDs of this event's auth_events - Includes the creation event ID for convenience of all the codepaths - which expects the auth chain to include the creator ID, even though - it's explicitly not included on the wire. Excludes the create event - for the create event itself. - """ - create_event_id = "$" + self.room_id[1:] - assert create_event_id not in self._dict["auth_events"] - if self.type == EventTypes.Create and self.get_state_key() == "": - return self._dict["auth_events"] # should be [] - return [*self._dict["auth_events"], create_event_id] - - -class FrozenEventVMSC4242(FrozenEventV4): - """FrozenEventVMSC4242, which differs from FrozenEventV4 only in the addition of prev_state_events""" - - format_version = EventFormatVersions.ROOM_VMSC4242 - prev_state_events: DictProperty[StrCollection] = DictProperty("prev_state_events") - - def __init__( - self, - event_dict: JsonDict, - room_version: RoomVersion, - internal_metadata_dict: JsonDict | None = None, - rejected_reason: str | None = None, - ): - # Similar to how we assert event_id isn't in V2+ events, we do the same with auth_events. - # We don't expect `auth_events` in the wire format because we calculate it from prev_state_events. - assert "auth_events" not in event_dict - super().__init__( - event_dict=event_dict, - room_version=room_version, - internal_metadata_dict=internal_metadata_dict, - rejected_reason=rejected_reason, - ) - - def auth_event_ids(self) -> StrCollection: - """Returns the list of _calculated_ auth event IDs. - - Returns: - The list of event IDs of this event's auth events - """ - # Catches cases where we accidentally call auth_event_ids() prior to calculating what they - # actually are. The exception being the m.room.create event which has no auth events. - if self.type != EventTypes.Create: - assert len(self.internal_metadata.calculated_auth_event_ids) > 0 - return self.internal_metadata.calculated_auth_event_ids - - def __repr__(self) -> str: - rejection = f"REJECTED={self.rejected_reason}, " if self.rejected_reason else "" - - return ( - f"<{self.__class__.__name__} " - f"{rejection}" - f"event_id={self.event_id}, " - f"type={self.get('type')}, " - f"state_key={self.get('state_key')}, " - f"prev_events={self.get('prev_events')}, " - f"prev_state_events={self.get('prev_state_events')}, " - f"outlier={self.internal_metadata.is_outlier()}" - ">" - ) - - -def _event_type_from_format_version( - format_version: int, -) -> type[FrozenEvent | FrozenEventV2 | FrozenEventV3 | FrozenEventVMSC4242]: - """Returns the python type to use to construct an Event object for the - given event format version. - - Args: - format_version: The event format version - - Returns: - A type that can be initialized as per the initializer of `FrozenEvent` - """ - - if format_version == EventFormatVersions.ROOM_V1_V2: - return FrozenEvent - elif format_version == EventFormatVersions.ROOM_V3: - return FrozenEventV2 - elif format_version == EventFormatVersions.ROOM_V4_PLUS: - return FrozenEventV3 - elif format_version == EventFormatVersions.ROOM_VMSC4242: - return FrozenEventVMSC4242 - elif format_version == EventFormatVersions.ROOM_V11_HYDRA_PLUS: - return FrozenEventV4 - else: - raise Exception("No event format %r" % (format_version,)) - def make_event_from_dict( event_dict: JsonDict, room_version: RoomVersion = RoomVersions.V1, internal_metadata_dict: JsonDict | None = None, rejected_reason: str | None = None, -) -> EventBase: +) -> Event: """Construct an EventBase from the given event dict""" - event_type = _event_type_from_format_version(room_version.event_format) - return event_type( - event_dict, room_version, internal_metadata_dict or {}, rejected_reason - ) + + try: + return Event( + event_dict=event_dict, + room_version=room_version, + internal_metadata_dict=internal_metadata_dict or {}, + rejected_reason=rejected_reason, + ) + except ValueError: + raise SynapseError(400, "Invalid event dict", Codes.BAD_JSON) @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -716,7 +94,7 @@ class _EventRelation: aggregation_key: str | None -def relation_from_event(event: EventBase) -> _EventRelation | None: +def relation_from_event(event: Event) -> _EventRelation | None: """ Attempt to parse relation information an event. diff --git a/synapse/events/py_protocol.py b/synapse/events/py_protocol.py index d9ac8c066f..a34170cd93 100644 --- a/synapse/events/py_protocol.py +++ b/synapse/events/py_protocol.py @@ -33,7 +33,6 @@ predicates here when a new room-version feature gates access to additional attributes. """ -import abc from typing import ( TYPE_CHECKING, Sequence, @@ -42,12 +41,14 @@ from typing import ( from typing_extensions import TypeIs from synapse.events import EventBase +from synapse.synapse_rust.events import Event +from synapse.types import StrCollection if TYPE_CHECKING: from synapse.events.snapshot import EventContext, EventPersistencePair -class _DisableIsInstance(abc.ABCMeta): +class _DisableIsInstance(type): """Metaclass which disables isinstance checks on classes which use it, by making isinstance() raise NotImplementedError. @@ -61,15 +62,38 @@ class _DisableIsInstance(abc.ABCMeta): raise NotImplementedError("Instance cannot be used.") -class EventProtocol(EventBase, metaclass=_DisableIsInstance): - """Helper subclass that allows type narrowing for `EventBase` objects.""" +# We now define `EventProtocol` as a helper class for type narrowing. +# +# During type checking, we want the type narrowed event classes to still have +# all the fields as a normal `Event`, so we make `EventProtocol` a subclass of +# `Event`. +# +# However, at runtime we a) can't subclass `Event` because it's a Rust class, +# and b) don't want to allow `isinstance` checks against `EventProtocol` (as +# it's purely a type annotation helper, not a real class). So at runtime, we +# make `EventProtocol` a class with a metaclass that raises on `isinstance` +# checks. +if TYPE_CHECKING: + + class EventProtocol(Event): + """Helper subclass that allows type narrowing for `EventBase` objects.""" + +else: + + class EventProtocol(metaclass=_DisableIsInstance): + """Helper subclass that allows type narrowing for `EventBase` objects.""" + + def __new__(cls): + raise NotImplementedError( + f"{cls.__name__} cannot be instantiated as it is not a real class." + ) class MSC4242Event(EventProtocol): """Marker protocol for events in MSC4242 rooms. This allows us to narrow the type of events.""" - prev_state_events: list[str] + prev_state_events: StrCollection def supports_msc4242_state_dag(event: EventBase) -> TypeIs[MSC4242Event]: diff --git a/synapse/events/utils.py b/synapse/events/utils.py index adbede7f16..54f662796b 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -39,18 +39,16 @@ from synapse.api.constants import ( CANONICALJSON_MAX_INT, CANONICALJSON_MIN_INT, MAX_PDU_SIZE, - EventContentFields, EventTypes, EventUnsignedContentFields, RelationTypes, ) from synapse.api.errors import Codes, SynapseError -from synapse.api.room_versions import RoomVersion from synapse.logging.opentracing import SynapseTags, set_tag, trace -from synapse.synapse_rust.events import Unsigned +from synapse.synapse_rust.events import Unsigned, redact_event from synapse.types import JsonDict, Requester -from . import EventBase, FrozenEventV2, StrippedStateEvent, make_event_from_dict +from . import EventBase, StrippedStateEvent if TYPE_CHECKING: from synapse.handlers.relations import BundledAggregations @@ -78,177 +76,18 @@ def prune_event(event: EventBase) -> EventBase: the user has specified, but we do want to keep necessary information like type, state_key etc. """ - pruned_event_dict = prune_event_dict(event.room_version, event.get_dict()) - - pruned_event = make_event_from_dict( - pruned_event_dict, event.room_version, event.internal_metadata.get_dict() - ) - - # Copy the bits of `internal_metadata` that aren't returned by `get_dict` - pruned_event.internal_metadata.stream_ordering = ( - event.internal_metadata.stream_ordering - ) - pruned_event.internal_metadata.instance_name = event.internal_metadata.instance_name - pruned_event.internal_metadata.outlier = event.internal_metadata.outlier - pruned_event.internal_metadata.redacted_by = event.internal_metadata.redacted_by - - # Mark the event as redacted - pruned_event.internal_metadata.redacted = True - - return pruned_event + return redact_event(event) def clone_event(event: EventBase) -> EventBase: """Take a copy of the event. - This is mostly useful because it does a *shallow* copy of the `unsigned` data, - which means it can then be updated without corrupting the in-memory cache. Note that - other properties of the event, such as `content`, are *not* (currently) copied here. - """ - # XXX: We rely on at least one of `event.get_dict()` and `make_event_from_dict()` - # making a copy of `unsigned`. Currently, both do, though I don't really know why. - # Still, as long as they do, there's not much point doing yet another copy here. - new_event = make_event_from_dict( - event.get_dict(), event.room_version, event.internal_metadata.get_dict() - ) - - # Starting FrozenEventV2, the event ID is an (expensive) hash of the event. This is - # lazily computed when we get the FrozenEventV2.event_id property, then cached in - # _event_id field. Later FrozenEvent formats all inherit from FrozenEventV2, so we - # can use the same logic here. - if isinstance(event, FrozenEventV2) and isinstance(new_event, FrozenEventV2): - # If we already pre-computed the event ID, use it. - new_event._event_id = event._event_id - - # Copy the bits of `internal_metadata` that aren't returned by `get_dict`. - new_event.internal_metadata.stream_ordering = ( - event.internal_metadata.stream_ordering - ) - new_event.internal_metadata.instance_name = event.internal_metadata.instance_name - new_event.internal_metadata.outlier = event.internal_metadata.outlier - new_event.internal_metadata.redacted_by = event.internal_metadata.redacted_by - - return new_event - - -def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDict: - """Redacts the event_dict in the same way as `prune_event`, except it - operates on dicts rather than event objects - - Returns: - A copy of the pruned event dict + Most fields of the event are immutable, however fields such as `unsigned`, + `signatures` and `internal_metadata` are mutable. Cloning the event allows + us to edit such fields without affecting the original event. """ - allowed_keys = [ - "event_id", - "sender", - "room_id", - "hashes", - "signatures", - "content", - "type", - "state_key", - "depth", - "prev_events", - "auth_events", - "origin_server_ts", - ] - - # Earlier room versions from had additional allowed keys. - if not room_version.updated_redaction_rules: - allowed_keys.extend(["prev_state", "membership", "origin"]) - # Custom room versions add new allowed keys and remove others - if room_version.msc4242_state_dags: - allowed_keys.extend(["prev_state_events"]) - allowed_keys.remove("auth_events") - - event_type = event_dict["type"] - - new_content = {} - - def add_fields(*fields: str) -> None: - for field in fields: - if field in event_dict["content"]: - new_content[field] = event_dict["content"][field] - - if event_type == EventTypes.Member: - add_fields("membership") - if room_version.restricted_join_rule_fix: - add_fields(EventContentFields.AUTHORISING_USER) - if room_version.updated_redaction_rules: - # Preserve the signed field under third_party_invite. - third_party_invite = event_dict["content"].get("third_party_invite") - if isinstance(third_party_invite, collections.abc.Mapping): - new_content["third_party_invite"] = {} - if "signed" in third_party_invite: - new_content["third_party_invite"]["signed"] = third_party_invite[ - "signed" - ] - - elif event_type == EventTypes.Create: - if room_version.updated_redaction_rules: - # MSC2176 rules state that create events cannot have their `content` redacted. - new_content = event_dict["content"] - if not room_version.implicit_room_creator: - # Some room versions give meaning to `creator` - add_fields("creator") - if room_version.msc4291_room_ids_as_hashes: - # room_id is not allowed on the create event as it's derived from the event ID - allowed_keys.remove("room_id") - - elif event_type == EventTypes.JoinRules: - add_fields("join_rule") - if room_version.restricted_join_rule: - add_fields("allow") - elif event_type == EventTypes.PowerLevels: - add_fields( - "users", - "users_default", - "events", - "events_default", - "state_default", - "ban", - "kick", - "redact", - ) - - if room_version.updated_redaction_rules: - add_fields("invite") - - elif event_type == EventTypes.Aliases and room_version.special_case_aliases_auth: - add_fields("aliases") - elif event_type == EventTypes.RoomHistoryVisibility: - add_fields("history_visibility") - elif event_type == EventTypes.Redaction and room_version.updated_redaction_rules: - add_fields("redacts") - - # Protect the rel_type and event_id fields under the m.relates_to field. - if room_version.msc3389_relation_redactions: - relates_to = event_dict["content"].get("m.relates_to") - if isinstance(relates_to, collections.abc.Mapping): - new_relates_to = {} - for field in ("rel_type", "event_id"): - if field in relates_to: - new_relates_to[field] = relates_to[field] - # Only include a non-empty relates_to field. - if new_relates_to: - new_content["m.relates_to"] = new_relates_to - - allowed_fields = {k: v for k, v in event_dict.items() if k in allowed_keys} - - allowed_fields["content"] = new_content - - unsigned: JsonDict = {} - allowed_fields["unsigned"] = unsigned - - event_unsigned = event_dict.get("unsigned", {}) - - if "age_ts" in event_unsigned: - unsigned["age_ts"] = event_unsigned["age_ts"] - if "replaces_state" in event_unsigned: - unsigned["replaces_state"] = event_unsigned["replaces_state"] - - return allowed_fields + return event.deep_copy() def _copy_field(src: JsonDict, dst: JsonDict, field: list[str]) -> None: diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 2b5ef5fbac..5b33454325 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -21,7 +21,6 @@ # -import copy import itertools import logging from typing import ( @@ -1192,7 +1191,7 @@ class FederationClient(FederationBase): # NB: We *need* to copy to ensure that we don't have multiple # references being passed on, as that causes... issues. signed_state = [ - copy.copy(valid_pdus_map[p.event_id]) + valid_pdus_map[p.event_id].deep_copy() for p in state if p.event_id in valid_pdus_map ] @@ -1203,11 +1202,6 @@ class FederationClient(FederationBase): if p.event_id in valid_pdus_map ] - # NB: We *need* to copy to ensure that we don't have multiple - # references being passed on, as that causes... issues. - for s in signed_state: - s.internal_metadata = s.internal_metadata.copy() - # double-check that the auth chain doesn't include a different create event auth_chain_create_events = [ e.event_id diff --git a/synapse/module_api/callbacks/third_party_event_rules_callbacks.py b/synapse/module_api/callbacks/third_party_event_rules_callbacks.py index 65f5a6b183..68053d65a3 100644 --- a/synapse/module_api/callbacks/third_party_event_rules_callbacks.py +++ b/synapse/module_api/callbacks/third_party_event_rules_callbacks.py @@ -283,11 +283,6 @@ class ThirdPartyEventRulesModuleApiCallbacks: events = await self.store.get_events(prev_state_ids.values()) state_events = {(ev.type, ev.state_key): ev for ev in events.values()} - # Ensure that the event is frozen, to make sure that the module is not tempted - # to try to modify it. Any attempt to modify it at this point will invalidate - # the hashes and signatures. - event.freeze() - for callback in self._check_event_allowed_callbacks: try: res, replacement_data = await delay_cancellation( diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 3783211a92..f6693e0923 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -58,7 +58,14 @@ from synapse.rest.admin._base import ( from synapse.rest.client.room import SerializeMessagesDeps, encode_messages_response from synapse.storage.databases.main.room import RoomSortOrder from synapse.streams.config import PaginationConfig -from synapse.types import JsonDict, RoomID, ScheduledTask, UserID, create_requester +from synapse.types import ( + JsonDict, + JsonMapping, + RoomID, + ScheduledTask, + UserID, + create_requester, +) from synapse.types.state import StateFilter if TYPE_CHECKING: @@ -682,6 +689,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): create_event = filtered_room_state[(EventTypes.Create, "")] power_levels = filtered_room_state.get((EventTypes.PowerLevels, "")) + pl_content: JsonMapping if power_levels is not None: # We pick the local user with the highest power. user_power = power_levels.content.get("users", {}) diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index 62e84f5ac5..13f958d998 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -36,7 +36,6 @@ from typing import ( Iterable, Sequence, TypeVar, - cast, ) import attr @@ -870,7 +869,7 @@ class EventsPersistenceStorageController: ) new_room_dag_fwd_extrems = await self._calculate_new_extremities( room_id, - cast(list[EventPersistencePair], event_contexts), + event_contexts, existing_room_dag_fwd_extrems, ) assert new_room_dag_fwd_extrems, ( @@ -889,7 +888,7 @@ class EventsPersistenceStorageController: ): (current_state, delta_ids, _) = await self._get_new_state_after_events( room_id, - cast(list[EventPersistencePair], event_contexts), + event_contexts, existing_state_dag_fwd_extrems, new_state_dag_fwd_extrems, # do not prune forward extremities in the state DAG @@ -923,7 +922,7 @@ class EventsPersistenceStorageController: # extremities. is_still_joined = await self._is_server_still_joined( room_id, - cast(list[EventPersistencePair], event_contexts), + event_contexts, delta, ) if not is_still_joined: @@ -1053,7 +1052,7 @@ class EventsPersistenceStorageController: async def _calculate_new_extremities( self, room_id: str, - event_contexts: list[EventPersistencePair], + event_contexts: Sequence[EventPersistencePair], latest_event_ids: AbstractSet[str], ) -> set[str]: """Calculates the new forward extremities for a room given events to @@ -1113,7 +1112,7 @@ class EventsPersistenceStorageController: async def _get_new_state_after_events( self, room_id: str, - events_context: list[EventPersistencePair], + events_context: Sequence[EventPersistencePair], old_latest_event_ids: AbstractSet[str], new_latest_event_ids: set[str], should_prune: bool = True, @@ -1297,7 +1296,7 @@ class EventsPersistenceStorageController: new_latest_event_ids: set[str], resolved_state_group: int, event_id_to_state_group: dict[str, int], - events_context: list[EventPersistencePair], + events_context: Sequence[EventPersistencePair], ) -> set[str]: """See if we can prune any of the extremities after calculating the resolved state. @@ -1434,7 +1433,7 @@ class EventsPersistenceStorageController: async def _is_server_still_joined( self, room_id: str, - ev_ctx_rm: list[EventPersistencePair], + ev_ctx_rm: Sequence[EventPersistencePair], delta: DeltaState, ) -> bool: """Check if the server will still be joined after the given events have diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py index a5ae4bf506..92770b6590 100644 --- a/synapse/storage/databases/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py @@ -22,7 +22,6 @@ import logging from typing import TYPE_CHECKING -from synapse.events.utils import prune_event_dict from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -32,6 +31,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.synapse_rust.events import redact_event from synapse.util.duration import Duration from synapse.util.json import json_encoder @@ -123,9 +123,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase ): # Redaction was allowed pruned_json: str | None = json_encoder.encode( - prune_event_dict( - original_event.room_version, original_event.get_dict() - ) + redact_event(original_event).get_pdu_json() ) else: # Redaction wasn't allowed @@ -190,9 +188,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase return # Prune the event's dict then convert it to JSON. - pruned_json = json_encoder.encode( - prune_event_dict(event.room_version, event.get_dict()) - ) + pruned_json = json_encoder.encode(redact_event(event).get_pdu_json()) # Update the event_json table to replace the event's JSON with the pruned # JSON. diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 5aab0067fc..84b38f4bf2 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -915,7 +915,9 @@ class PersistEventsStore: # instances as we'll potentially be pulling more events from the DB and # we don't need the overhead of fetching/parsing the full event JSON. event_to_types = {e.event_id: (e.type, e.state_key) for e in state_events} - event_to_auth_chain = {e.event_id: e.auth_event_ids() for e in state_events} + event_to_auth_chain: dict[str, StrCollection] = { + e.event_id: e.auth_event_ids() for e in state_events + } event_to_room_id = {e.event_id: e.room_id for e in state_events} return self._calculate_chain_cover_index( diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index c0d218398d..9a11f9b9bb 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -38,7 +38,6 @@ from synapse.crypto.event_signing import ( resign_event, ) from synapse.events import EventBase, make_event_from_dict -from synapse.events.utils import prune_event_dict from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, @@ -63,6 +62,7 @@ from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.storage.databases.main.stream import StreamWorkerStore from synapse.storage.engines import PostgresEngine from synapse.storage.types import Cursor +from synapse.synapse_rust.events import redact_event from synapse.types import JsonDict, RoomStreamToken, StateMap, StrCollection from synapse.types.handlers import SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES from synapse.types.state import StateFilter @@ -2831,7 +2831,7 @@ class EventsBackgroundUpdatesStore( # Verify the signature is genuinely from this key. We prune # first since signatures are computed over the redacted form. - pruned = prune_event_dict(event.room_version, event.get_pdu_json()) + pruned = redact_event(event).get_pdu_json() try: verify_signed_json(pruned, self.hs.hostname, old_verify_key) except SignatureVerifyException: diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 6f26bd17ce..4b372de141 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -37,6 +37,7 @@ from typing import ( import attr from prometheus_client import Gauge +from typing_extensions import assert_never from twisted.internet import defer @@ -775,6 +776,14 @@ class EventsWorkerStore(SQLBaseStore): continue elif redact_behaviour == EventRedactBehaviour.redact: event = entry.redacted_event + elif redact_behaviour == EventRedactBehaviour.as_is: + # Allow event through as is + pass + else: + # We (should) have covered all possible values of + # redact_behaviour, so this is unreachable. + assert_never(redact_behaviour) + raise ValueError(f"Unknown redact_behaviour {redact_behaviour}") events.append(event) @@ -1507,12 +1516,17 @@ class EventsWorkerStore(SQLBaseStore): ) continue - original_ev = make_event_from_dict( - event_dict=d, - room_version=room_version, - internal_metadata_dict=internal_metadata, - rejected_reason=rejected_reason, - ) + try: + original_ev = make_event_from_dict( + event_dict=d, + room_version=room_version, + internal_metadata_dict=internal_metadata, + rejected_reason=rejected_reason, + ) + except SynapseError: + logger.error("Unable to parse event from database: %s", event_id) + continue + original_ev.internal_metadata.stream_ordering = row.stream_ordering original_ev.internal_metadata.instance_name = row.instance_name original_ev.internal_metadata.outlier = row.outlier diff --git a/synapse/synapse_rust/events.pyi b/synapse/synapse_rust/events.pyi index 5b55d47f0d..9a69438c1d 100644 --- a/synapse/synapse_rust/events.pyi +++ b/synapse/synapse_rust/events.pyi @@ -12,7 +12,9 @@ from typing import Any, Iterator, Mapping -from synapse.types import JsonDict, JsonMapping +from synapse.synapse_rust.room_versions import RoomVersion +from synapse.types import JsonDict, JsonMapping, StrSequence +from synapse.util.duration import Duration class EventInternalMetadata: def __init__(self, internal_metadata_dict: JsonDict): ... @@ -159,60 +161,60 @@ class Signatures: """A class representing the signatures on an event.""" def __init__(self, signatures: Mapping[str, Mapping[str, str]] | None = None): ... - def get_signature(self, server_name: str, key_id: str) -> str | None: ... - """Get the signature for the given server name and key ID, if it exists.""" + def get_signature(self, server_name: str, key_id: str) -> str | None: + """Get the signature for the given server name and key ID, if it exists.""" - def __getitem__(self, server_name: str) -> Mapping[str, str]: ... - """Get the signatures for the given server name. Raises KeyError if there - are no signatures for that server.""" + def __getitem__(self, server_name: str) -> Mapping[str, str]: + """Get the signatures for the given server name. Raises KeyError if there + are no signatures for that server.""" - def __contains__(self, server_name: Any) -> bool: ... - """Check if there are signatures for the given server name.""" + def __contains__(self, server_name: Any) -> bool: + """Check if there are signatures for the given server name.""" - def __len__(self) -> int: ... - """Return the number of servers that have signatures.""" + def __len__(self) -> int: + """Return the number of servers that have signatures.""" - def add_signature(self, server_name: str, key_id: str, signature: str) -> None: ... - """Add a signature for the given server name and key ID.""" + def add_signature(self, server_name: str, key_id: str, signature: str) -> None: + """Add a signature for the given server name and key ID.""" - def update(self, signatures: Mapping[str, Mapping[str, str]]) -> None: ... - """Update the signatures with the given signatures. + def update(self, signatures: Mapping[str, Mapping[str, str]]) -> None: + """Update the signatures with the given signatures. - Will overwrite all existing signatures for the server names provided. - """ + Will overwrite all existing signatures for the server names provided. + """ - def as_dict(self) -> dict[str, dict[str, str]]: ... - """Return a copy of the signatures as a dictionary.""" + def as_dict(self) -> dict[str, dict[str, str]]: + """Return a copy of the signatures as a dictionary.""" class Unsigned: """A class representing the unsigned data of an event.""" def __init__(self, unsigned_dict: JsonMapping): ... - def __getitem__(self, key: str) -> Any: ... - """Get the value for the given key. + def __getitem__(self, key: str) -> Any: + """Get the value for the given key. - Raises KeyError if the key is unset or not recognised.""" + Raises KeyError if the key is unset or not recognised.""" - def __setitem__(self, key: str, value: Any) -> None: ... - """Set the value for the given key. + def __setitem__(self, key: str, value: Any) -> None: + """Set the value for the given key. - Raises KeyError if the key is not recognised.""" + Raises KeyError if the key is not recognised.""" - def __delitem__(self, key: str) -> None: ... - """Delete the value for the given key. + def __delitem__(self, key: str) -> None: + """Delete the value for the given key. - Raises KeyError if the key is unset or not recognised.""" + Raises KeyError if the key is unset or not recognised.""" def __contains__(self, key: Any) -> bool: ... - def get(self, key: str, default: Any = None) -> Any: ... - """Get the value for the given key, or ``default`` if the key is unset.""" + def get(self, key: str, default: Any = None) -> Any: + """Get the value for the given key, or ``default`` if the key is unset.""" - def for_persistence(self) -> JsonDict: ... - """Return a dict of the fields that should be persisted to the database.""" + def for_persistence(self) -> JsonDict: + """Return a dict of the fields that should be persisted to the database.""" - def for_event(self) -> JsonDict: ... - """Return a dict of all unsigned fields, including those only kept in - memory, suitable for inclusion in an event.""" + def for_event(self) -> JsonDict: + """Return a dict of all unsigned fields, including those only kept in + memory, suitable for inclusion in an event.""" class JsonObject(Mapping[str, Any]): """Immutable JSON object mapping.""" @@ -222,3 +224,99 @@ class JsonObject(Mapping[str, Any]): def __getitem__(self, key: str) -> Any: ... def __iter__(self) -> Iterator[str]: ... def __eq__(self, other: object) -> bool: ... + +class Event: + """Represents a Matrix event.""" + + def __init__( + self, + event_dict: JsonDict, + room_version: RoomVersion, + internal_metadata_dict: JsonDict, + rejected_reason: str | None, + ) -> None: ... + def get_dict(self) -> JsonDict: + """Convert the event to a dictionary suitable for serialisation.""" + + def get_dict_for_persistence(self) -> JsonDict: + """Like ``get_dict``, but serializes ``unsigned`` in a form suitable for + persistence.""" + + def get_pdu_json(self, time_now: int | None = None) -> JsonDict: + """Like ``get_dict``, but serializes ``unsigned`` in a form suitable + for sending over federation.""" + + def get_templated_pdu_json(self) -> JsonDict: + """Like ``get_dict``, except strips fields like ``signatures``, + ``hashes`` and ``unsigned`` so that the result is suitable as a template for + creating new events. Used in make_{join,leave,knock} flows.""" + + @property + def event_id(self) -> str: ... + @property + def room_id(self) -> str: ... + @property + def signatures(self) -> Signatures: ... + @property + def content(self) -> JsonMapping: ... + @property + def depth(self) -> int: ... + @property + def hashes(self) -> dict[str, str]: ... + @property + def origin_server_ts(self) -> int: ... + @property + def sender(self) -> str: ... + @property + def state_key(self) -> str: ... + @property + def type(self) -> str: ... + @property + def unsigned(self) -> Unsigned: ... + @property + def internal_metadata(self) -> EventInternalMetadata: ... + @property + def rejected_reason(self) -> str | None: ... + @property + def room_version(self) -> RoomVersion: ... + @property + def format_version(self) -> int: + """The EventFormatVersion implemented by this event.""" + + @property + def membership(self) -> Any: ... + @property + def redacts(self) -> Any | None: ... + def prev_event_ids(self) -> StrSequence: + """Returns the list of prev event IDs.""" + + def auth_event_ids(self) -> StrSequence: + """Returns the list of auth event IDs""" + + def is_state(self) -> bool: ... + def get_state_key(self) -> str | None: + """Get the state key of this event, or None if it's not a state event.""" + def __contains__(self, key: str) -> bool: ... + def get(self, key: str, default: Any = None) -> Any: ... + def items(self) -> list[tuple[str, Any]]: ... + def keys(self) -> list[str]: ... + def deep_copy(self) -> "Event": + """Returns a deep copy of this object, such that modifying the copy will + not affect the original.""" + + def sticky_duration(self) -> Duration | None: + """If this event has the ``msc4354_sticky`` top-level field, returns a + ``SynapseDuration`` representing the sticky duration. Otherwise returns + ``None``.""" + +def redact_event(event: Event) -> Event: + """Returns a pruned version of the given event, which removes all keys we + don't know about or think could potentially be dodgy. + """ + +def redact_event_dict(room_version: RoomVersion, event_dict: JsonMapping) -> JsonDict: + """Returns a pruned version of the given event dict, which removes all keys + we don't know about or think could potentially be dodgy. + + Returns the redacted event as a dict. + """ diff --git a/tests/events/test_py_protocol.py b/tests/events/test_py_protocol.py index 306e3c1704..cfcc3648ca 100644 --- a/tests/events/test_py_protocol.py +++ b/tests/events/test_py_protocol.py @@ -16,7 +16,7 @@ from unittest.mock import Mock from synapse.api.room_versions import RoomVersion, RoomVersions -from synapse.events import EventBase, FrozenEvent, make_event_from_dict +from synapse.events import EventBase from synapse.events.py_protocol import ( EventProtocol, MSC4242Event, @@ -24,20 +24,20 @@ from synapse.events.py_protocol import ( supports_msc4242_state_dag, ) +from tests.test_utils.event_builders import make_test_event from tests.unittest import TestCase def _make_event(room_version: RoomVersion) -> EventBase: """Helper to make an EventBase with the given room version.""" - event_dict = { - "content": {}, - "sender": "@user:example.com", - "type": "m.room.message", - "room_id": "!room:example.com", - } - if room_version.msc4242_state_dags: - event_dict["prev_state_events"] = [] - return make_event_from_dict(event_dict, room_version=room_version) + return make_test_event( + { + "sender": "@user:example.com", + "type": "m.room.message", + "room_id": "!room:example.com", + }, + room_version=room_version, + ) class TestMetaClass(TestCase): @@ -46,16 +46,15 @@ class TestMetaClass(TestCase): NotImplementedError, but that isinstance checks on EventBase and FrozenEvent still work as normal. """ - # EventBase and FrozenEvent should work as normal + # EventBase should work as normal self.assertFalse(isinstance(object(), EventBase)) - self.assertFalse(isinstance(object(), FrozenEvent)) - - with self.assertRaises(NotImplementedError): - isinstance(object(), EventProtocol) with self.assertRaises(NotImplementedError): isinstance(object(), MSC4242Event) + with self.assertRaises(NotImplementedError): + isinstance(object(), EventProtocol) + class SupportsMSC4242StateDagTestCase(TestCase): def test_single_event_msc4242(self) -> None: diff --git a/tests/events/test_validator.py b/tests/events/test_validator.py index 3810fdb3da..082ae04a4c 100644 --- a/tests/events/test_validator.py +++ b/tests/events/test_validator.py @@ -19,10 +19,10 @@ from tests.unittest import HomeserverTestCase class EventValidatorTestCase(HomeserverTestCase): - def test_validate_new_with_mentions_succeeds_even_when_frozen(self) -> None: + def test_validate_new_with_mentions_succeed(self) -> None: """ Test that `EventValidator.validate_new` accepts an event with valid `m.mentions` - content even when the event is frozen. + content. """ event = make_event_from_dict( { @@ -43,8 +43,5 @@ class EventValidatorTestCase(HomeserverTestCase): }, room_version=RoomVersions.V9, ) - # Sanity check that the event is valid before freezing - EventValidator().validate_new(event, self.hs.config) - event.freeze() - # Event should still be valid after freezing + EventValidator().validate_new(event, self.hs.config) diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 9e44b1dc1e..736f251c27 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -36,7 +36,7 @@ from synapse.api.errors import NotFoundError, SynapseError from synapse.api.room_versions import RoomVersions from synapse.appservice import ApplicationService from synapse.crypto.event_signing import add_hashes_and_signatures -from synapse.events import EventBase, FrozenEventV3 +from synapse.events import EventBase, make_event_from_dict from synapse.federation.federation_client import SendJoinResult from synapse.federation.transport.client import ( StateRequestResponse, @@ -677,7 +677,7 @@ class DeviceUnPartialStateTestCase(unittest.HomeserverTestCase): self.REMOTE1_SERVER_SIGNATURE_KEY, ) - create_event = FrozenEventV3(create_event_dict, room_version, {}, None) + create_event = make_event_from_dict(create_event_dict, room_version) events.append(create_event) room_version = self.hs.config.server.default_room_version @@ -700,7 +700,7 @@ class DeviceUnPartialStateTestCase(unittest.HomeserverTestCase): self.hs.hostname, self.hs.signing_key, ) - join_event = FrozenEventV3(join_event_dict, room_version, {}, None) + join_event = make_event_from_dict(join_event_dict, room_version) events.append(join_event) # Then set the join rules to public @@ -722,7 +722,7 @@ class DeviceUnPartialStateTestCase(unittest.HomeserverTestCase): self.REMOTE1_SERVER_NAME, self.REMOTE1_SERVER_SIGNATURE_KEY, ) - join_rules_event = FrozenEventV3(join_rules_event_dict, room_version, {}, None) + join_rules_event = make_event_from_dict(join_rules_event_dict, room_version) events.append(join_rules_event) return {(event.type, event.state_key): event for event in events} @@ -733,7 +733,7 @@ class DeviceUnPartialStateTestCase(unittest.HomeserverTestCase): user: str, signing_key: SigningKey, state: StateMap[EventBase], - ) -> FrozenEventV3: + ) -> EventBase: """Build a join event for the local user, signed by the local server.""" latest_event = max(state.values(), key=lambda e: e.depth) @@ -759,7 +759,7 @@ class DeviceUnPartialStateTestCase(unittest.HomeserverTestCase): get_domain_from_id(user), signing_key, ) - return FrozenEventV3(join_event_dict, room_version, {}, None) + return make_event_from_dict(join_event_dict, room_version) @parameterized.expand([("not_pruned", False), ("pruned", True)]) @patch( diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py index b550c2420b..9cc7ebfa80 100644 --- a/tests/handlers/test_room_member.py +++ b/tests/handlers/test_room_member.py @@ -8,7 +8,7 @@ import synapse.rest.client.room from synapse.api.constants import AccountDataTypes, EventTypes, Membership from synapse.api.errors import Codes, LimitExceededError, SynapseError from synapse.crypto.event_signing import add_hashes_and_signatures -from synapse.events import FrozenEventV3 +from synapse.events import Event from synapse.federation.federation_client import SendJoinResult from synapse.server import HomeServer from synapse.types import UserID, create_requester @@ -124,7 +124,7 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase): create_event_source, self.hs.config.server.default_room_version, ) - create_event = FrozenEventV3( + create_event = Event( create_event_source, self.hs.config.server.default_room_version, {}, @@ -148,7 +148,7 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase): self.hs.hostname, self.hs.signing_key, ) - join_event = FrozenEventV3( + join_event = Event( join_event_source, self.hs.config.server.default_room_version, {}, diff --git a/tests/handlers/test_room_policy.py b/tests/handlers/test_room_policy.py index 4f2188b8e7..c67ea9b0e0 100644 --- a/tests/handlers/test_room_policy.py +++ b/tests/handlers/test_room_policy.py @@ -27,7 +27,6 @@ from synapse.handlers.room_policy import POLICY_SERVER_KEY_ID from synapse.rest import admin from synapse.rest.client import filter, login, room, sync from synapse.server import HomeServer -from synapse.synapse_rust.events import Signatures from synapse.types import JsonDict, UserID from synapse.util.clock import Clock @@ -182,7 +181,7 @@ class RoomPolicyTestCase(unittest.FederatingHomeserverTestCase): non_policyserver_key = signedjson.key.generate_signing_key( "non_policyserver_key" ) - event.signatures = Signatures( + event.signatures.update( compute_event_signature( event.room_version, event.get_dict(), diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 2ba5da3b95..3114675052 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -841,10 +841,10 @@ class ModuleApiTestCase(BaseModuleApiTestCase): create_event = state[(EventTypes.Create, "")] # `.user_id` is a deprecated alias for `.sender`. - self.assertEqual(create_event.user_id, user_id) + self.assertEqual(create_event.user_id, user_id) # type: ignore[attr-defined] # The event supports looking up keys via `__getitem__` although deprecated - self.assertEqual(create_event["room_id"], room_id) + self.assertEqual(create_event["room_id"], room_id) # type: ignore[index] class ModuleApiWorkerTestCase(BaseModuleApiTestCase, BaseMultiWorkerStreamTestCase): diff --git a/tests/replication/storage/test_events.py b/tests/replication/storage/test_events.py index b7b94482ef..d7e6dfca83 100644 --- a/tests/replication/storage/test_events.py +++ b/tests/replication/storage/test_events.py @@ -272,6 +272,7 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase): "origin_server_ts": self.event_id, "prev_events": prev_events, "auth_events": auth_events, + "hashes": {}, } if key is not None: event_dict["state_key"] = key diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index b4fa71ece8..5eaa6f9fb2 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -236,7 +236,10 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): async def check( ev: EventBase, state: StateMap[EventBase] ) -> tuple[bool, JsonDict | None]: - ev.content = {"x": "y"} + # Try and modify the content, this will fail because the event is + # immutable. (We therefore need the type ignore linter, as the + # linter will pick this bug up) + ev.content = {"x": "y"} # type: ignore[misc] return True, None self.hs.get_module_api_callbacks().third_party_event_rules._check_event_allowed_callbacks = [ diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 8b3b919f44..a9924e28b1 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -97,9 +97,6 @@ class FakeEvent: Args: auth_events: list of event_ids prev_events: list of event_ids - - Returns: - FrozenEvent """ global ORIGIN_SERVER_TS diff --git a/tests/storage/test_msc4242_state_dag.py b/tests/storage/test_msc4242_state_dag.py index 2150bc0996..52a165b9b2 100644 --- a/tests/storage/test_msc4242_state_dag.py +++ b/tests/storage/test_msc4242_state_dag.py @@ -19,13 +19,13 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes from synapse.api.errors import SynapseError from synapse.api.room_versions import RoomVersions -from synapse.events.py_protocol import MSC4242Event, supports_msc4242_state_dag +from synapse.events import EventBase +from synapse.events.py_protocol import MSC4242Event from synapse.events.snapshot import EventContext from synapse.rest.client import room from synapse.server import HomeServer from synapse.util.clock import Clock -from tests.test_utils.event_builders import make_test_event from tests.unittest import HomeserverTestCase, override_config @@ -154,21 +154,16 @@ class MSC4242EventPersistenceStateDagsStoreTestCase(HomeserverTestCase): prev_state_events: list[str], rejected: bool = False, ) -> tuple[MSC4242Event, EventContext]: - ev = make_test_event( - { - "prev_state_events": prev_state_events, - "content": { - "membership": "join", - }, - "sender": "@unimportant:info", - "state_key": "@unimportant:info", - "type": "m.room.member", - "room_id": self.room_id, - }, - room_version=RoomVersions.MSC4242v12, - ) - ev._event_id = id # type: ignore[attr-defined] - assert supports_msc4242_state_dag(ev) + # We use a mock here to allow us to set the `event_id`. + # + # FIXME: Having consistent human-readable event IDs in these tests is + # nice but the `Mock` is less than ideal. It would be better to use a + # real event but that is more complex to set up. + ev = Mock(spec=EventBase) + ev.event_id = id + ev.prev_state_events = prev_state_events + ev.state_key = "@unimportant:info" + ev.is_state.return_value = True ctx = Mock() ctx.rejected = rejected return ev, ctx diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index c346245706..93e0b4a2b1 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -26,7 +26,7 @@ from twisted.internet.testing import MemoryReactor from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersion, RoomVersions -from synapse.events import EventBase +from synapse.events import EventBase, make_event_from_dict from synapse.events.builder import EventBuilder from synapse.server import HomeServer from synapse.synapse_rust.events import EventInternalMetadata @@ -238,11 +238,16 @@ class RedactionTestCase(unittest.HomeserverTestCase): prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids ) - built_event._event_id = self._event_id # type: ignore[attr-defined] - built_event._dict["event_id"] = self._event_id - assert built_event.event_id == self._event_id + event_dict = built_event.get_dict() + event_dict["event_id"] = self._event_id + rebuilt_event = make_event_from_dict( + event_dict, + room_version=built_event.room_version, + internal_metadata_dict=built_event.internal_metadata.get_dict(), + ) + assert rebuilt_event.event_id == self._event_id - return built_event + return rebuilt_event @property def room_id(self) -> str: diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index d51fa1f8ba..ba3b3802cf 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -35,7 +35,7 @@ from synapse.api.constants import ( ) from synapse.api.filtering import Filter from synapse.crypto.event_signing import add_hashes_and_signatures -from synapse.events import FrozenEventV3 +from synapse.events import Event from synapse.federation.federation_client import SendJoinResult from synapse.rest import admin from synapse.rest.client import login, room @@ -1385,7 +1385,7 @@ class GetCurrentStateDeltaMembershipChangesForUserFederationTestCase( create_event_source, self.hs.config.server.default_room_version, ) - create_event = FrozenEventV3( + create_event = Event( create_event_source, self.hs.config.server.default_room_version, {}, @@ -1408,7 +1408,7 @@ class GetCurrentStateDeltaMembershipChangesForUserFederationTestCase( creator_join_event_source, self.hs.config.server.default_room_version, ) - creator_join_event = FrozenEventV3( + creator_join_event = Event( creator_join_event_source, self.hs.config.server.default_room_version, {}, @@ -1433,7 +1433,7 @@ class GetCurrentStateDeltaMembershipChangesForUserFederationTestCase( self.hs.hostname, self.hs.signing_key, ) - join_event = FrozenEventV3( + join_event = Event( join_event_source, self.hs.config.server.default_room_version, {}, diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index 4537186ee6..eeba22728e 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -93,8 +93,8 @@ class EventAuthTestCase(unittest.TestCase): RoomVersions.V9, creator, "public", + rejected_reason="stinky", ) - rejected_join_rules.rejected_reason = "stinky" auth_events.append(rejected_join_rules) event_store.add_event(rejected_join_rules) @@ -1180,7 +1180,10 @@ def _random_state_event( def _join_rules_event( - room_version: RoomVersion, sender: str, join_rule: str + room_version: RoomVersion, + sender: str, + join_rule: str, + rejected_reason: str | None = None, ) -> EventBase: return make_test_event( { @@ -1194,6 +1197,7 @@ def _join_rules_event( }, }, room_version=room_version, + rejected_reason=rejected_reason, ) diff --git a/tests/test_utils/event_builders.py b/tests/test_utils/event_builders.py index a8eb586c1f..a5d686801d 100644 --- a/tests/test_utils/event_builders.py +++ b/tests/test_utils/event_builders.py @@ -76,6 +76,20 @@ def make_test_event( **(event_dict or {}), **fields, } + + # For room versions where the create event's room_id is derived from its + # event ID (v11+ format), omit the default room_id on create events so each + # create event ends up with a distinct room_id. + # + # We can't do this in the `default_event_fields` as we don't know the event + # type at that point. + if ( + room_version.msc4291_room_ids_as_hashes + and merged["type"] == "m.room.create" + and merged["state_key"] == "" + ): + merged.pop("room_id", None) + return make_event_from_dict( merged, room_version=room_version,