From 15c03b96895998c82eed0c5c28d1135a704fa64b Mon Sep 17 00:00:00 2001 From: Kegan Dougal <7190048+kegsay@users.noreply.github.com> Date: Thu, 16 Apr 2026 16:46:47 +0100 Subject: [PATCH 1/8] MSC4242: State DAGs (CSAPI) (#19424) This implements [MSC4242: State DAGs](https://github.com/matrix-org/matrix-spec-proposals/pull/4242), without support for federation. A general overview: - It adds a new room version and new event type. - It adds a new field `calculated_auth_event_ids` to internal metadata. - It stores the state DAG via new state DAG edges / forward extremities tables. - It adds new auth rules as per the MSC. - It uses the new `prev_state_events` field instead of `prev_event_ids()` when doing state resolution. Complement tests: https://github.com/matrix-org/complement/pull/841 ### Pull Request Checklist * [x] Pull request is based on the develop branch * [x] Pull request includes a [changelog file](https://element-hq.github.io/synapse/latest/development/contributing_guide.html#changelog). The entry should: - Be a short description of your change which makes sense to users. "Fixed a bug that prevented receiving messages from other servers." instead of "Moved X method from `EventStore` to `EventWorkerStore`.". - Use markdown where necessary, mostly for `code blocks`. - End with either a period (.) or an exclamation mark (!). - Start with a capital letter. - Feel free to credit yourself, by adding a sentence "Contributed by @github_username." or "Contributed by [Your Name]." to the end of the entry. * [x] [Code style](https://element-hq.github.io/synapse/latest/code_style.html) is correct (run the [linters](https://element-hq.github.io/synapse/latest/development/contributing_guide.html#run-the-linters)) --------- Co-authored-by: Eric Eastwood --- changelog.d/19424.feature | 1 + contrib/grafana/synapse.json | 151 ++++++- rust/src/events/internal_metadata.rs | 33 ++ rust/src/room_versions.rs | 60 ++- synapse/config/experimental.py | 6 + synapse/event_auth.py | 72 +++- synapse/events/__init__.py | 88 ++++- synapse/events/builder.py | 57 ++- synapse/events/utils.py | 4 + synapse/events/validator.py | 7 +- synapse/federation/federation_client.py | 5 + synapse/handlers/admin.py | 12 +- synapse/handlers/message.py | 69 +++- synapse/handlers/room.py | 21 +- synapse/handlers/room_member.py | 41 +- synapse/state/__init__.py | 38 +- synapse/storage/controllers/persist_events.py | 299 +++++++++++++- .../databases/main/event_federation.py | 9 + synapse/storage/databases/main/events.py | 90 ++++- .../storage/databases/main/purge_events.py | 4 + synapse/storage/databases/main/state.py | 12 +- synapse/storage/schema/__init__.py | 1 + .../delta/94/03_state_dag_fwd_extrems.sql | 38 ++ synapse/synapse_rust/events.pyi | 3 + synapse/synapse_rust/room_versions.pyi | 11 + tests/federation/test_federation_server.py | 24 +- tests/handlers/test_federation_event.py | 9 +- tests/storage/test_msc4242_state_dag.py | 371 ++++++++++++++++++ tests/storage/test_redaction.py | 1 + tests/test_event_auth.py | 293 +++++++++++++- 30 files changed, 1735 insertions(+), 95 deletions(-) create mode 100644 changelog.d/19424.feature create mode 100644 synapse/storage/schema/main/delta/94/03_state_dag_fwd_extrems.sql create mode 100644 tests/storage/test_msc4242_state_dag.py diff --git a/changelog.d/19424.feature b/changelog.d/19424.feature new file mode 100644 index 0000000000..8f241a87b5 --- /dev/null +++ b/changelog.d/19424.feature @@ -0,0 +1 @@ +Add experimental support for [MSC4242](https://github.com/matrix-org/matrix-spec-proposals/pull/4242): State DAGs. Excludes federation support. \ No newline at end of file diff --git a/contrib/grafana/synapse.json b/contrib/grafana/synapse.json index ceacc10369..a67d1aba0b 100644 --- a/contrib/grafana/synapse.json +++ b/contrib/grafana/synapse.json @@ -6809,6 +6809,155 @@ ], "title": "Stale extremity dropping", "type": "timeseries" + }, + { + "datasource": { + "uid": "${DS_PROMETHEUS}" + }, + "description": "For a given percentage P, the number X where P% of events were persisted to rooms with X state DAG forward extremities or fewer.", + "fieldConfig": { + "defaults": { + "links": [] + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 50 + }, + "id": 181, + "options": { + "alertThreshold": true + }, + "pluginVersion": "9.2.2", + "targets": [ + { + "datasource": { + "uid": "${DS_PROMETHEUS}" + }, + "expr": "histogram_quantile(0.5, rate(synapse_storage_msc4242_state_dag_forward_extremities_persisted_bucket{server_name=\"$server_name\"}[$bucket_size]) and on (instance, job, index) (synapse_storage_events_persisted_events_total > 0))", + "format": "time_series", + "intervalFactor": 1, + "legendFormat": "50%", + "refId": "A" + }, + { + "datasource": { + "uid": "${DS_PROMETHEUS}" + }, + "expr": "histogram_quantile(0.75, rate(synapse_storage_msc4242_state_dag_forward_extremities_persisted_bucket{server_name=\"$server_name\"}[$bucket_size]) and on (instance, job, index) (synapse_storage_events_persisted_events_total > 0))", + "format": "time_series", + "intervalFactor": 1, + "legendFormat": "75%", + "refId": "B" + }, + { + "datasource": { + "uid": "${DS_PROMETHEUS}" + }, + "expr": "histogram_quantile(0.90, rate(synapse_storage_msc4242_state_dag_forward_extremities_persisted_bucket{server_name=\"$server_name\"}[$bucket_size]) and on (instance, job, index) (synapse_storage_events_persisted_events_total > 0))", + "format": "time_series", + "intervalFactor": 1, + "legendFormat": "90%", + "refId": "C" + }, + { + "datasource": { + "uid": "${DS_PROMETHEUS}" + }, + "expr": "histogram_quantile(0.99, rate(synapse_storage_msc4242_state_dag_forward_extremities_persisted_bucket{server_name=\"$server_name\"}[$bucket_size]) and on (instance, job, index) (synapse_storage_events_persisted_events_total > 0))", + "format": "time_series", + "intervalFactor": 1, + "legendFormat": "99%", + "refId": "D" + } + ], + "title": "Events persisted, by number of state DAG forward extremities in room (quantiles)", + "type": "timeseries" + }, + { + "datasource": { + "uid": "${DS_PROMETHEUS}" + }, + "description": "Colour reflects the number of events persisted to rooms with the given number of state DAG forward extremities, or fewer.", + "fieldConfig": { + "defaults": { + "custom": { + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "scaleDistribution": { + "type": "linear" + } + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 50 + }, + "id": 127, + "options": { + "calculate": false, + "calculation": {}, + "cellGap": 1, + "cellValues": {}, + "color": { + "exponent": 0.5, + "fill": "#5794F2", + "min": 0, + "mode": "opacity", + "reverse": false, + "scale": "exponential", + "scheme": "Oranges", + "steps": 128 + }, + "exemplars": { + "color": "rgba(255,0,255,0.7)" + }, + "filterValues": { + "le": 1e-9 + }, + "legend": { + "show": true + }, + "rowsFrame": { + "layout": "auto" + }, + "showValue": "never", + "tooltip": { + "show": true, + "yHistogram": true + }, + "yAxis": { + "axisPlacement": "left", + "decimals": 0, + "reverse": false, + "unit": "short" + } + }, + "pluginVersion": "9.2.2", + "targets": [ + { + "datasource": { + "uid": "${DS_PROMETHEUS}" + }, + "expr": "rate(synapse_storage_msc4242_state_dag_forward_extremities_persisted_bucket{server_name=\"$server_name\"}[$bucket_size]) and on (instance, job, index) (synapse_storage_events_persisted_events_total > 0)", + "format": "heatmap", + "intervalFactor": 1, + "legendFormat": "{{le}}", + "refId": "A" + } + ], + "title": "Events persisted, by number of state DAG forward extremities in room (heatmap)", + "type": "heatmap" } ], "title": "Extremities", @@ -7711,4 +7860,4 @@ "uid": "000000012", "version": 1, "weekStart": "" -} +} \ No newline at end of file diff --git a/rust/src/events/internal_metadata.rs b/rust/src/events/internal_metadata.rs index 21d3b8c435..6fd3d06b00 100644 --- a/rust/src/events/internal_metadata.rs +++ b/rust/src/events/internal_metadata.rs @@ -65,6 +65,7 @@ enum EventInternalMetadataData { DelayId(Box), TokenId(i64), DeviceId(Box), + CalculatedAuthEventIDs(Vec), // MSC4242: State DAGs } impl EventInternalMetadataData { @@ -140,6 +141,10 @@ impl EventInternalMetadataData { pyo3::intern!(py, "device_id"), o.into_pyobject(py).unwrap_infallible().into_any(), ), + EventInternalMetadataData::CalculatedAuthEventIDs(o) => ( + pyo3::intern!(py, "calculated_auth_event_ids"), + o.into_pyobject(py).unwrap().into_any(), + ), } } @@ -218,6 +223,11 @@ impl EventInternalMetadataData { .map(String::into_boxed_str) .with_context(|| format!("'{key_str}' has invalid type"))?, ), + "calculated_auth_event_ids" => EventInternalMetadataData::CalculatedAuthEventIDs( + value + .extract() + .with_context(|| format!("'{key_str}' has invalid type"))?, + ), _ => return Ok(None), }; @@ -395,6 +405,10 @@ impl EventInternalMetadataInner { get_property_opt!(self, DelayId).map(|s| s.deref()) } + pub fn get_calculated_auth_event_ids(&self) -> Option<&Vec> { + get_property_opt!(self, CalculatedAuthEventIDs) + } + pub fn get_token_id(&self) -> Option { get_property_opt!(self, TokenId).copied() } @@ -456,6 +470,10 @@ impl EventInternalMetadataInner { pub fn set_device_id(&mut self, obj: String) { set_property!(self, DeviceId, obj.into_boxed_str()); } + + pub fn set_calculated_auth_event_ids(&mut self, obj: Vec) { + set_property!(self, CalculatedAuthEventIDs, obj); + } } #[pyclass(frozen)] @@ -722,6 +740,21 @@ impl EventInternalMetadata { Ok(()) } + /// The calculated auth event IDs, if it was set when the event was created. + #[getter] + fn get_calculated_auth_event_ids(&self) -> PyResult> { + let guard = self.read_inner()?; + attr_err( + guard.get_calculated_auth_event_ids().cloned(), + "calculated_auth_event_ids", + ) + } + #[setter] + fn set_calculated_auth_event_ids(&self, obj: Vec) -> PyResult<()> { + self.write_inner()?.set_calculated_auth_event_ids(obj); + Ok(()) + } + /// The delay ID, set only if the event was a delayed event. #[getter] fn get_delay_id(&self) -> PyResult { diff --git a/rust/src/room_versions.rs b/rust/src/room_versions.rs index fbcc32516a..dbc962174d 100644 --- a/rust/src/room_versions.rs +++ b/rust/src/room_versions.rs @@ -47,6 +47,9 @@ impl EventFormatVersions { /// MSC4291 room IDs as hashes: introduced for room HydraV11 #[classattr] const ROOM_V11_HYDRA_PLUS: i32 = 4; + /// MSC4242 state DAGs: adds prev_state_events, removes auth_events + #[classattr] + const ROOM_VMSC4242: i32 = 5; } /// Enum to identify the state resolution algorithms. @@ -146,6 +149,14 @@ pub struct RoomVersion { /// /// In these room versions, we are stricter with event size validation. pub strict_event_byte_limits_room_versions: bool, + /// MSC4242: State DAGs. Creates events with prev_state_events instead of auth_events and derives + /// state from it. Events are always processed in causal order without any gaps in the DAG + /// (prev_state_events are always known), guaranteeing that processed events have a path to the + /// create event. This is an emergent property of state DAGs as asserting that there is a path + /// to the create event every time we insert an event would be prohibitively expensive. + /// This is similar to how doubly-linked lists can potentially not refer to previous items correctly + /// without verifying the list's integrity, but doing it on every insert is too expensive. + pub msc4242_state_dags: bool, } const ROOM_VERSION_V1: RoomVersion = RoomVersion { @@ -170,6 +181,7 @@ const ROOM_VERSION_V1: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; const ROOM_VERSION_V2: RoomVersion = RoomVersion { @@ -194,6 +206,7 @@ const ROOM_VERSION_V2: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; const ROOM_VERSION_V3: RoomVersion = RoomVersion { @@ -218,6 +231,7 @@ const ROOM_VERSION_V3: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; const ROOM_VERSION_V4: RoomVersion = RoomVersion { @@ -242,6 +256,7 @@ const ROOM_VERSION_V4: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; const ROOM_VERSION_V5: RoomVersion = RoomVersion { @@ -266,6 +281,7 @@ const ROOM_VERSION_V5: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; const ROOM_VERSION_V6: RoomVersion = RoomVersion { @@ -290,6 +306,7 @@ const ROOM_VERSION_V6: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; const ROOM_VERSION_V7: RoomVersion = RoomVersion { @@ -314,6 +331,7 @@ const ROOM_VERSION_V7: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; const ROOM_VERSION_V8: RoomVersion = RoomVersion { @@ -338,6 +356,7 @@ const ROOM_VERSION_V8: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; const ROOM_VERSION_V9: RoomVersion = RoomVersion { @@ -362,6 +381,7 @@ const ROOM_VERSION_V9: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; const ROOM_VERSION_V10: RoomVersion = RoomVersion { @@ -386,6 +406,7 @@ const ROOM_VERSION_V10: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; /// MSC3389 (Redaction changes for events with a relation) based on room version "10". @@ -411,6 +432,7 @@ const ROOM_VERSION_MSC3389V10: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: true, + msc4242_state_dags: false, }; /// MSC1767 (Extensible Events) based on room version "10". @@ -436,6 +458,7 @@ const ROOM_VERSION_MSC1767V10: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; /// MSC3757 (Restricting who can overwrite a state event) based on room version "10". @@ -461,6 +484,7 @@ const ROOM_VERSION_MSC3757V10: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; const ROOM_VERSION_V11: RoomVersion = RoomVersion { @@ -485,6 +509,7 @@ const ROOM_VERSION_V11: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: true, // Changed from v10 + msc4242_state_dags: false, }; /// MSC3757 (Restricting who can overwrite a state event) based on room version "11". @@ -510,6 +535,7 @@ const ROOM_VERSION_MSC3757V11: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: true, + msc4242_state_dags: false, }; const ROOM_VERSION_HYDRA_V11: RoomVersion = RoomVersion { @@ -534,6 +560,7 @@ const ROOM_VERSION_HYDRA_V11: RoomVersion = RoomVersion { msc4289_creator_power_enabled: true, // Changed from v11 msc4291_room_ids_as_hashes: true, // Changed from v11 strict_event_byte_limits_room_versions: true, + msc4242_state_dags: false, }; const ROOM_VERSION_V12: RoomVersion = RoomVersion { @@ -558,6 +585,32 @@ const ROOM_VERSION_V12: RoomVersion = RoomVersion { msc4289_creator_power_enabled: true, // Changed from v11 msc4291_room_ids_as_hashes: true, // Changed from v11 strict_event_byte_limits_room_versions: true, + msc4242_state_dags: false, +}; + +const ROOM_VERSION_MSC4242V12: RoomVersion = RoomVersion { + identifier: "org.matrix.msc4242.12", + disposition: RoomDisposition::UNSTABLE, + event_format: EventFormatVersions::ROOM_VMSC4242, + state_res: StateResolutionVersions::V2_1, + enforce_key_validity: true, + special_case_aliases_auth: false, + strict_canonicaljson: true, + limit_notifications_power_levels: true, + implicit_room_creator: true, + updated_redaction_rules: true, + restricted_join_rule: true, + restricted_join_rule_fix: true, + knock_join_rule: true, + msc3389_relation_redactions: false, + knock_restricted_join_rule: true, + enforce_int_power_levels: true, + msc3931_push_features: &[], + msc3757_enabled: false, + msc4289_creator_power_enabled: true, + msc4291_room_ids_as_hashes: true, + strict_event_byte_limits_room_versions: true, + msc4242_state_dags: true, }; /// Helper class for managing the known room versions, and providing dict-like @@ -800,6 +853,10 @@ impl RoomVersions { fn V12(py: Python<'_>) -> PyResult> { ROOM_VERSION_V12.into_py_any(py) } + #[classattr] + fn MSC4242v12(py: Python<'_>) -> PyResult> { + ROOM_VERSION_MSC4242V12.into_py_any(py) + } } /// Called when registering modules with python. @@ -814,11 +871,12 @@ pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> child_module.add_class::()?; // Build KNOWN_EVENT_FORMAT_VERSIONS as a frozenset - let known_ef: [i32; 4] = [ + let known_ef: [i32; 5] = [ EventFormatVersions::ROOM_V1_V2, EventFormatVersions::ROOM_V3, EventFormatVersions::ROOM_V4_PLUS, EventFormatVersions::ROOM_V11_HYDRA_PLUS, + EventFormatVersions::ROOM_VMSC4242, ]; let known_event_format_versions = PyFrozenSet::new(py, known_ef)?; child_module.add("KNOWN_EVENT_FORMAT_VERSIONS", known_event_format_versions)?; diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 702c7e3246..f1a7771568 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -479,6 +479,12 @@ class ExperimentalConfig(Config): # Enable room version (and thus applicable push rules from MSC3931/3932) KNOWN_ROOM_VERSIONS.add_room_version(RoomVersions.MSC1767v10) + # MSC4242: State DAGs + self.msc4242_enabled: bool = experimental.get("msc4242_enabled", False) + if self.msc4242_enabled: + # Enable the room version + KNOWN_ROOM_VERSIONS.add_room_version(RoomVersions.MSC4242v12) + # MSC3391: Removing account data. self.msc3391_enabled = experimental.get("msc3391_enabled", False) diff --git a/synapse/event_auth.py b/synapse/event_auth.py index bf239e660d..ca528ae235 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -61,7 +61,7 @@ from synapse.api.room_versions import ( EventFormatVersions, RoomVersion, ) -from synapse.events import is_creator +from synapse.events import FrozenEventVMSC4242, is_creator from synapse.state import CREATE_KEY from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import ( @@ -186,6 +186,70 @@ async def check_state_independent_auth_rules( # 1.5 Otherwise, allow return + # State DAGs 2. Considering the event's prev_state_events: + if event.room_version.msc4242_state_dags: + prev_state_events_ids = set(cast(FrozenEventVMSC4242, event).prev_state_events) + # Fetch all of the `prev_state_events` + prev_state_events = {} + # Try to load the `prev_state_events` from `batched_auth_events` initially as + # that can save us a database hit. + if batched_auth_events is not None: + prev_state_events = { + event_id: value + for event_id in prev_state_events_ids + if (value := batched_auth_events.get(event_id)) is not None + } + # Fetch the rest of the `prev_state_events` + missing_prev_state_events_ids = prev_state_events_ids - set( + prev_state_events.keys() + ) + fetched_prev_state_events = await store.get_events( + missing_prev_state_events_ids, + redact_behaviour=EventRedactBehaviour.as_is, + allow_rejected=True, + ) + prev_state_events.update(fetched_prev_state_events) + if len(prev_state_events) != len(prev_state_events_ids): + # we should have all the `prev_state_events` by now, so if we do not, that suggests + # a Synapse programming error + known_prev_state_event_ids = set(prev_state_events) + raise AssertionError( + f"Event {event.event_id} has unknown prev_state_events " + + f"({len(prev_state_events)}/{len(prev_state_events_ids)} known)" + + f"{prev_state_events_ids - known_prev_state_event_ids} missing " + + f"out of {prev_state_events_ids}" + ) + for prev_state_event in prev_state_events.values(): + # 2.1 If there are entries which do not belong in the same room, reject. + if prev_state_event.room_id != event.room_id: + raise AuthError( + 403, + "During auth for event %s in room %s, found event %s in prev_state_events " + "which belongs to a different room %s" + % ( + event.event_id, + event.room_id, + prev_state_event.event_id, + prev_state_event.room_id, + ), + ) + # 2.2 If there are entries which do not have a state_key, reject. + if not prev_state_event.is_state(): + raise AuthError( + 403, + f"During auth for event {event.event_id} in room {event.room_id}, event has a " + + f"prev_state_event which is not state: {prev_state_event.event_id}", + ) + # 2.3 If there are entries which were themselves rejected under the checks performed on + # receipt of a PDU, reject. + if prev_state_event.rejected_reason is not None: + raise AuthError( + 403, + f"During auth for event {event.event_id} in room {event.room_id}, event has a " + + f"prev_state_event which is rejected ({prev_state_event.rejected_reason}): " + + f"{prev_state_event.event_id}", + ) + # 2. Reject if event has auth_events that: ... auth_events: ChainMap[str, EventBase] = ChainMap() if batched_auth_events: @@ -450,6 +514,12 @@ def _check_create(event: "EventBase") -> None: if event.prev_event_ids(): raise AuthError(403, "Create event has prev events") + # State DAGs 1.2 If it has any prev_state_events, reject. + if event.room_version.msc4242_state_dags: + assert isinstance(event, FrozenEventVMSC4242) + if len(event.prev_state_events) > 0: + raise AuthError(403, "Create event has prev state events") + if event.room_version.msc4291_room_ids_as_hashes: # 1.2 If the create event has a room_id, reject if "room_id" in event: diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index f48d5c4f1d..f4a5624d1a 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -44,10 +44,7 @@ from synapse.api.constants import ( ) from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions from synapse.synapse_rust.events import EventInternalMetadata -from synapse.types import ( - JsonDict, - StrCollection, -) +from synapse.types import JsonDict, StateKey, StrCollection from synapse.util.caches import intern_dict from synapse.util.duration import Duration from synapse.util.frozenutils import freeze @@ -575,9 +572,60 @@ class FrozenEventV4(FrozenEventV3): return [*self._dict["auth_events"], create_event_id] +class FrozenEventVMSC4242(FrozenEventV4): + """FrozenEventVMSC4242, which differs from FrozenEventV4 only in the addition of prev_state_events""" + + format_version = EventFormatVersions.ROOM_VMSC4242 + prev_state_events: DictProperty[list[str]] = DictProperty("prev_state_events") + + def __init__( + self, + event_dict: JsonDict, + room_version: RoomVersion, + internal_metadata_dict: JsonDict | None = None, + rejected_reason: str | None = None, + ): + # Similar to how we assert event_id isn't in V2+ events, we do the same with auth_events. + # We don't expect `auth_events` in the wire format because we calculate it from prev_state_events. + assert "auth_events" not in event_dict + super().__init__( + event_dict=event_dict, + room_version=room_version, + internal_metadata_dict=internal_metadata_dict, + rejected_reason=rejected_reason, + ) + + def auth_event_ids(self) -> StrCollection: + """Returns the list of _calculated_ auth event IDs. + + Returns: + The list of event IDs of this event's auth events + """ + # Catches cases where we accidentally call auth_event_ids() prior to calculating what they + # actually are. The exception being the m.room.create event which has no auth events. + if self.type != EventTypes.Create: + assert len(self.internal_metadata.calculated_auth_event_ids) > 0 + return self.internal_metadata.calculated_auth_event_ids + + def __repr__(self) -> str: + rejection = f"REJECTED={self.rejected_reason}, " if self.rejected_reason else "" + + return ( + f"<{self.__class__.__name__} " + f"{rejection}" + f"event_id={self.event_id}, " + f"type={self.get('type')}, " + f"state_key={self.get('state_key')}, " + f"prev_events={self.get('prev_events')}, " + f"prev_state_events={self.get('prev_state_events')}, " + f"outlier={self.internal_metadata.is_outlier()}" + ">" + ) + + def _event_type_from_format_version( format_version: int, -) -> type[FrozenEvent | FrozenEventV2 | FrozenEventV3]: +) -> type[FrozenEvent | FrozenEventV2 | FrozenEventV3 | FrozenEventVMSC4242]: """Returns the python type to use to construct an Event object for the given event format version. @@ -594,6 +642,8 @@ def _event_type_from_format_version( return FrozenEventV2 elif format_version == EventFormatVersions.ROOM_V4_PLUS: return FrozenEventV3 + elif format_version == EventFormatVersions.ROOM_VMSC4242: + return FrozenEventVMSC4242 elif format_version == EventFormatVersions.ROOM_V11_HYDRA_PLUS: return FrozenEventV4 else: @@ -655,6 +705,24 @@ def relation_from_event(event: EventBase) -> _EventRelation | None: return _EventRelation(parent_id, rel_type, aggregation_key) +def event_exists_in_state_dag( + event: Union["EventBase", "EventBuilder", "EventMetadata", "StateKey"], +) -> bool: + """Given an event, returns true if this event should form part of the state DAG. + Only valid for room versions which use a state DAG (MSC4242).""" + state_key = None + if isinstance(event, EventMetadata): + state_key = event.state_key + elif isinstance(event, tuple): # StateKey + # can't use StateKey else you get: + # "Subscripted generics cannot be used with class and instance checks" + state_key = event[1] + else: + state_key = event.state_key if event.is_state() else None + + return state_key is not None + + def is_creator(create: EventBase, user_id: str) -> bool: """ Return true if the provided user ID is the room creator. @@ -689,3 +757,13 @@ class StrippedStateEvent: state_key: str sender: str content: dict[str, Any] + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class EventMetadata: + """Returned by `get_metadata_for_events`""" + + room_id: str + event_type: str + state_key: str | None + rejection_reason: str | None diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 2cd1bf6106..78eb98e1e5 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -132,6 +132,7 @@ class EventBuilder: prev_event_ids: list[str], auth_event_ids: list[str] | None, depth: int | None = None, + prev_state_events: list[str] | None = None, ) -> EventBase: """Transform into a fully signed and hashed event @@ -143,10 +144,51 @@ class EventBuilder: depth: Override the depth used to order the event in the DAG. Should normally be set to None, which will cause the depth to be calculated based on the prev_events. - + prev_state_events: The event IDs to use as prev_state_events. + Only applicable on MSC4242 state DAG rooms. If this is supplied, auth_event_ids + must not be specified unless this event is part of a batch such that the builder + will be unable to compute the auth_event_ids due to the events not being persisted + yet. Returns: The signed and hashed event. """ + # If the caller specifies this, make sure the room version supports it. + if prev_state_events: + assert self.room_version.msc4242_state_dags + if self.room_version.msc4242_state_dags: + assert prev_state_events is not None + if self.room_id: + state_ids = await self._state.compute_state_after_events( + self.room_id, + prev_state_events, + state_filter=StateFilter.from_types( + auth_types_for_event(self.room_version, self) + ), + await_full_state=False, + ) + # When we create rooms we only insert the create+member events, and batch the rest. + # Therefore, we may not have state_ids from compute_state_after_events as the + # prev_state_events are unknown. If this happens, the caller provides the auth events + # to use instead. + calculated_auth_event_ids: list[ + str + ] = [] # assume it's the create event which has [] + if len(state_ids) == 0 and len(prev_state_events) > 0: + # it's a batched event, so we should have been provided the auth_events + assert auth_event_ids and len(auth_event_ids) > 0 + calculated_auth_event_ids = auth_event_ids + else: + calculated_auth_event_ids = ( + self._event_auth_handler.compute_auth_events(self, state_ids) + ) + else: + # event is a state DAG event and is the create event (room_id is not provided), + # therefore there are no auth_events. + calculated_auth_event_ids = [] + assert self.type == EventTypes.Create and self.state_key == "" + self.internal_metadata.calculated_auth_event_ids = calculated_auth_event_ids + auth_event_ids = calculated_auth_event_ids + # Create events always have empty auth_events. if self.type == EventTypes.Create and self.is_state() and self.state_key == "": auth_event_ids = [] @@ -155,6 +197,8 @@ class EventBuilder: if auth_event_ids is None: # Every non-create event must have a room ID assert self.room_id is not None + # this block must not be hit for MSC4242 rooms as it resolves state with prev_events + assert not self.room_version.msc4242_state_dags state_ids = await self._state.compute_state_after_events( self.room_id, prev_event_ids, @@ -231,7 +275,6 @@ class EventBuilder: # rejected by other servers (and so that they can be persisted in # the db) depth = min(depth, MAX_DEPTH) - event_dict: dict[str, Any] = { "auth_events": auth_events, "prev_events": prev_events, @@ -241,8 +284,6 @@ class EventBuilder: "unsigned": self.unsigned, "depth": depth, } - if self.room_id is not None: - event_dict["room_id"] = self.room_id if self.room_version.msc4291_room_ids_as_hashes: # In MSC4291: the create event has no room ID as the create event ID /is/ the room ID. @@ -262,6 +303,14 @@ class EventBuilder: auth_event_ids.remove(create_event_id) event_dict["auth_events"] = auth_event_ids + if self.room_version.msc4242_state_dags: + # Auth events are removed entirely on state DAG rooms + event_dict.pop("auth_events") + assert prev_state_events is not None + event_dict["prev_state_events"] = prev_state_events + if self.room_id is not None: + event_dict["room_id"] = self.room_id + if self.is_state(): event_dict["state_key"] = self._state_key diff --git a/synapse/events/utils.py b/synapse/events/utils.py index ff0476f5fb..f038fb5578 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -156,6 +156,10 @@ def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDic # Earlier room versions from had additional allowed keys. if not room_version.updated_redaction_rules: allowed_keys.extend(["prev_state", "membership", "origin"]) + # Custom room versions add new allowed keys and remove others + if room_version.msc4242_state_dags: + allowed_keys.extend(["prev_state_events"]) + allowed_keys.remove("auth_events") event_type = event_dict["type"] diff --git a/synapse/events/validator.py b/synapse/events/validator.py index b27f8a942a..ff22b2287f 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -63,14 +63,17 @@ class EventValidator: if event.format_version == EventFormatVersions.ROOM_V1_V2: EventID.from_string(event.event_id) - required = [ + required = { "auth_events", "content", "hashes", "prev_events", "sender", "type", - ] + } + if event.room_version.msc4242_state_dags: + required.remove("auth_events") + required.add("prev_state_events") for k in required: if k not in event: diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 55151ca549..78a1900c73 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1108,6 +1108,11 @@ class FederationClient(FederationBase): SynapseError: if the chosen remote server returns a 300/400 code, or no servers successfully handle the request. """ + # See related restriction in /createRoom requests in handlers/room.py + if room_version.msc4242_state_dags: + raise UnsupportedRoomVersionError( + "Homeserver does not support this room version over federation" + ) async def send_request(destination: str) -> SendJoinResult: response = await self._do_send_join( diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index d2c1f98d7c..51a752472f 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -32,7 +32,7 @@ import attr from synapse.api.constants import Direction, EventTypes, Membership from synapse.api.errors import SynapseError -from synapse.events import EventBase +from synapse.events import EventBase, FrozenEventVMSC4242 from synapse.events.utils import FilteredEvent from synapse.types import ( JsonMapping, @@ -494,9 +494,16 @@ class AdminHandler: event_dict["redacts"] = event.event_id try: + prev_state_events = None + if room_version.msc4242_state_dags: + assert isinstance(event, FrozenEventVMSC4242) + prev_state_events = event.prev_state_events + assert prev_state_events is not None, ( + "Parent event of redaction has no `prev_state_events` which should be impossible as `prev_state_events` is a required field in MSC4242 rooms" + ) # set the prev event to the offending message to allow for redactions # to be processed in the case where the user has been kicked/banned before - # redactions are requested + # redactions are requested. ( redaction, _, @@ -505,6 +512,7 @@ class AdminHandler: event_dict, prev_event_ids=[event.event_id], ratelimit=False, + prev_state_events=prev_state_events, ) except Exception as ex: logger.info( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 0aa0a16127..4032c7eca9 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -53,7 +53,7 @@ from synapse.api.errors import ( from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.urls import ConsentURIBuilder from synapse.event_auth import validate_event_for_room_version -from synapse.events import EventBase, relation_from_event +from synapse.events import EventBase, FrozenEventVMSC4242, relation_from_event from synapse.events.builder import EventBuilder from synapse.events.snapshot import ( EventContext, @@ -589,6 +589,7 @@ class EventCreationHandler: state_map: StateMap[str] | None = None, for_batch: bool = False, current_state_group: int | None = None, + prev_state_events: list[str] | None = None, delay_id: str | None = None, ) -> tuple[EventBase, UnpersistedEventContextBase]: """ @@ -644,6 +645,10 @@ class EventCreationHandler: current_state_group: the current state group, used only for creating events for batch persisting + prev_state_events: + The state event IDs which represent the current forward extremities of the state DAG. + Only applicable on room versions which use a state DAG (MSC4242). + delay_id: The delay ID of this event, if it was a delayed event. Raises: @@ -748,6 +753,7 @@ class EventCreationHandler: state_map=state_map, for_batch=for_batch, current_state_group=current_state_group, + prev_state_events=prev_state_events, ) # In an ideal world we wouldn't need the second part of this condition. However, @@ -976,6 +982,7 @@ class EventCreationHandler: ignore_shadow_ban: bool = False, outlier: bool = False, depth: int | None = None, + prev_state_events: list[str] | None = None, delay_id: str | None = None, ) -> tuple[EventBase, int]: """ @@ -1005,6 +1012,9 @@ class EventCreationHandler: depth: Override the depth used to order the event in the DAG. Should normally be set to None, which will cause the depth to be calculated based on the prev_events. + prev_state_events: + The state event IDs which represent the current forward extremities of the state DAG. + Only applicable on room versions which use a state DAG (MSC4242). delay_id: The delay ID of this event, if it was a delayed event. Returns: @@ -1102,6 +1112,7 @@ class EventCreationHandler: ignore_shadow_ban=ignore_shadow_ban, outlier=outlier, depth=depth, + prev_state_events=prev_state_events, delay_id=delay_id, ) @@ -1116,6 +1127,7 @@ class EventCreationHandler: ignore_shadow_ban: bool = False, outlier: bool = False, depth: int | None = None, + prev_state_events: list[str] | None = None, delay_id: str | None = None, ) -> tuple[EventBase, int]: room_id = event_dict["room_id"] @@ -1145,6 +1157,7 @@ class EventCreationHandler: state_event_ids=state_event_ids, outlier=outlier, depth=depth, + prev_state_events=prev_state_events, delay_id=delay_id, ) context = await unpersisted_context.persist(event) @@ -1240,6 +1253,7 @@ class EventCreationHandler: state_map: StateMap[str] | None = None, for_batch: bool = False, current_state_group: int | None = None, + prev_state_events: list[str] | None = None, ) -> tuple[EventBase, UnpersistedEventContextBase]: """Create a new event for a local client. If bool for_batch is true, will create an event using the prev_event_ids, and will create an event context for @@ -1281,9 +1295,30 @@ class EventCreationHandler: current_state_group: the current state group, used only for creating events for batch persisting + prev_state_events: + The state event IDs which represent the current forward extremities of the state DAG. + Only applicable on room versions which use a state DAG (MSC4242). + If unset, populates them from the current state dag forward extremities. + Returns: Tuple of created event, UnpersistedEventContext """ + if builder.room_version.msc4242_state_dags: + assert auth_event_ids is None + # (kegan) I can't find any call-site which uses this. We can't risk letting in + # untrusted input, so for now assert that we aren't told about any state. + assert state_event_ids is None + + if builder.room_id: + if prev_state_events is None: + prev_state_events = list( + await self.store.get_state_dag_extremities(builder.room_id) + ) + else: + # create event doesn't need prev_state_events to be fetched, but it must be non-None. + assert builder.type == EventTypes.Create and builder.state_key == "" + prev_state_events = [] + # Strip down the state_event_ids to only what we need to auth the event. # For example, we don't need extra m.room.member that don't match event.sender if state_event_ids is not None: @@ -1357,7 +1392,10 @@ class EventCreationHandler: assert state_map is not None auth_ids = self._event_auth_handler.compute_auth_events(builder, state_map) event = await builder.build( - prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth + prev_event_ids=prev_event_ids, + auth_event_ids=auth_ids, + depth=depth, + prev_state_events=prev_state_events, ) context: UnpersistedEventContextBase = ( @@ -1374,6 +1412,7 @@ class EventCreationHandler: prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids, depth=depth, + prev_state_events=prev_state_events, ) # Pass on the outlier property from the builder to the event @@ -1563,6 +1602,20 @@ class EventCreationHandler: auth_event = event_id_to_event.get(event_id) if auth_event: batched_auth_events[event_id] = auth_event + if event.room_version.msc4242_state_dags: + assert isinstance(event, FrozenEventVMSC4242) + # State DAG rooms will check that the prev_state_events are not rejected. + # To do that, we need to make sure we pass in the prev_state_events as + # batched_auth_events, else we will fail the event due to the + # prev_state_events not existing in the database. + for prev_state_event_id in event.prev_state_events: + prev_state_event = event_id_to_event.get( + prev_state_event_id + ) + if prev_state_event: + batched_auth_events[prev_state_event_id] = ( + prev_state_event + ) await self._event_auth_handler.check_auth_rules_from_context( event, batched_auth_events ) @@ -1817,7 +1870,10 @@ class EventCreationHandler: # set for a while, so that the expiry time is reset. state_entry = await self.state.resolve_state_groups_for_events( - event.room_id, event_ids=event.prev_event_ids() + event.room_id, + event_ids=event.prev_state_events + if isinstance(event, FrozenEventVMSC4242) + else event.prev_event_ids(), ) if state_entry.state_group: @@ -2360,9 +2416,16 @@ class EventCreationHandler: # case. prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id) + prev_state_events = None + if original_event.room_version.msc4242_state_dags: + prev_state_events = list( + await self.store.get_state_dag_extremities(builder.room_id) + ) + event = await builder.build( prev_event_ids=prev_event_ids, auth_event_ids=None, + prev_state_events=prev_state_events, ) # we rebuild the event context, to be on the safe side. If nothing else, diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 9074d7916b..f110be0a2f 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -65,7 +65,7 @@ from synapse.api.filtering import Filter from synapse.api.ratelimiting import Ratelimiter from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.event_auth import validate_event_for_room_version -from synapse.events import EventBase +from synapse.events import EventBase, event_exists_in_state_dag from synapse.events.snapshot import UnpersistedEventContext from synapse.events.utils import FilteredEvent, copy_and_fixup_power_levels_contents from synapse.handlers.relations import BundledAggregations @@ -1237,6 +1237,10 @@ class RoomCreationHandler: creation_content = config.get("creation_content", {}) # override any attempt to set room versions via the creation_content creation_content["room_version"] = room_version.identifier + # We do not currently support federating state DAG rooms. + # See related restriction in /send_join requests in federation_client.py. + if room_version.msc4242_state_dags: + creation_content[EventContentFields.FEDERATE] = False # trusted private chats have the invited users marked as additional creators if ( @@ -1486,6 +1490,11 @@ class RoomCreationHandler: # the most recently created event prev_event: list[str] = [] + # This should be the most recently created state event as we create each event + prev_state_events: list[str] | None = ( + [] if room_version.msc4242_state_dags else None + ) + # a map of event types, state keys -> event_ids. We collect these mappings this as events are # created (but not persisted to the db) to determine state for future created events # (as this info can't be pulled from the db) @@ -1512,6 +1521,7 @@ class RoomCreationHandler: """ nonlocal depth nonlocal prev_event + nonlocal prev_state_events # Create the event dictionary. event_dict = {"type": etype, "content": content} @@ -1525,6 +1535,7 @@ class RoomCreationHandler: creator, event_dict, prev_event_ids=prev_event, + prev_state_events=prev_state_events, depth=depth, # Take a copy to ensure each event gets a unique copy of # state_map since it is modified below. @@ -1535,7 +1546,8 @@ class RoomCreationHandler: depth += 1 prev_event = [new_event.event_id] state_map[(new_event.type, new_event.state_key)] = new_event.event_id - + if room_version.msc4242_state_dags and event_exists_in_state_dag(new_event): + prev_state_events = [new_event.event_id] return new_event, new_unpersisted_context preset_config, config = self._room_preset_config(room_config) @@ -1568,6 +1580,8 @@ class RoomCreationHandler: ignore_shadow_ban=True, ) last_sent_event_id = ev.event_id + if room_version.msc4242_state_dags: + prev_state_events = [ev.event_id] member_event_id, _ = await self.room_member_handler.update_membership( creator, @@ -1579,8 +1593,11 @@ class RoomCreationHandler: new_room=True, prev_event_ids=[last_sent_event_id], depth=depth, + prev_state_events=prev_state_events, ) prev_event = [member_event_id] + if room_version.msc4242_state_dags: + prev_state_events = [member_event_id] # update the depth and state map here as the membership event has been created # through a different code path diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index b2e678e90e..236c8ca03c 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -36,9 +36,11 @@ from synapse.api.constants import ( from synapse.api.errors import ( AuthError, Codes, + NotFoundError, PartialStateConflictError, ShadowBanError, SynapseError, + UnsupportedRoomVersionError, ) from synapse.api.ratelimiting import Ratelimiter from synapse.event_auth import get_named_level, get_power_level_event @@ -408,6 +410,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): require_consent: bool = True, outlier: bool = False, origin_server_ts: int | None = None, + prev_state_events: list[str] | None = None, delay_id: str | None = None, ) -> tuple[str, int]: """ @@ -494,6 +497,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): depth=depth, require_consent=require_consent, outlier=outlier, + prev_state_events=prev_state_events, delay_id=delay_id, ) context = await unpersisted_context.persist(event) @@ -590,6 +594,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): state_event_ids: list[str] | None = None, depth: int | None = None, origin_server_ts: int | None = None, + prev_state_events: list[str] | None = None, delay_id: str | None = None, ) -> tuple[str, int]: """Update a user's membership in a room. @@ -684,6 +689,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): state_event_ids=state_event_ids, depth=depth, origin_server_ts=origin_server_ts, + prev_state_events=prev_state_events, delay_id=delay_id, ) @@ -707,6 +713,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): state_event_ids: list[str] | None = None, depth: int | None = None, origin_server_ts: int | None = None, + prev_state_events: list[str] | None = None, delay_id: str | None = None, ) -> tuple[str, int]: """Helper for update_membership. @@ -951,10 +958,21 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): require_consent=require_consent, outlier=outlier, origin_server_ts=origin_server_ts, + prev_state_events=prev_state_events, delay_id=delay_id, ) - latest_event_ids = await self.store.get_prev_events_for_room(room_id) + is_state_dags = False + try: + room_version = await self.store.get_room_version(room_id) + is_state_dags = room_version.msc4242_state_dags + except (NotFoundError, UnsupportedRoomVersionError): + pass + + if is_state_dags: + latest_event_ids = list(await self.store.get_state_dag_extremities(room_id)) + else: + latest_event_ids = await self.store.get_prev_events_for_room(room_id) is_partial_state_room = await self.store.is_partial_state_room(room_id) partial_state_before_join = await self.state_handler.compute_state_after_events( @@ -1165,6 +1183,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # see: https://github.com/matrix-org/synapse/issues/7139 if len(latest_event_ids) == 0: latest_event_ids = [invite.event_id] + if invite.room_version.msc4242_state_dags: + prev_state_events = [invite.event_id] # or perhaps this is a remote room that a local user has knocked on elif current_membership_type == Membership.KNOCK: @@ -1210,6 +1230,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): require_consent=require_consent, outlier=outlier, origin_server_ts=origin_server_ts, + prev_state_events=prev_state_events, delay_id=delay_id, ) @@ -2108,10 +2129,21 @@ class RoomMemberMasterHandler(RoomMemberHandler): # # the prev_events consist solely of the previous membership event. prev_event_ids = [previous_membership_event.event_id] - auth_event_ids = ( - list(previous_membership_event.auth_event_ids()) + prev_event_ids - ) + auth_event_ids = None + # Authorise the leave by referencing the previous membership + prev_state_event_ids = None + if previous_membership_event.room_version.msc4242_state_dags: + prev_state_event_ids = [ + previous_membership_event.event_id, + ] + else: + auth_event_ids = ( + list(previous_membership_event.auth_event_ids()) + prev_event_ids + ) + # State DAG rooms should not have auth events specified + # Normal rooms should not have prev state event IDs specified + assert not (prev_state_event_ids is not None and auth_event_ids is not None) # Try several times, it could fail with PartialStateConflictError # in handle_new_client_event, cf comment in except block. max_retries = 5 @@ -2127,6 +2159,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids, outlier=True, + prev_state_events=prev_state_event_ids, ) context = await unpersisted_context.persist(event) event.internal_metadata.out_of_band_membership = True diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index a92233c863..2f0e3f2c3e 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -37,7 +37,7 @@ from prometheus_client import Counter, Histogram from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions -from synapse.events import EventBase +from synapse.events import EventBase, FrozenEventVMSC4242 from synapse.events.snapshot import ( EventContext, UnpersistedEventContext, @@ -239,31 +239,6 @@ class StateHandler: ) return await ret.get_state(self._state_storage_controller, state_filter) - async def get_current_user_ids_in_room( - self, room_id: str, latest_event_ids: StrCollection - ) -> set[str]: - """ - Get the users IDs who are currently in a room. - - Note: This is much slower than using the equivalent method - `DataStore.get_users_in_room` or `DataStore.get_users_in_room_with_profiles`, - so this should only be used when wanting the users at a particular point - in the room. - - Args: - room_id: The ID of the room. - latest_event_ids: Precomputed list of latest event IDs. Will be computed if None. - Returns: - Set of user IDs in the room. - """ - - assert latest_event_ids is not None - - logger.debug("calling resolve_state_groups from get_current_user_ids_in_room") - entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids) - state = await entry.get_state(self._state_storage_controller, StateFilter.all()) - return await self.store.get_joined_user_ids_from_state(room_id, state) - async def get_hosts_in_room_at_events( self, room_id: str, event_ids: StrCollection ) -> frozenset[str]: @@ -303,7 +278,8 @@ class StateHandler: membership events. `False` if `state_ids_before_event` is the full state. `None` when `state_ids_before_event` is not provided. In this case, the - flag will be calculated based on `event`'s prev events. + flag will be calculated based on `event`'s `prev_events` or `prev_state_events` + for state DAG rooms. state_group_before_event: the current state group at the time of event, if known Returns: @@ -337,7 +313,11 @@ class StateHandler: # (This is slightly racy - the prev-events might get fixed up before we use # their states - but I don't think that really matters; it just means we # might redundantly recalculate the state for this event later.) - prev_event_ids = event.prev_event_ids() + prev_event_ids = frozenset( + event.prev_state_events + if isinstance(event, FrozenEventVMSC4242) + else event.prev_event_ids() + ) incomplete_prev_events = await self.store.get_partial_state_events( prev_event_ids ) @@ -355,7 +335,7 @@ class StateHandler: entry = await self.resolve_state_groups_for_events( event.room_id, - event.prev_event_ids(), + prev_event_ids, await_full_state=False, ) diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index 2948227807..7cc6a39639 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -35,6 +35,7 @@ from typing import ( Generic, Iterable, TypeVar, + cast, ) import attr @@ -43,7 +44,9 @@ from prometheus_client import Counter, Histogram from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.events import EventBase +from synapse.api.errors import SynapseError +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.events import EventBase, FrozenEventVMSC4242, event_exists_in_state_dag from synapse.events.snapshot import EventContext, EventPersistencePair from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable @@ -68,6 +71,7 @@ from synapse.types import ( from synapse.types.state import StateFilter from synapse.util.async_helpers import ObservableDeferred, yieldable_gather_results from synapse.util.metrics import Measure +from synapse.util.stringutils import shortstr if TYPE_CHECKING: from synapse.server import HomeServer @@ -111,6 +115,14 @@ stale_forward_extremities_counter = Histogram( buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), ) +# The number of forward extremities for each new event. +msc4242_state_dag_forward_extremities_counter = Histogram( + "synapse_storage_msc4242_state_dag_forward_extremities_persisted", + "Number of forward extremities for each new event in the state DAG", + labelnames=[SERVER_NAME_LABEL], + buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), +) + state_resolutions_during_persistence = Counter( "synapse_storage_events_state_resolutions_during_persistence", "Number of times we had to do state res to calculate new current state", @@ -529,7 +541,15 @@ class EventsPersistenceStorageController: Returns: map from (type, state_key) to event id for the current state in the room """ - latest_event_ids = await self.main_store.get_latest_event_ids_in_room(room_id) + room_version = await self.main_store.get_room_version_id(room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + if room_version_obj.msc4242_state_dags: + latest_event_ids = await self.main_store.get_state_dag_extremities(room_id) + else: + latest_event_ids = await self.main_store.get_latest_event_ids_in_room( + room_id + ) + state_groups = set( ( await self.main_store._get_state_group_for_events(latest_event_ids) @@ -551,7 +571,6 @@ class EventsPersistenceStorageController: # Avoid a circular import. from synapse.state import StateResolutionStore - room_version = await self.main_store.get_room_version_id(room_id) res = await self._state_resolution_handler.resolve_state_groups( room_id, room_version, @@ -615,28 +634,52 @@ class EventsPersistenceStorageController: for x in range(0, len(events_and_contexts), 100) ] + # Get the room version for the first event. This room version is the same for all events + # as events_and_contexts is all for one room. + assert len(events_and_contexts) > 0 + room_version = events_and_contexts[0][0].room_version + for chunk in chunks: # We can't easily parallelize these since different chunks # might contain the same event. :( new_forward_extremities = None state_delta_for_room = None + new_state_dag_extrems = None if not backfilled: - with Measure( - self._clock, - name="_calculate_state_and_extrem", - server_name=self.server_name, - ): - # Work out the new "current state" for the room. - # We do this by working out what the new extremities are and then - # calculating the state from that. - ( - new_forward_extremities, - state_delta_for_room, - ) = await self._calculate_new_forward_extremities_and_state_delta( - room_id, chunk - ) + if room_version.msc4242_state_dags: + with Measure( + self._clock, + name="_process_state_dag_forward_extremities_and_state_delta", + server_name=self.server_name, + ): + assert all( + isinstance(ev, FrozenEventVMSC4242) for ev, _ in chunk + ) + ( + new_forward_extremities, # for prev_events + state_delta_for_room, # for state groups + new_state_dag_extrems, # for prev_state_events + ) = await self._process_state_dag_forward_extremities_and_state_delta( + room_id, + cast(list[tuple[FrozenEventVMSC4242, EventContext]], chunk), + ) + else: + with Measure( + self._clock, + name="_calculate_state_and_extrem", + server_name=self.server_name, + ): + # Work out the new "current state" for the room. + # We do this by working out what the new extremities are and then + # calculating the state from that. + ( + new_forward_extremities, + state_delta_for_room, + ) = await self._calculate_new_forward_extremities_and_state_delta( + room_id, chunk + ) with Measure( self._clock, @@ -666,6 +709,7 @@ class EventsPersistenceStorageController: use_negative_stream_ordering=backfilled, inhibit_local_membership_updates=backfilled, new_event_links=new_event_links, + new_state_dag_forward_extremities=new_state_dag_extrems, ) return replaced_events @@ -793,6 +837,216 @@ class EventsPersistenceStorageController: return (new_forward_extremities, delta) + async def _process_state_dag_forward_extremities_and_state_delta( + self, + room_id: str, + event_contexts: list[tuple[FrozenEventVMSC4242, EventContext]], + ) -> tuple[set[str] | None, DeltaState | None, set[str] | None]: + """Process the forwards extremities for state DAG rooms. + Returns: + - the new room dag extremities which should be written when these events are persisted. + - the state delta for the room, if applicable. + - the new state dag extremities which should be written when these events are persisted. + + NB: this does not write them because if it did, new events may see them _before_ the events + get persisted, causing failures in retrieving state groups. + """ + # Update forward extremities + # ...for the state DAG + existing_state_dag_fwd_extrems = ( + await self.main_store.get_state_dag_extremities(room_id) + ) + new_state_dag_fwd_extrems = await self._calculate_new_state_dag_extremities( + room_id, + existing_state_dag_fwd_extrems, + event_contexts, + ) + # ...and the room DAG + existing_room_dag_fwd_extrems = ( + await self.main_store.get_latest_event_ids_in_room(room_id) + ) + new_room_dag_fwd_extrems = await self._calculate_new_extremities( + room_id, + cast(list[EventPersistencePair], event_contexts), + existing_room_dag_fwd_extrems, + ) + assert new_room_dag_fwd_extrems, ( + f"No room dag forward extremities left in room {room_id}!" + ) + + # See if we need to calculate a state delta + if new_state_dag_fwd_extrems == existing_state_dag_fwd_extrems: + # No change in state extremities, so no new state to calculate + return new_room_dag_fwd_extrems, None, new_state_dag_fwd_extrems + + with Measure( + self._clock, + name="persist_events.state_dag.get_new_state_after_events", + server_name=self.server_name, + ): + (current_state, delta_ids, _) = await self._get_new_state_after_events( + room_id, + cast(list[EventPersistencePair], event_contexts), + existing_state_dag_fwd_extrems, + new_state_dag_fwd_extrems, + # do not prune forward extremities in the state DAG + # else we lose eventual delivery + should_prune=False, + ) + + # Following logic cargoculted from _calculate_new_forward_extremities_and_state_delta + # If either are not None then there has been a change, + # and we need to work out the delta (or use that + # given) + delta = None + if delta_ids is not None: + # If there is a delta we know that we've + # only added or replaced state, never + # removed keys entirely. + delta = DeltaState([], delta_ids) + elif current_state is not None: + with Measure( + self._clock, + name="persist_events.calculate_state_delta", + server_name=self.server_name, + ): + delta = await self._calculate_state_delta(room_id, current_state) + + if delta: + # If we have a change of state then lets check + # whether we're actually still a member of the room, + # or if our last user left. If we're no longer in + # the room then we delete the current state and + # extremities. + is_still_joined = await self._is_server_still_joined( + room_id, + cast(list[EventPersistencePair], event_contexts), + delta, + ) + if not is_still_joined: + logger.info("Server no longer in room %s", room_id) + delta.no_longer_in_room = True + + return new_room_dag_fwd_extrems, delta, new_state_dag_fwd_extrems + + async def _calculate_new_state_dag_extremities( + self, + room_id: str, + existing_fwd_extrems: frozenset[str], + event_contexts: list[tuple[FrozenEventVMSC4242, EventContext]], + ) -> set[str]: + """Calculate the new state dag forward extremities. Modifies existing_fwd_extrems. + + Assumes that event_contexts are only state events which should be in the state DAG. + + Raises: + SynapseError: if the new events include unknown prev_state_events + AssertionError: if there are no state DAG forward extremities remaining in the room + """ + # Events are always processed in causal order without any gaps in the DAG + # (prev_state_events are always known), guaranteeing that processed events have a path to the + # create event. This is an emergent property of state DAGs as asserting that there is a path + # to the create event every time we insert an event would be prohibitively expensive. + # This is similar to how doubly-linked lists can potentially not refer to previous items correctly + # without verifying the list's integrity, but doing it on every insert is too expensive. + + # filter out events which don't belong in the state dag. + new_state_events_contexts = [ + (e, ctx) for e, ctx in event_contexts if event_exists_in_state_dag(e) + ] + if len(new_state_events_contexts) == 0: + # if there are no state events being persisted, then the fwd extremities of the state dag + # do not change. + return set(existing_fwd_extrems) + + # This logic is very similar to _calculate_new_extremities with a few key differences: + # - We do not "Remove any events which are prev_events of any existing events." because the + # state DAG mandates that events are processed in causal order, so there MUST NOT be any + # existing, processed events which have the to-be-persisted events as prev_state_events. + # - We don't care if they are an "outlier" in the main room dag, so long as they AREN'T + # an outlier on the state dag, which this function checks, so we don't check outlier-ness. + # - We allow *soft-failed* events to become forward extremities, as per the MSC. We do not + # allow *rejected* events to become forward extremities though. + + rejected_events = [ev for ev, ctx in new_state_events_contexts if ctx.rejected] + new_state_events = [ + ev for ev, ctx in new_state_events_contexts if not ctx.rejected + ] + # We want to check that we are not missing any prev_state_events. + # To do this, we include rejected events in this check because other events may point to them. + # If we didn't include them, we might incorrectly say we are missing events when we are not. + all_new_state_events = set(rejected_events + new_state_events) + + # First, verify that we know all prev_state_events. If we fail this check then we don't have + # a complete DAG and that is bad, so bail out. + + # Start with them all missing. + missing_prev_state_events = { + e_id for event in all_new_state_events for e_id in event.prev_state_events + } + + # remove prev events which appear in all_events + missing_prev_state_events.difference_update( + event.event_id for event in all_new_state_events + ) + # the rest of these events should be present in the DB. Some of them may be forward extremities, + # some may not be, that's ok. + seen_events = await self.main_store.have_seen_events( + room_id, + missing_prev_state_events, + ) + missing_prev_state_events.difference_update(seen_events) + + if len(missing_prev_state_events) > 0: + logger.error( + "_calculate_new_state_dag_extremities: missing the following prev_state_events in room %s : %s", + room_id, + missing_prev_state_events, + ) + logger.error( + "_calculate_new_state_dag_extremities: was handling %s", + shortstr([ev.event_id for ev in all_new_state_events]), + ) + raise SynapseError( + code=500, + msg=f"missing {len(missing_prev_state_events)} prev_state_events in room {room_id}", + ) + + # Now calculate the forward extremities. + + # start with the existing forward extremities + result = set(existing_fwd_extrems) + + # add all the new events to the list + result.update(event.event_id for event in new_state_events) + + # Now remove all events which are prev_state_events of any of the new events + result.difference_update( + e_id for event in new_state_events for e_id in event.prev_state_events + ) + + # Finally handle the case where the new events have rejected/soft-failed `prev_state_events`. + # If they do we need to remove them and their `prev_state_events`, + # otherwise we end up with dangling extremities. + # Specifically, this handles the case where (F=fwd extrem, SF=soft-failed, N=new event) + # F <-- SF <-- SF <-- N + # where we want to remove F as a forward extremity and replace with N. + existing_prevs = await self.persist_events_store._get_prevs_before_rejected( + (e_id for event in new_state_events for e_id in event.prev_state_events), + include_soft_failed=False, + ) + result.difference_update(existing_prevs) + + # We only update metrics for events that change forward extremities + if result != existing_fwd_extrems: + msc4242_state_dag_forward_extremities_counter.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).observe(len(result)) + + # There should always be at least one forward extremity. + assert result, f"No state dag forward extremities left in room {room_id}!" + return result + async def _calculate_new_extremities( self, room_id: str, @@ -859,6 +1113,7 @@ class EventsPersistenceStorageController: events_context: list[EventPersistencePair], old_latest_event_ids: AbstractSet[str], new_latest_event_ids: set[str], + should_prune: bool = True, ) -> tuple[StateMap[str] | None, StateMap[str] | None, set[str]]: """Calculate the current state dict after adding some new events to a room @@ -873,9 +1128,15 @@ class EventsPersistenceStorageController: old_latest_event_ids: the old forward extremities for the room. - new_latest_event_ids : + new_latest_event_ids: the new forward extremities for the room. + should_prune: + if true, attempt to prune the forward extremities. + Pruning means we will not communicate some new events to other servers, + which can compromise eventual delivery, so graphs which are fully synchronised + e.g. state DAGs should not prune. + Returns: Returns a tuple of two state maps and a set of new forward extremities. @@ -1015,7 +1276,7 @@ class EventsPersistenceStorageController: # If the returned state matches the state group of one of the new # forward extremities then we check if we are able to prune some state # extremities. - if res.state_group and res.state_group in new_state_groups: + if should_prune and res.state_group and res.state_group in new_state_groups: new_latest_event_ids = await self._prune_extremities( room_id, new_latest_event_ids, diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index cc7083b605..415926eb0a 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1493,6 +1493,15 @@ class EventFederationWorkerStore( ) return frozenset(event_ids) + async def get_state_dag_extremities(self, room_id: str) -> frozenset[str]: + event_ids = await self.db_pool.simple_select_onecol( + table="msc4242_state_dag_forward_extremities", + keyvalues={"room_id": room_id}, + retcol="event_id", + desc="get_state_dag_extremities", + ) + return frozenset(event_ids) + async def get_min_depth(self, room_id: str) -> int | None: """For the given room, get the minimum depth we have seen for it.""" return await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 6d3bc15777..12c918eca6 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -48,7 +48,9 @@ from synapse.api.errors import PartialStateConflictError from synapse.api.room_versions import RoomVersions from synapse.events import ( EventBase, + FrozenEventVMSC4242, StrippedStateEvent, + event_exists_in_state_dag, is_creator, relation_from_event, ) @@ -295,6 +297,7 @@ class PersistEventsStore: new_event_links: dict[str, NewEventChainLinks], use_negative_stream_ordering: bool = False, inhibit_local_membership_updates: bool = False, + new_state_dag_forward_extremities: set[str] | None = None, ) -> None: """Persist a set of events alongside updates to the current state and forward extremities tables. @@ -315,6 +318,8 @@ class PersistEventsStore: from being updated by these events. This should be set to True for backfilled events because backfilled events in the past do not affect the current local state. + new_state_dag_forward_extremities: A set of event IDs that are the new forward + extremities for the state DAG for this room. MSC4242 only. Returns: Resolves when the events have been persisted @@ -379,6 +384,7 @@ class PersistEventsStore: new_forward_extremities=new_forward_extremities, new_event_links=new_event_links, sliding_sync_table_changes=sliding_sync_table_changes, + new_state_dag_forward_extremities=new_state_dag_forward_extremities, ) persist_event_counter.labels(**{SERVER_NAME_LABEL: self.server_name}).inc( len(events_and_contexts) @@ -962,8 +968,10 @@ class PersistEventsStore: return results - async def _get_prevs_before_rejected(self, event_ids: Iterable[str]) -> set[str]: - """Get soft-failed ancestors to remove from the extremities. + async def _get_prevs_before_rejected( + self, event_ids: Iterable[str], include_soft_failed: bool = True + ) -> set[str]: + """Get soft-failed/rejected ancestors to remove from the extremities. Given a set of events, find all those that have been soft-failed or rejected. Returns those soft failed/rejected events and their prev @@ -976,7 +984,8 @@ class PersistEventsStore: Args: event_ids: Events to find prev events for. Note that these must have already been persisted. - + include_soft_failed: Soft-failed events are included in the search. If false, only + rejected events are included. Returns: The previous events. """ @@ -1016,7 +1025,7 @@ class PersistEventsStore: continue soft_failed = db_to_json(metadata).get("soft_failed") - if soft_failed or rejected: + if (include_soft_failed and soft_failed) or rejected: to_recursively_check.append(prev_event_id) existing_prevs.add(prev_event_id) @@ -1038,6 +1047,7 @@ class PersistEventsStore: new_forward_extremities: set[str] | None, new_event_links: dict[str, NewEventChainLinks], sliding_sync_table_changes: SlidingSyncTableChanges | None, + new_state_dag_forward_extremities: set[str] | None = None, ) -> None: """Insert some number of room events into the necessary database tables. @@ -1146,6 +1156,11 @@ class PersistEventsStore: max_stream_order=max_stream_order, ) + if new_state_dag_forward_extremities: + self._set_state_dag_extremities_txn( + txn, room_id, new_state_dag_forward_extremities + ) + self._persist_transaction_ids_txn(txn, events_and_contexts) # Insert into event_to_state_groups. @@ -2475,6 +2490,29 @@ class PersistEventsStore: ], ) + def _set_state_dag_extremities_txn( + self, txn: LoggingTransaction, room_id: str, new_extrems: Collection[str] + ) -> None: + self.db_pool.simple_delete_txn( + txn, + table="msc4242_state_dag_forward_extremities", + keyvalues={ + "room_id": room_id, + }, + ) + self.db_pool.simple_insert_many_txn( + txn, + table="msc4242_state_dag_forward_extremities", + keys=("room_id", "event_id"), + values=[ + ( + room_id, + event_id, + ) + for event_id in new_extrems + ], + ) + @classmethod def _filter_events_and_contexts_for_duplicates( cls, events_and_contexts: list[EventPersistencePair] @@ -2859,6 +2897,12 @@ class PersistEventsStore: self._handle_event_relations(txn, event) + if event.room_version.msc4242_state_dags and event_exists_in_state_dag( + event + ): + assert isinstance(event, FrozenEventVMSC4242) + self._store_state_dag_edges(txn, event) + # Store the labels for this event. labels = event.content.get(EventContentFields.LABELS) if labels: @@ -2935,6 +2979,36 @@ class PersistEventsStore: txn.async_call_after(external_prefill) txn.call_after(local_prefill) + def _store_state_dag_edges( + self, txn: LoggingTransaction, event: FrozenEventVMSC4242 + ) -> None: + # the create event has no edge but we still need to persist it as get_state_dag just + # yanks all rows in this table. It's a bit gross to store NULL as the prev_state_event_id + # though. + if len(event.prev_state_events) == 0 and event.type == EventTypes.Create: + self.db_pool.simple_insert_txn( + txn, + table="msc4242_state_dag_edges", + values={ + "room_id": event.room_id, + "event_id": event.event_id, + "prev_state_event_id": None, + }, + ) + return + assert len(event.prev_state_events) > 0 + self.db_pool.simple_upsert_many_txn( + txn, + table="msc4242_state_dag_edges", + key_names=["room_id", "event_id", "prev_state_event_id"], + key_values=[ + (event.room_id, event.event_id, prev_state_event) + for prev_state_event in event.prev_state_events + ], + value_names=(), + value_values=(), + ) + def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None: assert event.redacts is not None self.db_pool.simple_upsert_txn( @@ -3456,7 +3530,13 @@ class PersistEventsStore: """ state_groups = {} for event, context in events_and_contexts: - if event.internal_metadata.is_outlier(): + # state dag rooms allow outliers to have state, as `/get_missing_events` state dag events are nominally + # outliers (not present in the timeline) but do need state persisted so we can calculate + # what the auth_events are for the event. + if ( + not event.room_version.msc4242_state_dags + and event.internal_metadata.is_outlier() + ): # double-check that we don't have any events that claim to be outliers # *and* have partial state (which is meaningless: we should have no # state at all for an outlier) diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index d55ea5cf7d..fe8079c201 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -71,6 +71,10 @@ purge_room_tables_with_room_id_column = ( # so must be deleted first. "sliding_sync_joined_rooms", "sliding_sync_membership_snapshots", + # Note: msc4242_state_dag_forward_extremities/edges have a foreign key to the `events` table + # so must be deleted first. + "msc4242_state_dag_forward_extremities", + "msc4242_state_dag_edges", "events", "federation_inbound_events_staging", "receipts_graph", diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index cfde107b48..87523e6f18 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -38,7 +38,7 @@ import attr from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion -from synapse.events import EventBase +from synapse.events import EventBase, EventMetadata from synapse.events.snapshot import EventContext from synapse.logging.opentracing import trace from synapse.replication.tcp.streams import UnPartialStatedEventStream @@ -78,16 +78,6 @@ class Sentinel: ROOM_UNKNOWN_SENTINEL = Sentinel() -@attr.s(slots=True, frozen=True, auto_attribs=True) -class EventMetadata: - """Returned by `get_metadata_for_events`""" - - room_id: str - event_type: str - state_key: str | None - rejection_reason: str | None - - def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion: v = KNOWN_ROOM_VERSIONS.get(room_version_id) if not v: diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index e3095a9d0d..1afc6d0b2a 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -174,6 +174,7 @@ Changes in SCHEMA_VERSION = 93 Changes in SCHEMA_VERSION = 94 - Add `recheck` column (boolean, default true) to the `redactions` table. + - MSC4242: Add state DAG tables. """ diff --git a/synapse/storage/schema/main/delta/94/03_state_dag_fwd_extrems.sql b/synapse/storage/schema/main/delta/94/03_state_dag_fwd_extrems.sql new file mode 100644 index 0000000000..bc5c738ba5 --- /dev/null +++ b/synapse/storage/schema/main/delta/94/03_state_dag_fwd_extrems.sql @@ -0,0 +1,38 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2026 Element Creations, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +CREATE TABLE IF NOT EXISTS msc4242_state_dag_forward_extremities( + -- we always expect the room to exist. If it gets removed, delete fwd extremities. + room_id TEXT NOT NULL REFERENCES rooms(room_id) ON DELETE CASCADE, + event_id TEXT NOT NULL REFERENCES events(event_id) ON DELETE CASCADE, + -- it doesn't make sense to reference the same event multiple times, and this uniqueness + -- index is also used to delete events once they are no longer forward extremities. + UNIQUE (event_id) +); +-- When creating events, we want to select all forward extremities for a room which this index helps with. +CREATE INDEX msc4242_state_dag_room ON msc4242_state_dag_forward_extremities(room_id); + + +CREATE TABLE IF NOT EXISTS msc4242_state_dag_edges( + -- Deleting the room deletes the state DAG. + room_id TEXT NOT NULL REFERENCES rooms(room_id) ON DELETE CASCADE, + -- the event IDs being referenced must exist (hence REFERENCES) and we do not want to accidentally delete + -- the event and create a hole in the state DAG. It is not possible for a state + -- DAG room to function with an holey DAG, so these events _cannot_ be purged. To purge them, the + -- entire room would need to be deleted. + event_id TEXT NOT NULL REFERENCES events(event_id), + -- one of the `prev_state_events` for this event ID. We must have it since we must have the entire state DAG. + -- can be NULL for the create event. + prev_state_event_id TEXT REFERENCES events(event_id) +); +CREATE UNIQUE INDEX msc4242_state_dag_edges_key ON msc4242_state_dag_edges(room_id, event_id, prev_state_event_id); diff --git a/synapse/synapse_rust/events.pyi b/synapse/synapse_rust/events.pyi index 29432bdd56..fe0ca04420 100644 --- a/synapse/synapse_rust/events.pyi +++ b/synapse/synapse_rust/events.pyi @@ -60,6 +60,9 @@ class EventInternalMetadata: device_id: str """The device ID of the user who sent this event, if any.""" + # MSC4242 state dags + calculated_auth_event_ids: list[str] + def get_dict(self) -> JsonDict: ... def is_outlier(self) -> bool: ... def copy(self) -> "EventInternalMetadata": ... diff --git a/synapse/synapse_rust/room_versions.pyi b/synapse/synapse_rust/room_versions.pyi index 909e3a1c26..9bbb538f18 100644 --- a/synapse/synapse_rust/room_versions.pyi +++ b/synapse/synapse_rust/room_versions.pyi @@ -31,6 +31,8 @@ class EventFormatVersions: """MSC1884-style format: introduced for room v4""" ROOM_V11_HYDRA_PLUS: int """MSC4291 room IDs as hashes: introduced for room HydraV11""" + ROOM_VMSC4242: int + """MSC4242 state DAGs: adds prev_state_events, removes auth_events""" KNOWN_EVENT_FORMAT_VERSIONS: frozenset[int] @@ -113,6 +115,14 @@ class RoomVersion: rather than in codepoints. If true, this room version uses stricter event size validation.""" + msc4242_state_dags: bool + """MSC4242: State DAGs. Creates events with prev_state_events instead of auth_events and derives + state from it. Events are always processed in causal order without any gaps in the DAG + (prev_state_events are always known), guaranteeing that processed events have a path to the + create event. This is an emergent property of state DAGs as asserting that there is a path + to the create event every time we insert an event would be prohibitively expensive. + This is similar to how doubly-linked lists can potentially not refer to previous items correctly + without verifying the list's integrity, but doing it on every insert is too expensive.""" class RoomVersions: V1: RoomVersion @@ -132,6 +142,7 @@ class RoomVersions: MSC3757v11: RoomVersion HydraV11: RoomVersion V12: RoomVersion + MSC4242v12: RoomVersion class KnownRoomVersionsMapping(Mapping[str, RoomVersion]): def add_room_version(self, room_version: RoomVersion) -> None: ... diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index 275e5dfa1d..a40e0b0680 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -432,7 +432,14 @@ class UnstableGetExtremitiesTests(unittest.FederatingHomeserverTestCase): self.assertEqual(channel.json_body["error"], "Server is banned from room") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - @parameterized.expand([(k,) for k in KNOWN_ROOM_VERSIONS.keys()]) + # FIXME: Exclude MSC4242 room versions whilst it lacks federation support + @parameterized.expand( + [ + (k,) + for k in KNOWN_ROOM_VERSIONS.keys() + if k != RoomVersions.MSC4242v12.identifier + ] + ) @override_config( {"use_frozen_dicts": True, "experimental_features": {"msc4370_enabled": True}} ) @@ -440,7 +447,14 @@ class UnstableGetExtremitiesTests(unittest.FederatingHomeserverTestCase): """Test GET /extremities with USE_FROZEN_DICTS=True""" self._test_get_extremities_common(room_version) - @parameterized.expand([(k,) for k in KNOWN_ROOM_VERSIONS.keys()]) + # FIXME: Exclude MSC4242 room versions whilst it lacks federation support + @parameterized.expand( + [ + (k,) + for k in KNOWN_ROOM_VERSIONS.keys() + if k != RoomVersions.MSC4242v12.identifier + ] + ) @override_config( {"use_frozen_dicts": False, "experimental_features": {"msc4370_enabled": True}} ) @@ -573,12 +587,18 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): @override_config({"use_frozen_dicts": True}) def test_send_join_with_frozen_dicts(self, room_version: str) -> None: """Test send_join with USE_FROZEN_DICTS=True""" + if room_version == RoomVersions.MSC4242v12.identifier: + # TODO: This room version doesn't work over federation in this PR. + return self._test_send_join_common(room_version) @parameterized.expand([(k,) for k in KNOWN_ROOM_VERSIONS.keys()]) @override_config({"use_frozen_dicts": False}) def test_send_join_without_frozen_dicts(self, room_version: str) -> None: """Test send_join with USE_FROZEN_DICTS=False""" + if room_version == RoomVersions.MSC4242v12.identifier: + # TODO: This room version doesn't work over federation in this PR. + return self._test_send_join_common(room_version) def test_send_join_partial_state(self) -> None: diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index 3d856b9346..1aaa86e2e8 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -1121,7 +1121,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): return {"pdus": [missing_event.get_pdu_json()]} async def get_room_state_ids( - destination: str, room_id: str, event_id: str + destination: str, + room_id: str, + event_id: str, ) -> JsonDict: self.assertEqual(destination, self.OTHER_SERVER_NAME) self.assertEqual(event_id, missing_event.event_id) @@ -1131,7 +1133,10 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): } async def get_room_state( - room_version: RoomVersion, destination: str, room_id: str, event_id: str + room_version: RoomVersion, + destination: str, + room_id: str, + event_id: str, ) -> StateRequestResponse: self.assertEqual(destination, self.OTHER_SERVER_NAME) self.assertEqual(event_id, missing_event.event_id) diff --git a/tests/storage/test_msc4242_state_dag.py b/tests/storage/test_msc4242_state_dag.py new file mode 100644 index 0000000000..8775e5c8eb --- /dev/null +++ b/tests/storage/test_msc4242_state_dag.py @@ -0,0 +1,371 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2026 Element Creations, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . + +from typing import Iterable +from unittest.mock import Mock + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.constants import EventTypes +from synapse.api.errors import SynapseError +from synapse.api.room_versions import RoomVersions +from synapse.events import FrozenEventVMSC4242, make_event_from_dict +from synapse.events.snapshot import EventContext +from synapse.rest.client import room +from synapse.server import HomeServer +from synapse.util.clock import Clock + +from tests.unittest import HomeserverTestCase, override_config + + +class MSC4242StateDagsTests(HomeserverTestCase): + user_id = "@user1:server" + servlets = [room.register_servlets] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + hs = self.setup_test_homeserver("server") + return hs + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.room_id = self.helper.create_room_as( + self.user_id, + room_version=RoomVersions.MSC4242v12.identifier, + ) + + self.store = hs.get_datastores().main + self._storage_controllers = self.hs.get_storage_controllers() + + def _get_prev_state_events(self, event_id: str) -> list[str]: + ev = self.helper.get_event(self.room_id, event_id) + prev_state_events: list[str] | None = ev.get("prev_state_events", None) + assert prev_state_events is not None + return prev_state_events + + @override_config({"experimental_features": {"msc4242_enabled": True}}) + def test_forward_extremities_are_calculated(self) -> None: + """ + Check that forward extremities are set as prev_state_events and that they don't change + for non-state events. + """ + # they don't change for messages + first_event_id = self.helper.send(self.room_id, body="test1")["event_id"] + first_prev_state_events = self._get_prev_state_events(first_event_id) + assert len(first_prev_state_events) == 1 + second_id = self.helper.send(self.room_id, body="test2")["event_id"] + second_prev_state_events = self._get_prev_state_events(second_id) + assert len(second_prev_state_events) == 1 + self.assertIncludes( + set(first_prev_state_events), set(second_prev_state_events), exact=True + ) + + # send an auth event, which should change the prev_state_events on *subsequent* events + join_rule_state_event_id = self.helper.send_state( + self.room_id, + EventTypes.JoinRules, + { + "join_rule": "knock", + }, + tok="nope", + )["event_id"] + join_rule_prev_state_event_ids = self._get_prev_state_events( + join_rule_state_event_id + ) + self.assertIncludes( + set(second_prev_state_events), + set(join_rule_prev_state_event_ids), + exact=True, + ) + + # prev_state_events should always point to the join rule now + third_event_id = self.helper.send(self.room_id, body="test3")["event_id"] + third_prev_state_events = self._get_prev_state_events(third_event_id) + self.assertIncludes( + set(third_prev_state_events), {join_rule_state_event_id}, exact=True + ) + # and non-auth state should also update prev_state_events + name_state_event_id = self.helper.send_state( + self.room_id, + EventTypes.Name, + { + "name": "State DAGs!", + }, + tok="nope", + )["event_id"] + name_prev_state_event_ids = self._get_prev_state_events(name_state_event_id) + self.assertIncludes( + set(name_prev_state_event_ids), {join_rule_state_event_id}, exact=True + ) + fourth_event_id = self.helper.send(self.room_id, body="test4")["event_id"] + fourth_prev_state_events = self._get_prev_state_events(fourth_event_id) + self.assertIncludes( + set(fourth_prev_state_events), {name_state_event_id}, exact=True + ) + + +class MSC4242EventPersistenceStateDagsStoreTestCase(HomeserverTestCase): + servlets = [ + room.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + persistence = hs.get_storage_controllers().persistence + assert persistence is not None + self.persistence = persistence + self.room_id = "!foo:bar" + self.seen_event_ids: set[str] = set() + self.persistence.main_store = Mock(spec=["have_seen_events"]) + self.persistence.main_store.have_seen_events.side_effect = ( + self._have_seen_events + ) + self.rejected_event_ids_and_their_prevs: set[str] = set() + self.persistence.persist_events_store = Mock( + spec=["_get_prevs_before_rejected"] + ) + self.persistence.persist_events_store._get_prevs_before_rejected.side_effect = ( + self._get_prevs_before_rejected + ) + + async def _have_seen_events( + self, room_id: str, event_ids: Iterable[str] + ) -> set[str]: + unknown_events = set(event_ids) + return self.seen_event_ids.intersection(unknown_events) + + async def _get_prevs_before_rejected( + self, event_ids: Iterable[str], include_soft_failed: bool = True + ) -> set[str]: + return self.rejected_event_ids_and_their_prevs + + def _make_event( + self, + id: str, + prev_state_events: list[str], + rejected: bool = False, + ) -> tuple[FrozenEventVMSC4242, EventContext]: + ev = make_event_from_dict( + { + "prev_state_events": prev_state_events, + "content": { + "membership": "join", + }, + "sender": "@unimportant:info", + "state_key": "@unimportant:info", + "type": "m.room.member", + "room_id": self.room_id, + }, + room_version=RoomVersions.MSC4242v12, + ) + assert isinstance(ev, FrozenEventVMSC4242) + ev._event_id = id + ctx = Mock() + ctx.rejected = rejected + return ev, ctx + + def _test( + self, + current_fwds: list[str], + new_events: list[tuple[FrozenEventVMSC4242, EventContext]], + want_new_extrems: set[str], + want_raises: bool = False, + ) -> None: + """ + Tests the logic of _calculate_new_state_dag_extremities. + + Tests that the new extremities calculated as a result of processing current_fwds and new_events + matches want_new_extrems or raises if want_raises is True. + """ + coroutine = self.persistence._calculate_new_state_dag_extremities( + self.room_id, + frozenset(current_fwds), + new_events, + ) + if want_raises: + f = self.get_failure(coroutine, SynapseError) + assert f is not None + return + + new_extrems = set(self.get_success(coroutine)) + self.assertIncludes( + new_extrems, + set(want_new_extrems), + exact=True, + message=f"want_new_extrems={want_new_extrems} got={new_extrems}", + ) + + @override_config({"experimental_features": {"msc4242_enabled": True}}) + def test_calculate_new_state_dag_extremities_simple(self) -> None: + # Simple linear chain + self._test( + current_fwds=[], + new_events=[ + self._make_event("$1", []), + self._make_event("$2", ["$1"]), + self._make_event("$3", ["$2"]), + self._make_event("$4", ["$3"]), + ], + want_new_extrems={"$4"}, + ) + + @override_config({"experimental_features": {"msc4242_enabled": True}}) + def test_calculate_new_state_dag_extremities_fork(self) -> None: + # Simple fork so we end up with two forward extrems + self._test( + current_fwds=[], + new_events=[ + self._make_event("$1", []), + self._make_event("$2", ["$1"]), + self._make_event("$3", ["$2"]), + self._make_event("$4", ["$2"]), + ], + want_new_extrems={"$3", "$4"}, + ) + + @override_config({"experimental_features": {"msc4242_enabled": True}}) + def test_calculate_new_state_dag_extremities_merge(self) -> None: + # Simple fork so we end up with two forward extrems + self._test( + current_fwds=[], + new_events=[ + self._make_event("$1", []), + self._make_event("$2", ["$1"]), + self._make_event("$3", ["$1"]), + self._make_event("$4", ["$2", "$3"]), + ], + want_new_extrems={"$4"}, + ) + + @override_config({"experimental_features": {"msc4242_enabled": True}}) + def test_calculate_new_state_dag_extremities_fork_on_existing(self) -> None: + # Fork where we are adding to older events + self.seen_event_ids = {"$1", "$2", "$3"} + self._test( + current_fwds=["$3"], + new_events=[ + self._make_event("$4", ["$3"]), # append to the forward extrem + self._make_event("$5", ["$1"]), # append to the root + ], + want_new_extrems={"$4", "$5"}, + ) + + @override_config({"experimental_features": {"msc4242_enabled": True}}) + def test_calculate_new_state_dag_extremities_merge_on_existing(self) -> None: + # Merge where we are merging to older events + self.seen_event_ids = {"$1", "$2", "$3"} + self._test( + current_fwds=["$3"], + new_events=[ + self._make_event("$4", ["$3", "$2"]), + ], + want_new_extrems={"$4"}, + ) + + @override_config({"experimental_features": {"msc4242_enabled": True}}) + def test_calculate_new_state_dag_extremities_merge_on_not_current(self) -> None: + # Merge where we are merging to older events + self.seen_event_ids = {"$1", "$2", "$3"} + self._test( + current_fwds=["$3"], + new_events=[ + self._make_event("$4", ["$1", "$2"]), + ], + want_new_extrems={"$3", "$4"}, + ) + + @override_config({"experimental_features": {"msc4242_enabled": True}}) + def test_calculate_new_state_dag_extremities_append_with_rejected(self) -> None: + # rejected events cannot be forward extremities + self.seen_event_ids = {"$1", "$2", "$3"} + self._test( + current_fwds=["$3"], + new_events=[ + self._make_event("$4", ["$3"], rejected=True), + ], + want_new_extrems={"$3"}, + ) + + self._test( + current_fwds=["$3"], + new_events=[ + self._make_event("$4", ["$3"], rejected=True), + self._make_event("$5", ["$4"], rejected=True), + ], + want_new_extrems={"$3"}, + ) + + @override_config({"experimental_features": {"msc4242_enabled": True}}) + def test_calculate_new_state_dag_extremities_append_with_rejected_in_chain( + self, + ) -> None: + # rejected events cannot be forward extremities, but events that come after them can. + # this shouldn't cause multiple forward extremities. + self.seen_event_ids = {"$1", "$2", "$3"} + self.rejected_event_ids_and_their_prevs = {"$4", "$3"} + self._test( + current_fwds=["$3"], + new_events=[ + self._make_event("$4", ["$3"], rejected=True), + self._make_event("$5", ["$4"]), + ], + want_new_extrems={"$5"}, + ) + + @override_config({"experimental_features": {"msc4242_enabled": True}}) + def test_calculate_new_state_dag_extremities_missing_prevs_raises(self) -> None: + self._test( + current_fwds=[], + new_events=[ + self._make_event("$1", []), + self._make_event("$2", ["$1"]), + self._make_event("$3", ["$unknown"]), + self._make_event("$4", ["$3"]), + ], + want_new_extrems={"$4"}, + want_raises=True, + ) + + @override_config({"experimental_features": {"msc4242_enabled": True}}) + def test_calculate_new_state_dag_extremities_complex(self) -> None: + """ + 1 + | \ + 2 4 + | + 3 + + Exists already, then becomes... + + 1______ + | \\ | + 2 4 5R + | | | + 3--7 6R + | \\ / \ + 10R 8 9 + + """ + # Merge where we are merging to older events + self.seen_event_ids = {"$1", "$2", "$3", "$4"} + self.rejected_event_ids_and_their_prevs = {"$1", "$5", "$6", "$3", "$10"} + self._test( + current_fwds=["$3", "$4"], + new_events=[ + self._make_event("$5", ["$1"], rejected=True), + self._make_event("$6", ["$5"], rejected=True), + self._make_event("$7", ["$4", "$3"]), + self._make_event("$8", ["$6", "$7"]), + self._make_event("$9", ["$6"]), + self._make_event("$10", ["$3"], rejected=True), + ], + want_new_extrems={"$8", "$9"}, + ) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 67a5c31c44..c346245706 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -232,6 +232,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): prev_event_ids: list[str], auth_event_ids: list[str] | None, depth: int | None = None, + prev_state_events: list[str] | None = None, ) -> EventBase: built_event = await self._base_builder.build( prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index 934a2fd307..9258f0d4dc 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -20,15 +20,16 @@ # import unittest +from collections import namedtuple from typing import Any, Collection, Iterable from parameterized import parameterized from synapse import event_auth -from synapse.api.constants import EventContentFields +from synapse.api.constants import EventContentFields, RejectedReason from synapse.api.errors import AuthError, SynapseError from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions -from synapse.events import EventBase, make_event_from_dict +from synapse.events import EventBase, event_exists_in_state_dag, make_event_from_dict from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import JsonDict, get_domain_from_id @@ -374,6 +375,195 @@ class EventAuthTestCase(unittest.TestCase): auth_events, ) + def test_msc4242_state_dag_rules(self) -> None: + """Tests additional rules in place for state DAG rooms. + + 1. m.room.create => if it has any prev_state_events, reject. + 2. Considering the event's prev_state_events: + i. If there are entries which do not belong in the same room, reject. + ii. If there are entries which do not have a state_key, reject. + iii. If there are entries which were themselves rejected under the checks performed on receipt of a PDU, reject. + """ + creator = "@creator:example.com" + room_version = RoomVersions.MSC4242v12 + + create_event = make_event_from_dict( + { + "type": "m.room.create", + "sender": creator, + "state_key": "", + "content": {"creator": creator}, + "prev_events": [], + "prev_state_events": [], + }, + room_version, + ) + create_event_2 = make_event_from_dict( + { + "type": "m.room.create", + "sender": creator, + "state_key": "", + "content": {"creator": creator, "another": "room"}, + "prev_events": [], + "prev_state_events": [], + }, + room_version, + ) + room_id = create_event.room_id + another_room_id = create_event_2.room_id + join_event = make_event_from_dict( + { + "room_id": room_id, + "type": "m.room.member", + "sender": creator, + "state_key": creator, + "content": {"membership": "join"}, + "prev_events": [create_event.event_id], + "prev_state_events": [create_event.event_id], + }, + room_version, + {"calculated_auth_event_ids": [create_event.event_id]}, + ) + event_in_another_room = make_event_from_dict( + { + "room_id": another_room_id, + "type": "m.room.join_rules", + "sender": creator, + "state_key": "", + "content": {"join_rule": "public"}, + "prev_events": [join_event.event_id], + "prev_state_events": [join_event.event_id], + }, + room_version, + {"calculated_auth_event_ids": [create_event.event_id, join_event.event_id]}, + ) + msg_event = make_event_from_dict( + { + "room_id": room_id, + "type": "m.room.message", + "sender": creator, + "content": {"msgtype": "m.text", "body": "I am a message"}, + "prev_events": [join_event.event_id], + "prev_state_events": [join_event.event_id], + }, + room_version, + {"calculated_auth_event_ids": [create_event.event_id, join_event.event_id]}, + ) + rejected_event = make_event_from_dict( + { + "room_id": room_id, + "type": "m.room.name", + "sender": creator, + "state_key": "", + "content": {"name": "REJECTED"}, + "prev_events": [join_event.event_id], + "prev_state_events": [join_event.event_id], + }, + room_version, + {"calculated_auth_event_ids": [create_event.event_id, join_event.event_id]}, + rejected_reason=RejectedReason.AUTH_ERROR, + ) + RejectingTestCase = namedtuple( + "RejectingTestCase", "name events_in_store test_event" + ) + rejecting_test_cases = [ + RejectingTestCase( + name="create event has prev_state_events", + events_in_store=[], + test_event=make_event_from_dict( + { + "type": "m.room.create", + "sender": creator, + "state_key": "", + "content": {"creator": creator}, + "prev_events": [], + "prev_state_events": [create_event.event_id], + }, + room_version, + {}, + ), + ), + RejectingTestCase( + name="prev_state_event belongs in a different room", + events_in_store=[create_event, join_event, event_in_another_room], + test_event=make_event_from_dict( + { + "room_id": room_id, + "type": "m.room.name", + "sender": creator, + "state_key": "", + "content": {"name": "prev_state_event is in another room"}, + "prev_events": [join_event.event_id], + "prev_state_events": [event_in_another_room.event_id], + }, + room_version, + { + "calculated_auth_event_ids": [ + create_event.event_id, + join_event.event_id, + ] + }, + ), + ), + RejectingTestCase( + name="prev_state_event is a message event", + events_in_store=[create_event, join_event, msg_event], + test_event=make_event_from_dict( + { + "room_id": room_id, + "type": "m.room.name", + "sender": creator, + "state_key": "", + "content": {"name": "prev state event is a message"}, + "prev_events": [msg_event.event_id], + "prev_state_events": [msg_event.event_id], + }, + room_version, + { + "calculated_auth_event_ids": [ + create_event.event_id, + join_event.event_id, + ] + }, + ), + ), + RejectingTestCase( + name="prev_state_event was rejected", + events_in_store=[create_event, join_event, rejected_event], + test_event=make_event_from_dict( + { + "room_id": room_id, + "type": "m.room.name", + "sender": creator, + "state_key": "", + "content": {"name": "prev state event was rejected"}, + "prev_events": [join_event.event_id], + "prev_state_events": [rejected_event.event_id], + }, + room_version, + { + "calculated_auth_event_ids": [ + create_event.event_id, + join_event.event_id, + ] + }, + ), + ), + ] + + for test_case in rejecting_test_cases: + event_store = _StubEventSourceStore() + event_store.add_events(test_case.events_in_store) + + with self.assertRaises( + AuthError, msg=f"test case {test_case.name} was not rejected" + ): + get_awaitable_result( + event_auth.check_state_independent_auth_rules( + event_store, test_case.test_event + ) + ) + @parameterized.expand([(RoomVersions.V1, True), (RoomVersions.V6, False)]) def test_notifications( self, room_version: RoomVersion, allow_modification: bool @@ -769,6 +959,105 @@ class EventAuthTestCase(unittest.TestCase): with self.assertRaises(SynapseError): event_auth._check_power_levels(event.room_version, event, {}) + def test_event_exists_in_state_dag(self) -> None: + events_that_exist_in_state_dag = [ + { + "type": "m.room.create", + "state_key": "", + "content": {}, + }, + { + "type": "m.room.join_rules", + "state_key": "", + "content": {}, + }, + { + "type": "m.room.power_levels", + "state_key": "", + "content": {}, + }, + { + "type": "m.room.server_acl", + "state_key": "", + "content": {}, + }, + { + "type": "m.room.member", + "state_key": "@alice:somewhere", + "content": {}, + }, + { + "type": "m.room.third_party_invite", + "state_key": "flibble", + "content": {}, + }, + { + "type": "m.room.create", + "state_key": " ", + "content": {}, + }, + { + "type": "m.room.join_rules", + "state_key": " ", + "content": {}, + }, + { + "type": "m.room.power_levels", + "state_key": " ", + "content": {}, + }, + { + "type": "m.room.name", + "state_key": "", + "content": {}, + }, + { + "type": "m.room.member", + "state_key": "", + "content": {}, + }, + { + "type": "m.room.member", + "state_key": "hello_world", + "content": {}, + }, + ] + events_that_dont_exist_in_state_dag = [ + { + "type": "m.room.message", + "content": {}, + }, + { + "type": "m.room.create", + "content": {}, + }, + { + "type": "m.room.join_rules", + "content": {}, + }, + { + "type": "m.room.power_levels", + "content": {}, + }, + ] + + def check_events(events: list[dict], should_exist: bool) -> None: + for ev in events: + base = { + "room_id": TEST_ROOM_ID, + "sender": "@test:test.com", + "signatures": {"test.com": {"ed25519:0": "some9signature"}}, + } + base.update(ev) + event = make_event_from_dict(base, RoomVersions.V10) + got = event_exists_in_state_dag(event) + self.assertEqual( + got, should_exist, f"{ev} should_exist={should_exist} but got {got}" + ) + + check_events(events_that_exist_in_state_dag, should_exist=True) + check_events(events_that_dont_exist_in_state_dag, should_exist=False) + # helpers for making events TEST_DOMAIN = "example.com" From bdb1cf7416b46a637b3dae323cb05b4d94fafc82 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 17 Apr 2026 10:57:38 +0100 Subject: [PATCH 2/8] Bump authlib from 1.6.9 to 1.6.11 (#19703) --- poetry.lock | 59 +++++++++++++++++++++++++------------------------- pyproject.toml | 4 ++-- 2 files changed, 31 insertions(+), 32 deletions(-) diff --git a/poetry.lock b/poetry.lock index ef5c13684d..fd8f1a43c9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -26,15 +26,15 @@ files = [ [[package]] name = "authlib" -version = "1.6.9" +version = "1.6.11" description = "The ultimate Python library in building OAuth and OpenID Connect servers and clients." optional = true python-versions = ">=3.9" groups = ["main"] -markers = "extra == \"oidc\" or extra == \"jwt\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"jwt\" or extra == \"oidc\"" files = [ - {file = "authlib-1.6.9-py2.py3-none-any.whl", hash = "sha256:f08b4c14e08f0861dc18a32357b33fbcfd2ea86cfe3fe149484b4d764c4a0ac3"}, - {file = "authlib-1.6.9.tar.gz", hash = "sha256:d8f2421e7e5980cc1ddb4e32d3f5fa659cfaf60d8eaf3281ebed192e4ab74f04"}, + {file = "authlib-1.6.11-py2.py3-none-any.whl", hash = "sha256:c8687a9a26451c51a34a06fa17bb97cb15bba46a6a626755e2d7f50da8bff3e3"}, + {file = "authlib-1.6.11.tar.gz", hash = "sha256:64db35b9b01aeccb4715a6c9a6613a06f2bd7be2ab9d2eb89edd1dfc7580a38f"}, ] [package.dependencies] @@ -62,7 +62,7 @@ description = "Backport of CPython tarfile module" optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "python_version < \"3.12\" and platform_machine != \"ppc64le\" and platform_machine != \"s390x\"" +markers = "platform_machine != \"ppc64le\" and platform_machine != \"s390x\" and python_version < \"3.12\"" files = [ {file = "backports.tarfile-1.2.0-py3-none-any.whl", hash = "sha256:77e284d754527b01fb1e6fa8a1afe577858ebe4e9dad8919e34c862cb399bc34"}, {file = "backports_tarfile-1.2.0.tar.gz", hash = "sha256:d75e02c268746e1b8144c278978b6e98e85de6ad16f8e4b0844a154557eca991"}, @@ -531,7 +531,7 @@ description = "XML bomb protection for Python stdlib modules" optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" groups = ["main"] -markers = "extra == \"saml2\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"saml2\"" files = [ {file = "defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61"}, {file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"}, @@ -556,7 +556,7 @@ description = "XPath 1.0/2.0/3.0/3.1 parsers and selectors for ElementTree and l optional = true python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"saml2\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"saml2\"" files = [ {file = "elementpath-4.8.0-py3-none-any.whl", hash = "sha256:5393191f84969bcf8033b05ec4593ef940e58622ea13cefe60ecefbbf09d58d9"}, {file = "elementpath-4.8.0.tar.gz", hash = "sha256:5822a2560d99e2633d95f78694c7ff9646adaa187db520da200a8e9479dc46ae"}, @@ -606,7 +606,7 @@ description = "Python wrapper for hiredis" optional = true python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"redis\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"redis\"" files = [ {file = "hiredis-3.3.1-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:f525734382a47f9828c9d6a1501522c78d5935466d8e2be1a41ba40ca5bb922b"}, {file = "hiredis-3.3.1-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:6e2e1024f0a021777740cb7c633a0efb2c4a4bc570f508223a8dcbcf79f99ef9"}, @@ -889,7 +889,7 @@ description = "Read metadata from Python packages" optional = false python-versions = ">=3.9" groups = ["dev"] -markers = "python_version < \"3.12\" and platform_machine != \"ppc64le\" and platform_machine != \"s390x\"" +markers = "platform_machine != \"ppc64le\" and platform_machine != \"s390x\" and python_version < \"3.12\"" files = [ {file = "importlib_metadata-8.7.1-py3-none-any.whl", hash = "sha256:5a1f80bf1daa489495071efbb095d75a634cf28a8bc299581244063b53176151"}, {file = "importlib_metadata-8.7.1.tar.gz", hash = "sha256:49fef1ae6440c182052f407c8d34a68f72efc36db9ca90dc0113398f2fdde8bb"}, @@ -930,7 +930,7 @@ description = "Jaeger Python OpenTracing Tracer implementation" optional = true python-versions = ">=3.7" groups = ["main"] -markers = "extra == \"opentracing\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"opentracing\"" files = [ {file = "jaeger-client-4.8.0.tar.gz", hash = "sha256:3157836edab8e2c209bd2d6ae61113db36f7ee399e66b1dcbb715d87ab49bfe0"}, ] @@ -1122,7 +1122,7 @@ description = "A strictly RFC 4510 conforming LDAP V3 pure Python client library optional = true python-versions = "*" groups = ["main"] -markers = "extra == \"matrix-synapse-ldap3\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"matrix-synapse-ldap3\"" files = [ {file = "ldap3-2.9.1-py2.py3-none-any.whl", hash = "sha256:5869596fc4948797020d3f03b7939da938778a0f9e2009f7a072ccf92b8e8d70"}, {file = "ldap3-2.9.1.tar.gz", hash = "sha256:f3e7fc4718e3f09dda568b57100095e0ce58633bcabbed8667ce3f8fbaa4229f"}, @@ -1239,7 +1239,7 @@ description = "Powerful and Pythonic XML processing library combining libxml2/li optional = true python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"url-preview\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"url-preview\"" files = [ {file = "lxml-6.0.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:e77dd455b9a16bbd2a5036a63ddbd479c19572af81b624e79ef422f929eef388"}, {file = "lxml-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5d444858b9f07cefff6455b983aea9a67f7462ba1f6cbe4a21e8bf6791bf2153"}, @@ -1553,7 +1553,7 @@ description = "An LDAP3 auth provider for Synapse" optional = true python-versions = ">=3.10" groups = ["main"] -markers = "extra == \"matrix-synapse-ldap3\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"matrix-synapse-ldap3\"" files = [ {file = "matrix_synapse_ldap3-0.4.0-py3-none-any.whl", hash = "sha256:bf080037230d2af5fd3639cb87266de65c1cad7a68ea206278c5b4bf9c1a17f3"}, {file = "matrix_synapse_ldap3-0.4.0.tar.gz", hash = "sha256:cff52ba780170de5e6e8af42863d2648ee23f3bf0a9fea6db52372f9fc00be2b"}, @@ -1834,7 +1834,7 @@ description = "OpenTracing API for Python. See documentation at http://opentraci optional = true python-versions = "*" groups = ["main"] -markers = "extra == \"opentracing\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"opentracing\"" files = [ {file = "opentracing-2.4.0.tar.gz", hash = "sha256:a173117e6ef580d55874734d1fa7ecb6f3655160b8b8974a2a1e98e5ec9c840d"}, ] @@ -2032,7 +2032,7 @@ description = "psycopg2 - Python-PostgreSQL Database Adapter" optional = true python-versions = ">=3.9" groups = ["main"] -markers = "extra == \"postgres\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"postgres\"" files = [ {file = "psycopg2-2.9.11-cp310-cp310-win_amd64.whl", hash = "sha256:103e857f46bb76908768ead4e2d0ba1d1a130e7b8ed77d3ae91e8b33481813e8"}, {file = "psycopg2-2.9.11-cp311-cp311-win_amd64.whl", hash = "sha256:210daed32e18f35e3140a1ebe059ac29209dd96468f2f7559aa59f75ee82a5cb"}, @@ -2050,7 +2050,7 @@ description = ".. image:: https://travis-ci.org/chtd/psycopg2cffi.svg?branch=mas optional = true python-versions = "*" groups = ["main"] -markers = "platform_python_implementation == \"PyPy\" and (extra == \"postgres\" or extra == \"all\")" +markers = "platform_python_implementation == \"PyPy\" and (extra == \"all\" or extra == \"postgres\")" files = [ {file = "psycopg2cffi-2.9.0.tar.gz", hash = "sha256:7e272edcd837de3a1d12b62185eb85c45a19feda9e62fa1b120c54f9e8d35c52"}, ] @@ -2066,7 +2066,7 @@ description = "A Simple library to enable psycopg2 compatability" optional = true python-versions = "*" groups = ["main"] -markers = "platform_python_implementation == \"PyPy\" and (extra == \"postgres\" or extra == \"all\")" +markers = "platform_python_implementation == \"PyPy\" and (extra == \"all\" or extra == \"postgres\")" files = [ {file = "psycopg2cffi-compat-1.1.tar.gz", hash = "sha256:d25e921748475522b33d13420aad5c2831c743227dc1f1f2585e0fdb5c914e05"}, ] @@ -2348,7 +2348,7 @@ description = "A development tool to measure, monitor and analyze the memory beh optional = true python-versions = ">=3.6" groups = ["main"] -markers = "extra == \"cache-memory\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"cache-memory\"" files = [ {file = "Pympler-1.0.1-py3-none-any.whl", hash = "sha256:d260dda9ae781e1eab6ea15bacb84015849833ba5555f141d2d9b7b7473b307d"}, {file = "Pympler-1.0.1.tar.gz", hash = "sha256:993f1a3599ca3f4fcd7160c7545ad06310c9e12f70174ae7ae8d4e25f6c5d3fa"}, @@ -2480,7 +2480,7 @@ description = "Python implementation of SAML Version 2 Standard" optional = true python-versions = ">=3.9,<4.0" groups = ["main"] -markers = "extra == \"saml2\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"saml2\"" files = [ {file = "pysaml2-7.5.0-py3-none-any.whl", hash = "sha256:bc6627cc344476a83c757f440a73fda1369f13b6fda1b4e16bca63ffbabb5318"}, {file = "pysaml2-7.5.0.tar.gz", hash = "sha256:f36871d4e5ee857c6b85532e942550d2cf90ea4ee943d75eb681044bbc4f54f7"}, @@ -2505,7 +2505,7 @@ description = "Extensions to the standard Python datetime module" optional = true python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" groups = ["main"] -markers = "extra == \"saml2\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"saml2\"" files = [ {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, @@ -2533,7 +2533,7 @@ description = "World timezone definitions, modern and historical" optional = true python-versions = "*" groups = ["main"] -markers = "extra == \"saml2\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"saml2\"" files = [ {file = "pytz-2026.1.post1-py2.py3-none-any.whl", hash = "sha256:f2fd16142fda348286a75e1a524be810bb05d444e5a081f37f7affc635035f7a"}, {file = "pytz-2026.1.post1.tar.gz", hash = "sha256:3378dde6a0c3d26719182142c56e60c7f9af7e968076f31aae569d72a0358ee1"}, @@ -2938,7 +2938,7 @@ description = "Python client for Sentry (https://sentry.io)" optional = true python-versions = ">=3.6" groups = ["main"] -markers = "extra == \"sentry\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"sentry\"" files = [ {file = "sentry_sdk-2.57.0-py2.py3-none-any.whl", hash = "sha256:812c8bf5ff3d2f0e89c82f5ce80ab3a6423e102729c4706af7413fd1eb480585"}, {file = "sentry_sdk-2.57.0.tar.gz", hash = "sha256:4be8d1e71c32fb27f79c577a337ac8912137bba4bcbc64a4ec1da4d6d8dc5199"}, @@ -3138,7 +3138,7 @@ description = "Tornado IOLoop Backed Concurrent Futures" optional = true python-versions = "*" groups = ["main"] -markers = "extra == \"opentracing\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"opentracing\"" files = [ {file = "threadloop-1.0.2-py2-none-any.whl", hash = "sha256:5c90dbefab6ffbdba26afb4829d2a9df8275d13ac7dc58dccb0e279992679599"}, {file = "threadloop-1.0.2.tar.gz", hash = "sha256:8b180aac31013de13c2ad5c834819771992d350267bddb854613ae77ef571944"}, @@ -3154,7 +3154,7 @@ description = "Python bindings for the Apache Thrift RPC system" optional = true python-versions = "*" groups = ["main"] -markers = "extra == \"opentracing\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"opentracing\"" files = [ {file = "thrift-0.22.0.tar.gz", hash = "sha256:42e8276afbd5f54fe1d364858b6877bc5e5a4a5ed69f6a005b94ca4918fe1466"}, ] @@ -3220,7 +3220,6 @@ files = [ {file = "tomli-2.4.0-py3-none-any.whl", hash = "sha256:1f776e7d669ebceb01dee46484485f43a4048746235e683bcdffacdf1fb4785a"}, {file = "tomli-2.4.0.tar.gz", hash = "sha256:aa89c3f6c277dd275d8e243ad24f3b5e701491a860d5121f2cdd399fbb31fc9c"}, ] -markers = {main = "python_version < \"3.14\""} [[package]] name = "tornado" @@ -3229,7 +3228,7 @@ description = "Tornado is a Python web framework and asynchronous networking lib optional = true python-versions = ">=3.9" groups = ["main"] -markers = "extra == \"opentracing\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"opentracing\"" files = [ {file = "tornado-6.5.5-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:487dc9cc380e29f58c7ab88f9e27cdeef04b2140862e5076a66fb6bb68bb1bfa"}, {file = "tornado-6.5.5-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:65a7f1d46d4bb41df1ac99f5fcb685fb25c7e61613742d5108b010975a9a6521"}, @@ -3361,7 +3360,7 @@ description = "non-blocking redis client for python" optional = true python-versions = "*" groups = ["main"] -markers = "extra == \"redis\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"redis\"" files = [ {file = "txredisapi-1.4.11-py3-none-any.whl", hash = "sha256:ac64d7a9342b58edca13ef267d4fa7637c1aa63f8595e066801c1e8b56b22d0b"}, {file = "txredisapi-1.4.11.tar.gz", hash = "sha256:3eb1af99aefdefb59eb877b1dd08861efad60915e30ad5bf3d5bf6c5cedcdbc6"}, @@ -3622,7 +3621,7 @@ description = "An XML Schema validator and decoder" optional = true python-versions = ">=3.7" groups = ["main"] -markers = "extra == \"saml2\" or extra == \"all\"" +markers = "extra == \"all\" or extra == \"saml2\"" files = [ {file = "xmlschema-2.5.1-py3-none-any.whl", hash = "sha256:ec2b2a15c8896c1fcd14dcee34ca30032b99456c3c43ce793fdb9dca2fb4b869"}, {file = "xmlschema-2.5.1.tar.gz", hash = "sha256:4f7497de6c8b6dc2c28ad7b9ed6e21d186f4afe248a5bea4f54eedab4da44083"}, @@ -3643,7 +3642,7 @@ description = "Backport of pathlib-compatible object wrapper for zip files" optional = false python-versions = ">=3.9" groups = ["dev"] -markers = "python_version < \"3.12\" and platform_machine != \"ppc64le\" and platform_machine != \"s390x\"" +markers = "platform_machine != \"ppc64le\" and platform_machine != \"s390x\" and python_version < \"3.12\"" files = [ {file = "zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e"}, {file = "zipp-3.23.0.tar.gz", hash = "sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166"}, @@ -3756,4 +3755,4 @@ url-preview = ["lxml"] [metadata] lock-version = "2.1" python-versions = ">=3.10.0,<4.0.0" -content-hash = "ef0540b89c417a69668f551688bd0974256ea7a580044f3954a76bdf0d8fe7c9" +content-hash = "8d994f1fc65664b2a04e1de78df4d1f06f3d99b39f95db16763790f2ee0aff11" diff --git a/pyproject.toml b/pyproject.toml index c156d4f899..9a6e487744 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -137,7 +137,7 @@ saml2 = [ "defusedxml>=0.7.1", # via pysaml2 "pytz>=2018.3", # via pysaml2 ] -oidc = ["authlib>=0.15.1"] +oidc = ["authlib>=1.6.11"] url-preview = ["lxml>=4.6.3"] sentry = ["sentry-sdk>=0.7.2"] opentracing = [ @@ -179,7 +179,7 @@ all = [ # saml2 "pysaml2>=4.5.0", # oidc and jwt - "authlib>=0.15.1", + "authlib>=1.6.11", # url-preview "lxml>=4.6.3", # sentry From 647fb5919050a16ac02c7c6903a620d1dd8a1727 Mon Sep 17 00:00:00 2001 From: Shay Date: Fri, 17 Apr 2026 03:01:23 -0700 Subject: [PATCH 3/8] Add Admin API endpoints to manage user reports (#19657) Adds [Admin API](https://element-hq.github.io/synapse/latest/usage/administration/admin_api/index.html) endpoints to list, fetch and delete user reports from the homeserver. Follows on from #18120, which added the endpoints to report users. --- changelog.d/19657.feature | 2 + synapse/rest/admin/__init__.py | 6 + synapse/rest/admin/user_reports.py | 173 +++++++ synapse/storage/databases/main/room.py | 126 +++++ tests/rest/admin/test_user_reports.py | 644 +++++++++++++++++++++++++ 5 files changed, 951 insertions(+) create mode 100644 changelog.d/19657.feature create mode 100644 synapse/rest/admin/user_reports.py create mode 100644 tests/rest/admin/test_user_reports.py diff --git a/changelog.d/19657.feature b/changelog.d/19657.feature new file mode 100644 index 0000000000..f87ef9fed8 --- /dev/null +++ b/changelog.d/19657.feature @@ -0,0 +1,2 @@ +Adds [Admin API](https://element-hq.github.io/synapse/latest/usage/administration/admin_api/index.html) endpoints to +list, fetch and delete user reports. \ No newline at end of file diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index b209404cd1..0774b6ed40 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -96,6 +96,10 @@ from synapse.rest.admin.statistics import ( LargestRoomsStatistics, UserMediaStatisticsRestServlet, ) +from synapse.rest.admin.user_reports import ( + UserReportDetailRestServlet, + UserReportsRestServlet, +) from synapse.rest.admin.username_available import UsernameAvailableRestServlet from synapse.rest.admin.users import ( AccountDataRestServlet, @@ -312,6 +316,8 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: LargestRoomsStatistics(hs).register(http_server) EventReportDetailRestServlet(hs).register(http_server) EventReportsRestServlet(hs).register(http_server) + UserReportsRestServlet(hs).register(http_server) + UserReportDetailRestServlet(hs).register(http_server) AccountDataRestServlet(hs).register(http_server) PushersRestServlet(hs).register(http_server) MakeRoomAdminRestServlet(hs).register(http_server) diff --git a/synapse/rest/admin/user_reports.py b/synapse/rest/admin/user_reports.py new file mode 100644 index 0000000000..119dc86517 --- /dev/null +++ b/synapse/rest/admin/user_reports.py @@ -0,0 +1,173 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2026 Element Creations, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# + + +import logging +from http import HTTPStatus +from typing import TYPE_CHECKING + +from synapse.api.constants import Direction +from synapse.api.errors import Codes, NotFoundError, SynapseError +from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string +from synapse.http.site import SynapseRequest +from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class UserReportsRestServlet(RestServlet): + """ + List all reported users that are known to the homeserver. Results are returned + in a dictionary containing report information. Supports pagination. + The requester must have administrator access in Synapse. + + GET /_synapse/admin/v1/user_reports + returns: + 200 OK with list of reports if success otherwise an error. + + Args: + The parameters `from` and `limit` are required only for pagination. + By default, a `limit` of 100 is used. + The parameter `dir` can be used to define the order of results. + The `user_id` query parameter filters by the user ID of the reporter of the target user. + The `target_user_id` query parameter filters by user id of the target user. + Returns: + A list of user reprots and an integer representing the total number of user + reports that exist given this query + """ + + PATTERNS = admin_patterns("/user_reports$") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastores().main + + async def on_GET(self, request: SynapseRequest) -> tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + start = parse_integer(request, "from", default=0) + limit = parse_integer(request, "limit", default=100) + direction = parse_enum(request, "dir", Direction, Direction.BACKWARDS) + user_id = parse_string(request, "user_id") + target_user_id = parse_string(request, "target_user_id") + + if start < 0: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "The start parameter must be a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if limit < 0: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "The limit parameter must be a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + user_reports, total = await self._store.get_user_reports_paginate( + start, limit, direction, user_id, target_user_id + ) + ret = {"user_reports": user_reports, "total": total} + if (start + limit) < total: + ret["next_token"] = start + len(user_reports) + + return HTTPStatus.OK, ret + + +class UserReportDetailRestServlet(RestServlet): + """ + Get a specific user report that is known to the homeserver. Results are returned + in a dictionary containing report information. + The requester must have administrator access in Synapse. + + GET /_synapse/admin/v1/user_reports/ + returns: + 200 OK with details report if success otherwise an error. + + Args: + The parameter `report_id` is the ID of the user report in the database. + Returns: + JSON blob of information about the user report + """ + + PATTERNS = admin_patterns("/user_reports/(?P[^/]*)$") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastores().main + + async def on_GET( + self, request: SynapseRequest, report_id: str + ) -> tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + message = ( + "The report_id parameter must be a string representing a positive integer." + ) + try: + resolved_report_id = int(report_id) + except ValueError: + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) + + if resolved_report_id < 0: + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) + + ret = await self._store.get_user_report(resolved_report_id) + if not ret: + raise NotFoundError("User report not found") + + id, received_ts, target_user_id, user_id, reason = ret + response = { + "id": id, + "received_ts": received_ts, + "target_user_id": target_user_id, + "user_id": user_id, + "reason": reason, + } + + return HTTPStatus.OK, response + + async def on_DELETE( + self, request: SynapseRequest, report_id: str + ) -> tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + message = ( + "The report_id parameter must be a string representing a positive integer." + ) + try: + resolved_report_id = int(report_id) + except ValueError: + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) + + if resolved_report_id < 0: + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) + + if await self._store.delete_user_report(resolved_report_id): + return HTTPStatus.OK, {} + + raise NotFoundError("User report not found") diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 8dbecc16e4..a0c42082f0 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -2277,6 +2277,132 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): return True + async def get_user_report( + self, report_id: int + ) -> tuple[int, int, str, str, str] | None: + """Retrieve a user report + + Args: + report_id: ID of user report in database + Returns: + JSON dict of information from a user report or None if the + report does not exist. + """ + + return await self.db_pool.simple_select_one( + table="user_reports", + keyvalues={"id": report_id}, + retcols=("id", "received_ts", "target_user_id", "user_id", "reason"), + allow_none=True, + desc="get_user_report", + ) + + async def get_user_reports_paginate( + self, + start: int, + limit: int, + direction: Direction = Direction.BACKWARDS, + user_id: str | None = None, + target_user_id: str | None = None, + ) -> tuple[list[JsonDict], int]: + """Retrieve a paginated list of user reports + + Args: + start: event offset to begin the query from + limit: number of rows to retrieve + direction: Whether to fetch the most recent first (backwards) or the + oldest first (forwards) + user_id: search for user_id of the reporter. Ignored if user_id is None + target_user_id: search for user_id of the target. Ignored if target_user_id is None + Returns: + Tuple of: + json list of user reports + total number of user reports matching the filter criteria + """ + + def _get_user_reports_paginate_txn( + txn: LoggingTransaction, + ) -> tuple[list[dict[str, Any]], int]: + filters = [] + args: list[object] = [] + + if user_id: + filters.append("user_id LIKE ?") + args.extend(["%" + user_id + "%"]) + if target_user_id: + filters.append("target_user_id LIKE ?") + args.extend(["%" + target_user_id + "%"]) + + if direction == Direction.BACKWARDS: + order = "DESC" + else: + order = "ASC" + + where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" + + sql = f""" + SELECT COUNT(*) as total_user_reports + FROM user_reports {where_clause} + """ + txn.execute(sql, args) + count = cast(tuple[int], txn.fetchone())[0] + + sql = f""" + SELECT + id, + received_ts, + target_user_id, + user_id, + reason + FROM user_reports + {where_clause} + ORDER BY received_ts {order} + LIMIT ? + OFFSET ? + """ + + args += [limit, start] + txn.execute(sql, args) + + user_reports = [] + for row in txn: + user_reports.append( + { + "id": row[0], + "received_ts": row[1], + "target_user_id": row[2], + "user_id": row[3], + "reason": row[4], + } + ) + + return user_reports, count + + return await self.db_pool.runInteraction( + "get_user_reports_paginate", _get_user_reports_paginate_txn + ) + + async def delete_user_report(self, report_id: int) -> bool: + """Remove a user report from database. + + Args: + report_id: Report to delete + + Returns: + Whether the report was successfully deleted or not. + """ + try: + await self.db_pool.simple_delete_one( + table="user_reports", + keyvalues={"id": report_id}, + desc="delete_user_report", + ) + except StoreError: + # Deletion failed because report does not exist + return False + + return True + async def set_room_is_public(self, room_id: str, is_public: bool) -> None: await self.db_pool.simple_update_one( table="rooms", diff --git a/tests/rest/admin/test_user_reports.py b/tests/rest/admin/test_user_reports.py new file mode 100644 index 0000000000..94ae242d86 --- /dev/null +++ b/tests/rest/admin/test_user_reports.py @@ -0,0 +1,644 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2026 Element Creations Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# + + +from twisted.internet.testing import MemoryReactor + +import synapse.rest.admin +from synapse.api.errors import Codes +from synapse.rest.client import login, reporting, room +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util.clock import Clock + +from tests import unittest + + +class UserReportsTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + reporting.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.users = {} + for i in range(10): + self.users[i] = self.register_user(f"user{i}", "pass") + + # users 1 and 2 report all other users + reporter_1_tok = self.login(self.users[0], "pass") + reporter_2_tok = self.login(self.users[1], "pass") + for num, user in self.users.items(): + if num <= 1: + continue + if num % 2 == 0: + self._report_user(user, reporter_1_tok) + else: + self._report_user(user, reporter_2_tok) + + self.url = "/_synapse/admin/v1/user_reports" + + def test_no_auth(self) -> None: + """ + Try to get a user report without authentication. + """ + channel = self.make_request("GET", self.url, {}) + + self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self) -> None: + """ + If the user is not a server admin, an error 403 is returned. + """ + rando_tok = self.login(self.users[4], "pass") + channel = self.make_request( + "GET", + self.url, + access_token=rando_tok, + ) + + self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_default_success(self) -> None: + """ + Testing list of reported users + """ + + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 8) + self.assertEqual(len(channel.json_body["user_reports"]), 8) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["user_reports"]) + + def test_limit(self) -> None: + """ + Testing list of reported users with limit + """ + + channel = self.make_request( + "GET", + self.url + "?limit=5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 8) + self.assertEqual(len(channel.json_body["user_reports"]), 5) + self.assertEqual(channel.json_body["next_token"], 5) + self._check_fields(channel.json_body["user_reports"]) + + def test_from(self) -> None: + """ + Testing list of reported users with a defined starting point (from) + """ + + channel = self.make_request( + "GET", + self.url + "?from=5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 8) + self.assertEqual(len(channel.json_body["user_reports"]), 3) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["user_reports"]) + + def test_limit_and_from(self) -> None: + """ + Testing list of reported users with a defined starting point and limit + """ + + channel = self.make_request( + "GET", + self.url + "?from=2&limit=5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 8) + self.assertEqual(channel.json_body["next_token"], 7) + self.assertEqual(len(channel.json_body["user_reports"]), 5) + self._check_fields(channel.json_body["user_reports"]) + + def test_filter_by_target_user_id(self) -> None: + """ + Testing list of reported users with a filter of target_user_id + """ + + channel = self.make_request( + "GET", + self.url + "?target_user_id=%s" % self.users[3], + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 1) + self.assertEqual(len(channel.json_body["user_reports"]), 1) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["user_reports"]) + + for report in channel.json_body["user_reports"]: + self.assertEqual(report["target_user_id"], self.users[3]) + + def test_filter_user(self) -> None: + """ + Testing list of reported users with a filter of reporting user + """ + + channel = self.make_request( + "GET", + self.url + "?user_id=%s" % self.users[0], + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 4) + self.assertEqual(len(channel.json_body["user_reports"]), 4) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["user_reports"]) + + for report in channel.json_body["user_reports"]: + self.assertEqual(report["user_id"], self.users[0]) + + def test_filter_user_and_target_user(self) -> None: + """ + Testing list of reported users with a filter of reporting user and target_user + """ + + channel = self.make_request( + "GET", + self.url + "?user_id=%s&target_user_id=%s" % (self.users[1], self.users[7]), + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 1) + self.assertEqual(len(channel.json_body["user_reports"]), 1) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["user_reports"]) + + for report in channel.json_body["user_reports"]: + self.assertEqual(report["user_id"], self.users[1]) + self.assertEqual(report["target_user_id"], self.users[7]) + + def test_valid_search_order(self) -> None: + """ + Testing search order. Order by timestamps. + """ + + # fetch the most recent first, largest timestamp + channel = self.make_request( + "GET", + self.url + "?dir=b", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 8) + self.assertEqual(len(channel.json_body["user_reports"]), 8) + report = 1 + while report < len(channel.json_body["user_reports"]): + self.assertGreaterEqual( + channel.json_body["user_reports"][report - 1]["received_ts"], + channel.json_body["user_reports"][report]["received_ts"], + ) + report += 1 + + # fetch the oldest first, smallest timestamp + channel = self.make_request( + "GET", + self.url + "?dir=f", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 8) + self.assertEqual(len(channel.json_body["user_reports"]), 8) + report = 1 + while report < len(channel.json_body["user_reports"]): + self.assertLessEqual( + channel.json_body["user_reports"][report - 1]["received_ts"], + channel.json_body["user_reports"][report]["received_ts"], + ) + report += 1 + + def test_invalid_search_order(self) -> None: + """ + Testing that a invalid search order returns a 400 + """ + + channel = self.make_request( + "GET", + self.url + "?dir=bar", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "Query parameter 'dir' must be one of ['b', 'f']", + channel.json_body["error"], + ) + + def test_limit_is_negative(self) -> None: + """ + Testing that a negative limit parameter returns a 400 + """ + + channel = self.make_request( + "GET", + self.url + "?limit=-5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + def test_from_is_negative(self) -> None: + """ + Testing that a negative from parameter returns a 400 + """ + + channel = self.make_request( + "GET", + self.url + "?from=-5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + def test_next_token(self) -> None: + """ + Testing that `next_token` appears at the right place + """ + + # `next_token` does not appear + # Number of results is the number of entries + channel = self.make_request( + "GET", + self.url + "?limit=8", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 8) + self.assertEqual(len(channel.json_body["user_reports"]), 8) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does not appear + # Number of max results is larger than the number of entries + channel = self.make_request( + "GET", + self.url + "?limit=10", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 8) + self.assertEqual(len(channel.json_body["user_reports"]), 8) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does appear + # Number of max results is smaller than the number of entries + channel = self.make_request( + "GET", + self.url + "?limit=6", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 8) + self.assertEqual(len(channel.json_body["user_reports"]), 6) + self.assertEqual(channel.json_body["next_token"], 6) + + # Check + # Set `from` to value of `next_token` for request remaining entries + # `next_token` does not appear + channel = self.make_request( + "GET", + self.url + "?from=6", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 8) + self.assertEqual(len(channel.json_body["user_reports"]), 2) + self.assertNotIn("next_token", channel.json_body) + + def _report_user(self, target_user: str, reporter_tok: str) -> None: + """Report a user""" + channel = self.make_request( + "POST", + "_matrix/client/v3/users/%s/report" % (target_user), + {"reason": "stone-cold bummer"}, + access_token=reporter_tok, + shorthand=False, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + def _check_fields(self, content: list[JsonDict]) -> None: + """Checks that all attributes are present in a user report""" + for c in content: + self.assertIn("id", c) + self.assertIn("received_ts", c) + self.assertIn("user_id", c) + self.assertIn("target_user_id", c) + self.assertIn("reason", c) + + +class UserReportDetailTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + reporting.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.bad_user = self.register_user("user", "pass") + self.bad_user_tok = self.login("user", "pass") + + channel = self.make_request( + "POST", + "_matrix/client/v3/users/%s/report" % (self.bad_user), + {"reason": "stone-cold bummer"}, + access_token=self.admin_user_tok, + shorthand=False, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # first created user report gets `id`=2 + self.url = "/_synapse/admin/v1/user_reports/2" + + def test_no_auth(self) -> None: + """ + Try to get user report without authentication. + """ + channel = self.make_request("GET", self.url, {}) + + self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self) -> None: + """ + If the user is not a server admin, an error 403 is returned. + """ + + channel = self.make_request( + "GET", + self.url, + access_token=self.bad_user_tok, + ) + + self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_default_success(self) -> None: + """ + Testing get a reported user + """ + + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self._check_fields(channel.json_body) + + def test_invalid_report_id(self) -> None: + """ + Testing that an invalid `report_id` returns a 400. + """ + + # `report_id` is negative + channel = self.make_request( + "GET", + "/_synapse/admin/v1/user_reports/-123", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "The report_id parameter must be a string representing a positive integer.", + channel.json_body["error"], + ) + + # `report_id` is a non-numerical string + channel = self.make_request( + "GET", + "/_synapse/admin/v1/user_reports/abcdef", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "The report_id parameter must be a string representing a positive integer.", + channel.json_body["error"], + ) + + # `report_id` is undefined + channel = self.make_request( + "GET", + "/_synapse/admin/v1/user_reports/", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "The report_id parameter must be a string representing a positive integer.", + channel.json_body["error"], + ) + + def test_report_id_not_found(self) -> None: + """ + Testing that a not existing `report_id` returns a 404. + """ + + channel = self.make_request( + "GET", + "/_synapse/admin/v1/user_reports/123", + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + self.assertEqual("User report not found", channel.json_body["error"]) + + def _check_fields(self, content: JsonDict) -> None: + """Checks that all attributes are present in a user report""" + self.assertIn("id", content) + self.assertIn("received_ts", content) + self.assertIn("target_user_id", content) + self.assertIn("user_id", content) + self.assertIn("reason", content) + + +class DeleteUserReportTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self._store = hs.get_datastores().main + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.bad_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + # create report + self.get_success( + self._store.add_user_report( + self.bad_user, + self.admin_user, + "super bummer", + self.clock.time_msec(), + ) + ) + + self.url = "/_synapse/admin/v1/user_reports/2" + + def test_no_auth(self) -> None: + """ + Try to delete user report without authentication. + """ + channel = self.make_request("DELETE", self.url) + + self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self) -> None: + """ + If the user is not a server admin, an error 403 is returned. + """ + + channel = self.make_request( + "DELETE", + self.url, + access_token=self.other_user_tok, + ) + + self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_delete_success(self) -> None: + """ + Testing delete a report. + """ + + channel = self.make_request( + "DELETE", + self.url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual({}, channel.json_body) + + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) + + # check that report was deleted + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_invalid_report_id(self) -> None: + """ + Testing that an invalid `report_id` returns a 400. + """ + + # `report_id` is negative + channel = self.make_request( + "DELETE", + "/_synapse/admin/v1/user_reports/-123", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "The report_id parameter must be a string representing a positive integer.", + channel.json_body["error"], + ) + + # `report_id` is a non-numerical string + channel = self.make_request( + "DELETE", + "/_synapse/admin/v1/user_reports/abcdef", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "The report_id parameter must be a string representing a positive integer.", + channel.json_body["error"], + ) + + # `report_id` is undefined + channel = self.make_request( + "DELETE", + "/_synapse/admin/v1/user_reports/", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "The report_id parameter must be a string representing a positive integer.", + channel.json_body["error"], + ) + + def test_report_id_not_found(self) -> None: + """ + Testing that a not existing `report_id` returns a 404. + """ + + channel = self.make_request( + "DELETE", + "/_synapse/admin/v1/user_reports/123", + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + self.assertEqual("User report not found", channel.json_body["error"]) From 2a8285931e98984d867404c4dfe4ce8fc5e35b16 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 17 Apr 2026 11:54:22 +0100 Subject: [PATCH 4/8] Prune old rows in `device_lists_changes_in_room` table. (#19473) Fixes #13043 The usages of the table mostly already correctly handled if we don't have old entries, as that was needed when we first added the table. I arbitrarily set the prune time to 30 days. The only use for old entries is for sync streams that haven't synced since then, and we should very rarely see sync streams that haven't been used in 30 days. Reviewable commit-by-commit. --------- Co-authored-by: Olivier 'reivilibre' Co-authored-by: Olivier 'reivilibre' --- changelog.d/19473.misc | 1 + synapse/handlers/device.py | 12 +- synapse/storage/databases/main/devices.py | 365 ++++++++++++++---- .../94/03_device_lists_room_timestamp.sql | 18 + .../94/04_device_lists_changes_max_pruned.sql | 34 ++ tests/handlers/test_device.py | 357 ++++++++++++++++- tests/storage/test_devices.py | 104 +++++ 7 files changed, 816 insertions(+), 75 deletions(-) create mode 100644 changelog.d/19473.misc create mode 100644 synapse/storage/schema/main/delta/94/03_device_lists_room_timestamp.sql create mode 100644 synapse/storage/schema/main/delta/94/04_device_lists_changes_max_pruned.sql diff --git a/changelog.d/19473.misc b/changelog.d/19473.misc new file mode 100644 index 0000000000..596d8a6b26 --- /dev/null +++ b/changelog.d/19473.misc @@ -0,0 +1 @@ +Reduce database disk space usage by pruning old rows from `device_lists_changes_in_room`. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 9a371651fb..2225466648 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -58,6 +58,7 @@ from synapse.types import ( DeviceListUpdates, JsonDict, JsonMapping, + MultiWriterStreamToken, ScheduledTask, StrCollection, StreamKeyType, @@ -1193,7 +1194,16 @@ class DeviceWriterHandler(DeviceHandler): changes = await self.store.get_device_list_changes_in_room( room_id, device_lists_stream_id ) - local_changes = {(u, d) for u, d in changes if self.hs.is_mine_id(u)} + if changes is not None: + local_changes = {(u, d) for u, d in changes if self.hs.is_mine_id(u)} + else: + # The `device_lists_stream_id` is too old, so we need to fall back + # to looking for changes for all local users. + local_users = await self.store.get_local_users_in_room(room_id) + local_changes = await self.store.get_device_changes_for_users( + MultiWriterStreamToken(stream=device_lists_stream_id), local_users + ) + if not local_changes: return diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index e9ecf46411..339fb8a6f7 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -79,6 +79,19 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" +# Background update name for adding an index on +# `device_lists_changes_in_room.inserted_ts`. +BG_UPDATE_ADD_INSERTED_TS_INDEX = "device_lists_changes_in_room_inserted_ts_idx" + + +# Prunes entries out of the `device_lists_changes_in_room` table that are more +# than this old. +PRUNE_DEVICE_LISTS_CHANGES_IN_ROOM_AGE = Duration(days=30) + +# The number of rows to delete at once when pruning old entries out of the +# `device_lists_changes_in_room` table. +PRUNE_DEVICE_LISTS_BATCH_SIZE = 1000 + class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): _device_list_id_gen: MultiWriterIdGenerator @@ -194,6 +207,10 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): self.clock.looping_call( self._prune_old_outbound_device_pokes, Duration(hours=1) ) + self.clock.looping_call( + self._prune_device_lists_changes_in_room, + Duration(hours=1), + ) def process_replication_rows( self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] @@ -1143,6 +1160,35 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): The set of user_ids whose devices have changed since `from_key` (exclusive) until `to_key` (inclusive). """ + return { + user_id + for user_id, _ in await self.get_device_changes_for_users( + from_key, user_ids, to_key + ) + } + + @cancellable + async def get_device_changes_for_users( + self, + from_key: MultiWriterStreamToken, + user_ids: Collection[str], + to_key: MultiWriterStreamToken | None = None, + ) -> set[tuple[str, str]]: + """Get set of user/device ID tuple whose devices have changed since `from_key` that + are in the given list of user_ids. + + Args: + from_key: The minimum device lists stream token to query device list changes for, + exclusive. + user_ids: If provided, only check if these users have changed their device lists. + Otherwise changes from all users are returned. + to_key: The maximum device lists stream token to query device list changes for, + inclusive. If None then no upper limit is applied. + + Returns: + The set of user/device ID tuples whose devices have changed since `from_key` + (exclusive) until `to_key` (inclusive). + """ # Get set of users who *may* have changed. Users not in the returned # list have definitely not changed. user_ids_to_check = self._device_list_stream_cache.get_entities_changed( @@ -1156,18 +1202,18 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): if to_key is None: to_key = self.get_device_stream_token() - def _get_users_whose_devices_changed_txn( + def get_device_changes_for_users_txn( txn: LoggingTransaction, from_key: MultiWriterStreamToken, to_key: MultiWriterStreamToken, - ) -> set[str]: + ) -> set[tuple[str, str]]: sql = """ - SELECT user_id, stream_id, instance_name + SELECT user_id, device_id, stream_id, instance_name FROM device_lists_stream WHERE ? < stream_id AND stream_id <= ? AND %s """ - changes: set[str] = set() + changes: set[tuple[str, str]] = set() # Query device changes with a batch of users at a time for chunk in batch_iter(user_ids_to_check, 100): @@ -1179,8 +1225,8 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): [from_key.stream, to_key.get_max_stream_pos()] + args, ) changes.update( - user_id - for (user_id, stream_id, instance_name) in txn + (user_id, device_id) + for (user_id, device_id, stream_id, instance_name) in txn if MultiWriterStreamToken.is_stream_position_in_range( low=from_key, high=to_key, @@ -1192,8 +1238,8 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): return changes return await self.db_pool.runInteraction( - "get_users_whose_devices_changed", - _get_users_whose_devices_changed_txn, + "get_device_changes_for_users", + get_device_changes_for_users_txn, from_key, to_key, ) @@ -1699,17 +1745,22 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): return devices - @cached() - async def _get_min_device_lists_changes_in_room(self) -> int: - """Returns the minimum stream ID that we have entries for - `device_lists_changes_in_room` + def _get_max_pruned_device_lists_changes_in_room_txn( + self, txn: LoggingTransaction + ) -> int: + """Returns the maximum stream ID that has been pruned from + `device_lists_changes_in_room`. + + Any queries for stream IDs less than this value cannot be answered + completely, as the data has been deleted. """ - return await self.db_pool.simple_select_one_onecol( - table="device_lists_changes_in_room", + return self.db_pool.simple_select_one_onecol_txn( + txn, + table="device_lists_changes_in_room_max_pruned_stream_id", keyvalues={}, - retcol="COALESCE(MIN(stream_id), 0)", - desc="get_min_device_lists_changes_in_room", + retcol="stream_id", + allow_none=False, ) @cancellable @@ -1728,55 +1779,54 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): if not room_ids: return set() - min_stream_id = await self._get_min_device_lists_changes_in_room() - - # Return early if there are no rows to process in device_lists_changes_in_room - if min_stream_id > from_token.stream: - return None - changed_room_ids = self._device_list_room_stream_cache.get_entities_changed( room_ids, from_token.stream ) if not changed_room_ids: return set() - sql = """ - SELECT user_id, stream_id, instance_name - FROM device_lists_changes_in_room - WHERE {clause} AND stream_id > ? AND stream_id <= ? - """ - def _get_device_list_changes_in_rooms_txn( txn: LoggingTransaction, - chunk: list[str], - ) -> set[str]: - clause, args = make_in_list_sql_clause( - self.database_engine, "room_id", chunk + ) -> set[str] | None: + # Check if the from_token is too old (i.e. data has been pruned). + max_pruned_stream_id = ( + self._get_max_pruned_device_lists_changes_in_room_txn(txn) ) - args.append(from_token.stream) - args.append(to_token.get_max_stream_pos()) + if max_pruned_stream_id > from_token.stream: + return None - txn.execute(sql.format(clause=clause), args) - return { - user_id - for (user_id, stream_id, instance_name) in txn - if MultiWriterStreamToken.is_stream_position_in_range( - low=from_token, - high=to_token, - instance_name=instance_name, - pos=stream_id, + changes: set[str] = set() + + for chunk in batch_iter(changed_room_ids, 1000): + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", chunk ) - } + args.append(from_token.stream) + args.append(to_token.get_max_stream_pos()) - changes = set() - for chunk in batch_iter(changed_room_ids, 1000): - changes |= await self.db_pool.runInteraction( - "get_device_list_changes_in_rooms", - _get_device_list_changes_in_rooms_txn, - chunk, - ) + sql = f""" + SELECT user_id, stream_id, instance_name + FROM device_lists_changes_in_room + WHERE {clause} AND stream_id > ? AND stream_id <= ? + """ + txn.execute(sql, args) + changes.update( + user_id + for (user_id, stream_id, instance_name) in txn + if MultiWriterStreamToken.is_stream_position_in_range( + low=from_token, + high=to_token, + instance_name=instance_name, + pos=stream_id, + ) + ) - return changes + return changes + + return await self.db_pool.runInteraction( + "get_device_list_changes_in_rooms", + _get_device_list_changes_in_rooms_txn, + ) async def get_all_device_list_changes(self, from_id: int, to_id: int) -> set[str]: """Return the set of rooms where devices have changed since the given @@ -1785,46 +1835,66 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): Will raise an exception if the given stream ID is too old. """ - min_stream_id = await self._get_min_device_lists_changes_in_room() - - if min_stream_id > from_id: - raise Exception("stream ID is too old") - - sql = """ - SELECT DISTINCT room_id FROM device_lists_changes_in_room - WHERE stream_id > ? AND stream_id <= ? - """ - def _get_all_device_list_changes_txn( txn: LoggingTransaction, - ) -> set[str]: + ) -> set[str] | None: + # Check if the from_token is too old (i.e. data has been pruned). + max_pruned_stream_id = ( + self._get_max_pruned_device_lists_changes_in_room_txn(txn) + ) + if max_pruned_stream_id > from_id: + logger.warning( + "Given stream ID is too old %d < %d", + from_id, + max_pruned_stream_id, + ) + return None + + sql = """ + SELECT DISTINCT room_id FROM device_lists_changes_in_room + WHERE stream_id > ? AND stream_id <= ? + """ + txn.execute(sql, (from_id, to_id)) return {room_id for (room_id,) in txn} - return await self.db_pool.runInteraction( + room_ids = await self.db_pool.runInteraction( "get_all_device_list_changes", _get_all_device_list_changes_txn, ) + if room_ids is None: + raise Exception(f"Given stream ID is too old {from_id}") + + return room_ids + async def get_device_list_changes_in_room( self, room_id: str, min_stream_id: int - ) -> Collection[tuple[str, str]]: + ) -> Collection[tuple[str, str]] | None: """Get all device list changes that happened in the room since the given stream ID. Returns: Collection of user ID/device ID tuples of all devices that have - changed - """ - - sql = """ - SELECT DISTINCT user_id, device_id FROM device_lists_changes_in_room - WHERE room_id = ? AND stream_id > ? + changed, or None if the given stream ID is too old and so a complete + list cannot be calculated. """ def get_device_list_changes_in_room_txn( txn: LoggingTransaction, - ) -> Collection[tuple[str, str]]: + ) -> Collection[tuple[str, str]] | None: + # Check if the from_token is too old (i.e. data has been pruned). + max_pruned_stream_id = ( + self._get_max_pruned_device_lists_changes_in_room_txn(txn) + ) + if max_pruned_stream_id > min_stream_id: + return None + + sql = """ + SELECT DISTINCT user_id, device_id FROM device_lists_changes_in_room + WHERE room_id = ? AND stream_id > ? + """ + txn.execute(sql, (room_id, min_stream_id)) return cast(Collection[tuple[str, str]], txn.fetchall()) @@ -2160,6 +2230,8 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): encoded_context = json_encoder.encode(context) + now = self.clock.time_msec() + # The `device_lists_changes_in_room.stream_id` column matches the # corresponding `stream_id` of the update in the `device_lists_stream` # table, i.e. all rows persisted for the same device update will have @@ -2175,6 +2247,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): "instance_name", "converted_to_destinations", "opentracing_context", + "inserted_ts", ), values=[ ( @@ -2186,6 +2259,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): # We only need to calculate outbound pokes for local users not self.hs.is_mine_id(user_id), encoded_context, + now, ) for room_id in room_ids for device_id, stream_id in zip(device_ids, stream_ids) @@ -2401,6 +2475,144 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): desc="set_device_change_last_converted_pos", ) + @wrap_as_background_process("prune_device_lists_changes_in_room") + async def _prune_device_lists_changes_in_room(self) -> None: + """Delete old entries out of the `device_lists_changes_in_room`, so that + the table doesn't grow indefinitely. + """ + + # Let's only do this pruning if the index on inserted_ts has been + # created, otherwise this query will be very inefficient. + has_index_been_created = ( + await self.db_pool.updates.has_completed_background_update( + BG_UPDATE_ADD_INSERTED_TS_INDEX + ) + ) + if not has_index_been_created: + return + + prune_before_ts = ( + self.clock.time_msec() - PRUNE_DEVICE_LISTS_CHANGES_IN_ROOM_AGE.as_millis() + ) + + # Get stream ID corresponding to the prune_before_ts timestamp. We can + # delete all rows with a stream ID less than or equal to this, as they + # will be older than the cutoff. + # + # Some rows will have a NULL inserted_ts (due to being inserted before + # the column was added), but we can assume that the timestamp will + # monotonically increase with stream ID, so we can safely ignore those + # rows when calculating the cutoff stream ID. This means that we may end + # up keeping some rows with a non-NULL inserted_ts that are older than + # the cutoff, but that's better than accidentally deleting rows that are + # newer than the cutoff. + cutoff_sql = """ + SELECT stream_id FROM device_lists_changes_in_room + WHERE inserted_ts <= ? AND inserted_ts IS NOT NULL + ORDER BY inserted_ts DESC + LIMIT 1 + """ + + def get_prune_before_stream_id_txn(txn: LoggingTransaction) -> int | None: + txn.execute(cutoff_sql, (prune_before_ts,)) + row = txn.fetchone() + return row[0] if row else None + + prune_before_stream_id = await self.db_pool.runInteraction( + "prune_device_lists_changes_in_room_get_stream_id", + get_prune_before_stream_id_txn, + ) + + if prune_before_stream_id is None: + return + + # Get the max stream ID in the table so we avoid deleting it. We need + # to keep the latest row so that we can calculate the maximum stream ID + # used. + max_stream_id = await self.db_pool.simple_select_one_onecol( + table="device_lists_changes_in_room", + keyvalues={}, + retcol="MAX(stream_id)", + desc="prune_device_lists_changes_in_room_get_max_stream_id", + ) + if prune_before_stream_id >= max_stream_id: + prune_before_stream_id = max_stream_id - 1 + + logger.debug( + "Pruning device_lists_changes_in_room before stream ID %d (timestamp %d)", + prune_before_stream_id, + prune_before_ts, + ) + + # Now delete all rows with stream_id less than the + # prune_before_stream_id. + # + # We also delete in batches to avoid massive churn when initially + # clearing out all the old entries. + # + # We set a minimum stream ID so that when we delete in batches the + # database doesn't have to scan through all the (dead) tuples that were just + # deleted to find the next batch to delete. + + # The minimum stream ID to delete in the next batch, c.f. comment above. + # We default to 0 here as that is less than all possible stream IDs. + min_stream_id = 0 + + def prune_device_lists_changes_in_room_txn(txn: LoggingTransaction) -> int: + nonlocal min_stream_id + + delete_sql = """ + DELETE FROM device_lists_changes_in_room + WHERE stream_id IN ( + SELECT stream_id FROM device_lists_changes_in_room + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + ) + RETURNING stream_id + """ + txn.execute( + delete_sql, + (min_stream_id, prune_before_stream_id, PRUNE_DEVICE_LISTS_BATCH_SIZE), + ) + + # We can't use rowcount as that is incorrect on SQLite when using + # RETURNING. + num_deleted = 0 + for row in txn: + num_deleted += 1 + min_stream_id = max(min_stream_id, row[0]) + + return num_deleted + + num_rows_deleted = 0 + while True: + batch_deleted = await self.db_pool.runInteraction( + "prune_device_lists_changes_in_room", + prune_device_lists_changes_in_room_txn, + ) + num_rows_deleted += batch_deleted + if batch_deleted < PRUNE_DEVICE_LISTS_BATCH_SIZE: + break + + # Sleep for a short time to avoid hammering the database too much if + # there are a lot of rows to delete. + await self.clock.sleep(Duration(milliseconds=100)) + + if num_rows_deleted: + # Update the max pruned stream ID tracking table so that the + # safety check knows data up to this point has been deleted. + await self.db_pool.simple_update_one( + table="device_lists_changes_in_room_max_pruned_stream_id", + keyvalues={}, + updatevalues={"stream_id": prune_before_stream_id}, + desc="prune_device_lists_changes_in_room_update_max_pruned", + ) + + logger.info( + "Pruned %d rows from device_lists_changes_in_room", num_rows_deleted + ) + class DeviceBackgroundUpdateStore(SQLBaseStore): _instance_name: str @@ -2459,6 +2671,15 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): columns=["room_id", "stream_id"], ) + # Add indexes to speed up pruning of device_lists_changes_in_room + self.db_pool.updates.register_background_index_update( + BG_UPDATE_ADD_INSERTED_TS_INDEX, + index_name="device_lists_changes_in_room_inserted_ts_idx", + table="device_lists_changes_in_room", + columns=["inserted_ts"], + where_clause="inserted_ts IS NOT NULL", + ) + async def _drop_device_list_streams_non_unique_indexes( self, progress: JsonDict, batch_size: int ) -> int: diff --git a/synapse/storage/schema/main/delta/94/03_device_lists_room_timestamp.sql b/synapse/storage/schema/main/delta/94/03_device_lists_room_timestamp.sql new file mode 100644 index 0000000000..b0ae1eaaf6 --- /dev/null +++ b/synapse/storage/schema/main/delta/94/03_device_lists_room_timestamp.sql @@ -0,0 +1,18 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 Element Creations, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +ALTER TABLE device_lists_changes_in_room ADD COLUMN inserted_ts BIGINT; + +-- Add a background update to add index +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (9403, 'device_lists_changes_in_room_inserted_ts_idx', '{}'); diff --git a/synapse/storage/schema/main/delta/94/04_device_lists_changes_max_pruned.sql b/synapse/storage/schema/main/delta/94/04_device_lists_changes_max_pruned.sql new file mode 100644 index 0000000000..73836841da --- /dev/null +++ b/synapse/storage/schema/main/delta/94/04_device_lists_changes_max_pruned.sql @@ -0,0 +1,34 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2026 Element Creations Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + + +-- Tracks the maximum stream_id that has been deleted (pruned) from the +-- device_lists_changes_in_room table. This is used to determine whether it's +-- safe to read from that table for a given stream_id — if the requested +-- stream_id is < the value here, the data has been pruned and the table cannot +-- provide a complete answer. +-- +-- We need a separate table, rather than looking at the minimum stream_id in the +-- device_lists_changes_in_room table, because not all valid stream IDs will +-- have entries in the table. This could lead to situations where the minimum +-- stream ID was potentially much more recent than when we actually pruned. This +-- would cause us to incorrectly think that the table was not safe to read from, +-- when in fact it was. +CREATE TABLE IF NOT EXISTS device_lists_changes_in_room_max_pruned_stream_id ( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, + stream_id BIGINT NOT NULL +); + +-- We assume that nothing has been deleted from the device_lists_changes_in_room +-- table, so we can set the initial value to 0. +INSERT INTO device_lists_changes_in_room_max_pruned_stream_id (stream_id) VALUES (0); diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index f99e3cd4a2..9e44b1dc1e 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -21,20 +21,42 @@ # from unittest import mock +from unittest.mock import AsyncMock, Mock, patch +import signedjson.key +from parameterized import parameterized +from signedjson.types import SigningKey + +from twisted.internet import defer from twisted.internet.defer import ensureDeferred from twisted.internet.testing import MemoryReactor -from synapse.api.constants import RoomEncryptionAlgorithms +from synapse.api.constants import EventTypes, JoinRules, RoomEncryptionAlgorithms from synapse.api.errors import NotFoundError, SynapseError +from synapse.api.room_versions import RoomVersions from synapse.appservice import ApplicationService +from synapse.crypto.event_signing import add_hashes_and_signatures +from synapse.events import EventBase, FrozenEventV3 +from synapse.federation.federation_client import SendJoinResult +from synapse.federation.transport.client import ( + StateRequestResponse, + TransportLayerClient, +) +from synapse.federation.units import Transaction from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceWriterHandler from synapse.rest import admin from synapse.rest.client import devices, login, register from synapse.server import HomeServer from synapse.storage.databases.main.appservice import _make_exclusive_regex -from synapse.types import JsonDict, UserID, create_requester +from synapse.types import ( + JsonDict, + StateMap, + UserID, + create_requester, + get_domain_from_id, +) from synapse.util.clock import Clock +from synapse.util.duration import Duration from synapse.util.task_scheduler import TaskScheduler from tests import unittest @@ -581,3 +603,334 @@ class DehydrationTestCase(unittest.HomeserverTestCase): self.assertTrue(len(res["next_batch"]) > 1) self.assertEqual(len(res["events"]), 1) self.assertEqual(res["events"][0]["content"]["body"], "foo") + + +@patch("synapse.crypto.keyring.Keyring.process_request", AsyncMock(return_value=None)) +class DeviceUnPartialStateTestCase(unittest.HomeserverTestCase): + """Tests that local device list changes during partial state are sent to + remote servers when the room un-partials.""" + + servlets = [ + admin.register_servlets, + login.register_servlets, + ] + + # The two remote servers to fake + REMOTE1_SERVER_NAME = "remote1" + REMOTE1_SERVER_SIGNATURE_KEY = signedjson.key.generate_signing_key("test") + REMOTE1_USER = f"@user:{REMOTE1_SERVER_NAME}" + + REMOTE2_SERVER_NAME = "remote2" + REMOTE2_SERVER_SIGNATURE_KEY = signedjson.key.generate_signing_key("test") + REMOTE2_USER = f"@user:{REMOTE2_SERVER_NAME}" + + def default_config(self) -> JsonDict: + config = super().default_config() + # Enable federation so that get_device_updates_by_remote works. + config["federation_sender_instances"] = ["master"] + return config + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + # Mock the federation transport client to prevent actual network calls. + self.federation_transport_client = AsyncMock(TransportLayerClient) + + self.federation_transport_client.send_transaction.return_value = {} + + hs = self.setup_test_homeserver( + federation_transport_client=self.federation_transport_client, + ) + handler = hs.get_device_handler() + assert isinstance(handler, DeviceWriterHandler) + self.device_handler = handler + self.store = hs.get_datastores().main + + return hs + + def _build_public_room(self) -> StateMap[EventBase]: + """Build a public room DAG that has REMOTE1 in it""" + + room_id = f"!room:{self.REMOTE1_SERVER_NAME}" + room_version = RoomVersions.V10 + + events: list[EventBase] = [] + + # First we make the create event + create_event_dict: JsonDict = { + "auth_events": [], + "content": { + "creator": self.REMOTE1_USER, + "room_version": room_version.identifier, + }, + "depth": 0, + "origin_server_ts": 0, + "prev_events": [], + "room_id": room_id, + "sender": self.REMOTE1_USER, + "state_key": "", + "type": EventTypes.Create, + } + + add_hashes_and_signatures( + room_version, + create_event_dict, + self.REMOTE1_SERVER_NAME, + self.REMOTE1_SERVER_SIGNATURE_KEY, + ) + + create_event = FrozenEventV3(create_event_dict, room_version, {}, None) + events.append(create_event) + + room_version = self.hs.config.server.default_room_version + join_event_dict: JsonDict = { + "auth_events": [ + create_event.event_id, + ], + "content": {"membership": "join"}, + "depth": 1, + "origin_server_ts": 100, + "prev_events": [create_event.event_id], + "sender": self.REMOTE1_USER, + "state_key": self.REMOTE1_USER, + "room_id": room_id, + "type": EventTypes.Member, + } + add_hashes_and_signatures( + room_version, + join_event_dict, + self.hs.hostname, + self.hs.signing_key, + ) + join_event = FrozenEventV3(join_event_dict, room_version, {}, None) + events.append(join_event) + + # Then set the join rules to public + join_rules_event_dict: JsonDict = { + "auth_events": [create_event.event_id, join_event.event_id], + "content": {"join_rule": JoinRules.PUBLIC}, + "depth": 2, + "origin_server_ts": 200, + "prev_events": [join_event.event_id], + "room_id": room_id, + "sender": self.REMOTE1_USER, + "state_key": "", + "type": EventTypes.JoinRules, + } + + add_hashes_and_signatures( + room_version, + join_rules_event_dict, + self.REMOTE1_SERVER_NAME, + self.REMOTE1_SERVER_SIGNATURE_KEY, + ) + join_rules_event = FrozenEventV3(join_rules_event_dict, room_version, {}, None) + events.append(join_rules_event) + + return {(event.type, event.state_key): event for event in events} + + def _build_signed_join_event( + self, + room_id: str, + user: str, + signing_key: SigningKey, + state: StateMap[EventBase], + ) -> FrozenEventV3: + """Build a join event for the local user, signed by the local server.""" + + latest_event = max(state.values(), key=lambda e: e.depth) + + room_version = self.hs.config.server.default_room_version + join_event_dict: JsonDict = { + "auth_events": [ + state[(EventTypes.Create, "")].event_id, + state[(EventTypes.JoinRules, "")].event_id, + ], + "content": {"membership": "join"}, + "depth": latest_event.depth + 1, + "origin_server_ts": latest_event.origin_server_ts + 100, + "prev_events": [latest_event.event_id], + "sender": user, + "state_key": user, + "room_id": room_id, + "type": EventTypes.Member, + } + add_hashes_and_signatures( + room_version, + join_event_dict, + get_domain_from_id(user), + signing_key, + ) + return FrozenEventV3(join_event_dict, room_version, {}, None) + + @parameterized.expand([("not_pruned", False), ("pruned", True)]) + @patch( + "synapse.storage.databases.main.devices.PRUNE_DEVICE_LISTS_CHANGES_IN_ROOM_AGE", + Duration(minutes=1), + ) + def test_local_device_changes_sent_to_new_servers_on_un_partial_state( + self, _test_suffix: str, prune_device_lists_change_in_room: bool + ) -> None: + """When a room un-partials, local device list changes made during the + partial state period should be sent to remote servers that were NOT + known at the time of the partial join. + + We do this by creating a room with one remote server, partialling + joining it, then receiving a join event from a second remote server. The + second remote server should receive a device list update EDU for any + local device changes that happened during the partial state period. + + We parameterize this test over whether during the unpartial process we + prune the `device_list_changes_in_room` table, to check that the + unpartial process correctly handles the case. + """ + + local_user = self.register_user("alice", "password") + self.login("alice", "password") + + # Build the remote room's state events. + room_state = self._build_public_room() + + # Before joining, we mock out the federation endpoints that are used + # during the unpartial process, so that we can control when the + # unpartial process completes. + get_room_state_ids_deferred: defer.Deferred[JsonDict] = defer.Deferred() + get_room_state_deferred: defer.Deferred[StateRequestResponse] = defer.Deferred() + self.federation_transport_client.get_room_state_ids = Mock( + side_effect=[get_room_state_ids_deferred] + ) + self.federation_transport_client.get_room_state = Mock( + side_effect=[get_room_state_deferred] + ) + + # Now make the local server partially join the room. + room_id = room_state[(EventTypes.Create, "")].room_id + room_version = room_state[(EventTypes.Create, "")].room_version + + local_join_event = self._build_signed_join_event( + room_id, local_user, self.hs.signing_key, room_state + ) + + # Mock the federation client endpoints for the partial join. + mock_make_membership_event = AsyncMock( + return_value=(self.REMOTE1_SERVER_NAME, local_join_event, room_version) + ) + mock_send_join = AsyncMock( + return_value=SendJoinResult( + local_join_event, + self.REMOTE1_SERVER_NAME, + state=list(room_state.values()), + auth_chain=list(room_state.values()), + partial_state=True, + # Only REMOTE1_SERVER_NAME is known at join time. + servers_in_room={self.REMOTE1_SERVER_NAME}, + ) + ) + + fed_handler = self.hs.get_federation_handler() + fed_client = self.hs.get_federation_client() + with ( + patch.object( + fed_client, "make_membership_event", mock_make_membership_event + ), + patch.object(fed_client, "send_join", mock_send_join), + ): + self.get_success( + fed_handler.do_invite_join( + [self.REMOTE1_SERVER_NAME], room_id, local_user, {} + ) + ) + + # The room should now be in partial state. + self.assertTrue(self.get_success(self.store.is_partial_state_room(room_id))) + + # A local device change happens while the room is in partial state. + self.get_success( + self.store.add_device_change_to_streams( + local_user, ["NEW_DEVICE"], [room_id] + ) + ) + + if prune_device_lists_change_in_room: + # Add a device change for another room, as we won't prune the most + # recent change. + self.get_success( + self.store.add_device_change_to_streams( + "@other:user", ["device1"], ["!some:room"] + ) + ) + + # Now prune the device list changes for the room. This simulates the + # case where the unpartial process prunes the + # `device_list_changes_in_room` table before processing the device + # list changes. + self.reactor.advance(120) # Advance past the pruning threshold + self.get_success(self.store._prune_device_lists_changes_in_room()) + + # Assert we actually pruned the device list changes for the room. + room_ids = self.get_success( + self.store.db_pool.simple_select_onecol( + table="device_lists_changes_in_room", + keyvalues={}, + retcol="room_id", + ) + ) + self.assertCountEqual(room_ids, ["!some:room"]) + + # Join the second server + new_state = dict(room_state) + new_state[(EventTypes.Member, local_user)] = local_join_event + join_event_2 = self._build_signed_join_event( + room_id, + self.REMOTE2_USER, + self.REMOTE2_SERVER_SIGNATURE_KEY, + new_state, + ) + + self.get_success( + self.hs.get_federation_event_handler().on_receive_pdu( + self.REMOTE2_SERVER_NAME, join_event_2 + ) + ) + + # Some EDUs may get sent out immediately, such as presence updates. + # However, we only care about the device list update EDU sent by the + # unpartialling process. Let's wait a few seconds and reset the mock. + self.reactor.advance(5) + self.federation_transport_client.send_transaction.reset_mock() + + # We now unblock the unpartial processs by returning the room state and + # state ids. This should trigger the device list update to be sent to + # REMOTE2_SERVER_NAME. + self.federation_transport_client.get_room_state_ids.assert_called_once_with( + self.REMOTE1_SERVER_NAME, + room_id, + event_id=local_join_event.prev_event_ids()[0], + ) + + get_room_state_ids_deferred.callback( + { + "pdu_ids": [event.event_id for event in room_state.values()], + "auth_event_ids": [], + } + ) + get_room_state_deferred.callback( + StateRequestResponse( + state=list(room_state.values()), + auth_events=[], + ) + ) + + # The device list EDU isn't necessarily sent out immediately + self.reactor.advance(30) + + # Check that only one transaction was sent, and that it contains the + # device list update EDU for the new device to REMOTE2_SERVER_NAME. + self.federation_transport_client.send_transaction.assert_called_once() + args, _ = self.federation_transport_client.send_transaction.call_args + transaction: Transaction = args[0] + + self.assertEqual(transaction.destination, self.REMOTE2_SERVER_NAME) + self.assertEqual(len(transaction.edus), 1) + + edu = transaction.edus[0] + self.assertEqual(edu["edu_type"], "m.device_list_update") + self.assertEqual(edu["content"]["device_id"], "NEW_DEVICE") diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 1d1979e19f..b153c74980 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -19,15 +19,21 @@ # # +import itertools from typing import Collection +from unittest.mock import patch from twisted.internet.testing import MemoryReactor import synapse.api.errors from synapse.api.constants import EduTypes from synapse.server import HomeServer +from synapse.storage.databases.main.devices import ( + PRUNE_DEVICE_LISTS_CHANGES_IN_ROOM_AGE, +) from synapse.types import JsonDict from synapse.util.clock import Clock +from synapse.util.duration import Duration from tests.unittest import HomeserverTestCase @@ -351,3 +357,101 @@ class DeviceStoreTestCase(HomeserverTestCase): synapse.api.errors.StoreError, ) self.assertEqual(404, exc.value.code) + + @patch("synapse.storage.databases.main.devices.PRUNE_DEVICE_LISTS_BATCH_SIZE", 5) + def test_prune_old_device_lists_changes_in_room(self) -> None: + """Test that old entries in the `device_lists_changes_in_room` table are pruned properly.""" + + # Pretend the user is in a few rooms. + room_ids = [f"!room{i}:test" for i in range(20)] + + # Create a generator for device IDs so we can easily create many unique + # device IDs without having to keep track of the count ourselves. + device_id_gen = (f"device_id{i}" for i in itertools.count()) + + def get_devices_in_room_status() -> tuple[int, str]: + """Helper function to get the count of entries in + `device_lists_changes_in_room` and the minimum device_id.""" + return self.get_success( + self.store.db_pool.simple_select_one( + table="device_lists_changes_in_room", + keyvalues={}, + retcols=("COUNT(*)", "MIN(device_id)"), + ) + ) + + # First we add some initial entries to the `device_lists_changes_in_room`. + self.get_success( + self.store.add_device_change_to_streams( + user_id="@user_id:test", + device_ids=[next(device_id_gen) for _ in range(10)], + room_ids=room_ids, + ) + ) + + # Advance the reactor a while, but not long enough to trigger pruning. + self.reactor.advance(Duration(hours=1).as_secs()) + + # The `device_lists_changes_in_room` table should now have 10 * + # len(room_ids) entries, and the minimum device_id should be + # `device_id0`. + count, min_device_id = get_devices_in_room_status() + self.assertEqual(count, 10 * len(room_ids)) + self.assertEqual(min_device_id, "device_id0") + + # Record the max pruned stream ID before pruning, so we can check + # that this correctly updates after pruning. + starting_max_pruned_id = self.get_success( + self.store.db_pool.runInteraction( + "get_max_pruned_device_lists_changes_in_room", + self.store._get_max_pruned_device_lists_changes_in_room_txn, + ) + ) + + # Now we add some more entries. + self.get_success( + self.store.add_device_change_to_streams( + user_id="@user_id:test", + device_ids=[next(device_id_gen) for _ in range(10)], + room_ids=room_ids, + ) + ) + + # Advance the reactor a while more, so that the first batch of entries is + # now old enough to be pruned. + self.reactor.advance( + (PRUNE_DEVICE_LISTS_CHANGES_IN_ROOM_AGE - Duration(minutes=30)).as_secs() + ) + + # Advance repeatedly a bit so that the pruning process can run to completion. + for _ in range(10): + self.reactor.advance(Duration(milliseconds=110).as_secs()) + + # Check that the old entries have been pruned, and the new entries are still there. + count, min_device_id = get_devices_in_room_status() + self.assertEqual(count, 10 * len(room_ids)) + self.assertEqual(min_device_id, "device_id10") + + # We should always keep the most recent entries, even if they are old enough to be pruned. + self.reactor.advance( + (PRUNE_DEVICE_LISTS_CHANGES_IN_ROOM_AGE + Duration(minutes=30)).as_secs() + ) + + # Advance repeatedly a bit so that the pruning process can run to completion. + for _ in range(10): + self.reactor.advance(Duration(milliseconds=110).as_secs()) + + count, min_device_id = get_devices_in_room_status() + # We should always keep the most recent entries so that we can + # calculate the maximum stream ID used. + self.assertEqual(count, len(room_ids)) + self.assertEqual(min_device_id, "device_id19") + + # Check that the max pruned stream ID has been advanced after pruning. + max_pruned_id = self.get_success( + self.store.db_pool.runInteraction( + "get_max_pruned_device_lists_changes_in_room", + self.store._get_max_pruned_device_lists_changes_in_room_txn, + ) + ) + self.assertGreater(max_pruned_id, starting_max_pruned_id) From 67b4d8e7e3c92c41ae34464ae4baf595765b0262 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Fri, 17 Apr 2026 09:50:37 -0500 Subject: [PATCH 5/8] Add docs for what to document about a new stream (#19696) Spawning from the follow-up necessary when adding a new stream (https://github.com/element-hq/synapse/pull/19694) --- changelog.d/19696.doc | 1 + docs/development/synapse_architecture/streams.md | 7 +++++++ 2 files changed, 8 insertions(+) create mode 100644 changelog.d/19696.doc diff --git a/changelog.d/19696.doc b/changelog.d/19696.doc new file mode 100644 index 0000000000..c531359444 --- /dev/null +++ b/changelog.d/19696.doc @@ -0,0 +1 @@ +Update the developer stream docs for creating a new stream to highlight places that require documentation updates. diff --git a/docs/development/synapse_architecture/streams.md b/docs/development/synapse_architecture/streams.md index b4057199a9..e7ab79091e 100644 --- a/docs/development/synapse_architecture/streams.md +++ b/docs/development/synapse_architecture/streams.md @@ -179,6 +179,13 @@ necessary registration and event handling. - don't forget the super call - add local-only [invalidations to your writer transactions](https://github.com/element-hq/synapse/blob/4367fb2d078c52959aeca0fe6874539c53e8360d/synapse/storage/databases/main/thread_subscriptions.py#L201) +**Update docs:** +- Update the [*Stream + writers*](https://github.com/element-hq/synapse/blob/develop/docs/workers.md#stream-writers) + section in the worker docs with a new section for the stream +- If this stream can only be handled by specific workers, add a new section to the + [upgrade notes](https://github.com/element-hq/synapse/blob/develop/docs/upgrade.md). + **For streams to be used in sync:** - add a new field to [`StreamToken`](https://github.com/element-hq/synapse/blob/4367fb2d078c52959aeca0fe6874539c53e8360d/synapse/types/__init__.py#L1003) - add a new [`StreamKeyType`](https://github.com/element-hq/synapse/blob/4367fb2d078c52959aeca0fe6874539c53e8360d/synapse/types/__init__.py#L999) From a9361c4f51ab35a7dd4955dfdb3db7c57ae2bc9b Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 17 Apr 2026 16:27:41 +0100 Subject: [PATCH 6/8] Bail out if `admin_unsafely_bypass_quarantine` was used by a non-admin (#19639) --- changelog.d/19639.bugfix | 1 + synapse/rest/client/media.py | 1 + tests/rest/admin/test_admin.py | 45 +++++++++++++++++++++++++++++++++- 3 files changed, 46 insertions(+), 1 deletion(-) create mode 100644 changelog.d/19639.bugfix diff --git a/changelog.d/19639.bugfix b/changelog.d/19639.bugfix new file mode 100644 index 0000000000..2d2928c1ee --- /dev/null +++ b/changelog.d/19639.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.145 where a non-admin could bypass admin checks for downloading remote quarantined media. This relied on the media already being previously present on the homeserver. \ No newline at end of file diff --git a/synapse/rest/client/media.py b/synapse/rest/client/media.py index 4db3b01576..15f58acb95 100644 --- a/synapse/rest/client/media.py +++ b/synapse/rest/client/media.py @@ -253,6 +253,7 @@ class DownloadResource(RestServlet): ), send_cors=True, ) + return set_cors_headers(request) set_corp_headers(request) diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 77d824dcd8..b77a72ec4a 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -18,12 +18,15 @@ # [This file includes modifications made by New Vector Limited] # # +from __future__ import annotations import urllib.parse -from typing import cast +from typing import Any, cast +from unittest.mock import Mock from parameterized import parameterized +from twisted.internet.defer import Deferred from twisted.internet.testing import MemoryReactor from twisted.web.resource import Resource @@ -70,6 +73,24 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): resources["/_matrix/media"] = self.hs.get_media_repository_resource() return resources + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.fetches: list[tuple[tuple[Any, ...], dict[str, Any]]] = [] + + # A remote fetch of media that was not intentional. + # Used to check that remote media fetches do NOT happen. + def unexpected_remote_fetch(*args: Any, **kwargs: Any) -> Deferred[Any]: + self.fetches.append((args, kwargs)) + return Deferred() + + client = Mock() + client.federation_get_file = unexpected_remote_fetch + client.get_file = unexpected_remote_fetch + + return self.setup_test_homeserver( + clock=clock, + federation_http_client=client, + ) + def _ensure_quarantined( self, user_tok: str, @@ -176,6 +197,28 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): ), ) + def test_non_admin_bypass_does_not_fetch_remote_media(self) -> None: + self.register_user("nonadmin", "pass", admin=False) + non_admin_user_tok = self.login("nonadmin", "pass") + + channel = self.make_request( + "GET", + "/_matrix/client/v1/media/download/example.com/remote_media" + "?admin_unsafely_bypass_quarantine=true", + shorthand=False, + access_token=non_admin_user_tok, + await_result=False, + ) + self.pump() + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual( + channel.json_body["error"], + "Must be a server admin to bypass quarantine", + ) + # Check that a remote fetch attempt did not occur. + self.assertEqual(self.fetches, []) + @parameterized.expand( [ # Attempt quarantine media APIs as non-admin From 3cdae2e27828daf9508144fe76423e45cb796853 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 21 Apr 2026 11:39:39 +0100 Subject: [PATCH 7/8] Fix race in new pruning of device lists tables. (#19709) Follows on from #19473. We should be recording where we have deleted up to in the same transaction as we perform the delete, rather than at the end. This code only starts deleting rows after a month (and the original PR isn't in a release yet), so no server should have run into this problem yet. Also let's log more regularly, as the initial set of deletions will likely take a long time. --- changelog.d/19709.misc | 1 + synapse/storage/databases/main/devices.py | 44 ++++++++++++++--------- 2 files changed, 28 insertions(+), 17 deletions(-) create mode 100644 changelog.d/19709.misc diff --git a/changelog.d/19709.misc b/changelog.d/19709.misc new file mode 100644 index 0000000000..596d8a6b26 --- /dev/null +++ b/changelog.d/19709.misc @@ -0,0 +1 @@ +Reduce database disk space usage by pruning old rows from `device_lists_changes_in_room`. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 339fb8a6f7..8670d68f38 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -2583,36 +2583,46 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): num_deleted += 1 min_stream_id = max(min_stream_id, row[0]) + if num_deleted: + # Update the max pruned stream ID tracking table so that the + # safety check knows data up to this point has been deleted. + self.db_pool.simple_update_one_txn( + txn, + table="device_lists_changes_in_room_max_pruned_stream_id", + keyvalues={}, + updatevalues={"stream_id": min_stream_id}, + ) + return num_deleted - num_rows_deleted = 0 + progress_num_rows_deleted = 0 while True: batch_deleted = await self.db_pool.runInteraction( "prune_device_lists_changes_in_room", prune_device_lists_changes_in_room_txn, ) - num_rows_deleted += batch_deleted - if batch_deleted < PRUNE_DEVICE_LISTS_BATCH_SIZE: + + finished = batch_deleted < PRUNE_DEVICE_LISTS_BATCH_SIZE + + progress_num_rows_deleted += batch_deleted + + # Periodically report progress in the logs. We do this either when + # we've deleted a significant number of rows or when we've finished + # deleting all rows in this round. + if finished or progress_num_rows_deleted > 10000: + logger.info( + "Pruned %d rows from device_lists_changes_in_room", + progress_num_rows_deleted, + ) + progress_num_rows_deleted = 0 + + if finished: break # Sleep for a short time to avoid hammering the database too much if # there are a lot of rows to delete. await self.clock.sleep(Duration(milliseconds=100)) - if num_rows_deleted: - # Update the max pruned stream ID tracking table so that the - # safety check knows data up to this point has been deleted. - await self.db_pool.simple_update_one( - table="device_lists_changes_in_room_max_pruned_stream_id", - keyvalues={}, - updatevalues={"stream_id": prune_before_stream_id}, - desc="prune_device_lists_changes_in_room_update_max_pruned", - ) - - logger.info( - "Pruned %d rows from device_lists_changes_in_room", num_rows_deleted - ) - class DeviceBackgroundUpdateStore(SQLBaseStore): _instance_name: str From c8ce96f504a671bd97d0aeb9225bbf967bd900d6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 22 Apr 2026 11:43:59 +0100 Subject: [PATCH 8/8] Reinstate removed EventBase methods (#19712) Both `__getitem__` and `.user_id` were removed in #19680 to simplify the event class. However, `EventBase` is exposed to modules who might also make use of those methods, so let's reinstate them (but otherwise not reinstate the usage of them in the code). --- changelog.d/19712.misc | 1 + synapse/events/__init__.py | 9 +++++++++ tests/module_api/test_api.py | 18 ++++++++++++++++++ 3 files changed, 28 insertions(+) create mode 100644 changelog.d/19712.misc diff --git a/changelog.d/19712.misc b/changelog.d/19712.misc new file mode 100644 index 0000000000..c8fa79bf47 --- /dev/null +++ b/changelog.d/19712.misc @@ -0,0 +1 @@ +Small simplifications to the events class. diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index f4a5624d1a..3c46d02e92 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -34,6 +34,7 @@ from typing import ( ) import attr +from typing_extensions import deprecated from unpaddedbase64 import encode_base64 from synapse.api.constants import ( @@ -219,6 +220,9 @@ class EventBase(metaclass=abc.ABCMeta): state_key: DictProperty[str] = DictProperty("state_key") type: DictProperty[str] = DictProperty("type") + # This is a deprecated property, use `sender` instead. Only used by modules. + user_id: DictProperty[str] = DictProperty("sender") + @property def event_id(self) -> str: raise NotImplementedError() @@ -360,6 +364,11 @@ class EventBase(metaclass=abc.ABCMeta): ">" ) + # Using `__getitem__` is deprecated. Only used by modules. + @deprecated("Use attribute access instead") + def __getitem__(self, field: str) -> Any | None: + return self._dict[field] + class FrozenEvent(EventBase): format_version = EventFormatVersions.ROOM_V1_V2 # All events of this type are V1 diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 12c8942bc8..f1b20a12ec 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -828,6 +828,24 @@ class ModuleApiTestCase(BaseModuleApiTestCase): # Ensure the pushers were deleted after the callback. self.assertEqual(len(self.hs.get_pusherpool().pushers[user_id].values()), 0) + def test_event_deprecated_methods(self) -> None: + """Test that deprecated methods on events are still functional.""" + user_id = self.register_user("user", "password") + tok = self.login("user", "password") + + room_id = self.helper.create_room_as(tok=tok) + + state = self.get_success( + self.hs.get_storage_controllers().state.get_current_state(room_id) + ) + create_event = state[(EventTypes.Create, "")] + + # `.user_id` is a deprecated alias for `.sender`. + self.assertEqual(create_event.user_id, user_id) + + # The event supports looking up keys via `__getitem__` although deprecated + self.assertEqual(create_event["room_id"], room_id) + class ModuleApiWorkerTestCase(BaseModuleApiTestCase, BaseMultiWorkerStreamTestCase): """For testing ModuleApi functionality in a multi-worker setup"""