This commit is contained in:
Erik Johnston
2026-04-09 10:22:25 +01:00
parent 3fc8480135
commit 61b71dffa0
5 changed files with 148 additions and 121 deletions

View File

@@ -41,7 +41,7 @@ use serde::{Deserialize, Serialize};
use crate::{
events::{internal_metadata::EventInternalMetadata, utils::calculate_event_id},
identifier::EventID,
identifier::{EventID, IdentifierError},
room_versions::{EventFormatVersions, RoomVersion},
};
@@ -105,23 +105,22 @@ impl JsonObject {
}
fn __iter__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyIterator>> {
Ok(PyIterator::from_object(
PyIterator::from_object(
&self
.object
.keys()
.map(|k| &**k)
.collect::<Vec<_>>()
.into_pyobject(py)?,
)?)
)
}
fn keys<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
Ok(self
.object
self.object
.keys()
.map(|k| &**k)
.collect::<Vec<_>>()
.into_pyobject(py)?)
.into_pyobject(py)
}
fn values<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
@@ -130,7 +129,7 @@ impl JsonObject {
.values()
.map(|v| pythonize(py, v))
.collect::<Result<Vec<_>, _>>()?;
Ok(values.into_pyobject(py)?)
values.into_pyobject(py)
}
fn items<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
@@ -139,7 +138,7 @@ impl JsonObject {
.iter()
.map(|(k, v)| PyResult::Ok((k.as_ref(), pythonize(py, v)?)))
.collect::<PyResult<Vec<_>>>()?;
Ok(items.into_pyobject(py)?)
items.into_pyobject(py)
}
#[pyo3(signature = (key, default=None))]
@@ -218,21 +217,20 @@ impl JsonObjectMutable {
fn __iter__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyIterator>> {
let obj = self.object.read().unwrap();
Ok(PyIterator::from_object(
PyIterator::from_object(
&obj.keys()
.map(|k| &**k)
.collect::<Vec<_>>()
.into_pyobject(py)?,
)?)
)
}
fn keys<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let obj = self.object.read().unwrap();
Ok(obj
.keys()
obj.keys()
.map(|k| &**k)
.collect::<Vec<_>>()
.into_pyobject(py)?)
.into_pyobject(py)
}
fn values<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
@@ -241,7 +239,7 @@ impl JsonObjectMutable {
.values()
.map(|v| pythonize(py, v))
.collect::<Result<Vec<_>, _>>()?;
Ok(values.into_pyobject(py)?)
values.into_pyobject(py)
}
fn items<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
@@ -250,7 +248,7 @@ impl JsonObjectMutable {
.iter()
.map(|(k, v)| PyResult::Ok((k.as_ref(), pythonize(py, v)?)))
.collect::<PyResult<Vec<_>>>()?;
Ok(items.into_pyobject(py)?)
items.into_pyobject(py)
}
#[pyo3(signature = (key, default=None))]
@@ -632,6 +630,10 @@ impl Event {
let rejected_reason = rejected_reason.map(String::into_boxed_str);
let event_format_enum = match room_version.event_format {
EventFormatVersions::ROOM_V1_V2 => {
let event_format = depythonize(event_dict)?;
EventFormatEnum::V1(event_format)
}
EventFormatVersions::ROOM_V3 | EventFormatVersions::ROOM_V4_PLUS => {
let event_format = depythonize(event_dict)?;
EventFormatEnum::V2V3(event_format)
@@ -651,15 +653,19 @@ impl Event {
let internal_metadata = Py::new(py, EventInternalMetadata::new(internal_metadata_dict)?)?;
let event_id = {
if room_version.event_format == EventFormatVersions::ROOM_V1_V2 {
// Read the event ID From the event
todo!()
} else {
let event_id = match &event_format_enum {
EventFormatEnum::V1(format) => {
// V1/V2 events have the event_id in the event dict.
let id_str: &str = &format.specific_fields.event_id;
id_str.try_into().map_err(|err: IdentifierError| {
PyValueError::new_err(format!("Invalid event_id: {}", err))
})?
}
_ => {
let event_value = serde_json::to_value(&event_format_enum).map_err(|err| {
PyException::new_err(format!("Failed to serialize event: {}", err))
})?;
calculate_event_id(&event_value, &room_version).map_err(|err| {
calculate_event_id(&event_value, room_version).map_err(|err| {
PyException::new_err(format!("Failed to calculate event_id: {}", err))
})?
}
@@ -676,9 +682,9 @@ impl Event {
fn get_dict<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
match &self.inner {
EventFormatEnum::V1(format) => Ok(pythonize(py, format)?),
EventFormatEnum::V2V3(format) => Ok(pythonize(py, format)?),
EventFormatEnum::V4(format) => Ok(pythonize(py, format)?),
// ...
}
}
@@ -732,95 +738,58 @@ impl Event {
#[getter]
fn room_id(&self) -> PyResult<Cow<'_, str>> {
match &self.inner {
EventFormatEnum::V1(format) => Ok(format.specific_fields.room_id.as_ref().into()),
EventFormatEnum::V2V3(format) => Ok(format.specific_fields.room_id.as_ref().into()),
EventFormatEnum::V4(format) => Ok(format.room_id(&self.event_id)?),
}
}
#[getter]
fn signatures(&self) -> PyResult<Signatures> {
match &self.inner {
EventFormatEnum::V2V3(format) => Ok(format.common_fields.signatures.clone()),
EventFormatEnum::V4(format) => Ok(format.common_fields.signatures.clone()),
// ...
}
fn signatures(&self) -> Signatures {
self.inner.common_fields().signatures.clone()
}
#[getter]
fn content(&self) -> PyResult<JsonObject> {
match &self.inner {
EventFormatEnum::V2V3(format) => Ok(format.common_fields.content.clone()),
EventFormatEnum::V4(format) => Ok(format.common_fields.content.clone()),
// ...
}
fn content(&self) -> JsonObject {
self.inner.common_fields().content.clone()
}
#[getter]
fn depth(&self) -> PyResult<i64> {
match &self.inner {
EventFormatEnum::V2V3(format) => Ok(format.common_fields.depth),
EventFormatEnum::V4(format) => Ok(format.common_fields.depth),
// ...
}
fn depth(&self) -> i64 {
self.inner.common_fields().depth
}
#[getter]
fn hashes(&self) -> PyResult<&HashMap<String, String>> {
match &self.inner {
EventFormatEnum::V2V3(format) => Ok(&format.common_fields.hashes),
EventFormatEnum::V4(format) => Ok(&format.common_fields.hashes),
// ...
}
fn hashes(&self) -> &HashMap<String, String> {
&self.inner.common_fields().hashes
}
#[getter]
fn origin_server_ts(&self) -> PyResult<i64> {
match &self.inner {
EventFormatEnum::V2V3(format) => Ok(format.common_fields.origin_server_ts),
EventFormatEnum::V4(format) => Ok(format.common_fields.origin_server_ts),
// ...
}
fn origin_server_ts(&self) -> i64 {
self.inner.common_fields().origin_server_ts
}
#[getter]
fn sender(&self) -> PyResult<&str> {
match &self.inner {
EventFormatEnum::V2V3(format) => Ok(&format.common_fields.sender),
EventFormatEnum::V4(format) => Ok(&format.common_fields.sender),
// ...
}
fn sender(&self) -> &str {
&self.inner.common_fields().sender
}
#[getter(state_key)]
fn state_key_attr(&self) -> PyResult<&str> {
let state_key = match &self.inner {
EventFormatEnum::V2V3(format) => &format.common_fields.state_key,
EventFormatEnum::V4(format) => &format.common_fields.state_key,
// ...
};
let Some(state_key) = state_key.as_deref() else {
let Some(state_key) = self.inner.common_fields().state_key.as_deref() else {
return Err(PyAttributeError::new_err("state_key"));
};
Ok(state_key)
}
#[getter]
fn r#type(&self) -> PyResult<&str> {
match &self.inner {
EventFormatEnum::V2V3(format) => Ok(&format.common_fields.type_),
EventFormatEnum::V4(format) => Ok(&format.common_fields.type_),
// ...
}
fn r#type(&self) -> &str {
&self.inner.common_fields().type_
}
#[getter]
fn unsigned(&self) -> PyResult<JsonObjectMutable> {
match &self.inner {
EventFormatEnum::V2V3(format) => Ok(format.common_fields.unsigned.clone()),
EventFormatEnum::V4(format) => Ok(format.common_fields.unsigned.clone()),
// ...
}
fn unsigned(&self) -> JsonObjectMutable {
self.inner.common_fields().unsigned.clone()
}
#[getter]
@@ -838,53 +807,40 @@ impl Event {
self.room_version
}
fn prev_event_ids(&self) -> PyResult<Vec<String>> {
fn prev_event_ids(&self) -> Vec<String> {
match &self.inner {
EventFormatEnum::V2V3(format) => Ok(format.specific_fields.prev_events.clone()),
EventFormatEnum::V4(format) => Ok(format.specific_fields.prev_events.clone()),
// ...
EventFormatEnum::V1(format) => format.prev_event_ids(),
EventFormatEnum::V2V3(format) => format.specific_fields.prev_events.clone(),
EventFormatEnum::V4(format) => format.specific_fields.prev_events.clone(),
}
}
fn auth_event_ids(&self) -> PyResult<Vec<String>> {
match &self.inner {
EventFormatEnum::V1(format) => Ok(format.auth_event_ids()),
EventFormatEnum::V2V3(format) => Ok(format.auth_event_ids()),
EventFormatEnum::V4(format) => Ok(format.auth_event_ids()?),
// ...
}
}
#[getter]
fn membership<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let content = self.content()?;
let content = self.content();
content.__getitem__(py, "membership")
}
#[getter]
fn redacts<'py>(&self, py: Python<'py>) -> PyResult<Option<Bound<'py, PyAny>>> {
let common = self.inner.common_fields();
let value = if !self.room_version.updated_redaction_rules {
let other_fields = match &self.inner {
EventFormatEnum::V2V3(format) => &format.common_fields.other_fields,
EventFormatEnum::V4(format) => &format.common_fields.other_fields,
// ...
};
let Some(value) = other_fields.get("redacts") else {
let Some(value) = common.other_fields.get("redacts") else {
return Ok(None);
};
value
} else {
let content = match &self.inner {
EventFormatEnum::V2V3(format) => &format.common_fields.content,
EventFormatEnum::V4(format) => &format.common_fields.content,
// ...
};
let Some(value) = content.object.get("redacts") else {
let Some(value) = common.content.object.get("redacts") else {
return Ok(None);
};
value
};
@@ -892,19 +848,11 @@ impl Event {
}
fn is_state(&self) -> bool {
match &self.inner {
EventFormatEnum::V2V3(format) => format.common_fields.state_key.is_some(),
EventFormatEnum::V4(format) => format.common_fields.state_key.is_some(),
// ...
}
self.inner.common_fields().state_key.is_some()
}
fn get_state_key(&self) -> Option<&str> {
match &self.inner {
EventFormatEnum::V2V3(format) => format.common_fields.state_key.as_deref(),
EventFormatEnum::V4(format) => format.common_fields.state_key.as_deref(),
// ...
}
self.inner.common_fields().state_key.as_deref()
}
fn __contains__<'py>(&self, py: Python<'py>, key: &str) -> PyResult<bool> {
@@ -948,9 +896,53 @@ impl Event {
#[derive(Serialize)]
#[serde(untagged)]
enum EventFormatEnum {
V1(EventFormatV1Container),
V2V3(EventFormatV2V3Container),
V4(EventFormatV4Container),
// ...
}
impl EventFormatEnum {
fn common_fields(&self) -> &EventCommonFields {
match self {
EventFormatEnum::V1(f) => &f.common_fields,
EventFormatEnum::V2V3(f) => &f.common_fields,
EventFormatEnum::V4(f) => &f.common_fields,
}
}
}
#[derive(Serialize, Deserialize)]
struct EventFormatV1 {
auth_events: Vec<(String, HashMap<String, String>)>,
prev_events: Vec<(String, HashMap<String, String>)>,
room_id: Box<str>,
event_id: Box<str>,
}
#[derive(Serialize, Deserialize)]
struct EventFormatV1Container {
#[serde(flatten)]
specific_fields: EventFormatV1,
#[serde(flatten)]
common_fields: EventCommonFields,
}
impl EventFormatV1Container {
fn auth_event_ids(&self) -> Vec<String> {
self.specific_fields
.auth_events
.iter()
.map(|(id, _)| id.clone())
.collect()
}
fn prev_event_ids(&self) -> Vec<String> {
self.specific_fields
.prev_events
.iter()
.map(|(id, _)| id.clone())
.collect()
}
}
#[derive(Serialize, Deserialize)]
@@ -1122,4 +1114,26 @@ mod tests {
"!BeXKh925K_M46DwsuJFR0EyBpE1P7CFUDGuWW4xw55Y"
);
}
#[test]
fn test_basic_v1_roundtrip() {
let json = r#"{"auth_events":[["$auth1:localhost",{"sha256":"abc"}],["$auth2:localhost",{"sha256":"def"}]],"prev_events":[["$prev1:localhost",{"sha256":"ghi"}]],"type":"m.room.message","sender":"@user:localhost","content":{"body":"hello","msgtype":"m.text"},"depth":5,"room_id":"!room:localhost","event_id":"$event1:localhost","origin_server_ts":1234567890,"hashes":{"sha256":"base64hash"},"signatures":{"localhost":{"ed25519:key":"sig"}},"unsigned":{}}"#;
let event_value: serde_json::Value = serde_json::from_str(json).unwrap();
let event: EventFormatV1Container = serde_json::from_str(json).unwrap();
let parsed_value = serde_json::to_value(&event).unwrap();
assert_eq!(&*event.common_fields.type_, "m.room.message");
assert_eq!(&*event.specific_fields.room_id, "!room:localhost");
assert_eq!(&*event.specific_fields.event_id, "$event1:localhost");
// Check auth/prev event extraction
let auth_ids = event.auth_event_ids();
assert_eq!(auth_ids, vec!["$auth1:localhost", "$auth2:localhost"]);
let prev_ids = event.prev_event_ids();
assert_eq!(prev_ids, vec!["$prev1:localhost"]);
assert_eq!(event_value, parsed_value);
}
}

View File

@@ -109,4 +109,5 @@ from synapse.synapse_rust.events import ( # noqa: E402
Mapping.register(JsonObject)
Mapping.register(Signatures)
Mapping.register(DomainSignatures)
Mapping.register(JsonObjectMutable)
MutableMapping.register(JsonObjectMutable)

View File

@@ -38,6 +38,7 @@ from synapse.api.constants import EduTypes, EventContentFields
from synapse.api.errors import SynapseError
from synapse.api.presence import UserPresenceState
from synapse.events import EventBase, relation_from_event
from synapse.synapse_rust.events import Event
from synapse.types import JsonDict, JsonMapping, RoomID, UserID
if TYPE_CHECKING:
@@ -410,7 +411,7 @@ class Filter:
# Check if the event has a relation.
rel_type = None
if isinstance(event, EventBase):
if isinstance(event, EventBase) or isinstance(event, Event):
relation = relation_from_event(event)
if relation:
rel_type = relation.rel_type
@@ -493,7 +494,11 @@ class Filter:
self, events: Collection[FilterEvent]
) -> list[FilterEvent]:
# The event IDs to check, mypy doesn't understand the isinstance check.
event_ids = [event.event_id for event in events if isinstance(event, EventBase)] # type: ignore[attr-defined]
event_ids = [
event.event_id
for event in events
if isinstance(event, EventBase) or isinstance(event, Event)
] # type: ignore[attr-defined]
event_ids_to_keep = set(
await self._store.events_have_relations(
event_ids, self.related_by_senders, self.related_by_rel_types
@@ -503,7 +508,8 @@ class Filter:
return [
event
for event in events
if not isinstance(event, EventBase) or event.event_id in event_ids_to_keep
if not (isinstance(event, EventBase) or isinstance(event, Event))
or event.event_id in event_ids_to_keep
]
async def filter(self, events: Iterable[FilterEvent]) -> list[FilterEvent]:

View File

@@ -586,7 +586,7 @@ def _event_type_from_format_version(
"""
if format_version == EventFormatVersions.ROOM_V1_V2:
return FrozenEvent
return Event
elif format_version == EventFormatVersions.ROOM_V3:
return Event
elif format_version == EventFormatVersions.ROOM_V4_PLUS:

View File

@@ -49,18 +49,24 @@ if TYPE_CHECKING:
def MockEvent(**kwargs: Any) -> EventBase:
if "event_id" not in kwargs:
kwargs["event_id"] = "fake_event_id"
if "type" not in kwargs:
kwargs["type"] = "fake_type"
if "content" not in kwargs:
kwargs["content"] = {}
# Move internal metadata out so we can call make_event properly
internal_metadata = kwargs.get("internal_metadata")
if internal_metadata is not None:
kwargs.pop("internal_metadata")
kwargs.setdefault("event_id", "$fake_event_id")
kwargs.setdefault("type", "fake_type")
kwargs.setdefault("auth_events", [])
kwargs.setdefault("prev_events", [])
kwargs.setdefault("content", {})
kwargs.setdefault("hashes", {})
kwargs.setdefault("signatures", {})
kwargs.setdefault("unsigned", {})
kwargs.setdefault("sender", "@fake_sender:domain")
kwargs.setdefault("room_id", "!fake_room_id")
kwargs.setdefault("depth", 0)
kwargs.setdefault("origin_server_ts", 0)
return make_event_from_dict(kwargs, internal_metadata_dict=internal_metadata)