diff --git a/rust/src/events/mod.rs b/rust/src/events/mod.rs index 62e79771b0..91aff38e4a 100644 --- a/rust/src/events/mod.rs +++ b/rust/src/events/mod.rs @@ -23,7 +23,6 @@ use std::{ borrow::Cow, collections::HashMap, - str::FromStr, sync::{Arc, RwLock}, }; @@ -632,33 +631,34 @@ impl Event { let rejected_reason = rejected_reason.map(String::into_boxed_str); - // Check we're the right room version - if ![ - EventFormatVersions::ROOM_V3, - EventFormatVersions::ROOM_V4_PLUS, - ] - .contains(&room_version.event_format) - { - return Err(PyValueError::new_err(format!( - "Unsupported room version: {}", - room_version - ))); - } - - let event_format_v3_v4_plus: EventFormatV3V4Container = depythonize(event_dict)?; - - event_format_v3_v4_plus.validate(room_version)?; + let event_format_enum = match room_version.event_format { + EventFormatVersions::ROOM_V3 => { + let event_format = depythonize(event_dict)?; + EventFormatEnum::V3(event_format) + } + EventFormatVersions::ROOM_V4_PLUS => { + let event_format: EventFormatV4Container = depythonize(event_dict)?; + event_format.validate(room_version)?; + EventFormatEnum::V4(event_format) + } + _ => { + return Err(PyValueError::new_err(format!( + "Unsupported room version: {}", + room_version + ))) + } + }; let internal_metadata = Py::new(py, EventInternalMetadata::new(internal_metadata_dict)?)?; - let event_value = serde_json::to_value(&event_format_v3_v4_plus) + let event_value = serde_json::to_value(&event_format_enum) .map_err(|err| PyException::new_err(format!("Failed to serialize event: {}", err)))?; let event_id = calculate_event_id(&event_value, &room_version).map_err(|err| { PyException::new_err(format!("Failed to calculate event_id: {}", err)) })?; Ok(Self { - inner: EventFormatEnum::V3V4Plus(event_format_v3_v4_plus), + inner: event_format_enum, event_id, room_version, rejected_reason, @@ -668,7 +668,8 @@ impl Event { fn get_dict<'py>(&self, py: Python<'py>) -> PyResult> { match &self.inner { - EventFormatEnum::V3V4Plus(format) => Ok(pythonize(py, format)?), + EventFormatEnum::V3(format) => Ok(pythonize(py, format)?), + EventFormatEnum::V4(format) => Ok(pythonize(py, format)?), // ... } } @@ -681,7 +682,8 @@ impl Event { ) -> PyResult> { // TODO: We need to do a bunch of changes here. match &self.inner { - EventFormatEnum::V3V4Plus(format) => Ok(pythonize(py, format)?), + EventFormatEnum::V3(format) => Ok(pythonize(py, format)?), + EventFormatEnum::V4(format) => Ok(pythonize(py, format)?), // ... } } @@ -702,14 +704,16 @@ impl Event { #[getter] fn room_id(&self) -> PyResult> { match &self.inner { - EventFormatEnum::V3V4Plus(format) => Ok(format.room_id(&self.event_id)?), + EventFormatEnum::V3(format) => Ok(format.specific_fields.room_id.as_ref().into()), + EventFormatEnum::V4(format) => Ok(format.room_id(&self.event_id)?), } } #[getter] fn signatures(&self) -> PyResult { match &self.inner { - EventFormatEnum::V3V4Plus(format) => Ok(format.common_fields.signatures.clone()), + EventFormatEnum::V3(format) => Ok(format.common_fields.signatures.clone()), + EventFormatEnum::V4(format) => Ok(format.common_fields.signatures.clone()), // ... } } @@ -717,7 +721,8 @@ impl Event { #[getter] fn content(&self) -> PyResult { match &self.inner { - EventFormatEnum::V3V4Plus(format) => Ok(format.common_fields.content.clone()), + EventFormatEnum::V3(format) => Ok(format.common_fields.content.clone()), + EventFormatEnum::V4(format) => Ok(format.common_fields.content.clone()), // ... } } @@ -725,7 +730,8 @@ impl Event { #[getter] fn depth(&self) -> PyResult { match &self.inner { - EventFormatEnum::V3V4Plus(format) => Ok(format.common_fields.depth), + EventFormatEnum::V3(format) => Ok(format.common_fields.depth), + EventFormatEnum::V4(format) => Ok(format.common_fields.depth), // ... } } @@ -733,7 +739,8 @@ impl Event { #[getter] fn hashes(&self) -> PyResult<&HashMap> { match &self.inner { - EventFormatEnum::V3V4Plus(format) => Ok(&format.common_fields.hashes), + EventFormatEnum::V3(format) => Ok(&format.common_fields.hashes), + EventFormatEnum::V4(format) => Ok(&format.common_fields.hashes), // ... } } @@ -741,7 +748,8 @@ impl Event { #[getter] fn origin_server_ts(&self) -> PyResult { match &self.inner { - EventFormatEnum::V3V4Plus(format) => Ok(format.common_fields.origin_server_ts), + EventFormatEnum::V3(format) => Ok(format.common_fields.origin_server_ts), + EventFormatEnum::V4(format) => Ok(format.common_fields.origin_server_ts), // ... } } @@ -749,7 +757,8 @@ impl Event { #[getter] fn sender(&self) -> PyResult<&str> { match &self.inner { - EventFormatEnum::V3V4Plus(format) => Ok(&format.common_fields.sender), + EventFormatEnum::V3(format) => Ok(&format.common_fields.sender), + EventFormatEnum::V4(format) => Ok(&format.common_fields.sender), // ... } } @@ -757,7 +766,8 @@ impl Event { #[getter(state_key)] fn state_key_attr(&self) -> PyResult<&str> { let state_key = match &self.inner { - EventFormatEnum::V3V4Plus(format) => &format.common_fields.state_key, + EventFormatEnum::V3(format) => &format.common_fields.state_key, + EventFormatEnum::V4(format) => &format.common_fields.state_key, // ... }; @@ -770,7 +780,8 @@ impl Event { #[getter] fn r#type(&self) -> PyResult<&str> { match &self.inner { - EventFormatEnum::V3V4Plus(format) => Ok(&format.common_fields.type_), + EventFormatEnum::V3(format) => Ok(&format.common_fields.type_), + EventFormatEnum::V4(format) => Ok(&format.common_fields.type_), // ... } } @@ -778,7 +789,8 @@ impl Event { #[getter] fn unsigned(&self) -> PyResult { match &self.inner { - EventFormatEnum::V3V4Plus(format) => Ok(format.common_fields.unsigned.clone()), + EventFormatEnum::V3(format) => Ok(format.common_fields.unsigned.clone()), + EventFormatEnum::V4(format) => Ok(format.common_fields.unsigned.clone()), // ... } } @@ -800,24 +812,16 @@ impl Event { fn prev_event_ids(&self) -> PyResult> { match &self.inner { - EventFormatEnum::V3V4Plus(format) => Ok(format - .specific_fields - .prev_events - .iter() - .map(|s| s.to_string()) - .collect()), + EventFormatEnum::V3(format) => Ok(format.specific_fields.prev_events.clone()), + EventFormatEnum::V4(format) => Ok(format.specific_fields.prev_events.clone()), // ... } } fn auth_event_ids(&self) -> PyResult> { match &self.inner { - EventFormatEnum::V3V4Plus(format) => Ok(format - .specific_fields - .auth_events - .iter() - .map(|s| s.to_string()) - .collect()), + EventFormatEnum::V3(format) => Ok(format.auth_event_ids()), + EventFormatEnum::V4(format) => Ok(format.auth_event_ids()), // ... } } @@ -832,7 +836,8 @@ impl Event { fn redacts<'py>(&self, py: Python<'py>) -> PyResult>> { let value = if !self.room_version.updated_redaction_rules { let other_fields = match &self.inner { - EventFormatEnum::V3V4Plus(format) => &format.common_fields.other_fields, + EventFormatEnum::V3(format) => &format.common_fields.other_fields, + EventFormatEnum::V4(format) => &format.common_fields.other_fields, // ... }; @@ -843,7 +848,8 @@ impl Event { value } else { let content = match &self.inner { - EventFormatEnum::V3V4Plus(format) => &format.common_fields.content, + EventFormatEnum::V3(format) => &format.common_fields.content, + EventFormatEnum::V4(format) => &format.common_fields.content, // ... }; @@ -859,14 +865,16 @@ impl Event { fn is_state(&self) -> bool { match &self.inner { - EventFormatEnum::V3V4Plus(format) => format.common_fields.state_key.is_some(), + EventFormatEnum::V3(format) => format.common_fields.state_key.is_some(), + EventFormatEnum::V4(format) => format.common_fields.state_key.is_some(), // ... } } fn get_state_key(&self) -> Option<&str> { match &self.inner { - EventFormatEnum::V3V4Plus(format) => format.common_fields.state_key.as_deref(), + EventFormatEnum::V3(format) => format.common_fields.state_key.as_deref(), + EventFormatEnum::V4(format) => format.common_fields.state_key.as_deref(), // ... } } @@ -892,11 +900,8 @@ impl Event { } #[getter] - fn format_version(&self) -> u8 { - match &self.inner { - EventFormatEnum::V3V4Plus(_) => 3, - // ... - } + fn format_version(&self) -> i32 { + self.room_version.event_format } fn items<'py>(&self, py: Python<'py>) -> PyResult> { @@ -912,27 +917,67 @@ impl Event { } } +#[derive(Serialize)] +#[serde(untagged)] enum EventFormatEnum { - V3V4Plus(EventFormatV3V4Container), + V3(EventFormatV3Container), + V4(EventFormatV4Container), // ... } #[derive(Serialize, Deserialize)] -struct EventFormatV3V4Plus { +struct EventFormatV3 { + auth_events: Vec, + prev_events: Vec, + room_id: Box, +} + +#[derive(Serialize, Deserialize)] +struct EventFormatV3Container { + #[serde(flatten)] + specific_fields: EventFormatV3, + #[serde(flatten)] + common_fields: EventCommonFields, +} + +impl EventFormatV3Container { + fn auth_event_ids(&self) -> Vec { + self.specific_fields.auth_events.clone() + } +} + +#[derive(Serialize, Deserialize)] +struct EventFormatV4 { auth_events: Vec, prev_events: Vec, room_id: Option>, } #[derive(Serialize, Deserialize)] -struct EventFormatV3V4Container { +struct EventFormatV4Container { #[serde(flatten)] - specific_fields: EventFormatV3V4Plus, + specific_fields: EventFormatV4, #[serde(flatten)] common_fields: EventCommonFields, } -impl EventFormatV3V4Container { +impl EventFormatV4Container { + fn validate(&self, room_version: &RoomVersion) -> Result<(), Error> { + if self.specific_fields.room_id.is_none() { + // Only create events in event formats v4 plus can have a missing room_id. + if room_version.event_format != EventFormatVersions::ROOM_V4_PLUS { + bail!("room_id is required for event formats v3 and below"); + } + if &*self.common_fields.type_ != "m.room.create" + && self.common_fields.state_key.as_deref() != Some("") + { + bail!("room_id is required for non-create events"); + } + } + + Ok(()) + } + fn room_id(&self, event_id: &EventID) -> Result, Error> { if let Some(room_id) = self.specific_fields.room_id.as_deref() { return Ok(room_id.into()); @@ -953,26 +998,14 @@ impl EventFormatV3V4Container { Ok(room_id.into()) } - fn validate(&self, room_version: &RoomVersion) -> Result<(), Error> { - if self.specific_fields.room_id.is_none() { - // Only create events in event formats v4 plus can have a missing room_id. - if room_version.event_format != EventFormatVersions::ROOM_V4_PLUS { - bail!("room_id is required for event formats v3 and below"); - } - if &*self.common_fields.type_ != "m.room.create" - && self.common_fields.state_key.as_deref() != Some("") - { - bail!("room_id is required for non-create events"); - } - } - - Ok(()) + fn auth_event_ids(&self) -> Vec { + // TODO: Add create event + self.specific_fields.auth_events.clone() } } #[cfg(test)] mod tests { - use signed_json::json; use super::*; @@ -981,14 +1014,14 @@ mod tests { let json = r#"{"auth_events":[],"prev_events":[],"type":"m.room.create","sender":"@anon-20260225_142731-20:localhost:8800","content":{"room_version":"10","creator":"@anon-20260225_142731-20:localhost:8800"},"depth":1,"room_id":"!qVoJSympOqdUQRUfiC:localhost:8800","state_key":"","origin_server_ts":1772029657149,"hashes":{"sha256":"RIDkn4CrExGMOfRZlHl//1weAro5QC/q2D76YcyAUqk"},"signatures":{"localhost:8800":{"ed25519:a_GMSl":"GU7WmvI2Kd5kLrXKrWpRbUfEiVKGgH0sxQNEpBMMvgF3QhHN25AubVMmIClht5r/c+Iihb1xsq1j5Sw+RGfiDg"}},"unsigned":{"age_ts":1772029657149}}"#; let event_value: serde_json::Value = serde_json::from_str(json).unwrap(); - let event: EventFormatV3V4Container = serde_json::from_str(json).unwrap(); + let event: EventFormatV3Container = serde_json::from_str(json).unwrap(); let parsed_value = serde_json::to_value(&event).unwrap(); assert_eq!(&*event.common_fields.type_, "m.room.create"); assert_eq!( - event.specific_fields.room_id.as_deref(), - Some("!qVoJSympOqdUQRUfiC:localhost:8800") + &*event.specific_fields.room_id, + "!qVoJSympOqdUQRUfiC:localhost:8800" ); assert_eq!(event_value, parsed_value); @@ -1023,7 +1056,7 @@ mod tests { let json = r#"{"auth_events":[],"prev_events":[],"type":"m.room.create","sender":"@erikj:jki.re","content":{"room_version":"12","predecessor":{"room_id":"!VuNGkDTdbMOOxSmuDa:jki.re"}},"depth":1,"state_key":"","origin_server_ts":1775568141481,"hashes":{"sha256":"qBX+glsKvogXFrvsEN0eh13pO2kpuE6o/b4yREPtOqw"},"signatures":{"jki.re":{"ed25519:auto":"n/4gHQRagk3r1r24L/7a+oaMMf9cysVfQRYdjpDZcf4ppkVym33rhTW18Vy4zMa1L5nsWLkxsBvbrRRDYUOhBQ"}},"unsigned":{"age_ts":1775568141481}}"#; let event_value: serde_json::Value = serde_json::from_str(json).unwrap(); - let event: EventFormatV3V4Container = serde_json::from_str(json).unwrap(); + let event: EventFormatV4Container = serde_json::from_str(json).unwrap(); let event_id = calculate_event_id(&event_value, &RoomVersion::V12).unwrap(); diff --git a/rust/src/room_versions.rs b/rust/src/room_versions.rs index c3803d477c..b4cff0a571 100644 --- a/rust/src/room_versions.rs +++ b/rust/src/room_versions.rs @@ -99,14 +99,14 @@ impl PushRuleRoomFlag { pub struct RoomVersion { /// The identifier for this version. pub identifier: &'static str, - /// One of the RoomDisposition constants. + /// One of the [`RoomDisposition`] constants. pub disposition: &'static str, - /// One of the EventFormatVersions constants. + /// One of the [`EventFormatVersions`] constants. pub event_format: i32, - /// One of the StateResolutionVersions constants. + /// One of the [`StateResolutionVersions`] constants. pub state_res: i32, pub enforce_key_validity: bool, - /// Before MSC2432, m.room.aliases had special auth rules and redaction rules. + /// Before MSC2432, `m.room.aliases` had special auth rules and redaction rules. pub special_case_aliases_auth: bool, /// Strictly enforce canonicaljson, do not allow: /// * Integers outside the range of [-2^53 + 1, 2^53 - 1]