This commit is contained in:
Erik Johnston
2026-04-08 16:35:27 +01:00
parent 05ca7a877e
commit 6d2e0885cf
2 changed files with 113 additions and 80 deletions
+109 -76
View File
@@ -23,7 +23,6 @@
use std::{
borrow::Cow,
collections::HashMap,
str::FromStr,
sync::{Arc, RwLock},
};
@@ -632,33 +631,34 @@ impl Event {
let rejected_reason = rejected_reason.map(String::into_boxed_str);
// Check we're the right room version
if ![
EventFormatVersions::ROOM_V3,
EventFormatVersions::ROOM_V4_PLUS,
]
.contains(&room_version.event_format)
{
return Err(PyValueError::new_err(format!(
"Unsupported room version: {}",
room_version
)));
}
let event_format_v3_v4_plus: EventFormatV3V4Container = depythonize(event_dict)?;
event_format_v3_v4_plus.validate(room_version)?;
let event_format_enum = match room_version.event_format {
EventFormatVersions::ROOM_V3 => {
let event_format = depythonize(event_dict)?;
EventFormatEnum::V3(event_format)
}
EventFormatVersions::ROOM_V4_PLUS => {
let event_format: EventFormatV4Container = depythonize(event_dict)?;
event_format.validate(room_version)?;
EventFormatEnum::V4(event_format)
}
_ => {
return Err(PyValueError::new_err(format!(
"Unsupported room version: {}",
room_version
)))
}
};
let internal_metadata = Py::new(py, EventInternalMetadata::new(internal_metadata_dict)?)?;
let event_value = serde_json::to_value(&event_format_v3_v4_plus)
let event_value = serde_json::to_value(&event_format_enum)
.map_err(|err| PyException::new_err(format!("Failed to serialize event: {}", err)))?;
let event_id = calculate_event_id(&event_value, &room_version).map_err(|err| {
PyException::new_err(format!("Failed to calculate event_id: {}", err))
})?;
Ok(Self {
inner: EventFormatEnum::V3V4Plus(event_format_v3_v4_plus),
inner: event_format_enum,
event_id,
room_version,
rejected_reason,
@@ -668,7 +668,8 @@ impl Event {
fn get_dict<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
match &self.inner {
EventFormatEnum::V3V4Plus(format) => Ok(pythonize(py, format)?),
EventFormatEnum::V3(format) => Ok(pythonize(py, format)?),
EventFormatEnum::V4(format) => Ok(pythonize(py, format)?),
// ...
}
}
@@ -681,7 +682,8 @@ impl Event {
) -> PyResult<Bound<'py, PyAny>> {
// TODO: We need to do a bunch of changes here.
match &self.inner {
EventFormatEnum::V3V4Plus(format) => Ok(pythonize(py, format)?),
EventFormatEnum::V3(format) => Ok(pythonize(py, format)?),
EventFormatEnum::V4(format) => Ok(pythonize(py, format)?),
// ...
}
}
@@ -702,14 +704,16 @@ impl Event {
#[getter]
fn room_id(&self) -> PyResult<Cow<'_, str>> {
match &self.inner {
EventFormatEnum::V3V4Plus(format) => Ok(format.room_id(&self.event_id)?),
EventFormatEnum::V3(format) => Ok(format.specific_fields.room_id.as_ref().into()),
EventFormatEnum::V4(format) => Ok(format.room_id(&self.event_id)?),
}
}
#[getter]
fn signatures(&self) -> PyResult<Signatures> {
match &self.inner {
EventFormatEnum::V3V4Plus(format) => Ok(format.common_fields.signatures.clone()),
EventFormatEnum::V3(format) => Ok(format.common_fields.signatures.clone()),
EventFormatEnum::V4(format) => Ok(format.common_fields.signatures.clone()),
// ...
}
}
@@ -717,7 +721,8 @@ impl Event {
#[getter]
fn content(&self) -> PyResult<JsonObject> {
match &self.inner {
EventFormatEnum::V3V4Plus(format) => Ok(format.common_fields.content.clone()),
EventFormatEnum::V3(format) => Ok(format.common_fields.content.clone()),
EventFormatEnum::V4(format) => Ok(format.common_fields.content.clone()),
// ...
}
}
@@ -725,7 +730,8 @@ impl Event {
#[getter]
fn depth(&self) -> PyResult<i64> {
match &self.inner {
EventFormatEnum::V3V4Plus(format) => Ok(format.common_fields.depth),
EventFormatEnum::V3(format) => Ok(format.common_fields.depth),
EventFormatEnum::V4(format) => Ok(format.common_fields.depth),
// ...
}
}
@@ -733,7 +739,8 @@ impl Event {
#[getter]
fn hashes(&self) -> PyResult<&HashMap<String, String>> {
match &self.inner {
EventFormatEnum::V3V4Plus(format) => Ok(&format.common_fields.hashes),
EventFormatEnum::V3(format) => Ok(&format.common_fields.hashes),
EventFormatEnum::V4(format) => Ok(&format.common_fields.hashes),
// ...
}
}
@@ -741,7 +748,8 @@ impl Event {
#[getter]
fn origin_server_ts(&self) -> PyResult<i64> {
match &self.inner {
EventFormatEnum::V3V4Plus(format) => Ok(format.common_fields.origin_server_ts),
EventFormatEnum::V3(format) => Ok(format.common_fields.origin_server_ts),
EventFormatEnum::V4(format) => Ok(format.common_fields.origin_server_ts),
// ...
}
}
@@ -749,7 +757,8 @@ impl Event {
#[getter]
fn sender(&self) -> PyResult<&str> {
match &self.inner {
EventFormatEnum::V3V4Plus(format) => Ok(&format.common_fields.sender),
EventFormatEnum::V3(format) => Ok(&format.common_fields.sender),
EventFormatEnum::V4(format) => Ok(&format.common_fields.sender),
// ...
}
}
@@ -757,7 +766,8 @@ impl Event {
#[getter(state_key)]
fn state_key_attr(&self) -> PyResult<&str> {
let state_key = match &self.inner {
EventFormatEnum::V3V4Plus(format) => &format.common_fields.state_key,
EventFormatEnum::V3(format) => &format.common_fields.state_key,
EventFormatEnum::V4(format) => &format.common_fields.state_key,
// ...
};
@@ -770,7 +780,8 @@ impl Event {
#[getter]
fn r#type(&self) -> PyResult<&str> {
match &self.inner {
EventFormatEnum::V3V4Plus(format) => Ok(&format.common_fields.type_),
EventFormatEnum::V3(format) => Ok(&format.common_fields.type_),
EventFormatEnum::V4(format) => Ok(&format.common_fields.type_),
// ...
}
}
@@ -778,7 +789,8 @@ impl Event {
#[getter]
fn unsigned(&self) -> PyResult<JsonObjectMutable> {
match &self.inner {
EventFormatEnum::V3V4Plus(format) => Ok(format.common_fields.unsigned.clone()),
EventFormatEnum::V3(format) => Ok(format.common_fields.unsigned.clone()),
EventFormatEnum::V4(format) => Ok(format.common_fields.unsigned.clone()),
// ...
}
}
@@ -800,24 +812,16 @@ impl Event {
fn prev_event_ids(&self) -> PyResult<Vec<String>> {
match &self.inner {
EventFormatEnum::V3V4Plus(format) => Ok(format
.specific_fields
.prev_events
.iter()
.map(|s| s.to_string())
.collect()),
EventFormatEnum::V3(format) => Ok(format.specific_fields.prev_events.clone()),
EventFormatEnum::V4(format) => Ok(format.specific_fields.prev_events.clone()),
// ...
}
}
fn auth_event_ids(&self) -> PyResult<Vec<String>> {
match &self.inner {
EventFormatEnum::V3V4Plus(format) => Ok(format
.specific_fields
.auth_events
.iter()
.map(|s| s.to_string())
.collect()),
EventFormatEnum::V3(format) => Ok(format.auth_event_ids()),
EventFormatEnum::V4(format) => Ok(format.auth_event_ids()),
// ...
}
}
@@ -832,7 +836,8 @@ impl Event {
fn redacts<'py>(&self, py: Python<'py>) -> PyResult<Option<Bound<'py, PyAny>>> {
let value = if !self.room_version.updated_redaction_rules {
let other_fields = match &self.inner {
EventFormatEnum::V3V4Plus(format) => &format.common_fields.other_fields,
EventFormatEnum::V3(format) => &format.common_fields.other_fields,
EventFormatEnum::V4(format) => &format.common_fields.other_fields,
// ...
};
@@ -843,7 +848,8 @@ impl Event {
value
} else {
let content = match &self.inner {
EventFormatEnum::V3V4Plus(format) => &format.common_fields.content,
EventFormatEnum::V3(format) => &format.common_fields.content,
EventFormatEnum::V4(format) => &format.common_fields.content,
// ...
};
@@ -859,14 +865,16 @@ impl Event {
fn is_state(&self) -> bool {
match &self.inner {
EventFormatEnum::V3V4Plus(format) => format.common_fields.state_key.is_some(),
EventFormatEnum::V3(format) => format.common_fields.state_key.is_some(),
EventFormatEnum::V4(format) => format.common_fields.state_key.is_some(),
// ...
}
}
fn get_state_key(&self) -> Option<&str> {
match &self.inner {
EventFormatEnum::V3V4Plus(format) => format.common_fields.state_key.as_deref(),
EventFormatEnum::V3(format) => format.common_fields.state_key.as_deref(),
EventFormatEnum::V4(format) => format.common_fields.state_key.as_deref(),
// ...
}
}
@@ -892,11 +900,8 @@ impl Event {
}
#[getter]
fn format_version(&self) -> u8 {
match &self.inner {
EventFormatEnum::V3V4Plus(_) => 3,
// ...
}
fn format_version(&self) -> i32 {
self.room_version.event_format
}
fn items<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
@@ -912,27 +917,67 @@ impl Event {
}
}
#[derive(Serialize)]
#[serde(untagged)]
enum EventFormatEnum {
V3V4Plus(EventFormatV3V4Container),
V3(EventFormatV3Container),
V4(EventFormatV4Container),
// ...
}
#[derive(Serialize, Deserialize)]
struct EventFormatV3V4Plus {
struct EventFormatV3 {
auth_events: Vec<String>,
prev_events: Vec<String>,
room_id: Box<str>,
}
#[derive(Serialize, Deserialize)]
struct EventFormatV3Container {
#[serde(flatten)]
specific_fields: EventFormatV3,
#[serde(flatten)]
common_fields: EventCommonFields,
}
impl EventFormatV3Container {
fn auth_event_ids(&self) -> Vec<String> {
self.specific_fields.auth_events.clone()
}
}
#[derive(Serialize, Deserialize)]
struct EventFormatV4 {
auth_events: Vec<String>,
prev_events: Vec<String>,
room_id: Option<Box<str>>,
}
#[derive(Serialize, Deserialize)]
struct EventFormatV3V4Container {
struct EventFormatV4Container {
#[serde(flatten)]
specific_fields: EventFormatV3V4Plus,
specific_fields: EventFormatV4,
#[serde(flatten)]
common_fields: EventCommonFields,
}
impl EventFormatV3V4Container {
impl EventFormatV4Container {
fn validate(&self, room_version: &RoomVersion) -> Result<(), Error> {
if self.specific_fields.room_id.is_none() {
// Only create events in event formats v4 plus can have a missing room_id.
if room_version.event_format != EventFormatVersions::ROOM_V4_PLUS {
bail!("room_id is required for event formats v3 and below");
}
if &*self.common_fields.type_ != "m.room.create"
&& self.common_fields.state_key.as_deref() != Some("")
{
bail!("room_id is required for non-create events");
}
}
Ok(())
}
fn room_id(&self, event_id: &EventID) -> Result<Cow<'_, str>, Error> {
if let Some(room_id) = self.specific_fields.room_id.as_deref() {
return Ok(room_id.into());
@@ -953,26 +998,14 @@ impl EventFormatV3V4Container {
Ok(room_id.into())
}
fn validate(&self, room_version: &RoomVersion) -> Result<(), Error> {
if self.specific_fields.room_id.is_none() {
// Only create events in event formats v4 plus can have a missing room_id.
if room_version.event_format != EventFormatVersions::ROOM_V4_PLUS {
bail!("room_id is required for event formats v3 and below");
}
if &*self.common_fields.type_ != "m.room.create"
&& self.common_fields.state_key.as_deref() != Some("")
{
bail!("room_id is required for non-create events");
}
}
Ok(())
fn auth_event_ids(&self) -> Vec<String> {
// TODO: Add create event
self.specific_fields.auth_events.clone()
}
}
#[cfg(test)]
mod tests {
use signed_json::json;
use super::*;
@@ -981,14 +1014,14 @@ mod tests {
let json = r#"{"auth_events":[],"prev_events":[],"type":"m.room.create","sender":"@anon-20260225_142731-20:localhost:8800","content":{"room_version":"10","creator":"@anon-20260225_142731-20:localhost:8800"},"depth":1,"room_id":"!qVoJSympOqdUQRUfiC:localhost:8800","state_key":"","origin_server_ts":1772029657149,"hashes":{"sha256":"RIDkn4CrExGMOfRZlHl//1weAro5QC/q2D76YcyAUqk"},"signatures":{"localhost:8800":{"ed25519:a_GMSl":"GU7WmvI2Kd5kLrXKrWpRbUfEiVKGgH0sxQNEpBMMvgF3QhHN25AubVMmIClht5r/c+Iihb1xsq1j5Sw+RGfiDg"}},"unsigned":{"age_ts":1772029657149}}"#;
let event_value: serde_json::Value = serde_json::from_str(json).unwrap();
let event: EventFormatV3V4Container = serde_json::from_str(json).unwrap();
let event: EventFormatV3Container = serde_json::from_str(json).unwrap();
let parsed_value = serde_json::to_value(&event).unwrap();
assert_eq!(&*event.common_fields.type_, "m.room.create");
assert_eq!(
event.specific_fields.room_id.as_deref(),
Some("!qVoJSympOqdUQRUfiC:localhost:8800")
&*event.specific_fields.room_id,
"!qVoJSympOqdUQRUfiC:localhost:8800"
);
assert_eq!(event_value, parsed_value);
@@ -1023,7 +1056,7 @@ mod tests {
let json = r#"{"auth_events":[],"prev_events":[],"type":"m.room.create","sender":"@erikj:jki.re","content":{"room_version":"12","predecessor":{"room_id":"!VuNGkDTdbMOOxSmuDa:jki.re"}},"depth":1,"state_key":"","origin_server_ts":1775568141481,"hashes":{"sha256":"qBX+glsKvogXFrvsEN0eh13pO2kpuE6o/b4yREPtOqw"},"signatures":{"jki.re":{"ed25519:auto":"n/4gHQRagk3r1r24L/7a+oaMMf9cysVfQRYdjpDZcf4ppkVym33rhTW18Vy4zMa1L5nsWLkxsBvbrRRDYUOhBQ"}},"unsigned":{"age_ts":1775568141481}}"#;
let event_value: serde_json::Value = serde_json::from_str(json).unwrap();
let event: EventFormatV3V4Container = serde_json::from_str(json).unwrap();
let event: EventFormatV4Container = serde_json::from_str(json).unwrap();
let event_id = calculate_event_id(&event_value, &RoomVersion::V12).unwrap();
+4 -4
View File
@@ -99,14 +99,14 @@ impl PushRuleRoomFlag {
pub struct RoomVersion {
/// The identifier for this version.
pub identifier: &'static str,
/// One of the RoomDisposition constants.
/// One of the [`RoomDisposition`] constants.
pub disposition: &'static str,
/// One of the EventFormatVersions constants.
/// One of the [`EventFormatVersions`] constants.
pub event_format: i32,
/// One of the StateResolutionVersions constants.
/// One of the [`StateResolutionVersions`] constants.
pub state_res: i32,
pub enforce_key_validity: bool,
/// Before MSC2432, m.room.aliases had special auth rules and redaction rules.
/// Before MSC2432, `m.room.aliases` had special auth rules and redaction rules.
pub special_case_aliases_auth: bool,
/// Strictly enforce canonicaljson, do not allow:
/// * Integers outside the range of [-2^53 + 1, 2^53 - 1]