diff --git a/CHANGES.md b/CHANGES.md index 3425bcca8c..d9b3f8b2c1 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,11 @@ +# Synapse 1.152.1 (2026-05-07) + +## Security Fixes + +- Prevent CPU starvation (Denial of Service) under worker lock contention, additionally capping the `WorkerLock` time out interval to a maximum of 60 seconds. Contributed by Famedly. ([\#19394](https://github.com/element-hq/synapse/issues/19394), ELEMENTSEC-2026-1706, [GHSA-8q93-326v-3m7g](https://github.com/element-hq/synapse/security/advisories/GHSA-8q93-326v-3m7g), CVE pending) +- Prevent pagination ending when a page is full of rejected events. (ELEMENTSEC-2025-1636, [GHSA-6qf2-7x63-mm6v](https://github.com/element-hq/synapse/security/advisories/GHSA-6qf2-7x63-mm6v), CVE pending) + + # Synapse 1.152.0 (2026-04-28) No significant changes since 1.152.0rc1. diff --git a/changelog.d/19394.bugfix b/changelog.d/19394.bugfix new file mode 100644 index 0000000000..4ca92cfb32 --- /dev/null +++ b/changelog.d/19394.bugfix @@ -0,0 +1 @@ +Capped the `WorkerLock` time out interval to a maximum of 60 seconds to prevent dealing with excessively long numbers. Contributed by Famedly. diff --git a/changelog.d/19706.misc b/changelog.d/19706.misc new file mode 100644 index 0000000000..205abd09d4 --- /dev/null +++ b/changelog.d/19706.misc @@ -0,0 +1 @@ +Port `Event.signatures` field to Rust. diff --git a/changelog.d/19708.misc b/changelog.d/19708.misc new file mode 100644 index 0000000000..308c2b04d0 --- /dev/null +++ b/changelog.d/19708.misc @@ -0,0 +1 @@ +Port `Event.unsigned` field to Rust. diff --git a/changelog.d/19755.misc b/changelog.d/19755.misc new file mode 100644 index 0000000000..6ad478e531 --- /dev/null +++ b/changelog.d/19755.misc @@ -0,0 +1 @@ +Reduce `WORKER_LOCK_MAX_RETRY_INTERVAL` to 5 seconds to reduce idle time after lock is released. diff --git a/changelog.d/19756.misc b/changelog.d/19756.misc new file mode 100644 index 0000000000..2450505b53 --- /dev/null +++ b/changelog.d/19756.misc @@ -0,0 +1 @@ +Force keyword-only args for `Duration` (prevent footgun) so people have to specify which time unit they want to us. diff --git a/debian/changelog b/debian/changelog index ff9dfe3e13..cfefe953e3 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.152.1) stable; urgency=medium + + * New Synapse release 1.152.1. + + -- Synapse Packaging team Thu, 07 May 2026 13:29:05 +0100 + matrix-synapse-py3 (1.152.0) stable; urgency=medium * New Synapse release 1.152.0. diff --git a/pyproject.toml b/pyproject.toml index f4c94acc1c..7ead67c8f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "matrix-synapse" -version = "1.152.0" +version = "1.152.1" description = "Homeserver for the Matrix decentralised comms protocol" readme = "README.rst" authors = [ diff --git a/rust/Cargo.toml b/rust/Cargo.toml index e6b378a092..5bdd194707 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -43,7 +43,7 @@ pyo3-log = "0.13.1" pythonize = "0.27.0" regex = "1.6.0" sha2 = "0.10.8" -serde = { version = "1.0.144", features = ["derive"] } +serde = { version = "1.0.144", features = ["derive", "rc"] } serde_json = { version = "1.0.85", features = ["raw_value"] } ulid = "1.1.2" icu_segmenter = "2.0.0" @@ -58,10 +58,6 @@ tokio = { version = "1.44.2", features = ["rt", "rt-multi-thread"] } once_cell = "1.18.0" itertools = "0.14.0" -[features] -extension-module = ["pyo3/extension-module"] -default = ["extension-module"] - [build-dependencies] blake2 = "0.10.4" hex = "0.4.3" diff --git a/rust/src/events/mod.rs b/rust/src/events/mod.rs index 209efb917b..5f505abb91 100644 --- a/rust/src/events/mod.rs +++ b/rust/src/events/mod.rs @@ -27,11 +27,15 @@ use pyo3::{ pub mod filter; mod internal_metadata; +pub mod signatures; +pub mod unsigned; /// Called when registering modules with python. pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { let child_module = PyModule::new(py, "events")?; child_module.add_class::()?; + child_module.add_class::()?; + child_module.add_class::()?; child_module.add_function(wrap_pyfunction!(filter::event_visible_to_server_py, m)?)?; m.add_submodule(&child_module)?; diff --git a/rust/src/events/signatures.rs b/rust/src/events/signatures.rs new file mode 100644 index 0000000000..0f2acd5c9b --- /dev/null +++ b/rust/src/events/signatures.rs @@ -0,0 +1,348 @@ +/* + * This file is licensed under the Affero General Public License (AGPL) version 3. + * + * Copyright (C) 2026 Element Creations Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * See the GNU Affero General Public License for more details: + * . + * + */ + +//! Class for representing event signatures + +use std::{ + collections::HashMap, + sync::{Arc, RwLock}, +}; + +use pyo3::{ + exceptions::{PyKeyError, PyRuntimeError}, + pyclass, pymethods, + types::{PyAnyMethods, PyDict, PyMapping, PyMappingMethods}, + Bound, IntoPyObject, PyAny, PyResult, Python, +}; +use serde::{Deserialize, Serialize}; + +/// A class representing the signatures on an event. +#[pyclass(frozen, skip_from_py_object)] +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(transparent)] +pub struct Signatures { + inner: Arc>>>, +} + +#[pymethods] +impl Signatures { + #[new] + #[pyo3(signature = (signatures = None))] + fn py_new(signatures: Option>>) -> Self { + let mut signatures = signatures.unwrap_or_default(); + + // Prune any entries that have no signatures. + signatures.retain(|_, server_sigs| !server_sigs.is_empty()); + + Self { + inner: Arc::new(RwLock::new(signatures)), + } + } + + /// Check if the signatures contain a signature for the given server name. + fn __contains__(&self, key: Bound<'_, PyAny>) -> PyResult { + let Ok(key) = key.extract::<&str>() else { + return Ok(false); + }; + + let signatures = self + .inner + .read() + .map_err(|_| PyRuntimeError::new_err("Failed to acquire lock"))?; + Ok(signatures.contains_key(key)) + } + + /// Get the number of servers that have signatures. + fn __len__(&self) -> PyResult { + let signatures = self + .inner + .read() + .map_err(|_| PyRuntimeError::new_err("Failed to acquire lock"))?; + Ok(signatures.len()) + } + + /// Get the signature for the given server name and key ID, if it exists. + fn get_signature(&self, server_name: &str, key_id: &str) -> PyResult> { + let signatures = self + .inner + .read() + .map_err(|_| PyRuntimeError::new_err("Failed to acquire lock"))?; + + Ok(signatures + .get(server_name) + .and_then(|server_sigs| server_sigs.get(key_id).cloned())) + } + + /// Get the signatures for the given server name. + fn __getitem__(&self, key: Bound<'_, PyAny>) -> PyResult> { + let Some(server_name) = key.extract::<&str>().ok() else { + return Err(PyKeyError::new_err(key.to_string())); + }; + + let signatures = self + .inner + .read() + .map_err(|_| PyRuntimeError::new_err("Failed to acquire lock"))?; + + if let Some(server_sigs) = signatures.get(server_name) { + Ok(server_sigs.clone()) + } else { + Err(PyKeyError::new_err(server_name.to_string())) + } + } + + /// Add a signature for the given server name and key ID. + fn add_signature( + &self, + server_name: String, + key_id: String, + signature: String, + ) -> PyResult<()> { + let mut signatures = self + .inner + .write() + .map_err(|_| PyRuntimeError::new_err("Failed to acquire lock"))?; + + signatures + .entry(server_name) + .or_default() + .insert(key_id, signature); + + Ok(()) + } + + /// Update the signatures with the given signatures. + /// + /// Will overwrite all existing signatures for the server names provided. + fn update(&self, other: &Bound<'_, PyMapping>) -> PyResult<()> { + let mut signatures = self + .inner + .write() + .map_err(|_| PyRuntimeError::new_err("Failed to acquire lock"))?; + + for list_entry in other.items()? { + let (server_name, server_sigs) = list_entry.extract::<(String, Bound)>()?; + + let mut entry = HashMap::new(); + for list_entry in server_sigs.items()? { + let (key, value) = list_entry.extract::<(String, String)>()?; + entry.insert(key, value); + } + + // Only insert the entry if it has at least one signature. + if !entry.is_empty() { + signatures.insert(server_name, entry); + } else { + signatures.remove(&server_name); + } + } + + Ok(()) + } + + /// Return a copy of the signatures as a dictionary. + fn as_dict<'py>(&self, py: Python<'py>) -> PyResult> { + let signatures = self + .inner + .read() + .map_err(|_| PyRuntimeError::new_err("Failed to acquire lock"))?; + + (&*signatures).into_pyobject(py) + } + + fn __repr__(&self) -> PyResult { + let signatures = self + .inner + .read() + .map_err(|_| PyRuntimeError::new_err("Failed to acquire lock"))?; + + Ok(format!("Signatures({signatures:?})")) + } +} + +#[cfg(test)] +mod tests { + use pythonize::pythonize; + + use super::*; + + /// Helper that reads the inner map directly. + fn read_inner(sigs: &Signatures) -> HashMap> { + sigs.inner.read().expect("lock poisoned").clone() + } + + /// Helper to create a server signatures map from a list of (key_id, sig) + /// pairs. + fn make_server_sigs(data: &[(&str, &str)]) -> HashMap { + let mut map = HashMap::new(); + for (key_id, sig) in data { + map.insert((*key_id).to_owned(), (*sig).to_owned()); + } + map + } + + /// Helper to create a `Signatures` object from a list of (server_name, + /// key_id, sig) tuples. + fn create_signatures(data: &[(&str, &str, &str)]) -> Signatures { + let mut map: HashMap> = HashMap::new(); + for (server_name, key_id, sig) in data { + map.entry((*server_name).to_owned()) + .or_default() + .insert((*key_id).to_owned(), (*sig).to_owned()); + } + Signatures::py_new(Some(map)) + } + + #[test] + fn test_new_empty() { + let sigs = Signatures::py_new(None); + assert!(read_inner(&sigs).is_empty()); + assert_eq!(sigs.__len__().unwrap(), 0); + } + + #[test] + fn test_new_with_data() { + let sigs = create_signatures(&[("example.com", "ed25519:key1", "sig1")]); + assert_eq!(sigs.__len__().unwrap(), 1); + assert_eq!( + sigs.get_signature("example.com", "ed25519:key1").unwrap(), + Some("sig1".to_string()) + ); + } + + #[test] + fn test_new_prunes_servers_with_no_signatures() { + let mut data = HashMap::new(); + data.insert("empty.example.com".to_string(), HashMap::new()); + data.insert( + "example.com".to_string(), + make_server_sigs(&[("ed25519:key1", "sig1")]), + ); + + let sigs = Signatures::py_new(Some(data)); + + let inner = read_inner(&sigs); + assert_eq!(inner.len(), 1); + assert!(inner.contains_key("example.com")); + assert!(!inner.contains_key("empty.example.com")); + } + + #[test] + fn test_add_signature() { + let sigs = Signatures::py_new(None); + sigs.add_signature( + "example.com".to_string(), + "ed25519:key1".to_string(), + "sig1".to_string(), + ) + .unwrap(); + + let inner = read_inner(&sigs); + assert_eq!(inner.len(), 1); + assert_eq!( + inner.get("example.com").and_then(|m| m.get("ed25519:key1")), + Some(&"sig1".to_string()) + ); + } + + #[test] + fn test_add_signature_to_existing_server() { + let sigs = create_signatures(&[("example.com", "ed25519:key1", "sig1")]); + sigs.add_signature( + "example.com".to_string(), + "ed25519:key2".to_string(), + "sig2".to_string(), + ) + .unwrap(); + + let inner = read_inner(&sigs); + assert_eq!(inner.len(), 1); + assert_eq!( + inner.get("example.com").and_then(|m| m.get("ed25519:key1")), + Some(&"sig1".to_string()) + ); + assert_eq!( + inner.get("example.com").and_then(|m| m.get("ed25519:key2")), + Some(&"sig2".to_string()) + ); + } + + #[test] + fn test_update_signatures_clobbers_existing() { + let sigs = create_signatures(&[("example.com", "ed25519:key1", "sig1")]); + + // Create a new signatures map with a different signature for the same + // server. + let mut other = HashMap::new(); + other.insert( + "example.com".to_string(), + make_server_sigs(&[("ed25519:key2", "sig2")]), + ); + + // Update the signatures with the new map. + Python::initialize(); + Python::attach(|py| { + let value = pythonize(py, &other).unwrap(); + let value = value.cast::().unwrap(); + + sigs.update(value).unwrap(); + }); + + // Check that the old signature has been replaced with the new one. + let inner = read_inner(&sigs); + assert_eq!(inner.len(), 1); + assert_eq!(inner["example.com"].len(), 1); + assert_eq!(inner["example.com"]["ed25519:key2"], "sig2"); + } + + #[test] + fn test_serialize() { + let mut data = HashMap::new(); + data.insert( + "example.com".to_string(), + make_server_sigs(&[("ed25519:key1", "sig1")]), + ); + let sigs = Signatures::py_new(Some(data)); + + let json = serde_json::to_string(&sigs).unwrap(); + assert_eq!(json, r#"{"example.com":{"ed25519:key1":"sig1"}}"#); + } + + #[test] + fn test_serialize_empty() { + let sigs = Signatures::py_new(None); + let json = serde_json::to_string(&sigs).unwrap(); + assert_eq!(json, "{}"); + } + + #[test] + fn test_deserialize() { + let json = r#"{"example.com":{"ed25519:key1":"sig1"}}"#; + let sigs: Signatures = serde_json::from_str(json).unwrap(); + + let inner = read_inner(&sigs); + assert_eq!(inner.len(), 1); + assert_eq!( + inner.get("example.com").and_then(|m| m.get("ed25519:key1")), + Some(&"sig1".to_string()) + ); + } + + #[test] + fn test_deserialize_empty() { + let sigs: Signatures = serde_json::from_str("{}").unwrap(); + assert!(read_inner(&sigs).is_empty()); + } +} diff --git a/rust/src/events/unsigned.rs b/rust/src/events/unsigned.rs new file mode 100644 index 0000000000..c41ed7e6e1 --- /dev/null +++ b/rust/src/events/unsigned.rs @@ -0,0 +1,429 @@ +/* + * This file is licensed under the Affero General Public License (AGPL) version 3. + * + * Copyright (C) 2026 Element Creations Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * See the GNU Affero General Public License for more details: + * . + * + */ + +use std::sync::{Arc, RwLock, RwLockReadGuard}; + +use pyo3::{ + exceptions::{PyKeyError, PyRuntimeError, PyTypeError}, + pyclass, pymethods, + types::{PyAnyMethods, PyList, PyListMethods, PyMapping}, + Bound, IntoPyObjectExt, PyAny, PyResult, Python, +}; +use pythonize::{depythonize, pythonize}; +use serde::{Deserialize, Serialize}; + +#[pyclass(frozen, skip_from_py_object)] +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[serde(transparent)] +pub struct Unsigned { + inner: Arc>, +} + +/// The fields in the unsigned data of an event that are persisted in the +/// database. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +struct PersistedUnsignedFields { + #[serde(skip_serializing_if = "Option::is_none")] + age_ts: Option, + #[serde(skip_serializing_if = "Option::is_none")] + replaces_state: Option, + #[serde(skip_serializing_if = "Option::is_none")] + invite_room_state: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + knock_room_state: Option>, +} + +/// The inner representation of the unsigned data of an event, which includes +/// both the fields that are persisted in the database and the fields that are +/// only used in memory. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct UnsignedInner { + #[serde(flatten)] + persisted_fields: PersistedUnsignedFields, + #[serde(skip_serializing_if = "Option::is_none")] + prev_content: Option>, // We use Box to minimise stack space + #[serde(skip_serializing_if = "Option::is_none")] + prev_sender: Option, +} + +/// The fields that exist on the unsigned data of an event. +/// +/// This is used when converting from python to rust, to ensure that if we add a +/// new field we don't forget to add it to all the necessary places. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum UnsignedField { + AgeTs, + ReplacesState, + InviteRoomState, + KnockRoomState, + PrevContent, + PrevSender, +} + +impl std::str::FromStr for UnsignedField { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "age_ts" => Ok(Self::AgeTs), + "replaces_state" => Ok(Self::ReplacesState), + "invite_room_state" => Ok(Self::InviteRoomState), + "knock_room_state" => Ok(Self::KnockRoomState), + "prev_content" => Ok(Self::PrevContent), + "prev_sender" => Ok(Self::PrevSender), + _ => Err(()), + } + } +} + +impl Unsigned { + fn py_read(&self) -> PyResult> { + self.inner + .read() + .map_err(|_| PyRuntimeError::new_err("Unsigned lock poisoned")) + } + + fn py_write(&self) -> PyResult> { + self.inner + .write() + .map_err(|_| PyRuntimeError::new_err("Unsigned lock poisoned")) + } +} + +#[pymethods] +impl Unsigned { + #[new] + fn py_new(unsigned: Bound<'_, PyMapping>) -> PyResult { + let inner = depythonize(&unsigned)?; + + Ok(Self { + inner: Arc::new(RwLock::new(inner)), + }) + } + + fn __getitem__<'py>( + &self, + py: Python<'py>, + key: Bound<'_, PyAny>, + ) -> PyResult> { + let key = key + .extract::<&str>() + .map_err(|_| PyTypeError::new_err("Unsigned keys must be strings"))?; + + let field: UnsignedField = key + .parse() + .map_err(|_| PyKeyError::new_err(format!("Unsigned has no key '{key}'")))?; + + let unsigned = self.py_read()?; + + match field { + UnsignedField::AgeTs => Ok(unsigned + .persisted_fields + .age_ts + .ok_or_else(|| PyKeyError::new_err("age_ts"))? + .into_bound_py_any(py)?), + UnsignedField::ReplacesState => Ok((unsigned.persisted_fields.replaces_state) + .as_ref() + .ok_or_else(|| PyKeyError::new_err("replaces_state"))? + .into_bound_py_any(py)?), + UnsignedField::InviteRoomState => Ok(room_state_to_py( + py, + unsigned + .persisted_fields + .invite_room_state + .as_ref() + .ok_or_else(|| PyKeyError::new_err("invite_room_state"))?, + )?), + UnsignedField::KnockRoomState => Ok(room_state_to_py( + py, + unsigned + .persisted_fields + .knock_room_state + .as_ref() + .ok_or_else(|| PyKeyError::new_err("knock_room_state"))?, + )?), + UnsignedField::PrevContent => Ok(pythonize( + py, + unsigned + .prev_content + .as_ref() + .ok_or_else(|| PyKeyError::new_err("prev_content"))?, + )?), + UnsignedField::PrevSender => Ok((unsigned.prev_sender) + .as_ref() + .ok_or_else(|| PyKeyError::new_err("prev_sender"))? + .into_bound_py_any(py)?), + } + } + + fn __contains__(&self, key: Bound<'_, PyAny>) -> PyResult { + let Ok(key) = key.extract::<&str>() else { + return Ok(false); + }; + + let Ok(field) = key.parse::() else { + return Ok(false); + }; + + let unsigned = self.py_read()?; + + let exists = match field { + UnsignedField::AgeTs => unsigned.persisted_fields.age_ts.is_some(), + UnsignedField::ReplacesState => unsigned.persisted_fields.replaces_state.is_some(), + UnsignedField::InviteRoomState => unsigned.persisted_fields.invite_room_state.is_some(), + UnsignedField::KnockRoomState => unsigned.persisted_fields.knock_room_state.is_some(), + UnsignedField::PrevContent => unsigned.prev_content.is_some(), + UnsignedField::PrevSender => unsigned.prev_sender.is_some(), + }; + + Ok(exists) + } + + fn __setitem__(&self, key: Bound<'_, PyAny>, value: Bound<'_, PyAny>) -> PyResult<()> { + let key = key + .extract::<&str>() + .map_err(|_| PyTypeError::new_err("Unsigned keys must be strings"))?; + + let field: UnsignedField = key + .parse() + .map_err(|_| PyKeyError::new_err(format!("Unsigned has no key '{key}'")))?; + + let mut unsigned = self.py_write()?; + + match field { + UnsignedField::AgeTs => unsigned.persisted_fields.age_ts = Some(value.extract()?), + UnsignedField::ReplacesState => { + unsigned.persisted_fields.replaces_state = Some(value.extract()?) + } + UnsignedField::InviteRoomState => { + unsigned.persisted_fields.invite_room_state = Some(room_state_from_py(value)?) + } + UnsignedField::KnockRoomState => { + unsigned.persisted_fields.knock_room_state = Some(room_state_from_py(value)?) + } + UnsignedField::PrevContent => { + unsigned.prev_content = Some(Box::new(depythonize(&value)?)) + } + UnsignedField::PrevSender => unsigned.prev_sender = Some(value.extract()?), + } + + Ok(()) + } + + fn __delitem__(&self, key: Bound<'_, PyAny>) -> PyResult<()> { + let key = key + .extract::<&str>() + .map_err(|_| PyTypeError::new_err("Unsigned keys must be strings"))?; + + let field: UnsignedField = key + .parse() + .map_err(|_| PyKeyError::new_err(format!("Unsigned has no key '{key}'")))?; + + let mut unsigned = self.py_write()?; + + match field { + UnsignedField::AgeTs => unsigned.persisted_fields.age_ts = None, + UnsignedField::ReplacesState => unsigned.persisted_fields.replaces_state = None, + UnsignedField::InviteRoomState => unsigned.persisted_fields.invite_room_state = None, + UnsignedField::KnockRoomState => unsigned.persisted_fields.knock_room_state = None, + UnsignedField::PrevContent => unsigned.prev_content = None, + UnsignedField::PrevSender => unsigned.prev_sender = None, + } + + Ok(()) + } + + #[pyo3(signature = (key, default=None))] + fn get<'py>( + &self, + py: Python<'py>, + key: Bound<'py, PyAny>, + default: Option>, + ) -> PyResult>> { + match self.__getitem__(py, key) { + Ok(value) => Ok(Some(value)), + Err(err) => { + if err.is_instance_of::(py) { + Ok(default) + } else { + Err(err) + } + } + } + } + + fn for_persistence<'py>(&self, py: Python<'py>) -> PyResult> { + Ok(pythonize(py, &self.py_read()?.persisted_fields)?) + } + + fn for_event<'py>(&self, py: Python<'py>) -> PyResult> { + Ok(pythonize(py, &*self.py_read()?)?) + } +} + +fn room_state_to_py<'py>( + py: Python<'py>, + state: &[serde_json::Value], +) -> PyResult> { + let py_list = PyList::empty(py); + + for item in state { + py_list.append(pythonize(py, item)?)?; + } + + py_list.into_bound_py_any(py) +} + +fn room_state_from_py(value: Bound<'_, PyAny>) -> PyResult> { + let py_list = value.cast::()?; + + let mut state = Vec::with_capacity(py_list.len()); + for item in py_list.iter() { + state.push(pythonize::depythonize(&item)?); + } + + Ok(state) +} + +#[cfg(test)] +mod tests { + use serde_json::json; + + use super::*; + + #[test] + fn test_unsigned_field_from_str_valid() { + assert_eq!("age_ts".parse(), Ok(UnsignedField::AgeTs)); + assert_eq!("replaces_state".parse(), Ok(UnsignedField::ReplacesState)); + assert_eq!( + "invite_room_state".parse(), + Ok(UnsignedField::InviteRoomState) + ); + assert_eq!( + "knock_room_state".parse(), + Ok(UnsignedField::KnockRoomState) + ); + assert_eq!("prev_content".parse(), Ok(UnsignedField::PrevContent)); + assert_eq!("prev_sender".parse(), Ok(UnsignedField::PrevSender)); + } + + #[test] + fn test_unsigned_field_from_str_invalid() { + assert_eq!("".parse::(), Err(())); + assert_eq!("unknown".parse::(), Err(())); + // Case-sensitive: upper-case should not match. + assert_eq!("AGE_TS".parse::(), Err(())); + // Must be an exact match, no whitespace. + assert_eq!(" age_ts".parse::(), Err(())); + } + + #[test] + fn test_persisted_fields_serialize_empty_is_empty_object() { + let fields = PersistedUnsignedFields::default(); + let json = serde_json::to_value(&fields).unwrap(); + assert_eq!(json, json!({})); + } + + #[test] + fn test_persisted_fields_serialize_populated() { + let fields = PersistedUnsignedFields { + age_ts: Some(1234), + replaces_state: Some("$prev:example.com".to_string()), + invite_room_state: Some(vec![json!({"type": "m.room.name"})]), + knock_room_state: Some(vec![json!({"type": "m.room.topic"})]), + }; + let json = serde_json::to_value(&fields).unwrap(); + assert_eq!( + json, + json!({ + "age_ts": 1234, + "replaces_state": "$prev:example.com", + "invite_room_state": [{"type": "m.room.name"}], + "knock_room_state": [{"type": "m.room.topic"}], + }) + ); + } + + #[test] + fn test_unsigned_inner_flattens_persisted_fields() { + let inner = UnsignedInner { + persisted_fields: PersistedUnsignedFields { + age_ts: Some(99), + ..Default::default() + }, + prev_content: Some(Box::new(json!({"body": "hi"}))), + prev_sender: Some("@alice:example.com".to_string()), + }; + + let json = serde_json::to_value(&inner).unwrap(); + assert_eq!( + json, + json!({ + "age_ts": 99, + "prev_content": {"body": "hi"}, + "prev_sender": "@alice:example.com", + }) + ); + } + + #[test] + fn test_unsigned_inner_roundtrip() { + let original = UnsignedInner { + persisted_fields: PersistedUnsignedFields { + age_ts: Some(10), + replaces_state: Some("$state:example.com".to_string()), + invite_room_state: None, + knock_room_state: None, + }, + prev_content: Some(Box::new(json!({"membership": "join"}))), + prev_sender: None, + }; + + let json = serde_json::to_string(&original).unwrap(); + let roundtripped: UnsignedInner = serde_json::from_str(&json).unwrap(); + + assert_eq!(roundtripped.persisted_fields.age_ts, Some(10)); + assert_eq!( + roundtripped.persisted_fields.replaces_state.as_deref(), + Some("$state:example.com") + ); + assert_eq!( + roundtripped.prev_content.as_deref(), + Some(&json!({"membership": "join"})) + ); + assert_eq!(roundtripped.prev_sender, None); + } + + #[test] + fn test_unsigned_serializes_transparently() { + // `Unsigned` is `#[serde(transparent)]` over its inner, so serializing + // an empty default should yield an empty object rather than a wrapper. + let unsigned = Unsigned::default(); + let json = serde_json::to_value(&unsigned).unwrap(); + assert_eq!(json, json!({})); + } + + #[test] + fn test_unsigned_deserialize_from_flat_object() { + let json = json!({ + "age_ts": 5, + "prev_sender": "@bob:example.com", + }); + let unsigned: Unsigned = serde_json::from_value(json).unwrap(); + let inner = unsigned.inner.read().unwrap(); + assert_eq!(inner.persisted_fields.age_ts, Some(5)); + assert_eq!(inner.prev_sender.as_deref(), Some("@bob:example.com")); + } +} diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py index d789c06a9c..823b6288e8 100644 --- a/synapse/crypto/event_signing.py +++ b/synapse/crypto/event_signing.py @@ -236,9 +236,7 @@ def event_needs_resigning( if sender.domain != server_name: return False want_key_id = verify_key.alg + ":" + verify_key.version - signed_with_current_key_id = ev.signatures.get(server_name, {}).get( - want_key_id, None - ) + signed_with_current_key_id = ev.signatures.get_signature(server_name, want_key_id) if signed_with_current_key_id: return False diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 0d4d5e0e17..36736b4559 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -120,8 +120,18 @@ class VerifyJsonRequest: ) -> "VerifyJsonRequest": """Create a VerifyJsonRequest to verify all signatures on an event object for the given server. + + Raises immediately if the event doesn't have any signatures from the + given server. """ - key_ids = list(event.signatures.get(server_name, [])) + if server_name not in event.signatures: + raise SynapseError( + 400, + f"Not signed by {server_name}", + Codes.UNAUTHORIZED, + ) + + key_ids = list(event.signatures[server_name]) return VerifyJsonRequest( server_name, # We defer creating the redacted json object, as it uses a lot more diff --git a/synapse/event_auth.py b/synapse/event_auth.py index ca528ae235..fd35da8ba0 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -128,7 +128,7 @@ def validate_event_for_room_version(event: "EventBase") -> None: ) # Check the sender's domain has signed the event - if not event.signatures.get(sender_domain): + if sender_domain not in event.signatures: # We allow invites via 3pid to have a sender from a different # HS, as the sender must match the sender of the original # 3pid invite. This is checked further down with the @@ -141,7 +141,7 @@ def validate_event_for_room_version(event: "EventBase") -> None: event_id_domain = get_domain_from_id(event.event_id) # Check the origin domain has signed the event - if not event.signatures.get(event_id_domain): + if event_id_domain not in event.signatures: raise AuthError(403, "Event not signed by sending server") is_invite_via_allow_rule = ( @@ -154,7 +154,7 @@ def validate_event_for_room_version(event: "EventBase") -> None: authoriser_domain = get_domain_from_id( event.content[EventContentFields.AUTHORISING_USER] ) - if not event.signatures.get(authoriser_domain): + if authoriser_domain not in event.signatures: raise AuthError(403, "Event not signed by authorising server") diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index ac9b31cfcf..0f850d19b1 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -44,8 +44,12 @@ from synapse.api.constants import ( StickyEvent, ) from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions -from synapse.synapse_rust.events import EventInternalMetadata -from synapse.types import JsonDict, StateKey, StrCollection +from synapse.synapse_rust.events import EventInternalMetadata, Signatures, Unsigned +from synapse.types import ( + JsonDict, + StateKey, + StrCollection, +) from synapse.util.caches import intern_dict from synapse.util.duration import Duration from synapse.util.frozenutils import freeze @@ -203,8 +207,8 @@ class EventBase(metaclass=abc.ABCMeta): assert room_version.event_format == self.format_version self.room_version = room_version - self.signatures = signatures - self.unsigned = unsigned + self.signatures = Signatures(signatures) + self.unsigned = Unsigned(unsigned) self.rejected_reason = rejected_reason self._dict = event_dict @@ -254,8 +258,26 @@ class EventBase(metaclass=abc.ABCMeta): return self._dict.get("state_key") def get_dict(self) -> JsonDict: + """Convert the event to a dictionary suitable for serialisation.""" d = dict(self._dict) - d.update({"signatures": self.signatures, "unsigned": dict(self.unsigned)}) + d.update( + { + "signatures": self.signatures.as_dict(), + "unsigned": self.unsigned.for_event(), + } + ) + + return d + + def get_dict_for_persistence(self) -> JsonDict: + """Convert the event to a dictionary suitable for persistence.""" + d = dict(self._dict) + d.update( + { + "signatures": self.signatures.as_dict(), + "unsigned": self.unsigned.for_persistence(), + } + ) return d @@ -395,7 +417,7 @@ class FrozenEvent(EventBase): for name, sigs in event_dict.pop("signatures", {}).items() } - unsigned = dict(event_dict.pop("unsigned", {})) + unsigned = event_dict.pop("unsigned", {}) # We intern these strings because they turn up a lot (especially when # caching). @@ -449,7 +471,7 @@ class FrozenEventV2(EventBase): assert "event_id" not in event_dict - unsigned = dict(event_dict.pop("unsigned", {})) + unsigned = event_dict.pop("unsigned", {}) # We intern these strings because they turn up a lot (especially when # caching). diff --git a/synapse/events/utils.py b/synapse/events/utils.py index f038fb5578..926c81b83d 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -47,6 +47,7 @@ from synapse.api.constants import ( from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersion from synapse.logging.opentracing import SynapseTags, set_tag, trace +from synapse.synapse_rust.events import Unsigned from synapse.types import JsonDict, Requester from . import EventBase, FrozenEventV2, StrippedStateEvent, make_event_from_dict @@ -987,7 +988,7 @@ def validate_canonicaljson(value: Any) -> None: def maybe_upsert_event_field( - event: EventBase, container: JsonDict, key: str, value: object + event: EventBase, container: Unsigned, key: str, value: object ) -> bool: """Upsert an event field, but only if this doesn't make the event too large. diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index fe0710a0bf..1631f021ca 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -307,20 +307,27 @@ def _is_invite_via_3pid(event: EventBase) -> bool: def parse_events_from_pdu_json( - pdus_json: Sequence[JsonDict], room_version: RoomVersion + pdus_json: Sequence[JsonDict], + room_version: RoomVersion, + received_time: int | None = None, ) -> list[EventBase]: return [ - event_from_pdu_json(pdu_json, room_version) + event_from_pdu_json(pdu_json, room_version, received_time=received_time) for pdu_json in filter_pdus_for_valid_depth(pdus_json) ] -def event_from_pdu_json(pdu_json: JsonDict, room_version: RoomVersion) -> EventBase: +def event_from_pdu_json( + pdu_json: JsonDict, room_version: RoomVersion, received_time: int | None = None +) -> EventBase: """Construct an EventBase from an event json received over federation Args: pdu_json: pdu as received over federation room_version: The version of the room this event belongs to + received_time: timestamp in ms that the event was received at. If + `None` then any `age` field in the `unsigned` block will be + dropped. Raises: SynapseError: if the pdu is missing required fields or is otherwise @@ -333,6 +340,25 @@ def event_from_pdu_json(pdu_json: JsonDict, room_version: RoomVersion) -> EventB if "unsigned" in pdu_json: _strip_unsigned_values(pdu_json) + # Handle the `age` field, which is sent by some servers as part of the + # `unsigned` block. We convert this into an `age_ts` field, which is + # what Synapse uses internally. We also remove the `age` field to avoid + # confusion. + # + # c.f. https://github.com/matrix-org/synapse/issues/8429 + unsigned = pdu_json["unsigned"] + age = unsigned.pop("age", None) + + # We check that the `age` is actually an int before using it below. We + # don't error here as the `age` a) doesn't affect the validity of the + # event, and b) is best effort anyway. + if not isinstance(age, int): + age = None + + unsigned.pop("age_ts", None) + if received_time is not None and age is not None: + unsigned["age_ts"] = received_time - int(age) + depth = pdu_json["depth"] if type(depth) is not int: # noqa: E721 raise SynapseError(400, "Depth %r not an intger" % (depth,), Codes.BAD_JSON) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 78a1900c73..2b5ef5fbac 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1574,10 +1574,15 @@ class FederationClient(FederationBase): min_depth=min_depth, timeout=timeout, ) + received_time = self._clock.time_msec() room_version = await self.store.get_room_version(room_id) - events = parse_events_from_pdu_json(content.get("events", []), room_version) + events = parse_events_from_pdu_json( + content.get("events", []), + room_version, + received_time=received_time, + ) signed_events = await self._check_sigs_and_hash_for_pulled_events_and_fetch( destination, events, room_version=room_version diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 1bbe144422..6069287975 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -451,16 +451,6 @@ class FederationServer(FederationBase): newest_pdu_ts = 0 for p in transaction.pdus: - # FIXME (richardv): I don't think this works: - # https://github.com/matrix-org/synapse/issues/8429 - if "unsigned" in p: - unsigned = p["unsigned"] - if "age" in unsigned: - p["age"] = unsigned["age"] - if "age" in p: - p["age_ts"] = request_time - int(p["age"]) - del p["age"] - # We try and pull out an event ID so that if later checks fail we # can log something sensible. We don't mandate an event ID here in # case future event formats get rid of the key. @@ -488,10 +478,15 @@ class FederationServer(FederationBase): continue try: - event = event_from_pdu_json(p, room_version) + event = event_from_pdu_json(p, room_version, received_time=request_time) except SynapseError as e: logger.info("Ignoring PDU for failing to deserialize: %s", e) continue + except Exception as e: + # We catch all exceptions here as we don't want a single bad + # event to cause us to fail the whole transaction. + logger.exception("Error deserializing PDU: %s", e) + continue pdus_by_room.setdefault(room_id, []).append(event) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 4032c7eca9..0687c9fa79 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -2089,10 +2089,9 @@ class EventCreationHandler: returned_invite = await federation_handler.send_invite( invitee.domain, event ) - event.unsigned.pop("room_state", None) # TODO: Make sure the signatures actually are correct. - event.signatures.update(returned_invite.signatures) + event.signatures.update(returned_invite.signatures.as_dict()) if event.content["membership"] == Membership.KNOCK: maybe_upsert_event_field( diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 8cbe4b63c8..2bc7efeb5e 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -566,7 +566,7 @@ class PaginationHandler: ( events, next_key, - _, + limited, ) = await self.store.paginate_room_events_by_topological_ordering( room_id=room_id, from_key=from_token.room_key, @@ -645,7 +645,7 @@ class PaginationHandler: ( events, next_key, - _, + limited, ) = await self.store.paginate_room_events_by_topological_ordering( room_id=room_id, from_key=from_token.room_key, @@ -668,11 +668,12 @@ class PaginationHandler: next_token = from_token.copy_and_replace(StreamKeyType.ROOM, next_key) - # if no events are returned from pagination, that implies - # we have reached the end of the available events. + # if no events are returned from pagination (this page is empty) + # and there aren't any more pages (not limited), + # that implies we have reached the end of the available events. # In that case we do not return end, to tell the client # there is no need for further queries. - if not events: + if not limited and not events: return GetMessagesResult( messages_chunk=[], bundled_aggregations={}, diff --git a/synapse/handlers/room_policy.py b/synapse/handlers/room_policy.py index 01943e1991..e46e6dc2ef 100644 --- a/synapse/handlers/room_policy.py +++ b/synapse/handlers/room_policy.py @@ -181,9 +181,10 @@ class RoomPolicyHandler: async def _verify_policy_server_signature( self, event: EventBase, policy_server: str, public_key: str ) -> bool: - # check the event is signed with this (via, public_key). - verify_json_req = VerifyJsonRequest.from_event(policy_server, event, 0) try: + # check the event is signed with this (via, public_key). + verify_json_req = VerifyJsonRequest.from_event(policy_server, event, 0) + key_bytes = decode_base64(public_key) verify_key = decode_verify_key_bytes(POLICY_SERVER_KEY_ID, key_bytes) # We would normally use KeyRing.verify_event_for_server but we can't here as we don't @@ -260,9 +261,7 @@ class RoomPolicyHandler: # servers need to manually fetch signatures for. This is the code that allows # those events to continue working (because they're legally sent, even if missing # the policy server signature). - event.signatures.setdefault(policy_server.server_name, {}).update( - signature.get(policy_server.server_name, {}) - ) + event.signatures.update(signature) except HttpResponseException as ex: # re-wrap HTTP errors as `SynapseError` so they can be proxied to clients directly raise ex.to_synapse_error() from ex diff --git a/synapse/handlers/worker_lock.py b/synapse/handlers/worker_lock.py index 1537a18cc0..57792ea53c 100644 --- a/synapse/handlers/worker_lock.py +++ b/synapse/handlers/worker_lock.py @@ -54,6 +54,20 @@ logger = logging.getLogger(__name__) # will not disappear under our feet as long as we don't delete the room. NEW_EVENT_DURING_PURGE_LOCK_NAME = "new_event_during_purge_lock" +WORKER_LOCK_MAX_RETRY_INTERVAL = Duration(seconds=5) +""" +The maximum wait time before retrying to acquire the lock. + +Better to retry more quickly than have workers wait around. 5 seconds is still a +reasonable gap in time to not overwhelm the CPU/Database. + +This matters most in cross-worker scenarios. When locks are on the same worker, when the +lock holder releases, we signal to other locks (with the same name/key) that they +should try reacquiring the lock immediately. But locks on other workers only re-check +based on their retry `_timeout_interval`. +""" +WORKER_LOCK_EXCESSIVE_WAITING_WARN_DURATION = Duration(minutes=10) + class WorkerLocksHandler: """A class for waiting on taking out locks, rather than using the storage @@ -206,9 +220,10 @@ class WaitingLock: lock_name: str lock_key: str write: bool | None + start_ts_ms: int = 0 deferred: "defer.Deferred[None]" = attr.Factory(defer.Deferred) _inner_lock: Lock | None = None - _retry_interval: float = 0.1 + _timeout_interval: float = 0.1 _lock_span: "opentracing.Scope" = attr.Factory( lambda: start_active_span("WaitingLock.lock") ) @@ -220,6 +235,7 @@ class WaitingLock: self.deferred.callback(None) async def __aenter__(self) -> None: + self.start_ts_ms = self.clock.time_msec() self._lock_span.__enter__() with start_active_span("WaitingLock.waiting_for_lock"): @@ -240,19 +256,44 @@ class WaitingLock: break try: - # Wait until the we get notified the lock might have been + # Wait until the notification that the lock might have been # released (by the deferred being resolved). We also - # periodically wake up in case the lock was released but we + # periodically wake up in case the lock was released, but we # weren't notified. with PreserveLoggingContext(): - timeout = self._get_next_retry_interval() await timeout_deferred( deferred=self.deferred, - timeout=timeout, + timeout=self._timeout_interval, clock=self.clock, ) - except Exception: - pass + except defer.TimeoutError: + # Only increment the timeout value if this was an actual timeout + # (defer.TimeoutError) + self._increment_timeout_interval() + + now_ms = self.clock.time_msec() + time_spent_trying_to_lock = Duration( + milliseconds=now_ms - self.start_ts_ms + ) + if ( + time_spent_trying_to_lock.as_millis() + > WORKER_LOCK_EXCESSIVE_WAITING_WARN_DURATION.as_millis() + ): + logger.warning( + "(WaitingLock (%s, %s)) Time spent waiting to acquire lock " + "is getting excessive: %ss. There may be a deadlock.", + self.lock_name, + self.lock_key, + time_spent_trying_to_lock.as_secs(), + ) + + except Exception as e: + logger.warning( + "Caught an exception while waiting on WaitingLock(lock_name=%s, lock_key=%s): %r", + self.lock_name, + self.lock_key, + e, + ) return await self._inner_lock.__aenter__() @@ -273,15 +314,14 @@ class WaitingLock: return r - def _get_next_retry_interval(self) -> float: - next = self._retry_interval - self._retry_interval = max(5, next * 2) - if self._retry_interval > Duration(minutes=10).as_secs(): # >7 iterations - logger.warning( - "Lock timeout is getting excessive: %ss. There may be a deadlock.", - self._retry_interval, - ) - return next * random.uniform(0.9, 1.1) + def _increment_timeout_interval(self) -> float: + next_interval = self._timeout_interval + next_interval = min(WORKER_LOCK_MAX_RETRY_INTERVAL.as_secs(), next_interval * 2) + + # The jitter value is maintained for the timeout, to help avoid a "thundering + # herd" situation when all locks may time out at the same time. + self._timeout_interval = next_interval * random.uniform(0.9, 1.1) + return self._timeout_interval @attr.s(auto_attribs=True, eq=False) @@ -294,10 +334,11 @@ class WaitingMultiLock: store: LockStore handler: WorkerLocksHandler + start_ts_ms: int = 0 deferred: "defer.Deferred[None]" = attr.Factory(defer.Deferred) _inner_lock_cm: AsyncContextManager | None = None - _retry_interval: float = 0.1 + _timeout_interval: float = 0.1 _lock_span: "opentracing.Scope" = attr.Factory( lambda: start_active_span("WaitingLock.lock") ) @@ -309,6 +350,7 @@ class WaitingMultiLock: self.deferred.callback(None) async def __aenter__(self) -> None: + self.start_ts_ms = self.clock.time_msec() self._lock_span.__enter__() with start_active_span("WaitingLock.waiting_for_lock"): @@ -324,19 +366,42 @@ class WaitingMultiLock: break try: - # Wait until the we get notified the lock might have been + # Wait until the notification that the lock might have been # released (by the deferred being resolved). We also - # periodically wake up in case the lock was released but we + # periodically wake up in case the lock was released, but we # weren't notified. with PreserveLoggingContext(): - timeout = self._get_next_retry_interval() await timeout_deferred( deferred=self.deferred, - timeout=timeout, + timeout=self._timeout_interval, clock=self.clock, ) - except Exception: - pass + except defer.TimeoutError: + # Only increment the timeout value if this was an actual timeout + # (defer.TimeoutError) + self._increment_timeout_interval() + + now_ms = self.clock.time_msec() + time_spent_trying_to_lock = Duration( + milliseconds=now_ms - self.start_ts_ms + ) + if ( + time_spent_trying_to_lock.as_millis() + > WORKER_LOCK_EXCESSIVE_WAITING_WARN_DURATION.as_millis() + ): + logger.warning( + "(WaitingMultiLock (%r)) Time spent waiting to acquire lock " + "is getting excessive: %ss. There may be a deadlock.", + self.lock_names, + time_spent_trying_to_lock.as_secs(), + ) + + except Exception as e: + logger.warning( + "Caught an exception while waiting on WaitingMultiLock(lock_names=%r): %r", + self.lock_names, + e, + ) assert self._inner_lock_cm await self._inner_lock_cm.__aenter__() @@ -360,12 +425,11 @@ class WaitingMultiLock: return r - def _get_next_retry_interval(self) -> float: - next = self._retry_interval - self._retry_interval = max(5, next * 2) - if self._retry_interval > Duration(minutes=10).as_secs(): # >7 iterations - logger.warning( - "Lock timeout is getting excessive: %ss. There may be a deadlock.", - self._retry_interval, - ) - return next * random.uniform(0.9, 1.1) + def _increment_timeout_interval(self) -> float: + next_interval = self._timeout_interval + next_interval = min(WORKER_LOCK_MAX_RETRY_INTERVAL.as_secs(), next_interval * 2) + + # The jitter value is maintained for the timeout, to help avoid a "thundering + # herd" situation when all locks may time out at the same time. + self._timeout_interval = next_interval * random.uniform(0.9, 1.1) + return self._timeout_interval diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 12c918eca6..5f6b03a988 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -2711,7 +2711,7 @@ class PersistEventsStore: return def event_dict(event: EventBase) -> JsonDict: - d = event.get_dict() + d = event.get_dict_for_persistence() d.pop("redacted", None) d.pop("redacted_because", None) return d diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index d2623f0760..c0d218398d 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -2824,8 +2824,8 @@ class EventsBackgroundUpdatesStore( # with the provided old key. if old_verify_key is not None: old_key_id = f"{old_verify_key.alg}:{old_verify_key.version}" - server_sigs = event.signatures.get(self.hs.hostname, {}) - if old_key_id not in server_sigs: + old_sig = event.signatures.get_signature(self.hs.hostname, old_key_id) + if old_sig is None: # Event wasn't signed with this key ID at all, skip. continue diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index cc79b8042b..6f26bd17ce 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -779,14 +779,26 @@ class EventsWorkerStore(SQLBaseStore): events.append(event) if get_prev_content: - if "replaces_state" in event.unsigned: + # The `event` here might be in the cache, and so might have + # already had the `prev_content` and `prev_sender` fields added + # to its unsigned. + # + # We check if a) we should add the previous content, and b) if + # we have already added it. + replaces_state = "replaces_state" in event.unsigned + has_prev = ( + "prev_content" in event.unsigned and "prev_sender" in event.unsigned + ) + if replaces_state and not has_prev: prev = await self.get_event( event.unsigned["replaces_state"], get_prev_content=False, allow_none=True, ) if prev: - event.unsigned = dict(event.unsigned) + # This mutates the cached event, but that's fine as the + # previous content/sender will be the same for all + # requests for this event. event.unsigned["prev_content"] = prev.content event.unsigned["prev_sender"] = prev.sender diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 8fa1e2e5a9..7d14f9f4d8 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -2425,12 +2425,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): event_filter: If provided filters the events to those that match the filter. Returns: - The results as a list of events, a token that points to the end of - the result set, and a boolean to indicate if there were more events - but we hit the limit. If no events are returned then the end of the + - The results as a list of events; + - a token that points to the end of the result set; and + - a boolean to indicate if there were more events + but we hit the limit (`limited`) + + If no events are returned and `limited` is false, then the end of the stream has been reached (i.e. there are no events between `from_key` and `to_key`). + When `limited` is true, that means that more pagination can be attempted. + Note that `limited` can be true even if no events are returned, + because rejected events are filtered out after the limit check. + When Direction.FORWARDS: from_key < x <= to_key, (ascending order) When Direction.BACKWARDS: from_key >= x > to_key, (descending order) """ diff --git a/synapse/synapse_rust/events.pyi b/synapse/synapse_rust/events.pyi index fe0ca04420..5ae2bb880a 100644 --- a/synapse/synapse_rust/events.pyi +++ b/synapse/synapse_rust/events.pyi @@ -10,9 +10,9 @@ # See the GNU Affero General Public License for more details: # . -from typing import Mapping +from typing import Any, Mapping -from synapse.types import JsonDict +from synapse.types import JsonDict, JsonMapping class EventInternalMetadata: def __init__(self, internal_metadata_dict: JsonDict): ... @@ -154,3 +154,62 @@ def event_visible_to_server( Returns: Whether the server is allowed to see the unredacted event. """ + +class Signatures: + """A class representing the signatures on an event.""" + + def __init__(self, signatures: Mapping[str, Mapping[str, str]] | None = None): ... + def get_signature(self, server_name: str, key_id: str) -> str | None: ... + """Get the signature for the given server name and key ID, if it exists.""" + + def __getitem__(self, server_name: str) -> Mapping[str, str]: ... + """Get the signatures for the given server name. Raises KeyError if there + are no signatures for that server.""" + + def __contains__(self, server_name: Any) -> bool: ... + """Check if there are signatures for the given server name.""" + + def __len__(self) -> int: ... + """Return the number of servers that have signatures.""" + + def add_signature(self, server_name: str, key_id: str, signature: str) -> None: ... + """Add a signature for the given server name and key ID.""" + + def update(self, signatures: Mapping[str, Mapping[str, str]]) -> None: ... + """Update the signatures with the given signatures. + + Will overwrite all existing signatures for the server names provided. + """ + + def as_dict(self) -> dict[str, dict[str, str]]: ... + """Return a copy of the signatures as a dictionary.""" + +class Unsigned: + """A class representing the unsigned data of an event.""" + + def __init__(self, unsigned_dict: JsonMapping): ... + def __getitem__(self, key: str) -> Any: ... + """Get the value for the given key. + + Raises KeyError if the key is unset or not recognised.""" + + def __setitem__(self, key: str, value: Any) -> None: ... + """Set the value for the given key. + + Raises KeyError if the key is not recognised.""" + + def __delitem__(self, key: str) -> None: ... + """Delete the value for the given key. + + Raises KeyError if the key is unset or not recognised.""" + + def __contains__(self, key: Any) -> bool: ... + def get(self, key: str, default: Any = None) -> Any: ... + """Get the value for the given key, or ``default`` if the key is unset.""" + + def for_persistence(self) -> JsonDict: ... + """Return a dict of the fields that should be persisted to the database.""" + + def for_event(self) -> JsonDict: ... + """Return a dict of all unsigned fields, including those only kept in + memory, suitable for inclusion in an event.""" diff --git a/synapse/util/duration.py b/synapse/util/duration.py index 135b980852..a1abe944b5 100644 --- a/synapse/util/duration.py +++ b/synapse/util/duration.py @@ -32,6 +32,33 @@ class Duration(timedelta): ``` """ + # Using `__new__` (instead of `__init__`) because that's what `timedelta` uses + def __new__( + cls, + # The whole goal of overriding `__new__` is to require keyword-only arguments. + # Without this, `Duration(5)` would create a duration represnting 5 *days* + # (timedelta's default), but callers almost certainly want to specify which unit + # like seconds or hours. + *, + days: float = 0, + seconds: float = 0, + microseconds: float = 0, + milliseconds: float = 0, + minutes: float = 0, + hours: float = 0, + weeks: float = 0, + ) -> "Duration": + return super().__new__( + cls, + days=days, + seconds=seconds, + microseconds=microseconds, + milliseconds=milliseconds, + minutes=minutes, + hours=hours, + weeks=weeks, + ) + def as_millis(self) -> int: """Returns the duration in milliseconds.""" return int(self / _ONE_MILLISECOND) diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index 12ef42866d..a402dd70d1 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -67,25 +67,31 @@ def MockEvent(**kwargs: Any) -> EventBase: class TestMaybeUpsertEventField(stdlib_unittest.TestCase): def test_update_okay(self) -> None: event = make_event_from_dict({"event_id": "$1234"}) - success = maybe_upsert_event_field(event, event.unsigned, "key", "value") + success = maybe_upsert_event_field( + event, event.unsigned, "replaces_state", "value" + ) self.assertTrue(success) - self.assertEqual(event.unsigned["key"], "value") + self.assertEqual(event.unsigned["replaces_state"], "value") def test_update_not_okay(self) -> None: event = make_event_from_dict({"event_id": "$1234"}) LARGE_STRING = "a" * 100_000 - success = maybe_upsert_event_field(event, event.unsigned, "key", LARGE_STRING) + success = maybe_upsert_event_field( + event, event.unsigned, "replaces_state", LARGE_STRING + ) self.assertFalse(success) - self.assertNotIn("key", event.unsigned) + self.assertNotIn("replaces_state", event.unsigned) def test_update_not_okay_leaves_original_value(self) -> None: event = make_event_from_dict( - {"event_id": "$1234", "unsigned": {"key": "value"}} + {"event_id": "$1234", "unsigned": {"replaces_state": "value"}} ) LARGE_STRING = "a" * 100_000 - success = maybe_upsert_event_field(event, event.unsigned, "key", LARGE_STRING) + success = maybe_upsert_event_field( + event, event.unsigned, "replaces_state", LARGE_STRING + ) self.assertFalse(success) - self.assertEqual(event.unsigned["key"], "value") + self.assertEqual(event.unsigned["replaces_state"], "value") class PruneEventTestCase(stdlib_unittest.TestCase): @@ -623,7 +629,7 @@ class CloneEventTestCase(stdlib_unittest.TestCase): { "type": "A", "event_id": "$test:domain", - "unsigned": {"a": 1, "b": 2}, + "unsigned": {"age_ts": 1, "replaces_state": "2"}, }, RoomVersions.V1, {"txn_id": "txn"}, @@ -634,10 +640,14 @@ class CloneEventTestCase(stdlib_unittest.TestCase): self.assertEqual(original.internal_metadata.instance_name, "worker1") cloned = clone_event(original) - cloned.unsigned["b"] = 3 + cloned.unsigned["age_ts"] = 3 - self.assertEqual(original.unsigned, {"a": 1, "b": 2}) - self.assertEqual(cloned.unsigned, {"a": 1, "b": 3}) + self.assertEqual( + original.unsigned.for_event(), {"age_ts": 1, "replaces_state": "2"} + ) + self.assertEqual( + cloned.unsigned.for_event(), {"age_ts": 3, "replaces_state": "2"} + ) self.assertEqual(cloned.internal_metadata.stream_ordering, 1234) self.assertEqual(cloned.internal_metadata.instance_name, "worker1") self.assertEqual(cloned.internal_metadata.txn_id, "txn") diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index a40e0b0680..7b1a2a5adc 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -754,9 +754,9 @@ class StripUnsignedFromEventsTestCase(unittest.TestCase): }, } - filtered_event2 = event_from_pdu_json(event2, RoomVersions.V1) - self.assertIn("age", filtered_event2.unsigned) - self.assertEqual(14, filtered_event2.unsigned["age"]) + filtered_event2 = event_from_pdu_json(event2, RoomVersions.V1, received_time=20) + self.assertIn("age_ts", filtered_event2.unsigned) + self.assertEqual(6, filtered_event2.unsigned["age_ts"]) self.assertNotIn("more warez", filtered_event2.unsigned) # Invite_room_state is allowed in events of type m.room.member self.assertIn("invite_room_state", filtered_event2.unsigned) @@ -779,8 +779,8 @@ class StripUnsignedFromEventsTestCase(unittest.TestCase): "invite_room_state": [], }, } - filtered_event3 = event_from_pdu_json(event3, RoomVersions.V1) - self.assertIn("age", filtered_event3.unsigned) + filtered_event3 = event_from_pdu_json(event3, RoomVersions.V1, received_time=20) + self.assertIn("age_ts", filtered_event3.unsigned) # Invite_room_state field is only permitted in event type m.room.member self.assertNotIn("invite_room_state", filtered_event3.unsigned) self.assertNotIn("more warez", filtered_event3.unsigned) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index dde1785854..20ffed68f4 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -368,7 +368,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): ) # the auth code requires that a signature exists, but doesn't check that # signature... go figure. - join_event.signatures[other_server] = {"x": "y"} + join_event.signatures.update({other_server: {"x": "y"}}) self.get_success( self.hs.get_federation_event_handler().on_send_membership_event( diff --git a/tests/handlers/test_worker_lock.py b/tests/handlers/test_worker_lock.py index 61ff51ff92..74201f4151 100644 --- a/tests/handlers/test_worker_lock.py +++ b/tests/handlers/test_worker_lock.py @@ -26,7 +26,9 @@ from twisted.internet import defer from twisted.internet.testing import MemoryReactor from synapse.server import HomeServer +from synapse.storage.databases.main.lock import _RENEWAL_INTERVAL from synapse.util.clock import Clock +from synapse.util.duration import Duration from tests import unittest from tests.replication._base import BaseMultiWorkerStreamTestCase @@ -40,6 +42,7 @@ class WorkerLockTestCase(unittest.HomeserverTestCase): self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer ) -> None: self.worker_lock_handler = self.hs.get_worker_locks_handler() + self.store = self.hs.get_datastores().main def test_wait_for_lock_locally(self) -> None: """Test waiting for a lock on a single worker""" @@ -56,6 +59,66 @@ class WorkerLockTestCase(unittest.HomeserverTestCase): self.get_success(d2) self.get_success(lock2.__aexit__(None, None, None)) + def test_timeouts_for_lock_locally(self) -> None: + """ + Test that we regularly retry to reacquire locks. + + This is a regression test to make sure the lock retry time doesn't balloon to a value + so large it can't even be printed reliably anymore. + """ + + # Create and acquire the first lock + lock1 = self.worker_lock_handler.acquire_lock("name", "key") + self.get_success(lock1.__aenter__()) + + # Create and try to acquire the second lock + lock2 = self.worker_lock_handler.acquire_lock("name", "key") + d2 = defer.ensureDeferred(lock2.__aenter__()) + # Make sure we haven't acquired the lock yet (`lock1` still holds it) + self.assertNoResult(d2) + + # Advance time by an hour (some duration that would previously cause our timeout + # to balloon if it weren't constrained). Max back-off (saturate) + # + # Note: We use `_pump_by` instead of `pump`/`advance` as the `Lock` has an + # internal background looping call that runs every 30 seconds + # (`_RENEWAL_INTERVAL`) to renew the `Lock` and push it's "drop timeout" value + # further out by 2 minutes (`_LOCK_TIMEOUT_MS`). The `Lock` will prematurely + # drop if this renewal is not allowed to run, which sours the test. + # self.pump(amount=Duration(hours=1)) + self._pump_by(amount=Duration(hours=1), by=_RENEWAL_INTERVAL) + + # Make sure we haven't acquired the `lock2` yet (`lock1` still holds it) + self.assertNoResult(d2) + + # Release the first lock (`lock1`). The second lock(`lock2`) should be + # automatically acquired by the `pump()` inside `get_success()` + self.get_success(lock1.__aexit__(None, None, None)) + + # We should now have the lock + self.successResultOf(d2) + + def _pump_by( + self, + *, + amount: Duration = Duration(seconds=0), + by: Duration = Duration(seconds=0.1), + ) -> None: + """ + Like `self.pump()` but you can specify the time increment to advance with until + you reach the time amount. + + Unlike `self.pump()`, this doesn't multiply the time at all. + + Args: + amount: The amount of time to advance + by: The time increment in seconds to advance time by until we reach the `amount` + """ + end_time_s = self.reactor.seconds() + amount.as_secs() + + while self.reactor.seconds() < end_time_s: + self.reactor.advance(by.as_secs()) + def test_lock_contention(self) -> None: """Test lock contention when a lot of locks wait on a single worker""" nb_locks_to_test = 500 @@ -124,3 +187,70 @@ class WorkerLockWorkersTestCase(BaseMultiWorkerStreamTestCase): self.get_success(d2) self.get_success(lock2.__aexit__(None, None, None)) + + def test_timeouts_for_lock_worker(self) -> None: + """ + Test that we regularly retry to reacquire locks. + + This is a regression test to make sure the lock retry time doesn't balloon to a value + so large it can't even be printed reliably anymore. + """ + worker = self.make_worker_hs( + "synapse.app.generic_worker", + extra_config={ + "redis": {"enabled": True}, + }, + ) + worker_lock_handler = worker.get_worker_locks_handler() + + # Create and acquire the first lock on the main process + lock1 = self.main_worker_lock_handler.acquire_lock("name", "key") + self.get_success(lock1.__aenter__()) + + # Create and try to acquire the second lock on the worker + lock2 = worker_lock_handler.acquire_lock("name", "key") + d2 = defer.ensureDeferred(lock2.__aenter__()) + # Make sure we haven't acquired the lock yet (`lock1` still holds it) + self.assertNoResult(d2) + + # Advance time by an hour (some duration that would previously cause our timeout + # to balloon if it weren't constrained). Max back-off (saturate) + # + # Note: We use `_pump_by` instead of `pump`/`advance` as the `Lock` has an + # internal background looping call that runs every 30 seconds + # (`_RENEWAL_INTERVAL`) to renew the `Lock` and push it's "drop timeout" value + # further out by 2 minutes (`_LOCK_TIMEOUT_MS`). The `Lock` will prematurely + # drop if this renewal is not allowed to run, which sours the test. + # self.pump(amount=Duration(hours=1)) + self._pump_by(amount=Duration(hours=1), by=_RENEWAL_INTERVAL) + + # Make sure we haven't acquired the `lock2` yet (`lock1` still holds it) + self.assertNoResult(d2) + + # Release the first lock (`lock1`). The second lock(`lock2`) should be + # automatically acquired by the `pump()` inside `get_success()` + self.get_success(lock1.__aexit__(None, None, None)) + + # We should now have the lock + self.successResultOf(d2) + + def _pump_by( + self, + *, + amount: Duration = Duration(seconds=0), + by: Duration = Duration(seconds=0.1), + ) -> None: + """ + Like `self.pump()` but you can specify the time increment to advance with until + you reach the time amount. + + Unlike `self.pump()`, this doesn't multiply the time at all. + + Args: + amount: The amount of time to advance + by: The time increment in seconds to advance time by until we reach the `amount` + """ + end_time_s = self.reactor.seconds() + amount.as_secs() + + while self.reactor.seconds() < end_time_s: + self.reactor.advance(by.as_secs()) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 221121007d..61e7e87f62 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -66,7 +66,10 @@ from synapse.util.stringutils import random_string from tests import unittest from tests.http.server._base import make_request_with_cancellation_test from tests.storage.test_stream import PaginationTestCase -from tests.test_utils.event_injection import create_event +from tests.test_utils.event_injection import ( + create_event, + inject_event, +) from tests.unittest import override_config from tests.utils import default_config @@ -2371,6 +2374,87 @@ class RoomMessageListTestCase(RoomBase): channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body ) + def test_room_messages_paginate_through_rejected_events( + self, + ) -> None: + """Test that pagination continues past a batch of rejected events. + + Regression test for https://github.com/element-hq/synapse/security/advisories/GHSA-6qf2-7x63-mm6v + + Synapse before 1.152.1 had a bug meaning that a batch full of only + rejected events would cause `/messages` to not return any more + pagination tokens, falsely signalling the end of backpagination. + """ + # Send an early message that should not be filtered. + early_event_id = self.helper.send(self.room_id, "early message")["event_id"] + + # Inject a batch of events and mark them as rejected in the database. + # We create more events than a single pagination request would fetch, + # so that one page of backward pagination request would only see rejected events. + for _ in range(3): + event = self.get_success( + inject_event( + self.hs, + room_id=self.room_id, + sender=self.user_id, + type=EventTypes.Message, + content={"body": "filtered event", "msgtype": "m.text"}, + ) + ) + self.get_success( + self.hs.get_datastores().main.db_pool.runInteraction( + "mark_rejected", + self.hs.get_datastores().main.mark_event_rejected_txn, + event.event_id, + "testing", + ) + ) + + # Send a message after all the rejected events. + latest_event_id = self.helper.send(self.room_id, "latest message")["event_id"] + + # Start backpaginating. + channel = self.make_request( + "GET", f"/rooms/{self.room_id}/messages?dir=b&limit=2" + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + + events_in_page = [e["event_id"] for e in channel.json_body["chunk"]] + end_token: str | None = channel.json_body["end"] + + self.assertEqual( + events_in_page, + [latest_event_id], + "The latest event should be included in the first page we see whilst backpaginating", + ) + + event_ids_in_pages: list[list[str]] = [events_in_page] + + # Bound the number of backpagination attempts to 2 + for _ in range(2): + channel = self.make_request( + "GET", f"/rooms/{self.room_id}/messages?from={end_token}&dir=b&limit=2" + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + events_in_page = [e["event_id"] for e in channel.json_body["chunk"]] + event_ids_in_pages.append(events_in_page) + + if early_event_id in events_in_page: + # We have found the event we were looking for + return + + self.assertIn( + "end", + channel.json_body, + f"No `end` token received. Did not find {early_event_id} whilst backpaginating ({latest_event_id = }, {event_ids_in_pages = })", + ) + # Use the end_token in the next iteration + end_token = channel.json_body["end"] + + self.fail( + f"Exhausted backpagination attempts. Did not find {early_event_id} whilst backpaginating ({latest_event_id = }, {event_ids_in_pages = })" + ) + class RoomMessageFilterTestCase(RoomBase): """Tests /rooms/$room_id/messages REST events."""