diff --git a/rust/src/events/mod.rs b/rust/src/events/mod.rs index b1916bcf99..be14fd72f9 100644 --- a/rust/src/events/mod.rs +++ b/rust/src/events/mod.rs @@ -41,7 +41,7 @@ use serde::{Deserialize, Serialize}; use crate::{ events::{internal_metadata::EventInternalMetadata, utils::calculate_event_id}, - identifier::EventID, + identifier::{EventID, IdentifierError}, room_versions::{EventFormatVersions, RoomVersion}, }; @@ -105,23 +105,22 @@ impl JsonObject { } fn __iter__<'py>(&self, py: Python<'py>) -> PyResult> { - Ok(PyIterator::from_object( + PyIterator::from_object( &self .object .keys() .map(|k| &**k) .collect::>() .into_pyobject(py)?, - )?) + ) } fn keys<'py>(&self, py: Python<'py>) -> PyResult> { - Ok(self - .object + self.object .keys() .map(|k| &**k) .collect::>() - .into_pyobject(py)?) + .into_pyobject(py) } fn values<'py>(&self, py: Python<'py>) -> PyResult> { @@ -130,7 +129,7 @@ impl JsonObject { .values() .map(|v| pythonize(py, v)) .collect::, _>>()?; - Ok(values.into_pyobject(py)?) + values.into_pyobject(py) } fn items<'py>(&self, py: Python<'py>) -> PyResult> { @@ -139,7 +138,7 @@ impl JsonObject { .iter() .map(|(k, v)| PyResult::Ok((k.as_ref(), pythonize(py, v)?))) .collect::>>()?; - Ok(items.into_pyobject(py)?) + items.into_pyobject(py) } #[pyo3(signature = (key, default=None))] @@ -218,21 +217,20 @@ impl JsonObjectMutable { fn __iter__<'py>(&self, py: Python<'py>) -> PyResult> { let obj = self.object.read().unwrap(); - Ok(PyIterator::from_object( + PyIterator::from_object( &obj.keys() .map(|k| &**k) .collect::>() .into_pyobject(py)?, - )?) + ) } fn keys<'py>(&self, py: Python<'py>) -> PyResult> { let obj = self.object.read().unwrap(); - Ok(obj - .keys() + obj.keys() .map(|k| &**k) .collect::>() - .into_pyobject(py)?) + .into_pyobject(py) } fn values<'py>(&self, py: Python<'py>) -> PyResult> { @@ -241,7 +239,7 @@ impl JsonObjectMutable { .values() .map(|v| pythonize(py, v)) .collect::, _>>()?; - Ok(values.into_pyobject(py)?) + values.into_pyobject(py) } fn items<'py>(&self, py: Python<'py>) -> PyResult> { @@ -250,7 +248,7 @@ impl JsonObjectMutable { .iter() .map(|(k, v)| PyResult::Ok((k.as_ref(), pythonize(py, v)?))) .collect::>>()?; - Ok(items.into_pyobject(py)?) + items.into_pyobject(py) } #[pyo3(signature = (key, default=None))] @@ -632,6 +630,10 @@ impl Event { let rejected_reason = rejected_reason.map(String::into_boxed_str); let event_format_enum = match room_version.event_format { + EventFormatVersions::ROOM_V1_V2 => { + let event_format = depythonize(event_dict)?; + EventFormatEnum::V1(event_format) + } EventFormatVersions::ROOM_V3 | EventFormatVersions::ROOM_V4_PLUS => { let event_format = depythonize(event_dict)?; EventFormatEnum::V2V3(event_format) @@ -651,15 +653,19 @@ impl Event { let internal_metadata = Py::new(py, EventInternalMetadata::new(internal_metadata_dict)?)?; - let event_id = { - if room_version.event_format == EventFormatVersions::ROOM_V1_V2 { - // Read the event ID From the event - todo!() - } else { + let event_id = match &event_format_enum { + EventFormatEnum::V1(format) => { + // V1/V2 events have the event_id in the event dict. + let id_str: &str = &format.specific_fields.event_id; + id_str.try_into().map_err(|err: IdentifierError| { + PyValueError::new_err(format!("Invalid event_id: {}", err)) + })? + } + _ => { let event_value = serde_json::to_value(&event_format_enum).map_err(|err| { PyException::new_err(format!("Failed to serialize event: {}", err)) })?; - calculate_event_id(&event_value, &room_version).map_err(|err| { + calculate_event_id(&event_value, room_version).map_err(|err| { PyException::new_err(format!("Failed to calculate event_id: {}", err)) })? } @@ -676,9 +682,9 @@ impl Event { fn get_dict<'py>(&self, py: Python<'py>) -> PyResult> { match &self.inner { + EventFormatEnum::V1(format) => Ok(pythonize(py, format)?), EventFormatEnum::V2V3(format) => Ok(pythonize(py, format)?), EventFormatEnum::V4(format) => Ok(pythonize(py, format)?), - // ... } } @@ -732,95 +738,58 @@ impl Event { #[getter] fn room_id(&self) -> PyResult> { match &self.inner { + EventFormatEnum::V1(format) => Ok(format.specific_fields.room_id.as_ref().into()), EventFormatEnum::V2V3(format) => Ok(format.specific_fields.room_id.as_ref().into()), EventFormatEnum::V4(format) => Ok(format.room_id(&self.event_id)?), } } #[getter] - fn signatures(&self) -> PyResult { - match &self.inner { - EventFormatEnum::V2V3(format) => Ok(format.common_fields.signatures.clone()), - EventFormatEnum::V4(format) => Ok(format.common_fields.signatures.clone()), - // ... - } + fn signatures(&self) -> Signatures { + self.inner.common_fields().signatures.clone() } #[getter] - fn content(&self) -> PyResult { - match &self.inner { - EventFormatEnum::V2V3(format) => Ok(format.common_fields.content.clone()), - EventFormatEnum::V4(format) => Ok(format.common_fields.content.clone()), - // ... - } + fn content(&self) -> JsonObject { + self.inner.common_fields().content.clone() } #[getter] - fn depth(&self) -> PyResult { - match &self.inner { - EventFormatEnum::V2V3(format) => Ok(format.common_fields.depth), - EventFormatEnum::V4(format) => Ok(format.common_fields.depth), - // ... - } + fn depth(&self) -> i64 { + self.inner.common_fields().depth } #[getter] - fn hashes(&self) -> PyResult<&HashMap> { - match &self.inner { - EventFormatEnum::V2V3(format) => Ok(&format.common_fields.hashes), - EventFormatEnum::V4(format) => Ok(&format.common_fields.hashes), - // ... - } + fn hashes(&self) -> &HashMap { + &self.inner.common_fields().hashes } #[getter] - fn origin_server_ts(&self) -> PyResult { - match &self.inner { - EventFormatEnum::V2V3(format) => Ok(format.common_fields.origin_server_ts), - EventFormatEnum::V4(format) => Ok(format.common_fields.origin_server_ts), - // ... - } + fn origin_server_ts(&self) -> i64 { + self.inner.common_fields().origin_server_ts } #[getter] - fn sender(&self) -> PyResult<&str> { - match &self.inner { - EventFormatEnum::V2V3(format) => Ok(&format.common_fields.sender), - EventFormatEnum::V4(format) => Ok(&format.common_fields.sender), - // ... - } + fn sender(&self) -> &str { + &self.inner.common_fields().sender } #[getter(state_key)] fn state_key_attr(&self) -> PyResult<&str> { - let state_key = match &self.inner { - EventFormatEnum::V2V3(format) => &format.common_fields.state_key, - EventFormatEnum::V4(format) => &format.common_fields.state_key, - // ... - }; - - let Some(state_key) = state_key.as_deref() else { + let Some(state_key) = self.inner.common_fields().state_key.as_deref() else { return Err(PyAttributeError::new_err("state_key")); }; Ok(state_key) } #[getter] - fn r#type(&self) -> PyResult<&str> { - match &self.inner { - EventFormatEnum::V2V3(format) => Ok(&format.common_fields.type_), - EventFormatEnum::V4(format) => Ok(&format.common_fields.type_), - // ... - } + fn r#type(&self) -> &str { + &self.inner.common_fields().type_ } #[getter] - fn unsigned(&self) -> PyResult { - match &self.inner { - EventFormatEnum::V2V3(format) => Ok(format.common_fields.unsigned.clone()), - EventFormatEnum::V4(format) => Ok(format.common_fields.unsigned.clone()), - // ... - } + fn unsigned(&self) -> JsonObjectMutable { + self.inner.common_fields().unsigned.clone() } #[getter] @@ -838,53 +807,40 @@ impl Event { self.room_version } - fn prev_event_ids(&self) -> PyResult> { + fn prev_event_ids(&self) -> Vec { match &self.inner { - EventFormatEnum::V2V3(format) => Ok(format.specific_fields.prev_events.clone()), - EventFormatEnum::V4(format) => Ok(format.specific_fields.prev_events.clone()), - // ... + EventFormatEnum::V1(format) => format.prev_event_ids(), + EventFormatEnum::V2V3(format) => format.specific_fields.prev_events.clone(), + EventFormatEnum::V4(format) => format.specific_fields.prev_events.clone(), } } fn auth_event_ids(&self) -> PyResult> { match &self.inner { + EventFormatEnum::V1(format) => Ok(format.auth_event_ids()), EventFormatEnum::V2V3(format) => Ok(format.auth_event_ids()), EventFormatEnum::V4(format) => Ok(format.auth_event_ids()?), - // ... } } #[getter] fn membership<'py>(&self, py: Python<'py>) -> PyResult> { - let content = self.content()?; + let content = self.content(); content.__getitem__(py, "membership") } #[getter] fn redacts<'py>(&self, py: Python<'py>) -> PyResult>> { + let common = self.inner.common_fields(); let value = if !self.room_version.updated_redaction_rules { - let other_fields = match &self.inner { - EventFormatEnum::V2V3(format) => &format.common_fields.other_fields, - EventFormatEnum::V4(format) => &format.common_fields.other_fields, - // ... - }; - - let Some(value) = other_fields.get("redacts") else { + let Some(value) = common.other_fields.get("redacts") else { return Ok(None); }; - value } else { - let content = match &self.inner { - EventFormatEnum::V2V3(format) => &format.common_fields.content, - EventFormatEnum::V4(format) => &format.common_fields.content, - // ... - }; - - let Some(value) = content.object.get("redacts") else { + let Some(value) = common.content.object.get("redacts") else { return Ok(None); }; - value }; @@ -892,19 +848,11 @@ impl Event { } fn is_state(&self) -> bool { - match &self.inner { - EventFormatEnum::V2V3(format) => format.common_fields.state_key.is_some(), - EventFormatEnum::V4(format) => format.common_fields.state_key.is_some(), - // ... - } + self.inner.common_fields().state_key.is_some() } fn get_state_key(&self) -> Option<&str> { - match &self.inner { - EventFormatEnum::V2V3(format) => format.common_fields.state_key.as_deref(), - EventFormatEnum::V4(format) => format.common_fields.state_key.as_deref(), - // ... - } + self.inner.common_fields().state_key.as_deref() } fn __contains__<'py>(&self, py: Python<'py>, key: &str) -> PyResult { @@ -948,9 +896,53 @@ impl Event { #[derive(Serialize)] #[serde(untagged)] enum EventFormatEnum { + V1(EventFormatV1Container), V2V3(EventFormatV2V3Container), V4(EventFormatV4Container), - // ... +} + +impl EventFormatEnum { + fn common_fields(&self) -> &EventCommonFields { + match self { + EventFormatEnum::V1(f) => &f.common_fields, + EventFormatEnum::V2V3(f) => &f.common_fields, + EventFormatEnum::V4(f) => &f.common_fields, + } + } +} + +#[derive(Serialize, Deserialize)] +struct EventFormatV1 { + auth_events: Vec<(String, HashMap)>, + prev_events: Vec<(String, HashMap)>, + room_id: Box, + event_id: Box, +} + +#[derive(Serialize, Deserialize)] +struct EventFormatV1Container { + #[serde(flatten)] + specific_fields: EventFormatV1, + #[serde(flatten)] + common_fields: EventCommonFields, +} + +impl EventFormatV1Container { + fn auth_event_ids(&self) -> Vec { + self.specific_fields + .auth_events + .iter() + .map(|(id, _)| id.clone()) + .collect() + } + + fn prev_event_ids(&self) -> Vec { + self.specific_fields + .prev_events + .iter() + .map(|(id, _)| id.clone()) + .collect() + } } #[derive(Serialize, Deserialize)] @@ -1122,4 +1114,26 @@ mod tests { "!BeXKh925K_M46DwsuJFR0EyBpE1P7CFUDGuWW4xw55Y" ); } + + #[test] + fn test_basic_v1_roundtrip() { + let json = r#"{"auth_events":[["$auth1:localhost",{"sha256":"abc"}],["$auth2:localhost",{"sha256":"def"}]],"prev_events":[["$prev1:localhost",{"sha256":"ghi"}]],"type":"m.room.message","sender":"@user:localhost","content":{"body":"hello","msgtype":"m.text"},"depth":5,"room_id":"!room:localhost","event_id":"$event1:localhost","origin_server_ts":1234567890,"hashes":{"sha256":"base64hash"},"signatures":{"localhost":{"ed25519:key":"sig"}},"unsigned":{}}"#; + let event_value: serde_json::Value = serde_json::from_str(json).unwrap(); + + let event: EventFormatV1Container = serde_json::from_str(json).unwrap(); + let parsed_value = serde_json::to_value(&event).unwrap(); + + assert_eq!(&*event.common_fields.type_, "m.room.message"); + assert_eq!(&*event.specific_fields.room_id, "!room:localhost"); + assert_eq!(&*event.specific_fields.event_id, "$event1:localhost"); + + // Check auth/prev event extraction + let auth_ids = event.auth_event_ids(); + assert_eq!(auth_ids, vec!["$auth1:localhost", "$auth2:localhost"]); + + let prev_ids = event.prev_event_ids(); + assert_eq!(prev_ids, vec!["$prev1:localhost"]); + + assert_eq!(event_value, parsed_value); + } } diff --git a/synapse/__init__.py b/synapse/__init__.py index 4918fca099..9bfc52a868 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -109,4 +109,5 @@ from synapse.synapse_rust.events import ( # noqa: E402 Mapping.register(JsonObject) Mapping.register(Signatures) Mapping.register(DomainSignatures) +Mapping.register(JsonObjectMutable) MutableMapping.register(JsonObjectMutable) diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 9b47c20437..e5699dd5e9 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -38,6 +38,7 @@ from synapse.api.constants import EduTypes, EventContentFields from synapse.api.errors import SynapseError from synapse.api.presence import UserPresenceState from synapse.events import EventBase, relation_from_event +from synapse.synapse_rust.events import Event from synapse.types import JsonDict, JsonMapping, RoomID, UserID if TYPE_CHECKING: @@ -410,7 +411,7 @@ class Filter: # Check if the event has a relation. rel_type = None - if isinstance(event, EventBase): + if isinstance(event, EventBase) or isinstance(event, Event): relation = relation_from_event(event) if relation: rel_type = relation.rel_type @@ -493,7 +494,11 @@ class Filter: self, events: Collection[FilterEvent] ) -> list[FilterEvent]: # The event IDs to check, mypy doesn't understand the isinstance check. - event_ids = [event.event_id for event in events if isinstance(event, EventBase)] # type: ignore[attr-defined] + event_ids = [ + event.event_id + for event in events + if isinstance(event, EventBase) or isinstance(event, Event) + ] # type: ignore[attr-defined] event_ids_to_keep = set( await self._store.events_have_relations( event_ids, self.related_by_senders, self.related_by_rel_types @@ -503,7 +508,8 @@ class Filter: return [ event for event in events - if not isinstance(event, EventBase) or event.event_id in event_ids_to_keep + if not (isinstance(event, EventBase) or isinstance(event, Event)) + or event.event_id in event_ids_to_keep ] async def filter(self, events: Iterable[FilterEvent]) -> list[FilterEvent]: diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index a1ef3121c7..27413d3d56 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -586,7 +586,7 @@ def _event_type_from_format_version( """ if format_version == EventFormatVersions.ROOM_V1_V2: - return FrozenEvent + return Event elif format_version == EventFormatVersions.ROOM_V3: return Event elif format_version == EventFormatVersions.ROOM_V4_PLUS: diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index 12ef42866d..b980c9de7f 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -49,18 +49,24 @@ if TYPE_CHECKING: def MockEvent(**kwargs: Any) -> EventBase: - if "event_id" not in kwargs: - kwargs["event_id"] = "fake_event_id" - if "type" not in kwargs: - kwargs["type"] = "fake_type" - if "content" not in kwargs: - kwargs["content"] = {} - # Move internal metadata out so we can call make_event properly internal_metadata = kwargs.get("internal_metadata") if internal_metadata is not None: kwargs.pop("internal_metadata") + kwargs.setdefault("event_id", "$fake_event_id") + kwargs.setdefault("type", "fake_type") + kwargs.setdefault("auth_events", []) + kwargs.setdefault("prev_events", []) + kwargs.setdefault("content", {}) + kwargs.setdefault("hashes", {}) + kwargs.setdefault("signatures", {}) + kwargs.setdefault("unsigned", {}) + kwargs.setdefault("sender", "@fake_sender:domain") + kwargs.setdefault("room_id", "!fake_room_id") + kwargs.setdefault("depth", 0) + kwargs.setdefault("origin_server_ts", 0) + return make_event_from_dict(kwargs, internal_metadata_dict=internal_metadata)