Port event content to Rust (#19725)

Based on #19708.

This is on the path to porting the entire event class to Rust, as
`event.content` will then return the new Rust class `JsonObject`.

This PR adds a pure Rust `JsonObject` class that is a `Mapping`
representing a json-style object. It uses `serde_json::Value` as its
in-memory representation and `pythonize` for conversion when a field is
looked up on the object.

I'm not thrilled with the name, but couldn't think of a better one.

This also adds `JsonObject` handling to the JSON serialisation functions
we use, as well as to the `freeze(..)` function.

Reviewable commit-by-commit.
This commit is contained in:
Erik Johnston
2026-05-08 14:19:03 +01:00
committed by GitHub
parent 8dbbc4000b
commit c430c16df4
21 changed files with 749 additions and 63 deletions
+1
View File
@@ -0,0 +1 @@
Port `Event.content` field to Rust.
+488
View File
@@ -0,0 +1,488 @@
/*
* This file is licensed under the Affero General Public License (AGPL) version 3.
*
* Copyright (C) 2026 Element Creations Ltd
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* See the GNU Affero General Public License for more details:
* <https://www.gnu.org/licenses/agpl-3.0.html>.
*
*/
use std::{collections::BTreeMap, sync::Arc};
use pyo3::{
exceptions::{PyKeyError, PyTypeError},
pyclass, pymethods,
types::{
PyAnyMethods, PyIterator, PyList, PyListMethods, PyMapping, PySet, PySetMethods, PyTuple,
},
Bound, IntoPyObject, IntoPyObjectExt, Py, PyAny, PyResult, Python,
};
use pythonize::{depythonize, pythonize};
use serde::{Deserialize, Serialize};
/// A generic class for representing immutable JSON objects.
///
/// This is used for representing the `content` field of an event.
///
/// The basic architecture here is to optimize for two things:
/// 1. Fast access of top-level keys (e.g. `event.content["key"]`)
/// 2. Pure Rust implementation.
#[derive(Serialize, Deserialize, Clone, Default)]
#[pyclass(mapping, frozen)]
#[serde(transparent)]
pub struct JsonObject {
object: Arc<BTreeMap<Box<str>, serde_json::Value>>,
}
#[pymethods]
impl JsonObject {
#[new]
#[pyo3(signature = (content = None))]
fn new<'a, 'py>(content: Option<&'a Bound<'py, PyAny>>) -> PyResult<Self> {
let Some(content) = content else {
// If no content is provided, default to an empty object.
return Ok(Self::default());
};
if let Ok(content) = content.cast::<JsonObject>() {
// If the content is already a JsonObject, we can just clone the
// underlying map (this is safe as the object is immutable).
return Ok(JsonObject {
object: content.get().object.clone(),
});
}
let Ok(content) = content.cast::<PyMapping>() else {
return Err(PyTypeError::new_err("'content' must be a mapping"));
};
// Use pythonize to try and convert from a mapping.
let content = depythonize(content)?;
Ok(Self {
object: Arc::new(content),
})
}
fn __len__(&self) -> usize {
self.object.len()
}
fn __contains__(&self, key: &Bound<'_, PyAny>) -> bool {
// Match dict semantics: a non-string key is simply "not in" the
// mapping, rather than raising TypeError.
let Ok(key_str) = key.extract::<&str>() else {
return false;
};
self.object.contains_key(key_str)
}
fn __getitem__<'py>(
&self,
py: Python<'py>,
key: Bound<'_, PyAny>,
) -> PyResult<Bound<'py, PyAny>> {
// We only ever store string keys, so any non-string lookup is a miss.
// Raise KeyError (not TypeError) to match dict's behaviour.
let Ok(key_str) = key.extract::<&str>() else {
return Err(PyKeyError::new_err(key.unbind()));
};
let Some(value) = self.object.get(key_str) else {
return Err(PyKeyError::new_err(key.unbind()));
};
Ok(pythonize(py, value)?)
}
fn __iter__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyIterator>> {
// The easiest way to get an iterator over the keys is to create a
// temporary list and call `iter()` on it. This is not the most
// efficient approach, but is much less boilerplate than implementing a
// custom iterator type. Since the keys are typically small in number
// this should be fine in practice.
let list = PyList::new(py, self.object.keys().map(Box::as_ref))?;
PyIterator::from_object(&list)
}
// The view classes below each hold a `JsonObject` clone. This is cheap
// because the underlying map is behind an `Arc`, and lets the view outlive
// the originating object (matching dict_keys/values/items semantics in
// Python, which also keep the dict alive).
fn keys(&self) -> JsonObjectKeysView {
JsonObjectKeysView {
object: self.clone(),
}
}
fn values(&self) -> JsonObjectValuesView {
JsonObjectValuesView {
object: self.clone(),
}
}
fn items(&self) -> JsonObjectItemsView {
JsonObjectItemsView {
object: self.clone(),
}
}
#[pyo3(signature = (key, default=None))]
fn get<'py>(
&self,
py: Python<'py>,
key: Bound<'_, PyAny>,
default: Option<Bound<'py, PyAny>>,
) -> PyResult<Bound<'py, PyAny>> {
// Non-string keys can never match, so treat them as a miss and return
// the caller-supplied default rather than raising.
let Ok(key_str) = key.extract::<&str>() else {
return Ok(default.into_pyobject(py)?);
};
match self.object.get(key_str) {
Some(value) => Ok(pythonize(py, value)?),
None => Ok(default.into_pyobject(py)?),
}
}
fn __eq__(&self, other: Bound<'_, PyAny>) -> bool {
// We support equality against any Python mapping (e.g. plain dicts),
// so callers can swap a JsonObject in without rewriting comparisons.
let Ok(mapping) = other.cast::<PyMapping>() else {
return false;
};
let Ok(other_len) = mapping.len() else {
return false;
};
if other_len != self.object.len() {
return false;
}
// We know the "other" is a mapping with the same number of fields as
// us. So we can convert it into a JsonObject and compare the underlying
// maps.
let Ok(other_dict) = depythonize(&other) else {
return false;
};
*self.object == other_dict
}
// Since we implement comparisons with other types, we need to disable
// hashing to avoid violating the invariant that equal objects must have the
// same hash.
//
// Alternatively, we could only allow comparisons with other JsonObjects and
// allow hashing, but a) its nicer to be able to compare with arbitrary
// mappings and b) we don't really need hashing for these objects.
#[classattr]
const __hash__: Option<Py<PyAny>> = None;
fn __str__(&self) -> String {
serde_json::to_string(&self.object).expect("Value should be serializable")
}
fn __repr__(&self) -> String {
format!("JsonObject({})", self.__str__())
}
}
/// Helper class returned by `JsonObject.keys()` to act as a view into the keys
/// of the object.
///
/// This needs to both be iterable *and* operate like a set.
#[pyclass(frozen)]
#[derive(Clone)]
pub struct JsonObjectKeysView {
object: JsonObject,
}
#[pymethods]
impl JsonObjectKeysView {
fn __iter__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyIterator>> {
// Create the iterator by making a temporary python list of the keys and
// calling `iter()` on it.
let list = PyList::new(py, self.object.object.keys().map(Box::as_ref))?;
PyIterator::from_object(&list)
}
fn __len__(&self) -> usize {
self.object.__len__()
}
fn __contains__(&self, key: &Bound<'_, PyAny>) -> bool {
self.object.__contains__(key)
}
fn __eq__(&self, other: Bound<'_, PyAny>) -> bool {
let other_len = match other.len() {
Ok(len) => len,
Err(_) => return false,
};
if self.object.__len__() != other_len {
return false;
}
for key in self.object.object.keys() {
if !matches!(other.contains(key.as_ref()), Ok(true)) {
return false;
}
}
true
}
// The set operators below match the behaviour of `dict.keys()` in Python:
// they accept any object that supports `__contains__` (for `&`) or is
// iterable (for `|`, `-`, `^`), not just sets. Each returns a fresh
// `PySet` so the caller gets a normal mutable Python set back.
//
// The `__r*__` variants are reflected operators, called by Python when
// the left-hand operand doesn't know how to combine with us. Since these
// operations are commutative for sets (or symmetric in the case of `^`),
// they just delegate. The asymmetric ops (`-`) need a separate impl.
fn __and__<'py>(
&self,
py: Python<'py>,
other: Bound<'_, PyAny>,
) -> PyResult<Bound<'py, PySet>> {
// Iterate our (typically small) key set and probe `other`, which may
// be any container supporting `__contains__`.
let mut result = Vec::new();
for key in self.object.object.keys() {
if matches!(other.contains(key.as_ref()), Ok(true)) {
result.push(key.as_ref());
}
}
PySet::new(py, &result)
}
fn __rand__<'py>(
&self,
py: Python<'py>,
other: Bound<'_, PyAny>,
) -> PyResult<Bound<'py, PySet>> {
self.__and__(py, other)
}
fn __or__<'py>(&self, py: Python<'py>, other: Bound<'_, PyAny>) -> PyResult<Bound<'py, PySet>> {
// Union needs to enumerate both sides, so the right operand must be
// iterable (a bare `__contains__` is not enough).
let Ok(other_iter) = other.try_iter() else {
return Err(PyTypeError::new_err("Right operand must be iterable"));
};
let result = PySet::new(py, self.object.object.keys().map(Box::as_ref))?;
// PySet handles dedup, so we can blindly add every element from the
// other iterable.
for item in other_iter {
let item = item?;
result.add(item)?;
}
Ok(result)
}
fn __ror__<'py>(
&self,
py: Python<'py>,
other: Bound<'_, PyAny>,
) -> PyResult<Bound<'py, PySet>> {
self.__or__(py, other)
}
fn __sub__<'py>(
&self,
py: Python<'py>,
other: Bound<'_, PyAny>,
) -> PyResult<Bound<'py, PySet>> {
// `self - other`: keep our keys that are not in `other`. Only `other`
// needs to support `__contains__` here.
let mut result = Vec::new();
for key in self.object.object.keys() {
if matches!(other.contains(key.as_ref()), Ok(true)) {
continue;
}
result.push(key.as_ref());
}
PySet::new(py, &result)
}
fn __rsub__<'py>(
&self,
py: Python<'py>,
other: Bound<'_, PyAny>,
) -> PyResult<Bound<'py, PySet>> {
// `other - self`: we need to enumerate `other`, so it must be
// iterable. Not symmetric with `__sub__`, hence a separate impl.
let Ok(other_iter) = other.try_iter() else {
return Err(PyTypeError::new_err("Left operand must be iterable"));
};
let result = PySet::empty(py)?;
for item in other_iter {
let item = item?;
if self.object.__contains__(&item) {
continue;
}
result.add(item)?;
}
Ok(result)
}
fn __xor__<'py>(
&self,
py: Python<'py>,
other: Bound<'_, PyAny>,
) -> PyResult<Bound<'py, PySet>> {
// Symmetric difference: elements in exactly one side. Implemented as
// two filtered passes — one over our keys, one over `other`.
let Ok(other_iter) = other.try_iter() else {
return Err(PyTypeError::new_err("Right operand must be iterable"));
};
let result = PySet::empty(py)?;
for key in self.object.object.keys() {
if matches!(other.contains(key.as_ref()), Ok(true)) {
continue;
}
result.add(key.as_ref())?;
}
for item in other_iter {
let item = item?;
if self.object.__contains__(&item) {
continue;
}
result.add(item)?;
}
Ok(result)
}
fn __rxor__<'py>(
&self,
py: Python<'py>,
other: Bound<'_, PyAny>,
) -> PyResult<Bound<'py, PySet>> {
self.__xor__(py, other)
}
fn isdisjoint(&self, other: Bound<'_, PyAny>) -> bool {
for key in self.object.object.keys() {
if matches!(other.contains(key.as_ref()), Ok(true)) {
return false;
}
}
true
}
}
/// Helper class returned by `JsonObject.values()` to act as a view into the
/// values of the object.
#[pyclass(frozen)]
#[derive(Clone)]
pub struct JsonObjectValuesView {
object: JsonObject,
}
#[pymethods]
impl JsonObjectValuesView {
fn __iter__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyIterator>> {
// Create the iterator by making a temporary python list of the keys and
// calling `iter()` on it.
let list = PyList::empty(py);
for v in self.object.object.values() {
let py_value = pythonize(py, v)?.into_bound_py_any(py)?;
list.append(py_value)?;
}
PyIterator::from_object(&list)
}
fn __len__(&self) -> usize {
self.object.__len__()
}
fn __contains__(&self, other: Bound<'_, PyAny>) -> bool {
// We compare by JSON equality rather than Python identity: convert
// the candidate into a `serde_json::Value` once and scan our values.
// Anything that fails to depythonize cannot match by definition.
let other_value: serde_json::Value = match depythonize(&other) {
Ok(v) => v,
Err(_) => return false,
};
self.object.object.values().any(|v| *v == other_value)
}
}
/// Helper class returned by `JsonObject.items()` to act as a view into the
/// items of the object.
///
/// Technically this should be a set-like view according to Python semantics,
/// unless the values are unhashable. Since the values are immutable we could
/// support it, but it's more work and nobody seems to actually use the set
/// operations on `dict_items` in practice.
#[pyclass(frozen)]
#[derive(Clone)]
pub struct JsonObjectItemsView {
object: JsonObject,
}
#[pymethods]
impl JsonObjectItemsView {
fn __iter__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyIterator>> {
// Create the iterator by making a temporary python list of the keys and
// calling `iter()` on it.
let list = PyList::empty(py);
for (k, v) in self.object.object.iter() {
let py_key = k.as_ref().into_bound_py_any(py)?;
let py_value = pythonize(py, v)?.into_bound_py_any(py)?;
let item = PyTuple::new(py, [py_key, py_value])?;
list.append(item)?;
}
PyIterator::from_object(&list)
}
fn __len__(&self) -> usize {
self.object.__len__()
}
fn __contains__(&self, other: Bound<'_, PyAny>) -> bool {
// `(key, value) in items` — only a 2-tuple can possibly match. We
// look the key up directly (avoiding a full scan) and then compare
// the stored value against `value` using JSON equality.
let Ok((key, value)) = other.extract::<(Bound<'_, PyAny>, Bound<'_, PyAny>)>() else {
return false;
};
let Ok(key_str) = key.extract::<&str>() else {
return false;
};
let Some(stored) = self.object.object.get(key_str) else {
return false;
};
let other_value: serde_json::Value = match depythonize(&value) {
Ok(v) => v,
Err(_) => return false,
};
*stored == other_value
}
}
+11 -1
View File
@@ -21,21 +21,31 @@
//! Classes for representing Events.
use pyo3::{
types::{PyAnyMethods, PyModule, PyModuleMethods},
types::{PyAnyMethods, PyMapping, PyModule, PyModuleMethods},
wrap_pyfunction, Bound, PyResult, Python,
};
pub mod filter;
mod internal_metadata;
mod json_object;
pub mod signatures;
pub mod unsigned;
use json_object::JsonObject;
/// Called when registering modules with python.
pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
// Register the `JsonObject` class as a `Mapping` so that `isinstance` works.
PyMapping::register::<JsonObject>(py)?;
let child_module = PyModule::new(py, "events")?;
child_module.add_class::<internal_metadata::EventInternalMetadata>()?;
child_module.add_class::<signatures::Signatures>()?;
child_module.add_class::<unsigned::Unsigned>()?;
child_module.add_class::<JsonObject>()?;
child_module.add_class::<json_object::JsonObjectKeysView>()?;
child_module.add_class::<json_object::JsonObjectValuesView>()?;
child_module.add_class::<json_object::JsonObjectItemsView>()?;
child_module.add_function(wrap_pyfunction!(filter::event_visible_to_server_py, m)?)?;
m.add_submodule(&child_module)?;
+8 -1
View File
@@ -65,7 +65,8 @@ try:
except ImportError:
pass
# Teach canonicaljson how to serialise immutabledicts.
# Teach canonicaljson how to serialise immutabledicts and JsonObjects.
try:
from canonicaljson import register_preserialisation_callback
from immutabledict import immutabledict
@@ -79,6 +80,12 @@ try:
return dict(d)
register_preserialisation_callback(immutabledict, _immutabledict_cb)
# Teach canonicaljson how to serialise JsonObjects, which is just to
# convert them to dicts.
from synapse.synapse_rust.events import JsonObject # noqa: E402
register_preserialisation_callback(JsonObject, lambda obj: dict(obj))
except ImportError:
pass
+30 -23
View File
@@ -44,9 +44,15 @@ from synapse.api.constants import (
StickyEvent,
)
from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
from synapse.synapse_rust.events import EventInternalMetadata, Signatures, Unsigned
from synapse.synapse_rust.events import (
EventInternalMetadata,
JsonObject,
Signatures,
Unsigned,
)
from synapse.types import (
JsonDict,
JsonMapping,
StateKey,
StrCollection,
)
@@ -206,17 +212,29 @@ class EventBase(metaclass=abc.ABCMeta):
):
assert room_version.event_format == self.format_version
if "content" in event_dict:
event_dict["content"] = JsonObject(event_dict["content"])
# We intern these strings because they turn up a lot (especially when
# caching).
event_dict = intern_dict(event_dict)
if USE_FROZEN_DICTS:
frozen_dict = freeze(event_dict)
else:
frozen_dict = event_dict
self.room_version = room_version
self.signatures = Signatures(signatures)
self.unsigned = Unsigned(unsigned)
self.rejected_reason = rejected_reason
self._dict = event_dict
self._dict = frozen_dict
self.internal_metadata = EventInternalMetadata(internal_metadata_dict)
depth: DictProperty[int] = DictProperty("depth")
content: DictProperty[JsonDict] = DictProperty("content")
content: DictProperty[JsonMapping] = DictProperty("content")
hashes: DictProperty[dict[str, str]] = DictProperty("hashes")
origin_server_ts: DictProperty[int] = DictProperty("origin_server_ts")
sender: DictProperty[str] = DictProperty("sender")
@@ -259,7 +277,14 @@ class EventBase(metaclass=abc.ABCMeta):
def get_dict(self) -> JsonDict:
"""Convert the event to a dictionary suitable for serialisation."""
d = dict(self._dict)
if "content" in d:
# Convert the content (which is a JsonObject) back to a dict. Json
# serialization should handle JsonObjects fine, but for sanities
# sake we want `get_dict()` and `get_pdu_json()` to return plain
# dicts.
d["content"] = dict(self.content)
d.update(
{
"signatures": self.signatures.as_dict(),
@@ -419,19 +444,10 @@ class FrozenEvent(EventBase):
unsigned = event_dict.pop("unsigned", {})
# We intern these strings because they turn up a lot (especially when
# caching).
event_dict = intern_dict(event_dict)
if USE_FROZEN_DICTS:
frozen_dict = freeze(event_dict)
else:
frozen_dict = event_dict
self._event_id = event_dict["event_id"]
super().__init__(
frozen_dict,
event_dict,
room_version=room_version,
signatures=signatures,
unsigned=unsigned,
@@ -473,19 +489,10 @@ class FrozenEventV2(EventBase):
unsigned = event_dict.pop("unsigned", {})
# We intern these strings because they turn up a lot (especially when
# caching).
event_dict = intern_dict(event_dict)
if USE_FROZEN_DICTS:
frozen_dict = freeze(event_dict)
else:
frozen_dict = event_dict
self._event_id: str | None = None
super().__init__(
frozen_dict,
event_dict,
room_version=room_version,
signatures=signatures,
unsigned=unsigned,
+1 -1
View File
@@ -1032,7 +1032,7 @@ def strip_event(event: EventBase) -> JsonDict:
return {
"type": event.type,
"state_key": event.state_key,
"content": event.content,
"content": dict(event.content),
"sender": event.sender,
}
+2 -2
View File
@@ -42,7 +42,7 @@ from synapse.events.utils import (
)
from synapse.http.servlet import validate_json_object
from synapse.storage.controllers.state import server_acl_evaluator_from_event
from synapse.types import EventID, JsonDict, RoomID, StrCollection, UserID
from synapse.types import EventID, JsonDict, JsonMapping, RoomID, StrCollection, UserID
from synapse.types.rest import RequestBodyModel
@@ -245,7 +245,7 @@ class EventValidator:
self._ensure_state_event(event)
def _ensure_strings(self, d: JsonDict, keys: StrCollection) -> None:
def _ensure_strings(self, d: JsonMapping, keys: StrCollection) -> None:
for s in keys:
if s not in d:
raise SynapseError(400, "'%s' not in content" % (s,))
+5 -4
View File
@@ -706,12 +706,12 @@ class RoomCreationHandler:
spam_check = await self._spam_checker_module_callbacks.user_may_create_room(
user_id,
{
"creation_content": creation_content,
"creation_content": dict(creation_content),
"initial_state": [
{
"type": state_key[0],
"state_key": state_key[1],
"content": event_content,
"content": dict(event_content),
}
for state_key, event_content in initial_state.items()
],
@@ -1437,7 +1437,7 @@ class RoomCreationHandler:
room_config: JsonDict,
invite_list: list[str],
initial_state: MutableStateMap,
creation_content: JsonDict,
creation_content: JsonMapping,
room_alias: RoomAlias | None = None,
power_level_content_override: JsonDict | None = None,
creator_join_profile: JsonDict | None = None,
@@ -1508,7 +1508,7 @@ class RoomCreationHandler:
async def create_event(
etype: str,
content: JsonDict,
content: JsonMapping,
for_batch: bool,
**kwargs: Any,
) -> tuple[EventBase, synapse.events.snapshot.UnpersistedEventContextBase]:
@@ -1561,6 +1561,7 @@ class RoomCreationHandler:
if creation_event_with_context is None:
# MSC2175 removes the creator field from the create event.
if not room_version.implicit_room_creator:
creation_content = dict(creation_content)
creation_content["creator"] = creator_id
creation_event, unpersisted_creation_context = await create_event(
EventTypes.Create, creation_content, False
+2 -2
View File
@@ -31,7 +31,7 @@ from typing import (
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.metrics import SERVER_NAME_LABEL, event_processing_positions
from synapse.storage.databases.main.state_deltas import StateDelta
from synapse.types import JsonDict
from synapse.types import JsonMapping
from synapse.util.duration import Duration
from synapse.util.events import get_plain_text_topic_from_event_content
@@ -195,7 +195,7 @@ class StatsHandler:
)
continue
event_content: JsonDict = {}
event_content: JsonMapping = {}
if delta.event_id is not None:
event = await self.store.get_event(delta.event_id, allow_none=True)
+2 -2
View File
@@ -51,7 +51,7 @@ from synapse.storage.databases.main.roommember import EventIdMembership
from synapse.storage.invite_rule import InviteRule
from synapse.storage.roommember import ProfileInfo
from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator
from synapse.types import JsonValue
from synapse.types import JsonMapping, JsonValue
from synapse.types.state import StateFilter
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import gather_results
@@ -231,7 +231,7 @@ class BulkPushRuleEvaluator:
event: EventBase,
context: EventContext,
event_id_to_event: Mapping[str, EventBase],
) -> tuple[dict, int | None]:
) -> tuple[JsonMapping, int | None]:
"""
Given an event and an event context, get the power level event relevant to the event
and the power level of the sender of the event.
+10 -1
View File
@@ -10,7 +10,7 @@
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
from typing import Any, Mapping
from typing import Any, Iterator, Mapping
from synapse.types import JsonDict, JsonMapping
@@ -213,3 +213,12 @@ class Unsigned:
def for_event(self) -> JsonDict: ...
"""Return a dict of all unsigned fields, including those only kept in
memory, suitable for inclusion in an event."""
class JsonObject(Mapping[str, Any]):
"""Immutable JSON object mapping."""
def __init__(self, content_dict: JsonMapping | None = None): ...
def __len__(self) -> int: ...
def __getitem__(self, key: str) -> Any: ...
def __iter__(self) -> Iterator[str]: ...
def __eq__(self, other: object) -> bool: ...
+2 -2
View File
@@ -17,7 +17,7 @@ from typing import Any
from pydantic import Field, StrictStr, ValidationError, field_validator
from synapse.types import JsonDict
from synapse.types import JsonMapping
from synapse.util.pydantic_models import ParseModel
from synapse.util.stringutils import random_string
@@ -103,7 +103,7 @@ class TopicContent(ParseModel):
return None
def get_plain_text_topic_from_event_content(content: JsonDict) -> str | None:
def get_plain_text_topic_from_event_content(content: JsonMapping) -> str | None:
"""
Given the `content` of an `m.room.topic` event, returns the plain-text topic
representation. Prefers pulling plain-text from the newer `m.topic` field if
+5
View File
@@ -23,6 +23,8 @@ from typing import Any
from immutabledict import immutabledict
from synapse.synapse_rust.events import JsonObject
def freeze(o: Any) -> Any:
if isinstance(o, dict):
@@ -31,6 +33,9 @@ def freeze(o: Any) -> Any:
if isinstance(o, immutabledict):
return o
if isinstance(o, JsonObject):
return o
if isinstance(o, (bytes, str)):
return o
+18 -12
View File
@@ -20,24 +20,30 @@ from typing import (
from immutabledict import immutabledict
from synapse.synapse_rust.events import JsonObject
def _reject_invalid_json(val: Any) -> None:
"""Do not allow Infinity, -Infinity, or NaN values in JSON."""
raise ValueError("Invalid JSON value: '%s'" % val)
def _handle_immutabledict(obj: Any) -> dict[Any, Any]:
"""Helper for json_encoder. Makes immutabledicts serializable by returning
the underlying dict
def _handle_extra_mappings(obj: Any) -> dict[Any, Any]:
"""Helper for json_encoder. Makes immutabledicts and JsonObjects
serializable
"""
if type(obj) is immutabledict:
# fishing the protected dict out of the object is a bit nasty,
# but we don't really want the overhead of copying the dict.
try:
# Safety: we catch the AttributeError immediately below.
return obj._dict
except AttributeError:
# If all else fails, resort to making a copy of the immutabledict
match obj:
case immutabledict():
# fishing the protected dict out of the object is a bit nasty,
# but we don't really want the overhead of copying the dict.
try:
# Safety: we catch the AttributeError immediately below.
return obj._dict
except AttributeError:
# If all else fails, resort to making a copy of the immutabledict
return dict(obj)
case JsonObject():
# Just convert to a dict.
return dict(obj)
raise TypeError(
"Object of type %s is not JSON serializable" % obj.__class__.__name__
@@ -49,7 +55,7 @@ def _handle_immutabledict(obj: Any) -> dict[Any, Any]:
# * produces valid JSON (no NaNs etc)
# * reduces redundant whitespace
json_encoder = json.JSONEncoder(
allow_nan=False, separators=(",", ":"), default=_handle_immutabledict
allow_nan=False, separators=(",", ":"), default=_handle_extra_mappings
)
# Create a custom decoder to reject Python extensions to JSON.
+4 -3
View File
@@ -62,6 +62,7 @@ class EventSigningTestCase(unittest.TestCase):
"origin_server_ts": 1000000,
"signatures": {},
"type": "X",
"content": {},
"unsigned": {"age_ts": 1000000},
}
@@ -74,7 +75,7 @@ class EventSigningTestCase(unittest.TestCase):
self.assertTrue(hasattr(event, "hashes"))
self.assertIn("sha256", event.hashes)
self.assertEqual(
event.hashes["sha256"], "A6Nco6sqoy18PPfPDVdYvoowfc0PVBk9g9OiyT3ncRM"
event.hashes["sha256"], "mq4QfPPpC+QsBd6eqfVsmJIEz8uvMSVK0+AU67PLESk"
)
self.assertTrue(hasattr(event, "signatures"))
@@ -82,8 +83,8 @@ class EventSigningTestCase(unittest.TestCase):
self.assertIn(KEY_NAME, event.signatures["domain"])
self.assertEqual(
event.signatures[HOSTNAME][KEY_NAME],
"PBc48yDVszWB9TRaB/+CZC1B+pDAC10F8zll006j+NN"
"fe4PEMWcVuLaG63LFTK9e4rwJE8iLZMPtCKhDTXhpAQ",
"18rGIkd4JJXxw9m+1j3BtN+TmqmLip4VHvFbyXLngpB"
"LXOqbxlQViQABRzep2cODQ2aa5FnFgz+Llt2P03WiAw",
)
def test_sign_message(self) -> None:
+2 -2
View File
@@ -265,7 +265,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
self.assertEqual(event.type, "m.room.message")
self.assertEqual(event.room_id, room_id)
self.assertFalse(hasattr(event, "state_key"))
self.assertDictEqual(event.content, content)
self.assertDictEqual(dict(event.content), content)
expected_requester = create_requester(
user_id, authenticated_entity=self.hs.hostname
@@ -301,7 +301,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
self.assertEqual(event.type, "m.room.power_levels")
self.assertEqual(event.room_id, room_id)
self.assertEqual(event.state_key, "")
self.assertDictEqual(event.content, content)
self.assertDictEqual(dict(event.content), content)
# Check that the event was sent
self.event_creation_handler.create_and_send_nonmember_event.assert_called_with(
+2 -2
View File
@@ -59,7 +59,7 @@ from synapse.rest.client import (
sync,
)
from synapse.server import HomeServer
from synapse.types import JsonDict, RoomAlias, UserID, create_requester
from synapse.types import JsonDict, JsonMapping, RoomAlias, UserID, create_requester
from synapse.util.clock import Clock
from synapse.util.stringutils import random_string
@@ -1859,7 +1859,7 @@ class RoomMessagesTestCase(RoomBase):
mock_return_value: str | bool | Codes | tuple[Codes, JsonDict] | bool = (
"NOT_SPAM"
)
mock_content: JsonDict | None = None
mock_content: JsonMapping | None = None
async def check_event_for_spam(
self,
+2 -1
View File
@@ -101,6 +101,7 @@ from synapse.storage.engines import BaseDatabaseEngine, create_engine
from synapse.storage.prepare_database import prepare_database
from synapse.types import ISynapseReactor, JsonDict
from synapse.util.clock import Clock
from synapse.util.json import json_encoder
from tests.utils import (
LEAVE_DB,
@@ -422,7 +423,7 @@ def make_request(
path = b"/" + path
if isinstance(content, dict):
content = json.dumps(content).encode("utf8")
content = json_encoder.encode(content).encode("utf8")
if isinstance(content, str):
content = content.encode("utf8")
+149
View File
@@ -0,0 +1,149 @@
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2026 Element Creations Ltd.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
from synapse.synapse_rust.events import JsonObject
from synapse.util import MutableOverlayMapping
from tests import unittest
class JsonObjectMappingTestCase(unittest.TestCase):
def test_new_and_basic_mapping_behavior(self) -> None:
obj = JsonObject({"a": 1, "b": 2})
self.assertEqual(len(obj), 2)
self.assertTrue("a" in obj)
self.assertTrue("b" in obj)
self.assertFalse("c" in obj)
self.assertFalse(123 in obj) # type: ignore[comparison-overlap]
def test_getitem_and_key_errors(self) -> None:
obj = JsonObject({"a": 1, "b": 2})
self.assertEqual(obj["a"], 1)
with self.assertRaises(KeyError):
_ = obj["missing"]
with self.assertRaises(KeyError):
_ = obj[10] # type: ignore[index]
def test_iter_keys_values_items(self) -> None:
obj = JsonObject({"a": 1, "b": 2})
iterator = iter(obj)
first = next(iterator)
second = next(iterator)
self.assertCountEqual((first, second), ("a", "b"))
with self.assertRaises(StopIteration):
next(iterator)
self.assertCountEqual(list(obj.keys()), ["a", "b"])
self.assertCountEqual(list(obj.values()), [1, 2])
self.assertCountEqual(list(obj.items()), [("a", 1), ("b", 2)])
def test_keys_set_like_behavior(self) -> None:
obj = JsonObject({"a": 1, "b": 2})
# Test 'and' operator.
self.assertEqual(obj.keys() & {"a"}, {"a"})
self.assertEqual({"a"} & obj.keys(), {"a"})
self.assertEqual(obj.keys() & {"c"}, set())
self.assertEqual({"c"} & obj.keys(), set())
# Test 'or' operator.
self.assertEqual(obj.keys() | {"a"}, {"a", "b"})
self.assertEqual({"a"} | obj.keys(), {"a", "b"})
self.assertEqual(obj.keys() | {"c"}, {"a", "b", "c"})
self.assertEqual({"c"} | obj.keys(), {"a", "b", "c"})
# Test 'xor' operator.
self.assertEqual(obj.keys() ^ {"a"}, {"b"})
self.assertEqual({"a"} ^ obj.keys(), {"b"})
self.assertEqual(obj.keys() ^ {"c"}, {"a", "b", "c"})
self.assertEqual({"c"} ^ obj.keys(), {"a", "b", "c"})
# Test 'sub' operator.
self.assertEqual(obj.keys() - {"a"}, {"b"})
self.assertEqual({"a"} - obj.keys(), set())
self.assertEqual(obj.keys() - {"c"}, {"a", "b"})
self.assertEqual({"c"} - obj.keys(), {"c"})
def test_values_view(self) -> None:
obj = JsonObject({"a": 1, "b": 2})
values = obj.values()
self.assertEqual(len(values), 2)
self.assertCountEqual(list(values), [1, 2])
self.assertIn(1, values)
self.assertIn(2, values)
self.assertNotIn(3, values)
self.assertNotIn("a", values)
self.assertNotIn(object(), values)
# Iterating twice should yield the same values.
self.assertCountEqual(list(values), [1, 2])
def test_items_view(self) -> None:
obj = JsonObject({"a": 1, "b": 2})
items = obj.items()
self.assertEqual(len(items), 2)
self.assertCountEqual(list(items), [("a", 1), ("b", 2)])
self.assertIn(("a", 1), items)
self.assertIn(("b", 2), items)
self.assertNotIn(("a", 2), items)
self.assertNotIn(("c", 1), items)
self.assertNotIn("a", items)
self.assertNotIn(("a", 1, "extra"), items)
# Iterating twice should yield the same items.
self.assertCountEqual(list(items), [("a", 1), ("b", 2)])
def test_get(self) -> None:
obj = JsonObject({"a": 1, "b": 2})
self.assertEqual(obj.get("a"), 1)
self.assertEqual(obj.get("missing", "fallback"), "fallback")
self.assertEqual(obj.get(5, "fallback"), "fallback") # type: ignore[call-overload]
def test_eq(self) -> None:
obj = JsonObject({"a": 1, "b": 2})
self.assertEqual(obj, {"a": 1, "b": 2})
self.assertNotEqual(obj, {"a": 1})
self.assertNotEqual(obj, ["a", "b"])
def test_str_and_repr(self) -> None:
obj = JsonObject({"a": 1, "b": 2})
self.assertEqual(str(obj), r'{"a":1,"b":2}')
self.assertEqual(repr(obj), r'JsonObject({"a":1,"b":2})')
def test_json_object_constructor(self) -> None:
obj = JsonObject({"a": 1, "b": 2})
# Passing in an existing JsonObject should work.
obj2 = JsonObject(obj)
self.assertEqual(obj2, {"a": 1, "b": 2})
# Other mapping types should also work.
obj3 = JsonObject(MutableOverlayMapping({"a": 1, "b": 2}))
self.assertEqual(obj3, {"a": 1, "b": 2})
# Test that passing a non-mapping raises a TypeError.
with self.assertRaises(TypeError):
JsonObject(123) # type: ignore[arg-type]
+3 -2
View File
@@ -35,7 +35,7 @@ from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.state import StateHandler, StateResolutionHandler, _make_state_cache_entry
from synapse.types import MutableStateMap, StateMap
from synapse.types import JsonDict, MutableStateMap, StateMap
from synapse.types.state import StateFilter
from synapse.util.macaroons import MacaroonGenerator
@@ -67,13 +67,14 @@ def create_event(
else:
name = "<%s, %s>" % (type, event_id)
d = {
d: JsonDict = {
"event_id": event_id,
"type": type,
"sender": "@user_id:example.com",
"room_id": "!room_id:example.com",
"depth": depth,
"prev_events": prev_events or [],
"content": {},
}
if state_key is not None:
+2 -2
View File
@@ -24,7 +24,6 @@ Utilities for running the unit tests
"""
import base64
import json
import sys
import warnings
from binascii import unhexlify
@@ -41,6 +40,7 @@ from twisted.web.http_headers import Headers
from twisted.web.iweb import IResponse
from synapse.types import JsonSerializable
from synapse.util.json import json_encoder
if TYPE_CHECKING:
from sys import UnraisableHookArgs
@@ -127,7 +127,7 @@ class FakeResponse: # type: ignore[misc]
@classmethod
def json(cls, *, code: int = 200, payload: JsonSerializable) -> "FakeResponse":
headers = Headers({"Content-Type": ["application/json"]})
body = json.dumps(payload).encode("utf-8")
body = json_encoder.encode(payload).encode("utf-8")
return cls(code=code, body=body, headers=headers)