From c430c16df47229f8ecef6783739accc042fcafbe Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 8 May 2026 14:19:03 +0100 Subject: [PATCH] Port event content to Rust (#19725) Based on #19708. This is on the path to porting the entire event class to Rust, as `event.content` will then return the new Rust class `JsonObject`. This PR adds a pure Rust `JsonObject` class that is a `Mapping` representing a json-style object. It uses `serde_json::Value` as its in-memory representation and `pythonize` for conversion when a field is looked up on the object. I'm not thrilled with the name, but couldn't think of a better one. This also adds `JsonObject` handling to the JSON serialisation functions we use, as well as to the `freeze(..)` function. Reviewable commit-by-commit. --- changelog.d/19725.misc | 1 + rust/src/events/json_object.rs | 488 +++++++++++++++++++++++ rust/src/events/mod.rs | 12 +- synapse/__init__.py | 9 +- synapse/events/__init__.py | 53 +-- synapse/events/utils.py | 2 +- synapse/events/validator.py | 4 +- synapse/handlers/room.py | 9 +- synapse/handlers/stats.py | 4 +- synapse/push/bulk_push_rule_evaluator.py | 4 +- synapse/synapse_rust/events.pyi | 11 +- synapse/util/events.py | 4 +- synapse/util/frozenutils.py | 5 + synapse/util/json.py | 30 +- tests/crypto/test_event_signing.py | 7 +- tests/module_api/test_api.py | 4 +- tests/rest/client/test_rooms.py | 4 +- tests/server.py | 3 +- tests/synapse_rust/test_json_object.py | 149 +++++++ tests/test_state.py | 5 +- tests/test_utils/__init__.py | 4 +- 21 files changed, 749 insertions(+), 63 deletions(-) create mode 100644 changelog.d/19725.misc create mode 100644 rust/src/events/json_object.rs create mode 100644 tests/synapse_rust/test_json_object.py diff --git a/changelog.d/19725.misc b/changelog.d/19725.misc new file mode 100644 index 0000000000..b320f42b9c --- /dev/null +++ b/changelog.d/19725.misc @@ -0,0 +1 @@ +Port `Event.content` field to Rust. diff --git a/rust/src/events/json_object.rs b/rust/src/events/json_object.rs new file mode 100644 index 0000000000..2c4be1c87b --- /dev/null +++ b/rust/src/events/json_object.rs @@ -0,0 +1,488 @@ +/* + * This file is licensed under the Affero General Public License (AGPL) version 3. + * + * Copyright (C) 2026 Element Creations Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * See the GNU Affero General Public License for more details: + * . + * + */ + +use std::{collections::BTreeMap, sync::Arc}; + +use pyo3::{ + exceptions::{PyKeyError, PyTypeError}, + pyclass, pymethods, + types::{ + PyAnyMethods, PyIterator, PyList, PyListMethods, PyMapping, PySet, PySetMethods, PyTuple, + }, + Bound, IntoPyObject, IntoPyObjectExt, Py, PyAny, PyResult, Python, +}; +use pythonize::{depythonize, pythonize}; +use serde::{Deserialize, Serialize}; + +/// A generic class for representing immutable JSON objects. +/// +/// This is used for representing the `content` field of an event. +/// +/// The basic architecture here is to optimize for two things: +/// 1. Fast access of top-level keys (e.g. `event.content["key"]`) +/// 2. Pure Rust implementation. +#[derive(Serialize, Deserialize, Clone, Default)] +#[pyclass(mapping, frozen)] +#[serde(transparent)] +pub struct JsonObject { + object: Arc, serde_json::Value>>, +} + +#[pymethods] +impl JsonObject { + #[new] + #[pyo3(signature = (content = None))] + fn new<'a, 'py>(content: Option<&'a Bound<'py, PyAny>>) -> PyResult { + let Some(content) = content else { + // If no content is provided, default to an empty object. + return Ok(Self::default()); + }; + + if let Ok(content) = content.cast::() { + // If the content is already a JsonObject, we can just clone the + // underlying map (this is safe as the object is immutable). + return Ok(JsonObject { + object: content.get().object.clone(), + }); + } + + let Ok(content) = content.cast::() else { + return Err(PyTypeError::new_err("'content' must be a mapping")); + }; + + // Use pythonize to try and convert from a mapping. + let content = depythonize(content)?; + Ok(Self { + object: Arc::new(content), + }) + } + + fn __len__(&self) -> usize { + self.object.len() + } + + fn __contains__(&self, key: &Bound<'_, PyAny>) -> bool { + // Match dict semantics: a non-string key is simply "not in" the + // mapping, rather than raising TypeError. + let Ok(key_str) = key.extract::<&str>() else { + return false; + }; + self.object.contains_key(key_str) + } + + fn __getitem__<'py>( + &self, + py: Python<'py>, + key: Bound<'_, PyAny>, + ) -> PyResult> { + // We only ever store string keys, so any non-string lookup is a miss. + // Raise KeyError (not TypeError) to match dict's behaviour. + let Ok(key_str) = key.extract::<&str>() else { + return Err(PyKeyError::new_err(key.unbind())); + }; + let Some(value) = self.object.get(key_str) else { + return Err(PyKeyError::new_err(key.unbind())); + }; + Ok(pythonize(py, value)?) + } + + fn __iter__<'py>(&self, py: Python<'py>) -> PyResult> { + // The easiest way to get an iterator over the keys is to create a + // temporary list and call `iter()` on it. This is not the most + // efficient approach, but is much less boilerplate than implementing a + // custom iterator type. Since the keys are typically small in number + // this should be fine in practice. + let list = PyList::new(py, self.object.keys().map(Box::as_ref))?; + PyIterator::from_object(&list) + } + + // The view classes below each hold a `JsonObject` clone. This is cheap + // because the underlying map is behind an `Arc`, and lets the view outlive + // the originating object (matching dict_keys/values/items semantics in + // Python, which also keep the dict alive). + + fn keys(&self) -> JsonObjectKeysView { + JsonObjectKeysView { + object: self.clone(), + } + } + + fn values(&self) -> JsonObjectValuesView { + JsonObjectValuesView { + object: self.clone(), + } + } + + fn items(&self) -> JsonObjectItemsView { + JsonObjectItemsView { + object: self.clone(), + } + } + + #[pyo3(signature = (key, default=None))] + fn get<'py>( + &self, + py: Python<'py>, + key: Bound<'_, PyAny>, + default: Option>, + ) -> PyResult> { + // Non-string keys can never match, so treat them as a miss and return + // the caller-supplied default rather than raising. + let Ok(key_str) = key.extract::<&str>() else { + return Ok(default.into_pyobject(py)?); + }; + match self.object.get(key_str) { + Some(value) => Ok(pythonize(py, value)?), + None => Ok(default.into_pyobject(py)?), + } + } + + fn __eq__(&self, other: Bound<'_, PyAny>) -> bool { + // We support equality against any Python mapping (e.g. plain dicts), + // so callers can swap a JsonObject in without rewriting comparisons. + let Ok(mapping) = other.cast::() else { + return false; + }; + + let Ok(other_len) = mapping.len() else { + return false; + }; + + if other_len != self.object.len() { + return false; + } + + // We know the "other" is a mapping with the same number of fields as + // us. So we can convert it into a JsonObject and compare the underlying + // maps. + let Ok(other_dict) = depythonize(&other) else { + return false; + }; + + *self.object == other_dict + } + + // Since we implement comparisons with other types, we need to disable + // hashing to avoid violating the invariant that equal objects must have the + // same hash. + // + // Alternatively, we could only allow comparisons with other JsonObjects and + // allow hashing, but a) its nicer to be able to compare with arbitrary + // mappings and b) we don't really need hashing for these objects. + #[classattr] + const __hash__: Option> = None; + + fn __str__(&self) -> String { + serde_json::to_string(&self.object).expect("Value should be serializable") + } + + fn __repr__(&self) -> String { + format!("JsonObject({})", self.__str__()) + } +} + +/// Helper class returned by `JsonObject.keys()` to act as a view into the keys +/// of the object. +/// +/// This needs to both be iterable *and* operate like a set. +#[pyclass(frozen)] +#[derive(Clone)] +pub struct JsonObjectKeysView { + object: JsonObject, +} + +#[pymethods] +impl JsonObjectKeysView { + fn __iter__<'py>(&self, py: Python<'py>) -> PyResult> { + // Create the iterator by making a temporary python list of the keys and + // calling `iter()` on it. + let list = PyList::new(py, self.object.object.keys().map(Box::as_ref))?; + PyIterator::from_object(&list) + } + + fn __len__(&self) -> usize { + self.object.__len__() + } + + fn __contains__(&self, key: &Bound<'_, PyAny>) -> bool { + self.object.__contains__(key) + } + + fn __eq__(&self, other: Bound<'_, PyAny>) -> bool { + let other_len = match other.len() { + Ok(len) => len, + Err(_) => return false, + }; + + if self.object.__len__() != other_len { + return false; + } + + for key in self.object.object.keys() { + if !matches!(other.contains(key.as_ref()), Ok(true)) { + return false; + } + } + + true + } + + // The set operators below match the behaviour of `dict.keys()` in Python: + // they accept any object that supports `__contains__` (for `&`) or is + // iterable (for `|`, `-`, `^`), not just sets. Each returns a fresh + // `PySet` so the caller gets a normal mutable Python set back. + // + // The `__r*__` variants are reflected operators, called by Python when + // the left-hand operand doesn't know how to combine with us. Since these + // operations are commutative for sets (or symmetric in the case of `^`), + // they just delegate. The asymmetric ops (`-`) need a separate impl. + + fn __and__<'py>( + &self, + py: Python<'py>, + other: Bound<'_, PyAny>, + ) -> PyResult> { + // Iterate our (typically small) key set and probe `other`, which may + // be any container supporting `__contains__`. + let mut result = Vec::new(); + + for key in self.object.object.keys() { + if matches!(other.contains(key.as_ref()), Ok(true)) { + result.push(key.as_ref()); + } + } + + PySet::new(py, &result) + } + + fn __rand__<'py>( + &self, + py: Python<'py>, + other: Bound<'_, PyAny>, + ) -> PyResult> { + self.__and__(py, other) + } + + fn __or__<'py>(&self, py: Python<'py>, other: Bound<'_, PyAny>) -> PyResult> { + // Union needs to enumerate both sides, so the right operand must be + // iterable (a bare `__contains__` is not enough). + let Ok(other_iter) = other.try_iter() else { + return Err(PyTypeError::new_err("Right operand must be iterable")); + }; + + let result = PySet::new(py, self.object.object.keys().map(Box::as_ref))?; + + // PySet handles dedup, so we can blindly add every element from the + // other iterable. + for item in other_iter { + let item = item?; + result.add(item)?; + } + + Ok(result) + } + + fn __ror__<'py>( + &self, + py: Python<'py>, + other: Bound<'_, PyAny>, + ) -> PyResult> { + self.__or__(py, other) + } + + fn __sub__<'py>( + &self, + py: Python<'py>, + other: Bound<'_, PyAny>, + ) -> PyResult> { + // `self - other`: keep our keys that are not in `other`. Only `other` + // needs to support `__contains__` here. + let mut result = Vec::new(); + + for key in self.object.object.keys() { + if matches!(other.contains(key.as_ref()), Ok(true)) { + continue; + } + result.push(key.as_ref()); + } + + PySet::new(py, &result) + } + + fn __rsub__<'py>( + &self, + py: Python<'py>, + other: Bound<'_, PyAny>, + ) -> PyResult> { + // `other - self`: we need to enumerate `other`, so it must be + // iterable. Not symmetric with `__sub__`, hence a separate impl. + let Ok(other_iter) = other.try_iter() else { + return Err(PyTypeError::new_err("Left operand must be iterable")); + }; + + let result = PySet::empty(py)?; + + for item in other_iter { + let item = item?; + if self.object.__contains__(&item) { + continue; + } + result.add(item)?; + } + + Ok(result) + } + + fn __xor__<'py>( + &self, + py: Python<'py>, + other: Bound<'_, PyAny>, + ) -> PyResult> { + // Symmetric difference: elements in exactly one side. Implemented as + // two filtered passes — one over our keys, one over `other`. + let Ok(other_iter) = other.try_iter() else { + return Err(PyTypeError::new_err("Right operand must be iterable")); + }; + + let result = PySet::empty(py)?; + + for key in self.object.object.keys() { + if matches!(other.contains(key.as_ref()), Ok(true)) { + continue; + } + result.add(key.as_ref())?; + } + + for item in other_iter { + let item = item?; + if self.object.__contains__(&item) { + continue; + } + result.add(item)?; + } + + Ok(result) + } + + fn __rxor__<'py>( + &self, + py: Python<'py>, + other: Bound<'_, PyAny>, + ) -> PyResult> { + self.__xor__(py, other) + } + + fn isdisjoint(&self, other: Bound<'_, PyAny>) -> bool { + for key in self.object.object.keys() { + if matches!(other.contains(key.as_ref()), Ok(true)) { + return false; + } + } + + true + } +} + +/// Helper class returned by `JsonObject.values()` to act as a view into the +/// values of the object. +#[pyclass(frozen)] +#[derive(Clone)] +pub struct JsonObjectValuesView { + object: JsonObject, +} + +#[pymethods] +impl JsonObjectValuesView { + fn __iter__<'py>(&self, py: Python<'py>) -> PyResult> { + // Create the iterator by making a temporary python list of the keys and + // calling `iter()` on it. + let list = PyList::empty(py); + for v in self.object.object.values() { + let py_value = pythonize(py, v)?.into_bound_py_any(py)?; + list.append(py_value)?; + } + + PyIterator::from_object(&list) + } + + fn __len__(&self) -> usize { + self.object.__len__() + } + + fn __contains__(&self, other: Bound<'_, PyAny>) -> bool { + // We compare by JSON equality rather than Python identity: convert + // the candidate into a `serde_json::Value` once and scan our values. + // Anything that fails to depythonize cannot match by definition. + let other_value: serde_json::Value = match depythonize(&other) { + Ok(v) => v, + Err(_) => return false, + }; + self.object.object.values().any(|v| *v == other_value) + } +} + +/// Helper class returned by `JsonObject.items()` to act as a view into the +/// items of the object. +/// +/// Technically this should be a set-like view according to Python semantics, +/// unless the values are unhashable. Since the values are immutable we could +/// support it, but it's more work and nobody seems to actually use the set +/// operations on `dict_items` in practice. +#[pyclass(frozen)] +#[derive(Clone)] +pub struct JsonObjectItemsView { + object: JsonObject, +} + +#[pymethods] +impl JsonObjectItemsView { + fn __iter__<'py>(&self, py: Python<'py>) -> PyResult> { + // Create the iterator by making a temporary python list of the keys and + // calling `iter()` on it. + let list = PyList::empty(py); + for (k, v) in self.object.object.iter() { + let py_key = k.as_ref().into_bound_py_any(py)?; + let py_value = pythonize(py, v)?.into_bound_py_any(py)?; + let item = PyTuple::new(py, [py_key, py_value])?; + list.append(item)?; + } + + PyIterator::from_object(&list) + } + + fn __len__(&self) -> usize { + self.object.__len__() + } + + fn __contains__(&self, other: Bound<'_, PyAny>) -> bool { + // `(key, value) in items` — only a 2-tuple can possibly match. We + // look the key up directly (avoiding a full scan) and then compare + // the stored value against `value` using JSON equality. + let Ok((key, value)) = other.extract::<(Bound<'_, PyAny>, Bound<'_, PyAny>)>() else { + return false; + }; + let Ok(key_str) = key.extract::<&str>() else { + return false; + }; + let Some(stored) = self.object.object.get(key_str) else { + return false; + }; + let other_value: serde_json::Value = match depythonize(&value) { + Ok(v) => v, + Err(_) => return false, + }; + *stored == other_value + } +} diff --git a/rust/src/events/mod.rs b/rust/src/events/mod.rs index 5f505abb91..e60cdb7078 100644 --- a/rust/src/events/mod.rs +++ b/rust/src/events/mod.rs @@ -21,21 +21,31 @@ //! Classes for representing Events. use pyo3::{ - types::{PyAnyMethods, PyModule, PyModuleMethods}, + types::{PyAnyMethods, PyMapping, PyModule, PyModuleMethods}, wrap_pyfunction, Bound, PyResult, Python, }; pub mod filter; mod internal_metadata; +mod json_object; pub mod signatures; pub mod unsigned; +use json_object::JsonObject; + /// Called when registering modules with python. pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { + // Register the `JsonObject` class as a `Mapping` so that `isinstance` works. + PyMapping::register::(py)?; + let child_module = PyModule::new(py, "events")?; child_module.add_class::()?; child_module.add_class::()?; child_module.add_class::()?; + child_module.add_class::()?; + child_module.add_class::()?; + child_module.add_class::()?; + child_module.add_class::()?; child_module.add_function(wrap_pyfunction!(filter::event_visible_to_server_py, m)?)?; m.add_submodule(&child_module)?; diff --git a/synapse/__init__.py b/synapse/__init__.py index 2bed060878..3acfc1a0d7 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -65,7 +65,8 @@ try: except ImportError: pass -# Teach canonicaljson how to serialise immutabledicts. + +# Teach canonicaljson how to serialise immutabledicts and JsonObjects. try: from canonicaljson import register_preserialisation_callback from immutabledict import immutabledict @@ -79,6 +80,12 @@ try: return dict(d) register_preserialisation_callback(immutabledict, _immutabledict_cb) + + # Teach canonicaljson how to serialise JsonObjects, which is just to + # convert them to dicts. + from synapse.synapse_rust.events import JsonObject # noqa: E402 + + register_preserialisation_callback(JsonObject, lambda obj: dict(obj)) except ImportError: pass diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 0f850d19b1..5be0298c30 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -44,9 +44,15 @@ from synapse.api.constants import ( StickyEvent, ) from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions -from synapse.synapse_rust.events import EventInternalMetadata, Signatures, Unsigned +from synapse.synapse_rust.events import ( + EventInternalMetadata, + JsonObject, + Signatures, + Unsigned, +) from synapse.types import ( JsonDict, + JsonMapping, StateKey, StrCollection, ) @@ -206,17 +212,29 @@ class EventBase(metaclass=abc.ABCMeta): ): assert room_version.event_format == self.format_version + if "content" in event_dict: + event_dict["content"] = JsonObject(event_dict["content"]) + + # We intern these strings because they turn up a lot (especially when + # caching). + event_dict = intern_dict(event_dict) + + if USE_FROZEN_DICTS: + frozen_dict = freeze(event_dict) + else: + frozen_dict = event_dict + self.room_version = room_version self.signatures = Signatures(signatures) self.unsigned = Unsigned(unsigned) self.rejected_reason = rejected_reason - self._dict = event_dict + self._dict = frozen_dict self.internal_metadata = EventInternalMetadata(internal_metadata_dict) depth: DictProperty[int] = DictProperty("depth") - content: DictProperty[JsonDict] = DictProperty("content") + content: DictProperty[JsonMapping] = DictProperty("content") hashes: DictProperty[dict[str, str]] = DictProperty("hashes") origin_server_ts: DictProperty[int] = DictProperty("origin_server_ts") sender: DictProperty[str] = DictProperty("sender") @@ -259,7 +277,14 @@ class EventBase(metaclass=abc.ABCMeta): def get_dict(self) -> JsonDict: """Convert the event to a dictionary suitable for serialisation.""" + d = dict(self._dict) + if "content" in d: + # Convert the content (which is a JsonObject) back to a dict. Json + # serialization should handle JsonObjects fine, but for sanities + # sake we want `get_dict()` and `get_pdu_json()` to return plain + # dicts. + d["content"] = dict(self.content) d.update( { "signatures": self.signatures.as_dict(), @@ -419,19 +444,10 @@ class FrozenEvent(EventBase): unsigned = event_dict.pop("unsigned", {}) - # We intern these strings because they turn up a lot (especially when - # caching). - event_dict = intern_dict(event_dict) - - if USE_FROZEN_DICTS: - frozen_dict = freeze(event_dict) - else: - frozen_dict = event_dict - self._event_id = event_dict["event_id"] super().__init__( - frozen_dict, + event_dict, room_version=room_version, signatures=signatures, unsigned=unsigned, @@ -473,19 +489,10 @@ class FrozenEventV2(EventBase): unsigned = event_dict.pop("unsigned", {}) - # We intern these strings because they turn up a lot (especially when - # caching). - event_dict = intern_dict(event_dict) - - if USE_FROZEN_DICTS: - frozen_dict = freeze(event_dict) - else: - frozen_dict = event_dict - self._event_id: str | None = None super().__init__( - frozen_dict, + event_dict, room_version=room_version, signatures=signatures, unsigned=unsigned, diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 926c81b83d..adbede7f16 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -1032,7 +1032,7 @@ def strip_event(event: EventBase) -> JsonDict: return { "type": event.type, "state_key": event.state_key, - "content": event.content, + "content": dict(event.content), "sender": event.sender, } diff --git a/synapse/events/validator.py b/synapse/events/validator.py index ff22b2287f..d1b5152d77 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -42,7 +42,7 @@ from synapse.events.utils import ( ) from synapse.http.servlet import validate_json_object from synapse.storage.controllers.state import server_acl_evaluator_from_event -from synapse.types import EventID, JsonDict, RoomID, StrCollection, UserID +from synapse.types import EventID, JsonDict, JsonMapping, RoomID, StrCollection, UserID from synapse.types.rest import RequestBodyModel @@ -245,7 +245,7 @@ class EventValidator: self._ensure_state_event(event) - def _ensure_strings(self, d: JsonDict, keys: StrCollection) -> None: + def _ensure_strings(self, d: JsonMapping, keys: StrCollection) -> None: for s in keys: if s not in d: raise SynapseError(400, "'%s' not in content" % (s,)) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 6a5c76c667..13647caa2a 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -706,12 +706,12 @@ class RoomCreationHandler: spam_check = await self._spam_checker_module_callbacks.user_may_create_room( user_id, { - "creation_content": creation_content, + "creation_content": dict(creation_content), "initial_state": [ { "type": state_key[0], "state_key": state_key[1], - "content": event_content, + "content": dict(event_content), } for state_key, event_content in initial_state.items() ], @@ -1437,7 +1437,7 @@ class RoomCreationHandler: room_config: JsonDict, invite_list: list[str], initial_state: MutableStateMap, - creation_content: JsonDict, + creation_content: JsonMapping, room_alias: RoomAlias | None = None, power_level_content_override: JsonDict | None = None, creator_join_profile: JsonDict | None = None, @@ -1508,7 +1508,7 @@ class RoomCreationHandler: async def create_event( etype: str, - content: JsonDict, + content: JsonMapping, for_batch: bool, **kwargs: Any, ) -> tuple[EventBase, synapse.events.snapshot.UnpersistedEventContextBase]: @@ -1561,6 +1561,7 @@ class RoomCreationHandler: if creation_event_with_context is None: # MSC2175 removes the creator field from the create event. if not room_version.implicit_room_creator: + creation_content = dict(creation_content) creation_content["creator"] = creator_id creation_event, unpersisted_creation_context = await create_event( EventTypes.Create, creation_content, False diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index c87b5f854a..4fd69262b2 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -31,7 +31,7 @@ from typing import ( from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.metrics import SERVER_NAME_LABEL, event_processing_positions from synapse.storage.databases.main.state_deltas import StateDelta -from synapse.types import JsonDict +from synapse.types import JsonMapping from synapse.util.duration import Duration from synapse.util.events import get_plain_text_topic_from_event_content @@ -195,7 +195,7 @@ class StatsHandler: ) continue - event_content: JsonDict = {} + event_content: JsonMapping = {} if delta.event_id is not None: event = await self.store.get_event(delta.event_id, allow_none=True) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 7cf89200a8..03dd341744 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -51,7 +51,7 @@ from synapse.storage.databases.main.roommember import EventIdMembership from synapse.storage.invite_rule import InviteRule from synapse.storage.roommember import ProfileInfo from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator -from synapse.types import JsonValue +from synapse.types import JsonMapping, JsonValue from synapse.types.state import StateFilter from synapse.util import unwrapFirstError from synapse.util.async_helpers import gather_results @@ -231,7 +231,7 @@ class BulkPushRuleEvaluator: event: EventBase, context: EventContext, event_id_to_event: Mapping[str, EventBase], - ) -> tuple[dict, int | None]: + ) -> tuple[JsonMapping, int | None]: """ Given an event and an event context, get the power level event relevant to the event and the power level of the sender of the event. diff --git a/synapse/synapse_rust/events.pyi b/synapse/synapse_rust/events.pyi index 5ae2bb880a..5b55d47f0d 100644 --- a/synapse/synapse_rust/events.pyi +++ b/synapse/synapse_rust/events.pyi @@ -10,7 +10,7 @@ # See the GNU Affero General Public License for more details: # . -from typing import Any, Mapping +from typing import Any, Iterator, Mapping from synapse.types import JsonDict, JsonMapping @@ -213,3 +213,12 @@ class Unsigned: def for_event(self) -> JsonDict: ... """Return a dict of all unsigned fields, including those only kept in memory, suitable for inclusion in an event.""" + +class JsonObject(Mapping[str, Any]): + """Immutable JSON object mapping.""" + + def __init__(self, content_dict: JsonMapping | None = None): ... + def __len__(self) -> int: ... + def __getitem__(self, key: str) -> Any: ... + def __iter__(self) -> Iterator[str]: ... + def __eq__(self, other: object) -> bool: ... diff --git a/synapse/util/events.py b/synapse/util/events.py index 19eca1c1ae..e7c1c83a37 100644 --- a/synapse/util/events.py +++ b/synapse/util/events.py @@ -17,7 +17,7 @@ from typing import Any from pydantic import Field, StrictStr, ValidationError, field_validator -from synapse.types import JsonDict +from synapse.types import JsonMapping from synapse.util.pydantic_models import ParseModel from synapse.util.stringutils import random_string @@ -103,7 +103,7 @@ class TopicContent(ParseModel): return None -def get_plain_text_topic_from_event_content(content: JsonDict) -> str | None: +def get_plain_text_topic_from_event_content(content: JsonMapping) -> str | None: """ Given the `content` of an `m.room.topic` event, returns the plain-text topic representation. Prefers pulling plain-text from the newer `m.topic` field if diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index 0bc27410c6..92c03690f2 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -23,6 +23,8 @@ from typing import Any from immutabledict import immutabledict +from synapse.synapse_rust.events import JsonObject + def freeze(o: Any) -> Any: if isinstance(o, dict): @@ -31,6 +33,9 @@ def freeze(o: Any) -> Any: if isinstance(o, immutabledict): return o + if isinstance(o, JsonObject): + return o + if isinstance(o, (bytes, str)): return o diff --git a/synapse/util/json.py b/synapse/util/json.py index b1091704a8..8f8d731c6d 100644 --- a/synapse/util/json.py +++ b/synapse/util/json.py @@ -20,24 +20,30 @@ from typing import ( from immutabledict import immutabledict +from synapse.synapse_rust.events import JsonObject + def _reject_invalid_json(val: Any) -> None: """Do not allow Infinity, -Infinity, or NaN values in JSON.""" raise ValueError("Invalid JSON value: '%s'" % val) -def _handle_immutabledict(obj: Any) -> dict[Any, Any]: - """Helper for json_encoder. Makes immutabledicts serializable by returning - the underlying dict +def _handle_extra_mappings(obj: Any) -> dict[Any, Any]: + """Helper for json_encoder. Makes immutabledicts and JsonObjects + serializable """ - if type(obj) is immutabledict: - # fishing the protected dict out of the object is a bit nasty, - # but we don't really want the overhead of copying the dict. - try: - # Safety: we catch the AttributeError immediately below. - return obj._dict - except AttributeError: - # If all else fails, resort to making a copy of the immutabledict + match obj: + case immutabledict(): + # fishing the protected dict out of the object is a bit nasty, + # but we don't really want the overhead of copying the dict. + try: + # Safety: we catch the AttributeError immediately below. + return obj._dict + except AttributeError: + # If all else fails, resort to making a copy of the immutabledict + return dict(obj) + case JsonObject(): + # Just convert to a dict. return dict(obj) raise TypeError( "Object of type %s is not JSON serializable" % obj.__class__.__name__ @@ -49,7 +55,7 @@ def _handle_immutabledict(obj: Any) -> dict[Any, Any]: # * produces valid JSON (no NaNs etc) # * reduces redundant whitespace json_encoder = json.JSONEncoder( - allow_nan=False, separators=(",", ":"), default=_handle_immutabledict + allow_nan=False, separators=(",", ":"), default=_handle_extra_mappings ) # Create a custom decoder to reject Python extensions to JSON. diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py index 334ff64bc2..6ec0f64ffc 100644 --- a/tests/crypto/test_event_signing.py +++ b/tests/crypto/test_event_signing.py @@ -62,6 +62,7 @@ class EventSigningTestCase(unittest.TestCase): "origin_server_ts": 1000000, "signatures": {}, "type": "X", + "content": {}, "unsigned": {"age_ts": 1000000}, } @@ -74,7 +75,7 @@ class EventSigningTestCase(unittest.TestCase): self.assertTrue(hasattr(event, "hashes")) self.assertIn("sha256", event.hashes) self.assertEqual( - event.hashes["sha256"], "A6Nco6sqoy18PPfPDVdYvoowfc0PVBk9g9OiyT3ncRM" + event.hashes["sha256"], "mq4QfPPpC+QsBd6eqfVsmJIEz8uvMSVK0+AU67PLESk" ) self.assertTrue(hasattr(event, "signatures")) @@ -82,8 +83,8 @@ class EventSigningTestCase(unittest.TestCase): self.assertIn(KEY_NAME, event.signatures["domain"]) self.assertEqual( event.signatures[HOSTNAME][KEY_NAME], - "PBc48yDVszWB9TRaB/+CZC1B+pDAC10F8zll006j+NN" - "fe4PEMWcVuLaG63LFTK9e4rwJE8iLZMPtCKhDTXhpAQ", + "18rGIkd4JJXxw9m+1j3BtN+TmqmLip4VHvFbyXLngpB" + "LXOqbxlQViQABRzep2cODQ2aa5FnFgz+Llt2P03WiAw", ) def test_sign_message(self) -> None: diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index f1b20a12ec..2ba5da3b95 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -265,7 +265,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase): self.assertEqual(event.type, "m.room.message") self.assertEqual(event.room_id, room_id) self.assertFalse(hasattr(event, "state_key")) - self.assertDictEqual(event.content, content) + self.assertDictEqual(dict(event.content), content) expected_requester = create_requester( user_id, authenticated_entity=self.hs.hostname @@ -301,7 +301,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase): self.assertEqual(event.type, "m.room.power_levels") self.assertEqual(event.room_id, room_id) self.assertEqual(event.state_key, "") - self.assertDictEqual(event.content, content) + self.assertDictEqual(dict(event.content), content) # Check that the event was sent self.event_creation_handler.create_and_send_nonmember_event.assert_called_with( diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 61e7e87f62..28872fa06c 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -59,7 +59,7 @@ from synapse.rest.client import ( sync, ) from synapse.server import HomeServer -from synapse.types import JsonDict, RoomAlias, UserID, create_requester +from synapse.types import JsonDict, JsonMapping, RoomAlias, UserID, create_requester from synapse.util.clock import Clock from synapse.util.stringutils import random_string @@ -1859,7 +1859,7 @@ class RoomMessagesTestCase(RoomBase): mock_return_value: str | bool | Codes | tuple[Codes, JsonDict] | bool = ( "NOT_SPAM" ) - mock_content: JsonDict | None = None + mock_content: JsonMapping | None = None async def check_event_for_spam( self, diff --git a/tests/server.py b/tests/server.py index 55860701da..20fcc42081 100644 --- a/tests/server.py +++ b/tests/server.py @@ -101,6 +101,7 @@ from synapse.storage.engines import BaseDatabaseEngine, create_engine from synapse.storage.prepare_database import prepare_database from synapse.types import ISynapseReactor, JsonDict from synapse.util.clock import Clock +from synapse.util.json import json_encoder from tests.utils import ( LEAVE_DB, @@ -422,7 +423,7 @@ def make_request( path = b"/" + path if isinstance(content, dict): - content = json.dumps(content).encode("utf8") + content = json_encoder.encode(content).encode("utf8") if isinstance(content, str): content = content.encode("utf8") diff --git a/tests/synapse_rust/test_json_object.py b/tests/synapse_rust/test_json_object.py new file mode 100644 index 0000000000..77b188eee0 --- /dev/null +++ b/tests/synapse_rust/test_json_object.py @@ -0,0 +1,149 @@ +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2026 Element Creations Ltd. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . + +from synapse.synapse_rust.events import JsonObject +from synapse.util import MutableOverlayMapping + +from tests import unittest + + +class JsonObjectMappingTestCase(unittest.TestCase): + def test_new_and_basic_mapping_behavior(self) -> None: + obj = JsonObject({"a": 1, "b": 2}) + + self.assertEqual(len(obj), 2) + self.assertTrue("a" in obj) + self.assertTrue("b" in obj) + self.assertFalse("c" in obj) + self.assertFalse(123 in obj) # type: ignore[comparison-overlap] + + def test_getitem_and_key_errors(self) -> None: + obj = JsonObject({"a": 1, "b": 2}) + + self.assertEqual(obj["a"], 1) + + with self.assertRaises(KeyError): + _ = obj["missing"] + + with self.assertRaises(KeyError): + _ = obj[10] # type: ignore[index] + + def test_iter_keys_values_items(self) -> None: + obj = JsonObject({"a": 1, "b": 2}) + + iterator = iter(obj) + first = next(iterator) + second = next(iterator) + self.assertCountEqual((first, second), ("a", "b")) + with self.assertRaises(StopIteration): + next(iterator) + + self.assertCountEqual(list(obj.keys()), ["a", "b"]) + self.assertCountEqual(list(obj.values()), [1, 2]) + self.assertCountEqual(list(obj.items()), [("a", 1), ("b", 2)]) + + def test_keys_set_like_behavior(self) -> None: + obj = JsonObject({"a": 1, "b": 2}) + + # Test 'and' operator. + self.assertEqual(obj.keys() & {"a"}, {"a"}) + self.assertEqual({"a"} & obj.keys(), {"a"}) + self.assertEqual(obj.keys() & {"c"}, set()) + self.assertEqual({"c"} & obj.keys(), set()) + + # Test 'or' operator. + self.assertEqual(obj.keys() | {"a"}, {"a", "b"}) + self.assertEqual({"a"} | obj.keys(), {"a", "b"}) + self.assertEqual(obj.keys() | {"c"}, {"a", "b", "c"}) + self.assertEqual({"c"} | obj.keys(), {"a", "b", "c"}) + + # Test 'xor' operator. + self.assertEqual(obj.keys() ^ {"a"}, {"b"}) + self.assertEqual({"a"} ^ obj.keys(), {"b"}) + self.assertEqual(obj.keys() ^ {"c"}, {"a", "b", "c"}) + self.assertEqual({"c"} ^ obj.keys(), {"a", "b", "c"}) + + # Test 'sub' operator. + self.assertEqual(obj.keys() - {"a"}, {"b"}) + self.assertEqual({"a"} - obj.keys(), set()) + self.assertEqual(obj.keys() - {"c"}, {"a", "b"}) + self.assertEqual({"c"} - obj.keys(), {"c"}) + + def test_values_view(self) -> None: + obj = JsonObject({"a": 1, "b": 2}) + + values = obj.values() + + self.assertEqual(len(values), 2) + self.assertCountEqual(list(values), [1, 2]) + + self.assertIn(1, values) + self.assertIn(2, values) + self.assertNotIn(3, values) + self.assertNotIn("a", values) + self.assertNotIn(object(), values) + + # Iterating twice should yield the same values. + self.assertCountEqual(list(values), [1, 2]) + + def test_items_view(self) -> None: + obj = JsonObject({"a": 1, "b": 2}) + + items = obj.items() + + self.assertEqual(len(items), 2) + self.assertCountEqual(list(items), [("a", 1), ("b", 2)]) + + self.assertIn(("a", 1), items) + self.assertIn(("b", 2), items) + self.assertNotIn(("a", 2), items) + self.assertNotIn(("c", 1), items) + self.assertNotIn("a", items) + self.assertNotIn(("a", 1, "extra"), items) + + # Iterating twice should yield the same items. + self.assertCountEqual(list(items), [("a", 1), ("b", 2)]) + + def test_get(self) -> None: + obj = JsonObject({"a": 1, "b": 2}) + + self.assertEqual(obj.get("a"), 1) + self.assertEqual(obj.get("missing", "fallback"), "fallback") + self.assertEqual(obj.get(5, "fallback"), "fallback") # type: ignore[call-overload] + + def test_eq(self) -> None: + obj = JsonObject({"a": 1, "b": 2}) + + self.assertEqual(obj, {"a": 1, "b": 2}) + self.assertNotEqual(obj, {"a": 1}) + self.assertNotEqual(obj, ["a", "b"]) + + def test_str_and_repr(self) -> None: + obj = JsonObject({"a": 1, "b": 2}) + + self.assertEqual(str(obj), r'{"a":1,"b":2}') + self.assertEqual(repr(obj), r'JsonObject({"a":1,"b":2})') + + def test_json_object_constructor(self) -> None: + obj = JsonObject({"a": 1, "b": 2}) + + # Passing in an existing JsonObject should work. + obj2 = JsonObject(obj) + self.assertEqual(obj2, {"a": 1, "b": 2}) + + # Other mapping types should also work. + obj3 = JsonObject(MutableOverlayMapping({"a": 1, "b": 2})) + self.assertEqual(obj3, {"a": 1, "b": 2}) + + # Test that passing a non-mapping raises a TypeError. + with self.assertRaises(TypeError): + JsonObject(123) # type: ignore[arg-type] diff --git a/tests/test_state.py b/tests/test_state.py index 7df95ebf8b..0ca88aef74 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -35,7 +35,7 @@ from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext from synapse.state import StateHandler, StateResolutionHandler, _make_state_cache_entry -from synapse.types import MutableStateMap, StateMap +from synapse.types import JsonDict, MutableStateMap, StateMap from synapse.types.state import StateFilter from synapse.util.macaroons import MacaroonGenerator @@ -67,13 +67,14 @@ def create_event( else: name = "<%s, %s>" % (type, event_id) - d = { + d: JsonDict = { "event_id": event_id, "type": type, "sender": "@user_id:example.com", "room_id": "!room_id:example.com", "depth": depth, "prev_events": prev_events or [], + "content": {}, } if state_key is not None: diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index 0df5a4e6c3..4170768208 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -24,7 +24,6 @@ Utilities for running the unit tests """ import base64 -import json import sys import warnings from binascii import unhexlify @@ -41,6 +40,7 @@ from twisted.web.http_headers import Headers from twisted.web.iweb import IResponse from synapse.types import JsonSerializable +from synapse.util.json import json_encoder if TYPE_CHECKING: from sys import UnraisableHookArgs @@ -127,7 +127,7 @@ class FakeResponse: # type: ignore[misc] @classmethod def json(cls, *, code: int = 200, payload: JsonSerializable) -> "FakeResponse": headers = Headers({"Content-Type": ["application/json"]}) - body = json.dumps(payload).encode("utf-8") + body = json_encoder.encode(payload).encode("utf-8") return cls(code=code, body=body, headers=headers)