mirror of
https://github.com/element-hq/synapse.git
synced 2026-05-14 21:15:12 +00:00
Merge branch 'develop' into madlittlemods/wait_for_multi_writer_stream_token
This commit is contained in:
@@ -1,3 +1,11 @@
|
||||
# Synapse 1.152.1 (2026-05-07)
|
||||
|
||||
## Security Fixes
|
||||
|
||||
- Prevent CPU starvation (Denial of Service) under worker lock contention, additionally capping the `WorkerLock` time out interval to a maximum of 60 seconds. Contributed by Famedly. ([\#19394](https://github.com/element-hq/synapse/issues/19394), ELEMENTSEC-2026-1706, [GHSA-8q93-326v-3m7g](https://github.com/element-hq/synapse/security/advisories/GHSA-8q93-326v-3m7g), CVE pending)
|
||||
- Prevent pagination ending when a page is full of rejected events. (ELEMENTSEC-2025-1636, [GHSA-6qf2-7x63-mm6v](https://github.com/element-hq/synapse/security/advisories/GHSA-6qf2-7x63-mm6v), CVE pending)
|
||||
|
||||
|
||||
# Synapse 1.152.0 (2026-04-28)
|
||||
|
||||
No significant changes since 1.152.0rc1.
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
Capped the `WorkerLock` time out interval to a maximum of 60 seconds to prevent dealing with excessively long numbers. Contributed by Famedly.
|
||||
@@ -0,0 +1 @@
|
||||
Port `Event.signatures` field to Rust.
|
||||
@@ -0,0 +1 @@
|
||||
Port `Event.unsigned` field to Rust.
|
||||
@@ -0,0 +1 @@
|
||||
Reduce `WORKER_LOCK_MAX_RETRY_INTERVAL` to 5 seconds to reduce idle time after lock is released.
|
||||
@@ -0,0 +1 @@
|
||||
Force keyword-only args for `Duration` (prevent footgun) so people have to specify which time unit they want to us.
|
||||
Vendored
+6
@@ -1,3 +1,9 @@
|
||||
matrix-synapse-py3 (1.152.1) stable; urgency=medium
|
||||
|
||||
* New Synapse release 1.152.1.
|
||||
|
||||
-- Synapse Packaging team <packages@matrix.org> Thu, 07 May 2026 13:29:05 +0100
|
||||
|
||||
matrix-synapse-py3 (1.152.0) stable; urgency=medium
|
||||
|
||||
* New Synapse release 1.152.0.
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "matrix-synapse"
|
||||
version = "1.152.0"
|
||||
version = "1.152.1"
|
||||
description = "Homeserver for the Matrix decentralised comms protocol"
|
||||
readme = "README.rst"
|
||||
authors = [
|
||||
|
||||
+1
-5
@@ -43,7 +43,7 @@ pyo3-log = "0.13.1"
|
||||
pythonize = "0.27.0"
|
||||
regex = "1.6.0"
|
||||
sha2 = "0.10.8"
|
||||
serde = { version = "1.0.144", features = ["derive"] }
|
||||
serde = { version = "1.0.144", features = ["derive", "rc"] }
|
||||
serde_json = { version = "1.0.85", features = ["raw_value"] }
|
||||
ulid = "1.1.2"
|
||||
icu_segmenter = "2.0.0"
|
||||
@@ -58,10 +58,6 @@ tokio = { version = "1.44.2", features = ["rt", "rt-multi-thread"] }
|
||||
once_cell = "1.18.0"
|
||||
itertools = "0.14.0"
|
||||
|
||||
[features]
|
||||
extension-module = ["pyo3/extension-module"]
|
||||
default = ["extension-module"]
|
||||
|
||||
[build-dependencies]
|
||||
blake2 = "0.10.4"
|
||||
hex = "0.4.3"
|
||||
|
||||
@@ -27,11 +27,15 @@ use pyo3::{
|
||||
|
||||
pub mod filter;
|
||||
mod internal_metadata;
|
||||
pub mod signatures;
|
||||
pub mod unsigned;
|
||||
|
||||
/// Called when registering modules with python.
|
||||
pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
let child_module = PyModule::new(py, "events")?;
|
||||
child_module.add_class::<internal_metadata::EventInternalMetadata>()?;
|
||||
child_module.add_class::<signatures::Signatures>()?;
|
||||
child_module.add_class::<unsigned::Unsigned>()?;
|
||||
child_module.add_function(wrap_pyfunction!(filter::event_visible_to_server_py, m)?)?;
|
||||
|
||||
m.add_submodule(&child_module)?;
|
||||
|
||||
@@ -0,0 +1,348 @@
|
||||
/*
|
||||
* This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
*
|
||||
* Copyright (C) 2026 Element Creations Ltd
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as
|
||||
* published by the Free Software Foundation, either version 3 of the
|
||||
* License, or (at your option) any later version.
|
||||
*
|
||||
* See the GNU Affero General Public License for more details:
|
||||
* <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
*
|
||||
*/
|
||||
|
||||
//! Class for representing event signatures
|
||||
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{Arc, RwLock},
|
||||
};
|
||||
|
||||
use pyo3::{
|
||||
exceptions::{PyKeyError, PyRuntimeError},
|
||||
pyclass, pymethods,
|
||||
types::{PyAnyMethods, PyDict, PyMapping, PyMappingMethods},
|
||||
Bound, IntoPyObject, PyAny, PyResult, Python,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// A class representing the signatures on an event.
|
||||
#[pyclass(frozen, skip_from_py_object)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct Signatures {
|
||||
inner: Arc<RwLock<HashMap<String, HashMap<String, String>>>>,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl Signatures {
|
||||
#[new]
|
||||
#[pyo3(signature = (signatures = None))]
|
||||
fn py_new(signatures: Option<HashMap<String, HashMap<String, String>>>) -> Self {
|
||||
let mut signatures = signatures.unwrap_or_default();
|
||||
|
||||
// Prune any entries that have no signatures.
|
||||
signatures.retain(|_, server_sigs| !server_sigs.is_empty());
|
||||
|
||||
Self {
|
||||
inner: Arc::new(RwLock::new(signatures)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the signatures contain a signature for the given server name.
|
||||
fn __contains__(&self, key: Bound<'_, PyAny>) -> PyResult<bool> {
|
||||
let Ok(key) = key.extract::<&str>() else {
|
||||
return Ok(false);
|
||||
};
|
||||
|
||||
let signatures = self
|
||||
.inner
|
||||
.read()
|
||||
.map_err(|_| PyRuntimeError::new_err("Failed to acquire lock"))?;
|
||||
Ok(signatures.contains_key(key))
|
||||
}
|
||||
|
||||
/// Get the number of servers that have signatures.
|
||||
fn __len__(&self) -> PyResult<usize> {
|
||||
let signatures = self
|
||||
.inner
|
||||
.read()
|
||||
.map_err(|_| PyRuntimeError::new_err("Failed to acquire lock"))?;
|
||||
Ok(signatures.len())
|
||||
}
|
||||
|
||||
/// Get the signature for the given server name and key ID, if it exists.
|
||||
fn get_signature(&self, server_name: &str, key_id: &str) -> PyResult<Option<String>> {
|
||||
let signatures = self
|
||||
.inner
|
||||
.read()
|
||||
.map_err(|_| PyRuntimeError::new_err("Failed to acquire lock"))?;
|
||||
|
||||
Ok(signatures
|
||||
.get(server_name)
|
||||
.and_then(|server_sigs| server_sigs.get(key_id).cloned()))
|
||||
}
|
||||
|
||||
/// Get the signatures for the given server name.
|
||||
fn __getitem__(&self, key: Bound<'_, PyAny>) -> PyResult<HashMap<String, String>> {
|
||||
let Some(server_name) = key.extract::<&str>().ok() else {
|
||||
return Err(PyKeyError::new_err(key.to_string()));
|
||||
};
|
||||
|
||||
let signatures = self
|
||||
.inner
|
||||
.read()
|
||||
.map_err(|_| PyRuntimeError::new_err("Failed to acquire lock"))?;
|
||||
|
||||
if let Some(server_sigs) = signatures.get(server_name) {
|
||||
Ok(server_sigs.clone())
|
||||
} else {
|
||||
Err(PyKeyError::new_err(server_name.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a signature for the given server name and key ID.
|
||||
fn add_signature(
|
||||
&self,
|
||||
server_name: String,
|
||||
key_id: String,
|
||||
signature: String,
|
||||
) -> PyResult<()> {
|
||||
let mut signatures = self
|
||||
.inner
|
||||
.write()
|
||||
.map_err(|_| PyRuntimeError::new_err("Failed to acquire lock"))?;
|
||||
|
||||
signatures
|
||||
.entry(server_name)
|
||||
.or_default()
|
||||
.insert(key_id, signature);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update the signatures with the given signatures.
|
||||
///
|
||||
/// Will overwrite all existing signatures for the server names provided.
|
||||
fn update(&self, other: &Bound<'_, PyMapping>) -> PyResult<()> {
|
||||
let mut signatures = self
|
||||
.inner
|
||||
.write()
|
||||
.map_err(|_| PyRuntimeError::new_err("Failed to acquire lock"))?;
|
||||
|
||||
for list_entry in other.items()? {
|
||||
let (server_name, server_sigs) = list_entry.extract::<(String, Bound<PyMapping>)>()?;
|
||||
|
||||
let mut entry = HashMap::new();
|
||||
for list_entry in server_sigs.items()? {
|
||||
let (key, value) = list_entry.extract::<(String, String)>()?;
|
||||
entry.insert(key, value);
|
||||
}
|
||||
|
||||
// Only insert the entry if it has at least one signature.
|
||||
if !entry.is_empty() {
|
||||
signatures.insert(server_name, entry);
|
||||
} else {
|
||||
signatures.remove(&server_name);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Return a copy of the signatures as a dictionary.
|
||||
fn as_dict<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
|
||||
let signatures = self
|
||||
.inner
|
||||
.read()
|
||||
.map_err(|_| PyRuntimeError::new_err("Failed to acquire lock"))?;
|
||||
|
||||
(&*signatures).into_pyobject(py)
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> PyResult<String> {
|
||||
let signatures = self
|
||||
.inner
|
||||
.read()
|
||||
.map_err(|_| PyRuntimeError::new_err("Failed to acquire lock"))?;
|
||||
|
||||
Ok(format!("Signatures({signatures:?})"))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use pythonize::pythonize;
|
||||
|
||||
use super::*;
|
||||
|
||||
/// Helper that reads the inner map directly.
|
||||
fn read_inner(sigs: &Signatures) -> HashMap<String, HashMap<String, String>> {
|
||||
sigs.inner.read().expect("lock poisoned").clone()
|
||||
}
|
||||
|
||||
/// Helper to create a server signatures map from a list of (key_id, sig)
|
||||
/// pairs.
|
||||
fn make_server_sigs(data: &[(&str, &str)]) -> HashMap<String, String> {
|
||||
let mut map = HashMap::new();
|
||||
for (key_id, sig) in data {
|
||||
map.insert((*key_id).to_owned(), (*sig).to_owned());
|
||||
}
|
||||
map
|
||||
}
|
||||
|
||||
/// Helper to create a `Signatures` object from a list of (server_name,
|
||||
/// key_id, sig) tuples.
|
||||
fn create_signatures(data: &[(&str, &str, &str)]) -> Signatures {
|
||||
let mut map: HashMap<String, HashMap<String, String>> = HashMap::new();
|
||||
for (server_name, key_id, sig) in data {
|
||||
map.entry((*server_name).to_owned())
|
||||
.or_default()
|
||||
.insert((*key_id).to_owned(), (*sig).to_owned());
|
||||
}
|
||||
Signatures::py_new(Some(map))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_new_empty() {
|
||||
let sigs = Signatures::py_new(None);
|
||||
assert!(read_inner(&sigs).is_empty());
|
||||
assert_eq!(sigs.__len__().unwrap(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_new_with_data() {
|
||||
let sigs = create_signatures(&[("example.com", "ed25519:key1", "sig1")]);
|
||||
assert_eq!(sigs.__len__().unwrap(), 1);
|
||||
assert_eq!(
|
||||
sigs.get_signature("example.com", "ed25519:key1").unwrap(),
|
||||
Some("sig1".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_new_prunes_servers_with_no_signatures() {
|
||||
let mut data = HashMap::new();
|
||||
data.insert("empty.example.com".to_string(), HashMap::new());
|
||||
data.insert(
|
||||
"example.com".to_string(),
|
||||
make_server_sigs(&[("ed25519:key1", "sig1")]),
|
||||
);
|
||||
|
||||
let sigs = Signatures::py_new(Some(data));
|
||||
|
||||
let inner = read_inner(&sigs);
|
||||
assert_eq!(inner.len(), 1);
|
||||
assert!(inner.contains_key("example.com"));
|
||||
assert!(!inner.contains_key("empty.example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_signature() {
|
||||
let sigs = Signatures::py_new(None);
|
||||
sigs.add_signature(
|
||||
"example.com".to_string(),
|
||||
"ed25519:key1".to_string(),
|
||||
"sig1".to_string(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let inner = read_inner(&sigs);
|
||||
assert_eq!(inner.len(), 1);
|
||||
assert_eq!(
|
||||
inner.get("example.com").and_then(|m| m.get("ed25519:key1")),
|
||||
Some(&"sig1".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_signature_to_existing_server() {
|
||||
let sigs = create_signatures(&[("example.com", "ed25519:key1", "sig1")]);
|
||||
sigs.add_signature(
|
||||
"example.com".to_string(),
|
||||
"ed25519:key2".to_string(),
|
||||
"sig2".to_string(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let inner = read_inner(&sigs);
|
||||
assert_eq!(inner.len(), 1);
|
||||
assert_eq!(
|
||||
inner.get("example.com").and_then(|m| m.get("ed25519:key1")),
|
||||
Some(&"sig1".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
inner.get("example.com").and_then(|m| m.get("ed25519:key2")),
|
||||
Some(&"sig2".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_update_signatures_clobbers_existing() {
|
||||
let sigs = create_signatures(&[("example.com", "ed25519:key1", "sig1")]);
|
||||
|
||||
// Create a new signatures map with a different signature for the same
|
||||
// server.
|
||||
let mut other = HashMap::new();
|
||||
other.insert(
|
||||
"example.com".to_string(),
|
||||
make_server_sigs(&[("ed25519:key2", "sig2")]),
|
||||
);
|
||||
|
||||
// Update the signatures with the new map.
|
||||
Python::initialize();
|
||||
Python::attach(|py| {
|
||||
let value = pythonize(py, &other).unwrap();
|
||||
let value = value.cast::<PyMapping>().unwrap();
|
||||
|
||||
sigs.update(value).unwrap();
|
||||
});
|
||||
|
||||
// Check that the old signature has been replaced with the new one.
|
||||
let inner = read_inner(&sigs);
|
||||
assert_eq!(inner.len(), 1);
|
||||
assert_eq!(inner["example.com"].len(), 1);
|
||||
assert_eq!(inner["example.com"]["ed25519:key2"], "sig2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serialize() {
|
||||
let mut data = HashMap::new();
|
||||
data.insert(
|
||||
"example.com".to_string(),
|
||||
make_server_sigs(&[("ed25519:key1", "sig1")]),
|
||||
);
|
||||
let sigs = Signatures::py_new(Some(data));
|
||||
|
||||
let json = serde_json::to_string(&sigs).unwrap();
|
||||
assert_eq!(json, r#"{"example.com":{"ed25519:key1":"sig1"}}"#);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serialize_empty() {
|
||||
let sigs = Signatures::py_new(None);
|
||||
let json = serde_json::to_string(&sigs).unwrap();
|
||||
assert_eq!(json, "{}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize() {
|
||||
let json = r#"{"example.com":{"ed25519:key1":"sig1"}}"#;
|
||||
let sigs: Signatures = serde_json::from_str(json).unwrap();
|
||||
|
||||
let inner = read_inner(&sigs);
|
||||
assert_eq!(inner.len(), 1);
|
||||
assert_eq!(
|
||||
inner.get("example.com").and_then(|m| m.get("ed25519:key1")),
|
||||
Some(&"sig1".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_empty() {
|
||||
let sigs: Signatures = serde_json::from_str("{}").unwrap();
|
||||
assert!(read_inner(&sigs).is_empty());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,429 @@
|
||||
/*
|
||||
* This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
*
|
||||
* Copyright (C) 2026 Element Creations Ltd
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as
|
||||
* published by the Free Software Foundation, either version 3 of the
|
||||
* License, or (at your option) any later version.
|
||||
*
|
||||
* See the GNU Affero General Public License for more details:
|
||||
* <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
*
|
||||
*/
|
||||
|
||||
use std::sync::{Arc, RwLock, RwLockReadGuard};
|
||||
|
||||
use pyo3::{
|
||||
exceptions::{PyKeyError, PyRuntimeError, PyTypeError},
|
||||
pyclass, pymethods,
|
||||
types::{PyAnyMethods, PyList, PyListMethods, PyMapping},
|
||||
Bound, IntoPyObjectExt, PyAny, PyResult, Python,
|
||||
};
|
||||
use pythonize::{depythonize, pythonize};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[pyclass(frozen, skip_from_py_object)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
#[serde(transparent)]
|
||||
pub struct Unsigned {
|
||||
inner: Arc<RwLock<UnsignedInner>>,
|
||||
}
|
||||
|
||||
/// The fields in the unsigned data of an event that are persisted in the
|
||||
/// database.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
struct PersistedUnsignedFields {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
age_ts: Option<i64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
replaces_state: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
invite_room_state: Option<Vec<serde_json::Value>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
knock_room_state: Option<Vec<serde_json::Value>>,
|
||||
}
|
||||
|
||||
/// The inner representation of the unsigned data of an event, which includes
|
||||
/// both the fields that are persisted in the database and the fields that are
|
||||
/// only used in memory.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct UnsignedInner {
|
||||
#[serde(flatten)]
|
||||
persisted_fields: PersistedUnsignedFields,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
prev_content: Option<Box<serde_json::Value>>, // We use Box to minimise stack space
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
prev_sender: Option<String>,
|
||||
}
|
||||
|
||||
/// The fields that exist on the unsigned data of an event.
|
||||
///
|
||||
/// This is used when converting from python to rust, to ensure that if we add a
|
||||
/// new field we don't forget to add it to all the necessary places.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum UnsignedField {
|
||||
AgeTs,
|
||||
ReplacesState,
|
||||
InviteRoomState,
|
||||
KnockRoomState,
|
||||
PrevContent,
|
||||
PrevSender,
|
||||
}
|
||||
|
||||
impl std::str::FromStr for UnsignedField {
|
||||
type Err = ();
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"age_ts" => Ok(Self::AgeTs),
|
||||
"replaces_state" => Ok(Self::ReplacesState),
|
||||
"invite_room_state" => Ok(Self::InviteRoomState),
|
||||
"knock_room_state" => Ok(Self::KnockRoomState),
|
||||
"prev_content" => Ok(Self::PrevContent),
|
||||
"prev_sender" => Ok(Self::PrevSender),
|
||||
_ => Err(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Unsigned {
|
||||
fn py_read(&self) -> PyResult<RwLockReadGuard<'_, UnsignedInner>> {
|
||||
self.inner
|
||||
.read()
|
||||
.map_err(|_| PyRuntimeError::new_err("Unsigned lock poisoned"))
|
||||
}
|
||||
|
||||
fn py_write(&self) -> PyResult<std::sync::RwLockWriteGuard<'_, UnsignedInner>> {
|
||||
self.inner
|
||||
.write()
|
||||
.map_err(|_| PyRuntimeError::new_err("Unsigned lock poisoned"))
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl Unsigned {
|
||||
#[new]
|
||||
fn py_new(unsigned: Bound<'_, PyMapping>) -> PyResult<Self> {
|
||||
let inner = depythonize(&unsigned)?;
|
||||
|
||||
Ok(Self {
|
||||
inner: Arc::new(RwLock::new(inner)),
|
||||
})
|
||||
}
|
||||
|
||||
fn __getitem__<'py>(
|
||||
&self,
|
||||
py: Python<'py>,
|
||||
key: Bound<'_, PyAny>,
|
||||
) -> PyResult<Bound<'py, PyAny>> {
|
||||
let key = key
|
||||
.extract::<&str>()
|
||||
.map_err(|_| PyTypeError::new_err("Unsigned keys must be strings"))?;
|
||||
|
||||
let field: UnsignedField = key
|
||||
.parse()
|
||||
.map_err(|_| PyKeyError::new_err(format!("Unsigned has no key '{key}'")))?;
|
||||
|
||||
let unsigned = self.py_read()?;
|
||||
|
||||
match field {
|
||||
UnsignedField::AgeTs => Ok(unsigned
|
||||
.persisted_fields
|
||||
.age_ts
|
||||
.ok_or_else(|| PyKeyError::new_err("age_ts"))?
|
||||
.into_bound_py_any(py)?),
|
||||
UnsignedField::ReplacesState => Ok((unsigned.persisted_fields.replaces_state)
|
||||
.as_ref()
|
||||
.ok_or_else(|| PyKeyError::new_err("replaces_state"))?
|
||||
.into_bound_py_any(py)?),
|
||||
UnsignedField::InviteRoomState => Ok(room_state_to_py(
|
||||
py,
|
||||
unsigned
|
||||
.persisted_fields
|
||||
.invite_room_state
|
||||
.as_ref()
|
||||
.ok_or_else(|| PyKeyError::new_err("invite_room_state"))?,
|
||||
)?),
|
||||
UnsignedField::KnockRoomState => Ok(room_state_to_py(
|
||||
py,
|
||||
unsigned
|
||||
.persisted_fields
|
||||
.knock_room_state
|
||||
.as_ref()
|
||||
.ok_or_else(|| PyKeyError::new_err("knock_room_state"))?,
|
||||
)?),
|
||||
UnsignedField::PrevContent => Ok(pythonize(
|
||||
py,
|
||||
unsigned
|
||||
.prev_content
|
||||
.as_ref()
|
||||
.ok_or_else(|| PyKeyError::new_err("prev_content"))?,
|
||||
)?),
|
||||
UnsignedField::PrevSender => Ok((unsigned.prev_sender)
|
||||
.as_ref()
|
||||
.ok_or_else(|| PyKeyError::new_err("prev_sender"))?
|
||||
.into_bound_py_any(py)?),
|
||||
}
|
||||
}
|
||||
|
||||
fn __contains__(&self, key: Bound<'_, PyAny>) -> PyResult<bool> {
|
||||
let Ok(key) = key.extract::<&str>() else {
|
||||
return Ok(false);
|
||||
};
|
||||
|
||||
let Ok(field) = key.parse::<UnsignedField>() else {
|
||||
return Ok(false);
|
||||
};
|
||||
|
||||
let unsigned = self.py_read()?;
|
||||
|
||||
let exists = match field {
|
||||
UnsignedField::AgeTs => unsigned.persisted_fields.age_ts.is_some(),
|
||||
UnsignedField::ReplacesState => unsigned.persisted_fields.replaces_state.is_some(),
|
||||
UnsignedField::InviteRoomState => unsigned.persisted_fields.invite_room_state.is_some(),
|
||||
UnsignedField::KnockRoomState => unsigned.persisted_fields.knock_room_state.is_some(),
|
||||
UnsignedField::PrevContent => unsigned.prev_content.is_some(),
|
||||
UnsignedField::PrevSender => unsigned.prev_sender.is_some(),
|
||||
};
|
||||
|
||||
Ok(exists)
|
||||
}
|
||||
|
||||
fn __setitem__(&self, key: Bound<'_, PyAny>, value: Bound<'_, PyAny>) -> PyResult<()> {
|
||||
let key = key
|
||||
.extract::<&str>()
|
||||
.map_err(|_| PyTypeError::new_err("Unsigned keys must be strings"))?;
|
||||
|
||||
let field: UnsignedField = key
|
||||
.parse()
|
||||
.map_err(|_| PyKeyError::new_err(format!("Unsigned has no key '{key}'")))?;
|
||||
|
||||
let mut unsigned = self.py_write()?;
|
||||
|
||||
match field {
|
||||
UnsignedField::AgeTs => unsigned.persisted_fields.age_ts = Some(value.extract()?),
|
||||
UnsignedField::ReplacesState => {
|
||||
unsigned.persisted_fields.replaces_state = Some(value.extract()?)
|
||||
}
|
||||
UnsignedField::InviteRoomState => {
|
||||
unsigned.persisted_fields.invite_room_state = Some(room_state_from_py(value)?)
|
||||
}
|
||||
UnsignedField::KnockRoomState => {
|
||||
unsigned.persisted_fields.knock_room_state = Some(room_state_from_py(value)?)
|
||||
}
|
||||
UnsignedField::PrevContent => {
|
||||
unsigned.prev_content = Some(Box::new(depythonize(&value)?))
|
||||
}
|
||||
UnsignedField::PrevSender => unsigned.prev_sender = Some(value.extract()?),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn __delitem__(&self, key: Bound<'_, PyAny>) -> PyResult<()> {
|
||||
let key = key
|
||||
.extract::<&str>()
|
||||
.map_err(|_| PyTypeError::new_err("Unsigned keys must be strings"))?;
|
||||
|
||||
let field: UnsignedField = key
|
||||
.parse()
|
||||
.map_err(|_| PyKeyError::new_err(format!("Unsigned has no key '{key}'")))?;
|
||||
|
||||
let mut unsigned = self.py_write()?;
|
||||
|
||||
match field {
|
||||
UnsignedField::AgeTs => unsigned.persisted_fields.age_ts = None,
|
||||
UnsignedField::ReplacesState => unsigned.persisted_fields.replaces_state = None,
|
||||
UnsignedField::InviteRoomState => unsigned.persisted_fields.invite_room_state = None,
|
||||
UnsignedField::KnockRoomState => unsigned.persisted_fields.knock_room_state = None,
|
||||
UnsignedField::PrevContent => unsigned.prev_content = None,
|
||||
UnsignedField::PrevSender => unsigned.prev_sender = None,
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[pyo3(signature = (key, default=None))]
|
||||
fn get<'py>(
|
||||
&self,
|
||||
py: Python<'py>,
|
||||
key: Bound<'py, PyAny>,
|
||||
default: Option<Bound<'py, PyAny>>,
|
||||
) -> PyResult<Option<Bound<'py, PyAny>>> {
|
||||
match self.__getitem__(py, key) {
|
||||
Ok(value) => Ok(Some(value)),
|
||||
Err(err) => {
|
||||
if err.is_instance_of::<PyKeyError>(py) {
|
||||
Ok(default)
|
||||
} else {
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn for_persistence<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
|
||||
Ok(pythonize(py, &self.py_read()?.persisted_fields)?)
|
||||
}
|
||||
|
||||
fn for_event<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
|
||||
Ok(pythonize(py, &*self.py_read()?)?)
|
||||
}
|
||||
}
|
||||
|
||||
fn room_state_to_py<'py>(
|
||||
py: Python<'py>,
|
||||
state: &[serde_json::Value],
|
||||
) -> PyResult<Bound<'py, PyAny>> {
|
||||
let py_list = PyList::empty(py);
|
||||
|
||||
for item in state {
|
||||
py_list.append(pythonize(py, item)?)?;
|
||||
}
|
||||
|
||||
py_list.into_bound_py_any(py)
|
||||
}
|
||||
|
||||
fn room_state_from_py(value: Bound<'_, PyAny>) -> PyResult<Vec<serde_json::Value>> {
|
||||
let py_list = value.cast::<PyList>()?;
|
||||
|
||||
let mut state = Vec::with_capacity(py_list.len());
|
||||
for item in py_list.iter() {
|
||||
state.push(pythonize::depythonize(&item)?);
|
||||
}
|
||||
|
||||
Ok(state)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use serde_json::json;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_unsigned_field_from_str_valid() {
|
||||
assert_eq!("age_ts".parse(), Ok(UnsignedField::AgeTs));
|
||||
assert_eq!("replaces_state".parse(), Ok(UnsignedField::ReplacesState));
|
||||
assert_eq!(
|
||||
"invite_room_state".parse(),
|
||||
Ok(UnsignedField::InviteRoomState)
|
||||
);
|
||||
assert_eq!(
|
||||
"knock_room_state".parse(),
|
||||
Ok(UnsignedField::KnockRoomState)
|
||||
);
|
||||
assert_eq!("prev_content".parse(), Ok(UnsignedField::PrevContent));
|
||||
assert_eq!("prev_sender".parse(), Ok(UnsignedField::PrevSender));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unsigned_field_from_str_invalid() {
|
||||
assert_eq!("".parse::<UnsignedField>(), Err(()));
|
||||
assert_eq!("unknown".parse::<UnsignedField>(), Err(()));
|
||||
// Case-sensitive: upper-case should not match.
|
||||
assert_eq!("AGE_TS".parse::<UnsignedField>(), Err(()));
|
||||
// Must be an exact match, no whitespace.
|
||||
assert_eq!(" age_ts".parse::<UnsignedField>(), Err(()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_persisted_fields_serialize_empty_is_empty_object() {
|
||||
let fields = PersistedUnsignedFields::default();
|
||||
let json = serde_json::to_value(&fields).unwrap();
|
||||
assert_eq!(json, json!({}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_persisted_fields_serialize_populated() {
|
||||
let fields = PersistedUnsignedFields {
|
||||
age_ts: Some(1234),
|
||||
replaces_state: Some("$prev:example.com".to_string()),
|
||||
invite_room_state: Some(vec![json!({"type": "m.room.name"})]),
|
||||
knock_room_state: Some(vec![json!({"type": "m.room.topic"})]),
|
||||
};
|
||||
let json = serde_json::to_value(&fields).unwrap();
|
||||
assert_eq!(
|
||||
json,
|
||||
json!({
|
||||
"age_ts": 1234,
|
||||
"replaces_state": "$prev:example.com",
|
||||
"invite_room_state": [{"type": "m.room.name"}],
|
||||
"knock_room_state": [{"type": "m.room.topic"}],
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unsigned_inner_flattens_persisted_fields() {
|
||||
let inner = UnsignedInner {
|
||||
persisted_fields: PersistedUnsignedFields {
|
||||
age_ts: Some(99),
|
||||
..Default::default()
|
||||
},
|
||||
prev_content: Some(Box::new(json!({"body": "hi"}))),
|
||||
prev_sender: Some("@alice:example.com".to_string()),
|
||||
};
|
||||
|
||||
let json = serde_json::to_value(&inner).unwrap();
|
||||
assert_eq!(
|
||||
json,
|
||||
json!({
|
||||
"age_ts": 99,
|
||||
"prev_content": {"body": "hi"},
|
||||
"prev_sender": "@alice:example.com",
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unsigned_inner_roundtrip() {
|
||||
let original = UnsignedInner {
|
||||
persisted_fields: PersistedUnsignedFields {
|
||||
age_ts: Some(10),
|
||||
replaces_state: Some("$state:example.com".to_string()),
|
||||
invite_room_state: None,
|
||||
knock_room_state: None,
|
||||
},
|
||||
prev_content: Some(Box::new(json!({"membership": "join"}))),
|
||||
prev_sender: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&original).unwrap();
|
||||
let roundtripped: UnsignedInner = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(roundtripped.persisted_fields.age_ts, Some(10));
|
||||
assert_eq!(
|
||||
roundtripped.persisted_fields.replaces_state.as_deref(),
|
||||
Some("$state:example.com")
|
||||
);
|
||||
assert_eq!(
|
||||
roundtripped.prev_content.as_deref(),
|
||||
Some(&json!({"membership": "join"}))
|
||||
);
|
||||
assert_eq!(roundtripped.prev_sender, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unsigned_serializes_transparently() {
|
||||
// `Unsigned` is `#[serde(transparent)]` over its inner, so serializing
|
||||
// an empty default should yield an empty object rather than a wrapper.
|
||||
let unsigned = Unsigned::default();
|
||||
let json = serde_json::to_value(&unsigned).unwrap();
|
||||
assert_eq!(json, json!({}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unsigned_deserialize_from_flat_object() {
|
||||
let json = json!({
|
||||
"age_ts": 5,
|
||||
"prev_sender": "@bob:example.com",
|
||||
});
|
||||
let unsigned: Unsigned = serde_json::from_value(json).unwrap();
|
||||
let inner = unsigned.inner.read().unwrap();
|
||||
assert_eq!(inner.persisted_fields.age_ts, Some(5));
|
||||
assert_eq!(inner.prev_sender.as_deref(), Some("@bob:example.com"));
|
||||
}
|
||||
}
|
||||
@@ -236,9 +236,7 @@ def event_needs_resigning(
|
||||
if sender.domain != server_name:
|
||||
return False
|
||||
want_key_id = verify_key.alg + ":" + verify_key.version
|
||||
signed_with_current_key_id = ev.signatures.get(server_name, {}).get(
|
||||
want_key_id, None
|
||||
)
|
||||
signed_with_current_key_id = ev.signatures.get_signature(server_name, want_key_id)
|
||||
if signed_with_current_key_id:
|
||||
return False
|
||||
|
||||
|
||||
@@ -120,8 +120,18 @@ class VerifyJsonRequest:
|
||||
) -> "VerifyJsonRequest":
|
||||
"""Create a VerifyJsonRequest to verify all signatures on an event
|
||||
object for the given server.
|
||||
|
||||
Raises immediately if the event doesn't have any signatures from the
|
||||
given server.
|
||||
"""
|
||||
key_ids = list(event.signatures.get(server_name, []))
|
||||
if server_name not in event.signatures:
|
||||
raise SynapseError(
|
||||
400,
|
||||
f"Not signed by {server_name}",
|
||||
Codes.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
key_ids = list(event.signatures[server_name])
|
||||
return VerifyJsonRequest(
|
||||
server_name,
|
||||
# We defer creating the redacted json object, as it uses a lot more
|
||||
|
||||
@@ -128,7 +128,7 @@ def validate_event_for_room_version(event: "EventBase") -> None:
|
||||
)
|
||||
|
||||
# Check the sender's domain has signed the event
|
||||
if not event.signatures.get(sender_domain):
|
||||
if sender_domain not in event.signatures:
|
||||
# We allow invites via 3pid to have a sender from a different
|
||||
# HS, as the sender must match the sender of the original
|
||||
# 3pid invite. This is checked further down with the
|
||||
@@ -141,7 +141,7 @@ def validate_event_for_room_version(event: "EventBase") -> None:
|
||||
event_id_domain = get_domain_from_id(event.event_id)
|
||||
|
||||
# Check the origin domain has signed the event
|
||||
if not event.signatures.get(event_id_domain):
|
||||
if event_id_domain not in event.signatures:
|
||||
raise AuthError(403, "Event not signed by sending server")
|
||||
|
||||
is_invite_via_allow_rule = (
|
||||
@@ -154,7 +154,7 @@ def validate_event_for_room_version(event: "EventBase") -> None:
|
||||
authoriser_domain = get_domain_from_id(
|
||||
event.content[EventContentFields.AUTHORISING_USER]
|
||||
)
|
||||
if not event.signatures.get(authoriser_domain):
|
||||
if authoriser_domain not in event.signatures:
|
||||
raise AuthError(403, "Event not signed by authorising server")
|
||||
|
||||
|
||||
|
||||
@@ -44,8 +44,12 @@ from synapse.api.constants import (
|
||||
StickyEvent,
|
||||
)
|
||||
from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
|
||||
from synapse.synapse_rust.events import EventInternalMetadata
|
||||
from synapse.types import JsonDict, StateKey, StrCollection
|
||||
from synapse.synapse_rust.events import EventInternalMetadata, Signatures, Unsigned
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
StateKey,
|
||||
StrCollection,
|
||||
)
|
||||
from synapse.util.caches import intern_dict
|
||||
from synapse.util.duration import Duration
|
||||
from synapse.util.frozenutils import freeze
|
||||
@@ -203,8 +207,8 @@ class EventBase(metaclass=abc.ABCMeta):
|
||||
assert room_version.event_format == self.format_version
|
||||
|
||||
self.room_version = room_version
|
||||
self.signatures = signatures
|
||||
self.unsigned = unsigned
|
||||
self.signatures = Signatures(signatures)
|
||||
self.unsigned = Unsigned(unsigned)
|
||||
self.rejected_reason = rejected_reason
|
||||
|
||||
self._dict = event_dict
|
||||
@@ -254,8 +258,26 @@ class EventBase(metaclass=abc.ABCMeta):
|
||||
return self._dict.get("state_key")
|
||||
|
||||
def get_dict(self) -> JsonDict:
|
||||
"""Convert the event to a dictionary suitable for serialisation."""
|
||||
d = dict(self._dict)
|
||||
d.update({"signatures": self.signatures, "unsigned": dict(self.unsigned)})
|
||||
d.update(
|
||||
{
|
||||
"signatures": self.signatures.as_dict(),
|
||||
"unsigned": self.unsigned.for_event(),
|
||||
}
|
||||
)
|
||||
|
||||
return d
|
||||
|
||||
def get_dict_for_persistence(self) -> JsonDict:
|
||||
"""Convert the event to a dictionary suitable for persistence."""
|
||||
d = dict(self._dict)
|
||||
d.update(
|
||||
{
|
||||
"signatures": self.signatures.as_dict(),
|
||||
"unsigned": self.unsigned.for_persistence(),
|
||||
}
|
||||
)
|
||||
|
||||
return d
|
||||
|
||||
@@ -395,7 +417,7 @@ class FrozenEvent(EventBase):
|
||||
for name, sigs in event_dict.pop("signatures", {}).items()
|
||||
}
|
||||
|
||||
unsigned = dict(event_dict.pop("unsigned", {}))
|
||||
unsigned = event_dict.pop("unsigned", {})
|
||||
|
||||
# We intern these strings because they turn up a lot (especially when
|
||||
# caching).
|
||||
@@ -449,7 +471,7 @@ class FrozenEventV2(EventBase):
|
||||
|
||||
assert "event_id" not in event_dict
|
||||
|
||||
unsigned = dict(event_dict.pop("unsigned", {}))
|
||||
unsigned = event_dict.pop("unsigned", {})
|
||||
|
||||
# We intern these strings because they turn up a lot (especially when
|
||||
# caching).
|
||||
|
||||
@@ -47,6 +47,7 @@ from synapse.api.constants import (
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.api.room_versions import RoomVersion
|
||||
from synapse.logging.opentracing import SynapseTags, set_tag, trace
|
||||
from synapse.synapse_rust.events import Unsigned
|
||||
from synapse.types import JsonDict, Requester
|
||||
|
||||
from . import EventBase, FrozenEventV2, StrippedStateEvent, make_event_from_dict
|
||||
@@ -987,7 +988,7 @@ def validate_canonicaljson(value: Any) -> None:
|
||||
|
||||
|
||||
def maybe_upsert_event_field(
|
||||
event: EventBase, container: JsonDict, key: str, value: object
|
||||
event: EventBase, container: Unsigned, key: str, value: object
|
||||
) -> bool:
|
||||
"""Upsert an event field, but only if this doesn't make the event too large.
|
||||
|
||||
|
||||
@@ -307,20 +307,27 @@ def _is_invite_via_3pid(event: EventBase) -> bool:
|
||||
|
||||
|
||||
def parse_events_from_pdu_json(
|
||||
pdus_json: Sequence[JsonDict], room_version: RoomVersion
|
||||
pdus_json: Sequence[JsonDict],
|
||||
room_version: RoomVersion,
|
||||
received_time: int | None = None,
|
||||
) -> list[EventBase]:
|
||||
return [
|
||||
event_from_pdu_json(pdu_json, room_version)
|
||||
event_from_pdu_json(pdu_json, room_version, received_time=received_time)
|
||||
for pdu_json in filter_pdus_for_valid_depth(pdus_json)
|
||||
]
|
||||
|
||||
|
||||
def event_from_pdu_json(pdu_json: JsonDict, room_version: RoomVersion) -> EventBase:
|
||||
def event_from_pdu_json(
|
||||
pdu_json: JsonDict, room_version: RoomVersion, received_time: int | None = None
|
||||
) -> EventBase:
|
||||
"""Construct an EventBase from an event json received over federation
|
||||
|
||||
Args:
|
||||
pdu_json: pdu as received over federation
|
||||
room_version: The version of the room this event belongs to
|
||||
received_time: timestamp in ms that the event was received at. If
|
||||
`None` then any `age` field in the `unsigned` block will be
|
||||
dropped.
|
||||
|
||||
Raises:
|
||||
SynapseError: if the pdu is missing required fields or is otherwise
|
||||
@@ -333,6 +340,25 @@ def event_from_pdu_json(pdu_json: JsonDict, room_version: RoomVersion) -> EventB
|
||||
if "unsigned" in pdu_json:
|
||||
_strip_unsigned_values(pdu_json)
|
||||
|
||||
# Handle the `age` field, which is sent by some servers as part of the
|
||||
# `unsigned` block. We convert this into an `age_ts` field, which is
|
||||
# what Synapse uses internally. We also remove the `age` field to avoid
|
||||
# confusion.
|
||||
#
|
||||
# c.f. https://github.com/matrix-org/synapse/issues/8429
|
||||
unsigned = pdu_json["unsigned"]
|
||||
age = unsigned.pop("age", None)
|
||||
|
||||
# We check that the `age` is actually an int before using it below. We
|
||||
# don't error here as the `age` a) doesn't affect the validity of the
|
||||
# event, and b) is best effort anyway.
|
||||
if not isinstance(age, int):
|
||||
age = None
|
||||
|
||||
unsigned.pop("age_ts", None)
|
||||
if received_time is not None and age is not None:
|
||||
unsigned["age_ts"] = received_time - int(age)
|
||||
|
||||
depth = pdu_json["depth"]
|
||||
if type(depth) is not int: # noqa: E721
|
||||
raise SynapseError(400, "Depth %r not an intger" % (depth,), Codes.BAD_JSON)
|
||||
|
||||
@@ -1574,10 +1574,15 @@ class FederationClient(FederationBase):
|
||||
min_depth=min_depth,
|
||||
timeout=timeout,
|
||||
)
|
||||
received_time = self._clock.time_msec()
|
||||
|
||||
room_version = await self.store.get_room_version(room_id)
|
||||
|
||||
events = parse_events_from_pdu_json(content.get("events", []), room_version)
|
||||
events = parse_events_from_pdu_json(
|
||||
content.get("events", []),
|
||||
room_version,
|
||||
received_time=received_time,
|
||||
)
|
||||
|
||||
signed_events = await self._check_sigs_and_hash_for_pulled_events_and_fetch(
|
||||
destination, events, room_version=room_version
|
||||
|
||||
@@ -451,16 +451,6 @@ class FederationServer(FederationBase):
|
||||
newest_pdu_ts = 0
|
||||
|
||||
for p in transaction.pdus:
|
||||
# FIXME (richardv): I don't think this works:
|
||||
# https://github.com/matrix-org/synapse/issues/8429
|
||||
if "unsigned" in p:
|
||||
unsigned = p["unsigned"]
|
||||
if "age" in unsigned:
|
||||
p["age"] = unsigned["age"]
|
||||
if "age" in p:
|
||||
p["age_ts"] = request_time - int(p["age"])
|
||||
del p["age"]
|
||||
|
||||
# We try and pull out an event ID so that if later checks fail we
|
||||
# can log something sensible. We don't mandate an event ID here in
|
||||
# case future event formats get rid of the key.
|
||||
@@ -488,10 +478,15 @@ class FederationServer(FederationBase):
|
||||
continue
|
||||
|
||||
try:
|
||||
event = event_from_pdu_json(p, room_version)
|
||||
event = event_from_pdu_json(p, room_version, received_time=request_time)
|
||||
except SynapseError as e:
|
||||
logger.info("Ignoring PDU for failing to deserialize: %s", e)
|
||||
continue
|
||||
except Exception as e:
|
||||
# We catch all exceptions here as we don't want a single bad
|
||||
# event to cause us to fail the whole transaction.
|
||||
logger.exception("Error deserializing PDU: %s", e)
|
||||
continue
|
||||
|
||||
pdus_by_room.setdefault(room_id, []).append(event)
|
||||
|
||||
|
||||
@@ -2089,10 +2089,9 @@ class EventCreationHandler:
|
||||
returned_invite = await federation_handler.send_invite(
|
||||
invitee.domain, event
|
||||
)
|
||||
event.unsigned.pop("room_state", None)
|
||||
|
||||
# TODO: Make sure the signatures actually are correct.
|
||||
event.signatures.update(returned_invite.signatures)
|
||||
event.signatures.update(returned_invite.signatures.as_dict())
|
||||
|
||||
if event.content["membership"] == Membership.KNOCK:
|
||||
maybe_upsert_event_field(
|
||||
|
||||
@@ -566,7 +566,7 @@ class PaginationHandler:
|
||||
(
|
||||
events,
|
||||
next_key,
|
||||
_,
|
||||
limited,
|
||||
) = await self.store.paginate_room_events_by_topological_ordering(
|
||||
room_id=room_id,
|
||||
from_key=from_token.room_key,
|
||||
@@ -645,7 +645,7 @@ class PaginationHandler:
|
||||
(
|
||||
events,
|
||||
next_key,
|
||||
_,
|
||||
limited,
|
||||
) = await self.store.paginate_room_events_by_topological_ordering(
|
||||
room_id=room_id,
|
||||
from_key=from_token.room_key,
|
||||
@@ -668,11 +668,12 @@ class PaginationHandler:
|
||||
|
||||
next_token = from_token.copy_and_replace(StreamKeyType.ROOM, next_key)
|
||||
|
||||
# if no events are returned from pagination, that implies
|
||||
# we have reached the end of the available events.
|
||||
# if no events are returned from pagination (this page is empty)
|
||||
# and there aren't any more pages (not limited),
|
||||
# that implies we have reached the end of the available events.
|
||||
# In that case we do not return end, to tell the client
|
||||
# there is no need for further queries.
|
||||
if not events:
|
||||
if not limited and not events:
|
||||
return GetMessagesResult(
|
||||
messages_chunk=[],
|
||||
bundled_aggregations={},
|
||||
|
||||
@@ -181,9 +181,10 @@ class RoomPolicyHandler:
|
||||
async def _verify_policy_server_signature(
|
||||
self, event: EventBase, policy_server: str, public_key: str
|
||||
) -> bool:
|
||||
# check the event is signed with this (via, public_key).
|
||||
verify_json_req = VerifyJsonRequest.from_event(policy_server, event, 0)
|
||||
try:
|
||||
# check the event is signed with this (via, public_key).
|
||||
verify_json_req = VerifyJsonRequest.from_event(policy_server, event, 0)
|
||||
|
||||
key_bytes = decode_base64(public_key)
|
||||
verify_key = decode_verify_key_bytes(POLICY_SERVER_KEY_ID, key_bytes)
|
||||
# We would normally use KeyRing.verify_event_for_server but we can't here as we don't
|
||||
@@ -260,9 +261,7 @@ class RoomPolicyHandler:
|
||||
# servers need to manually fetch signatures for. This is the code that allows
|
||||
# those events to continue working (because they're legally sent, even if missing
|
||||
# the policy server signature).
|
||||
event.signatures.setdefault(policy_server.server_name, {}).update(
|
||||
signature.get(policy_server.server_name, {})
|
||||
)
|
||||
event.signatures.update(signature)
|
||||
except HttpResponseException as ex:
|
||||
# re-wrap HTTP errors as `SynapseError` so they can be proxied to clients directly
|
||||
raise ex.to_synapse_error() from ex
|
||||
|
||||
@@ -54,6 +54,20 @@ logger = logging.getLogger(__name__)
|
||||
# will not disappear under our feet as long as we don't delete the room.
|
||||
NEW_EVENT_DURING_PURGE_LOCK_NAME = "new_event_during_purge_lock"
|
||||
|
||||
WORKER_LOCK_MAX_RETRY_INTERVAL = Duration(seconds=5)
|
||||
"""
|
||||
The maximum wait time before retrying to acquire the lock.
|
||||
|
||||
Better to retry more quickly than have workers wait around. 5 seconds is still a
|
||||
reasonable gap in time to not overwhelm the CPU/Database.
|
||||
|
||||
This matters most in cross-worker scenarios. When locks are on the same worker, when the
|
||||
lock holder releases, we signal to other locks (with the same name/key) that they
|
||||
should try reacquiring the lock immediately. But locks on other workers only re-check
|
||||
based on their retry `_timeout_interval`.
|
||||
"""
|
||||
WORKER_LOCK_EXCESSIVE_WAITING_WARN_DURATION = Duration(minutes=10)
|
||||
|
||||
|
||||
class WorkerLocksHandler:
|
||||
"""A class for waiting on taking out locks, rather than using the storage
|
||||
@@ -206,9 +220,10 @@ class WaitingLock:
|
||||
lock_name: str
|
||||
lock_key: str
|
||||
write: bool | None
|
||||
start_ts_ms: int = 0
|
||||
deferred: "defer.Deferred[None]" = attr.Factory(defer.Deferred)
|
||||
_inner_lock: Lock | None = None
|
||||
_retry_interval: float = 0.1
|
||||
_timeout_interval: float = 0.1
|
||||
_lock_span: "opentracing.Scope" = attr.Factory(
|
||||
lambda: start_active_span("WaitingLock.lock")
|
||||
)
|
||||
@@ -220,6 +235,7 @@ class WaitingLock:
|
||||
self.deferred.callback(None)
|
||||
|
||||
async def __aenter__(self) -> None:
|
||||
self.start_ts_ms = self.clock.time_msec()
|
||||
self._lock_span.__enter__()
|
||||
|
||||
with start_active_span("WaitingLock.waiting_for_lock"):
|
||||
@@ -240,19 +256,44 @@ class WaitingLock:
|
||||
break
|
||||
|
||||
try:
|
||||
# Wait until the we get notified the lock might have been
|
||||
# Wait until the notification that the lock might have been
|
||||
# released (by the deferred being resolved). We also
|
||||
# periodically wake up in case the lock was released but we
|
||||
# periodically wake up in case the lock was released, but we
|
||||
# weren't notified.
|
||||
with PreserveLoggingContext():
|
||||
timeout = self._get_next_retry_interval()
|
||||
await timeout_deferred(
|
||||
deferred=self.deferred,
|
||||
timeout=timeout,
|
||||
timeout=self._timeout_interval,
|
||||
clock=self.clock,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
except defer.TimeoutError:
|
||||
# Only increment the timeout value if this was an actual timeout
|
||||
# (defer.TimeoutError)
|
||||
self._increment_timeout_interval()
|
||||
|
||||
now_ms = self.clock.time_msec()
|
||||
time_spent_trying_to_lock = Duration(
|
||||
milliseconds=now_ms - self.start_ts_ms
|
||||
)
|
||||
if (
|
||||
time_spent_trying_to_lock.as_millis()
|
||||
> WORKER_LOCK_EXCESSIVE_WAITING_WARN_DURATION.as_millis()
|
||||
):
|
||||
logger.warning(
|
||||
"(WaitingLock (%s, %s)) Time spent waiting to acquire lock "
|
||||
"is getting excessive: %ss. There may be a deadlock.",
|
||||
self.lock_name,
|
||||
self.lock_key,
|
||||
time_spent_trying_to_lock.as_secs(),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Caught an exception while waiting on WaitingLock(lock_name=%s, lock_key=%s): %r",
|
||||
self.lock_name,
|
||||
self.lock_key,
|
||||
e,
|
||||
)
|
||||
|
||||
return await self._inner_lock.__aenter__()
|
||||
|
||||
@@ -273,15 +314,14 @@ class WaitingLock:
|
||||
|
||||
return r
|
||||
|
||||
def _get_next_retry_interval(self) -> float:
|
||||
next = self._retry_interval
|
||||
self._retry_interval = max(5, next * 2)
|
||||
if self._retry_interval > Duration(minutes=10).as_secs(): # >7 iterations
|
||||
logger.warning(
|
||||
"Lock timeout is getting excessive: %ss. There may be a deadlock.",
|
||||
self._retry_interval,
|
||||
)
|
||||
return next * random.uniform(0.9, 1.1)
|
||||
def _increment_timeout_interval(self) -> float:
|
||||
next_interval = self._timeout_interval
|
||||
next_interval = min(WORKER_LOCK_MAX_RETRY_INTERVAL.as_secs(), next_interval * 2)
|
||||
|
||||
# The jitter value is maintained for the timeout, to help avoid a "thundering
|
||||
# herd" situation when all locks may time out at the same time.
|
||||
self._timeout_interval = next_interval * random.uniform(0.9, 1.1)
|
||||
return self._timeout_interval
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True, eq=False)
|
||||
@@ -294,10 +334,11 @@ class WaitingMultiLock:
|
||||
store: LockStore
|
||||
handler: WorkerLocksHandler
|
||||
|
||||
start_ts_ms: int = 0
|
||||
deferred: "defer.Deferred[None]" = attr.Factory(defer.Deferred)
|
||||
|
||||
_inner_lock_cm: AsyncContextManager | None = None
|
||||
_retry_interval: float = 0.1
|
||||
_timeout_interval: float = 0.1
|
||||
_lock_span: "opentracing.Scope" = attr.Factory(
|
||||
lambda: start_active_span("WaitingLock.lock")
|
||||
)
|
||||
@@ -309,6 +350,7 @@ class WaitingMultiLock:
|
||||
self.deferred.callback(None)
|
||||
|
||||
async def __aenter__(self) -> None:
|
||||
self.start_ts_ms = self.clock.time_msec()
|
||||
self._lock_span.__enter__()
|
||||
|
||||
with start_active_span("WaitingLock.waiting_for_lock"):
|
||||
@@ -324,19 +366,42 @@ class WaitingMultiLock:
|
||||
break
|
||||
|
||||
try:
|
||||
# Wait until the we get notified the lock might have been
|
||||
# Wait until the notification that the lock might have been
|
||||
# released (by the deferred being resolved). We also
|
||||
# periodically wake up in case the lock was released but we
|
||||
# periodically wake up in case the lock was released, but we
|
||||
# weren't notified.
|
||||
with PreserveLoggingContext():
|
||||
timeout = self._get_next_retry_interval()
|
||||
await timeout_deferred(
|
||||
deferred=self.deferred,
|
||||
timeout=timeout,
|
||||
timeout=self._timeout_interval,
|
||||
clock=self.clock,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
except defer.TimeoutError:
|
||||
# Only increment the timeout value if this was an actual timeout
|
||||
# (defer.TimeoutError)
|
||||
self._increment_timeout_interval()
|
||||
|
||||
now_ms = self.clock.time_msec()
|
||||
time_spent_trying_to_lock = Duration(
|
||||
milliseconds=now_ms - self.start_ts_ms
|
||||
)
|
||||
if (
|
||||
time_spent_trying_to_lock.as_millis()
|
||||
> WORKER_LOCK_EXCESSIVE_WAITING_WARN_DURATION.as_millis()
|
||||
):
|
||||
logger.warning(
|
||||
"(WaitingMultiLock (%r)) Time spent waiting to acquire lock "
|
||||
"is getting excessive: %ss. There may be a deadlock.",
|
||||
self.lock_names,
|
||||
time_spent_trying_to_lock.as_secs(),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Caught an exception while waiting on WaitingMultiLock(lock_names=%r): %r",
|
||||
self.lock_names,
|
||||
e,
|
||||
)
|
||||
|
||||
assert self._inner_lock_cm
|
||||
await self._inner_lock_cm.__aenter__()
|
||||
@@ -360,12 +425,11 @@ class WaitingMultiLock:
|
||||
|
||||
return r
|
||||
|
||||
def _get_next_retry_interval(self) -> float:
|
||||
next = self._retry_interval
|
||||
self._retry_interval = max(5, next * 2)
|
||||
if self._retry_interval > Duration(minutes=10).as_secs(): # >7 iterations
|
||||
logger.warning(
|
||||
"Lock timeout is getting excessive: %ss. There may be a deadlock.",
|
||||
self._retry_interval,
|
||||
)
|
||||
return next * random.uniform(0.9, 1.1)
|
||||
def _increment_timeout_interval(self) -> float:
|
||||
next_interval = self._timeout_interval
|
||||
next_interval = min(WORKER_LOCK_MAX_RETRY_INTERVAL.as_secs(), next_interval * 2)
|
||||
|
||||
# The jitter value is maintained for the timeout, to help avoid a "thundering
|
||||
# herd" situation when all locks may time out at the same time.
|
||||
self._timeout_interval = next_interval * random.uniform(0.9, 1.1)
|
||||
return self._timeout_interval
|
||||
|
||||
@@ -2711,7 +2711,7 @@ class PersistEventsStore:
|
||||
return
|
||||
|
||||
def event_dict(event: EventBase) -> JsonDict:
|
||||
d = event.get_dict()
|
||||
d = event.get_dict_for_persistence()
|
||||
d.pop("redacted", None)
|
||||
d.pop("redacted_because", None)
|
||||
return d
|
||||
|
||||
@@ -2824,8 +2824,8 @@ class EventsBackgroundUpdatesStore(
|
||||
# with the provided old key.
|
||||
if old_verify_key is not None:
|
||||
old_key_id = f"{old_verify_key.alg}:{old_verify_key.version}"
|
||||
server_sigs = event.signatures.get(self.hs.hostname, {})
|
||||
if old_key_id not in server_sigs:
|
||||
old_sig = event.signatures.get_signature(self.hs.hostname, old_key_id)
|
||||
if old_sig is None:
|
||||
# Event wasn't signed with this key ID at all, skip.
|
||||
continue
|
||||
|
||||
|
||||
@@ -779,14 +779,26 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
events.append(event)
|
||||
|
||||
if get_prev_content:
|
||||
if "replaces_state" in event.unsigned:
|
||||
# The `event` here might be in the cache, and so might have
|
||||
# already had the `prev_content` and `prev_sender` fields added
|
||||
# to its unsigned.
|
||||
#
|
||||
# We check if a) we should add the previous content, and b) if
|
||||
# we have already added it.
|
||||
replaces_state = "replaces_state" in event.unsigned
|
||||
has_prev = (
|
||||
"prev_content" in event.unsigned and "prev_sender" in event.unsigned
|
||||
)
|
||||
if replaces_state and not has_prev:
|
||||
prev = await self.get_event(
|
||||
event.unsigned["replaces_state"],
|
||||
get_prev_content=False,
|
||||
allow_none=True,
|
||||
)
|
||||
if prev:
|
||||
event.unsigned = dict(event.unsigned)
|
||||
# This mutates the cached event, but that's fine as the
|
||||
# previous content/sender will be the same for all
|
||||
# requests for this event.
|
||||
event.unsigned["prev_content"] = prev.content
|
||||
event.unsigned["prev_sender"] = prev.sender
|
||||
|
||||
|
||||
@@ -2425,12 +2425,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
event_filter: If provided filters the events to those that match the filter.
|
||||
|
||||
Returns:
|
||||
The results as a list of events, a token that points to the end of
|
||||
the result set, and a boolean to indicate if there were more events
|
||||
but we hit the limit. If no events are returned then the end of the
|
||||
- The results as a list of events;
|
||||
- a token that points to the end of the result set; and
|
||||
- a boolean to indicate if there were more events
|
||||
but we hit the limit (`limited`)
|
||||
|
||||
If no events are returned and `limited` is false, then the end of the
|
||||
stream has been reached (i.e. there are no events between `from_key`
|
||||
and `to_key`).
|
||||
|
||||
When `limited` is true, that means that more pagination can be attempted.
|
||||
Note that `limited` can be true even if no events are returned,
|
||||
because rejected events are filtered out after the limit check.
|
||||
|
||||
When Direction.FORWARDS: from_key < x <= to_key, (ascending order)
|
||||
When Direction.BACKWARDS: from_key >= x > to_key, (descending order)
|
||||
"""
|
||||
|
||||
@@ -10,9 +10,9 @@
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
from typing import Mapping
|
||||
from typing import Any, Mapping
|
||||
|
||||
from synapse.types import JsonDict
|
||||
from synapse.types import JsonDict, JsonMapping
|
||||
|
||||
class EventInternalMetadata:
|
||||
def __init__(self, internal_metadata_dict: JsonDict): ...
|
||||
@@ -154,3 +154,62 @@ def event_visible_to_server(
|
||||
Returns:
|
||||
Whether the server is allowed to see the unredacted event.
|
||||
"""
|
||||
|
||||
class Signatures:
|
||||
"""A class representing the signatures on an event."""
|
||||
|
||||
def __init__(self, signatures: Mapping[str, Mapping[str, str]] | None = None): ...
|
||||
def get_signature(self, server_name: str, key_id: str) -> str | None: ...
|
||||
"""Get the signature for the given server name and key ID, if it exists."""
|
||||
|
||||
def __getitem__(self, server_name: str) -> Mapping[str, str]: ...
|
||||
"""Get the signatures for the given server name. Raises KeyError if there
|
||||
are no signatures for that server."""
|
||||
|
||||
def __contains__(self, server_name: Any) -> bool: ...
|
||||
"""Check if there are signatures for the given server name."""
|
||||
|
||||
def __len__(self) -> int: ...
|
||||
"""Return the number of servers that have signatures."""
|
||||
|
||||
def add_signature(self, server_name: str, key_id: str, signature: str) -> None: ...
|
||||
"""Add a signature for the given server name and key ID."""
|
||||
|
||||
def update(self, signatures: Mapping[str, Mapping[str, str]]) -> None: ...
|
||||
"""Update the signatures with the given signatures.
|
||||
|
||||
Will overwrite all existing signatures for the server names provided.
|
||||
"""
|
||||
|
||||
def as_dict(self) -> dict[str, dict[str, str]]: ...
|
||||
"""Return a copy of the signatures as a dictionary."""
|
||||
|
||||
class Unsigned:
|
||||
"""A class representing the unsigned data of an event."""
|
||||
|
||||
def __init__(self, unsigned_dict: JsonMapping): ...
|
||||
def __getitem__(self, key: str) -> Any: ...
|
||||
"""Get the value for the given key.
|
||||
|
||||
Raises KeyError if the key is unset or not recognised."""
|
||||
|
||||
def __setitem__(self, key: str, value: Any) -> None: ...
|
||||
"""Set the value for the given key.
|
||||
|
||||
Raises KeyError if the key is not recognised."""
|
||||
|
||||
def __delitem__(self, key: str) -> None: ...
|
||||
"""Delete the value for the given key.
|
||||
|
||||
Raises KeyError if the key is unset or not recognised."""
|
||||
|
||||
def __contains__(self, key: Any) -> bool: ...
|
||||
def get(self, key: str, default: Any = None) -> Any: ...
|
||||
"""Get the value for the given key, or ``default`` if the key is unset."""
|
||||
|
||||
def for_persistence(self) -> JsonDict: ...
|
||||
"""Return a dict of the fields that should be persisted to the database."""
|
||||
|
||||
def for_event(self) -> JsonDict: ...
|
||||
"""Return a dict of all unsigned fields, including those only kept in
|
||||
memory, suitable for inclusion in an event."""
|
||||
|
||||
@@ -32,6 +32,33 @@ class Duration(timedelta):
|
||||
```
|
||||
"""
|
||||
|
||||
# Using `__new__` (instead of `__init__`) because that's what `timedelta` uses
|
||||
def __new__(
|
||||
cls,
|
||||
# The whole goal of overriding `__new__` is to require keyword-only arguments.
|
||||
# Without this, `Duration(5)` would create a duration represnting 5 *days*
|
||||
# (timedelta's default), but callers almost certainly want to specify which unit
|
||||
# like seconds or hours.
|
||||
*,
|
||||
days: float = 0,
|
||||
seconds: float = 0,
|
||||
microseconds: float = 0,
|
||||
milliseconds: float = 0,
|
||||
minutes: float = 0,
|
||||
hours: float = 0,
|
||||
weeks: float = 0,
|
||||
) -> "Duration":
|
||||
return super().__new__(
|
||||
cls,
|
||||
days=days,
|
||||
seconds=seconds,
|
||||
microseconds=microseconds,
|
||||
milliseconds=milliseconds,
|
||||
minutes=minutes,
|
||||
hours=hours,
|
||||
weeks=weeks,
|
||||
)
|
||||
|
||||
def as_millis(self) -> int:
|
||||
"""Returns the duration in milliseconds."""
|
||||
return int(self / _ONE_MILLISECOND)
|
||||
|
||||
+21
-11
@@ -67,25 +67,31 @@ def MockEvent(**kwargs: Any) -> EventBase:
|
||||
class TestMaybeUpsertEventField(stdlib_unittest.TestCase):
|
||||
def test_update_okay(self) -> None:
|
||||
event = make_event_from_dict({"event_id": "$1234"})
|
||||
success = maybe_upsert_event_field(event, event.unsigned, "key", "value")
|
||||
success = maybe_upsert_event_field(
|
||||
event, event.unsigned, "replaces_state", "value"
|
||||
)
|
||||
self.assertTrue(success)
|
||||
self.assertEqual(event.unsigned["key"], "value")
|
||||
self.assertEqual(event.unsigned["replaces_state"], "value")
|
||||
|
||||
def test_update_not_okay(self) -> None:
|
||||
event = make_event_from_dict({"event_id": "$1234"})
|
||||
LARGE_STRING = "a" * 100_000
|
||||
success = maybe_upsert_event_field(event, event.unsigned, "key", LARGE_STRING)
|
||||
success = maybe_upsert_event_field(
|
||||
event, event.unsigned, "replaces_state", LARGE_STRING
|
||||
)
|
||||
self.assertFalse(success)
|
||||
self.assertNotIn("key", event.unsigned)
|
||||
self.assertNotIn("replaces_state", event.unsigned)
|
||||
|
||||
def test_update_not_okay_leaves_original_value(self) -> None:
|
||||
event = make_event_from_dict(
|
||||
{"event_id": "$1234", "unsigned": {"key": "value"}}
|
||||
{"event_id": "$1234", "unsigned": {"replaces_state": "value"}}
|
||||
)
|
||||
LARGE_STRING = "a" * 100_000
|
||||
success = maybe_upsert_event_field(event, event.unsigned, "key", LARGE_STRING)
|
||||
success = maybe_upsert_event_field(
|
||||
event, event.unsigned, "replaces_state", LARGE_STRING
|
||||
)
|
||||
self.assertFalse(success)
|
||||
self.assertEqual(event.unsigned["key"], "value")
|
||||
self.assertEqual(event.unsigned["replaces_state"], "value")
|
||||
|
||||
|
||||
class PruneEventTestCase(stdlib_unittest.TestCase):
|
||||
@@ -623,7 +629,7 @@ class CloneEventTestCase(stdlib_unittest.TestCase):
|
||||
{
|
||||
"type": "A",
|
||||
"event_id": "$test:domain",
|
||||
"unsigned": {"a": 1, "b": 2},
|
||||
"unsigned": {"age_ts": 1, "replaces_state": "2"},
|
||||
},
|
||||
RoomVersions.V1,
|
||||
{"txn_id": "txn"},
|
||||
@@ -634,10 +640,14 @@ class CloneEventTestCase(stdlib_unittest.TestCase):
|
||||
self.assertEqual(original.internal_metadata.instance_name, "worker1")
|
||||
|
||||
cloned = clone_event(original)
|
||||
cloned.unsigned["b"] = 3
|
||||
cloned.unsigned["age_ts"] = 3
|
||||
|
||||
self.assertEqual(original.unsigned, {"a": 1, "b": 2})
|
||||
self.assertEqual(cloned.unsigned, {"a": 1, "b": 3})
|
||||
self.assertEqual(
|
||||
original.unsigned.for_event(), {"age_ts": 1, "replaces_state": "2"}
|
||||
)
|
||||
self.assertEqual(
|
||||
cloned.unsigned.for_event(), {"age_ts": 3, "replaces_state": "2"}
|
||||
)
|
||||
self.assertEqual(cloned.internal_metadata.stream_ordering, 1234)
|
||||
self.assertEqual(cloned.internal_metadata.instance_name, "worker1")
|
||||
self.assertEqual(cloned.internal_metadata.txn_id, "txn")
|
||||
|
||||
@@ -754,9 +754,9 @@ class StripUnsignedFromEventsTestCase(unittest.TestCase):
|
||||
},
|
||||
}
|
||||
|
||||
filtered_event2 = event_from_pdu_json(event2, RoomVersions.V1)
|
||||
self.assertIn("age", filtered_event2.unsigned)
|
||||
self.assertEqual(14, filtered_event2.unsigned["age"])
|
||||
filtered_event2 = event_from_pdu_json(event2, RoomVersions.V1, received_time=20)
|
||||
self.assertIn("age_ts", filtered_event2.unsigned)
|
||||
self.assertEqual(6, filtered_event2.unsigned["age_ts"])
|
||||
self.assertNotIn("more warez", filtered_event2.unsigned)
|
||||
# Invite_room_state is allowed in events of type m.room.member
|
||||
self.assertIn("invite_room_state", filtered_event2.unsigned)
|
||||
@@ -779,8 +779,8 @@ class StripUnsignedFromEventsTestCase(unittest.TestCase):
|
||||
"invite_room_state": [],
|
||||
},
|
||||
}
|
||||
filtered_event3 = event_from_pdu_json(event3, RoomVersions.V1)
|
||||
self.assertIn("age", filtered_event3.unsigned)
|
||||
filtered_event3 = event_from_pdu_json(event3, RoomVersions.V1, received_time=20)
|
||||
self.assertIn("age_ts", filtered_event3.unsigned)
|
||||
# Invite_room_state field is only permitted in event type m.room.member
|
||||
self.assertNotIn("invite_room_state", filtered_event3.unsigned)
|
||||
self.assertNotIn("more warez", filtered_event3.unsigned)
|
||||
|
||||
@@ -368,7 +368,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
|
||||
)
|
||||
# the auth code requires that a signature exists, but doesn't check that
|
||||
# signature... go figure.
|
||||
join_event.signatures[other_server] = {"x": "y"}
|
||||
join_event.signatures.update({other_server: {"x": "y"}})
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_federation_event_handler().on_send_membership_event(
|
||||
|
||||
@@ -26,7 +26,9 @@ from twisted.internet import defer
|
||||
from twisted.internet.testing import MemoryReactor
|
||||
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.databases.main.lock import _RENEWAL_INTERVAL
|
||||
from synapse.util.clock import Clock
|
||||
from synapse.util.duration import Duration
|
||||
|
||||
from tests import unittest
|
||||
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||
@@ -40,6 +42,7 @@ class WorkerLockTestCase(unittest.HomeserverTestCase):
|
||||
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
|
||||
) -> None:
|
||||
self.worker_lock_handler = self.hs.get_worker_locks_handler()
|
||||
self.store = self.hs.get_datastores().main
|
||||
|
||||
def test_wait_for_lock_locally(self) -> None:
|
||||
"""Test waiting for a lock on a single worker"""
|
||||
@@ -56,6 +59,66 @@ class WorkerLockTestCase(unittest.HomeserverTestCase):
|
||||
self.get_success(d2)
|
||||
self.get_success(lock2.__aexit__(None, None, None))
|
||||
|
||||
def test_timeouts_for_lock_locally(self) -> None:
|
||||
"""
|
||||
Test that we regularly retry to reacquire locks.
|
||||
|
||||
This is a regression test to make sure the lock retry time doesn't balloon to a value
|
||||
so large it can't even be printed reliably anymore.
|
||||
"""
|
||||
|
||||
# Create and acquire the first lock
|
||||
lock1 = self.worker_lock_handler.acquire_lock("name", "key")
|
||||
self.get_success(lock1.__aenter__())
|
||||
|
||||
# Create and try to acquire the second lock
|
||||
lock2 = self.worker_lock_handler.acquire_lock("name", "key")
|
||||
d2 = defer.ensureDeferred(lock2.__aenter__())
|
||||
# Make sure we haven't acquired the lock yet (`lock1` still holds it)
|
||||
self.assertNoResult(d2)
|
||||
|
||||
# Advance time by an hour (some duration that would previously cause our timeout
|
||||
# to balloon if it weren't constrained). Max back-off (saturate)
|
||||
#
|
||||
# Note: We use `_pump_by` instead of `pump`/`advance` as the `Lock` has an
|
||||
# internal background looping call that runs every 30 seconds
|
||||
# (`_RENEWAL_INTERVAL`) to renew the `Lock` and push it's "drop timeout" value
|
||||
# further out by 2 minutes (`_LOCK_TIMEOUT_MS`). The `Lock` will prematurely
|
||||
# drop if this renewal is not allowed to run, which sours the test.
|
||||
# self.pump(amount=Duration(hours=1))
|
||||
self._pump_by(amount=Duration(hours=1), by=_RENEWAL_INTERVAL)
|
||||
|
||||
# Make sure we haven't acquired the `lock2` yet (`lock1` still holds it)
|
||||
self.assertNoResult(d2)
|
||||
|
||||
# Release the first lock (`lock1`). The second lock(`lock2`) should be
|
||||
# automatically acquired by the `pump()` inside `get_success()`
|
||||
self.get_success(lock1.__aexit__(None, None, None))
|
||||
|
||||
# We should now have the lock
|
||||
self.successResultOf(d2)
|
||||
|
||||
def _pump_by(
|
||||
self,
|
||||
*,
|
||||
amount: Duration = Duration(seconds=0),
|
||||
by: Duration = Duration(seconds=0.1),
|
||||
) -> None:
|
||||
"""
|
||||
Like `self.pump()` but you can specify the time increment to advance with until
|
||||
you reach the time amount.
|
||||
|
||||
Unlike `self.pump()`, this doesn't multiply the time at all.
|
||||
|
||||
Args:
|
||||
amount: The amount of time to advance
|
||||
by: The time increment in seconds to advance time by until we reach the `amount`
|
||||
"""
|
||||
end_time_s = self.reactor.seconds() + amount.as_secs()
|
||||
|
||||
while self.reactor.seconds() < end_time_s:
|
||||
self.reactor.advance(by.as_secs())
|
||||
|
||||
def test_lock_contention(self) -> None:
|
||||
"""Test lock contention when a lot of locks wait on a single worker"""
|
||||
nb_locks_to_test = 500
|
||||
@@ -124,3 +187,70 @@ class WorkerLockWorkersTestCase(BaseMultiWorkerStreamTestCase):
|
||||
|
||||
self.get_success(d2)
|
||||
self.get_success(lock2.__aexit__(None, None, None))
|
||||
|
||||
def test_timeouts_for_lock_worker(self) -> None:
|
||||
"""
|
||||
Test that we regularly retry to reacquire locks.
|
||||
|
||||
This is a regression test to make sure the lock retry time doesn't balloon to a value
|
||||
so large it can't even be printed reliably anymore.
|
||||
"""
|
||||
worker = self.make_worker_hs(
|
||||
"synapse.app.generic_worker",
|
||||
extra_config={
|
||||
"redis": {"enabled": True},
|
||||
},
|
||||
)
|
||||
worker_lock_handler = worker.get_worker_locks_handler()
|
||||
|
||||
# Create and acquire the first lock on the main process
|
||||
lock1 = self.main_worker_lock_handler.acquire_lock("name", "key")
|
||||
self.get_success(lock1.__aenter__())
|
||||
|
||||
# Create and try to acquire the second lock on the worker
|
||||
lock2 = worker_lock_handler.acquire_lock("name", "key")
|
||||
d2 = defer.ensureDeferred(lock2.__aenter__())
|
||||
# Make sure we haven't acquired the lock yet (`lock1` still holds it)
|
||||
self.assertNoResult(d2)
|
||||
|
||||
# Advance time by an hour (some duration that would previously cause our timeout
|
||||
# to balloon if it weren't constrained). Max back-off (saturate)
|
||||
#
|
||||
# Note: We use `_pump_by` instead of `pump`/`advance` as the `Lock` has an
|
||||
# internal background looping call that runs every 30 seconds
|
||||
# (`_RENEWAL_INTERVAL`) to renew the `Lock` and push it's "drop timeout" value
|
||||
# further out by 2 minutes (`_LOCK_TIMEOUT_MS`). The `Lock` will prematurely
|
||||
# drop if this renewal is not allowed to run, which sours the test.
|
||||
# self.pump(amount=Duration(hours=1))
|
||||
self._pump_by(amount=Duration(hours=1), by=_RENEWAL_INTERVAL)
|
||||
|
||||
# Make sure we haven't acquired the `lock2` yet (`lock1` still holds it)
|
||||
self.assertNoResult(d2)
|
||||
|
||||
# Release the first lock (`lock1`). The second lock(`lock2`) should be
|
||||
# automatically acquired by the `pump()` inside `get_success()`
|
||||
self.get_success(lock1.__aexit__(None, None, None))
|
||||
|
||||
# We should now have the lock
|
||||
self.successResultOf(d2)
|
||||
|
||||
def _pump_by(
|
||||
self,
|
||||
*,
|
||||
amount: Duration = Duration(seconds=0),
|
||||
by: Duration = Duration(seconds=0.1),
|
||||
) -> None:
|
||||
"""
|
||||
Like `self.pump()` but you can specify the time increment to advance with until
|
||||
you reach the time amount.
|
||||
|
||||
Unlike `self.pump()`, this doesn't multiply the time at all.
|
||||
|
||||
Args:
|
||||
amount: The amount of time to advance
|
||||
by: The time increment in seconds to advance time by until we reach the `amount`
|
||||
"""
|
||||
end_time_s = self.reactor.seconds() + amount.as_secs()
|
||||
|
||||
while self.reactor.seconds() < end_time_s:
|
||||
self.reactor.advance(by.as_secs())
|
||||
|
||||
@@ -66,7 +66,10 @@ from synapse.util.stringutils import random_string
|
||||
from tests import unittest
|
||||
from tests.http.server._base import make_request_with_cancellation_test
|
||||
from tests.storage.test_stream import PaginationTestCase
|
||||
from tests.test_utils.event_injection import create_event
|
||||
from tests.test_utils.event_injection import (
|
||||
create_event,
|
||||
inject_event,
|
||||
)
|
||||
from tests.unittest import override_config
|
||||
from tests.utils import default_config
|
||||
|
||||
@@ -2371,6 +2374,87 @@ class RoomMessageListTestCase(RoomBase):
|
||||
channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body
|
||||
)
|
||||
|
||||
def test_room_messages_paginate_through_rejected_events(
|
||||
self,
|
||||
) -> None:
|
||||
"""Test that pagination continues past a batch of rejected events.
|
||||
|
||||
Regression test for https://github.com/element-hq/synapse/security/advisories/GHSA-6qf2-7x63-mm6v
|
||||
|
||||
Synapse before 1.152.1 had a bug meaning that a batch full of only
|
||||
rejected events would cause `/messages` to not return any more
|
||||
pagination tokens, falsely signalling the end of backpagination.
|
||||
"""
|
||||
# Send an early message that should not be filtered.
|
||||
early_event_id = self.helper.send(self.room_id, "early message")["event_id"]
|
||||
|
||||
# Inject a batch of events and mark them as rejected in the database.
|
||||
# We create more events than a single pagination request would fetch,
|
||||
# so that one page of backward pagination request would only see rejected events.
|
||||
for _ in range(3):
|
||||
event = self.get_success(
|
||||
inject_event(
|
||||
self.hs,
|
||||
room_id=self.room_id,
|
||||
sender=self.user_id,
|
||||
type=EventTypes.Message,
|
||||
content={"body": "filtered event", "msgtype": "m.text"},
|
||||
)
|
||||
)
|
||||
self.get_success(
|
||||
self.hs.get_datastores().main.db_pool.runInteraction(
|
||||
"mark_rejected",
|
||||
self.hs.get_datastores().main.mark_event_rejected_txn,
|
||||
event.event_id,
|
||||
"testing",
|
||||
)
|
||||
)
|
||||
|
||||
# Send a message after all the rejected events.
|
||||
latest_event_id = self.helper.send(self.room_id, "latest message")["event_id"]
|
||||
|
||||
# Start backpaginating.
|
||||
channel = self.make_request(
|
||||
"GET", f"/rooms/{self.room_id}/messages?dir=b&limit=2"
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
|
||||
|
||||
events_in_page = [e["event_id"] for e in channel.json_body["chunk"]]
|
||||
end_token: str | None = channel.json_body["end"]
|
||||
|
||||
self.assertEqual(
|
||||
events_in_page,
|
||||
[latest_event_id],
|
||||
"The latest event should be included in the first page we see whilst backpaginating",
|
||||
)
|
||||
|
||||
event_ids_in_pages: list[list[str]] = [events_in_page]
|
||||
|
||||
# Bound the number of backpagination attempts to 2
|
||||
for _ in range(2):
|
||||
channel = self.make_request(
|
||||
"GET", f"/rooms/{self.room_id}/messages?from={end_token}&dir=b&limit=2"
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
|
||||
events_in_page = [e["event_id"] for e in channel.json_body["chunk"]]
|
||||
event_ids_in_pages.append(events_in_page)
|
||||
|
||||
if early_event_id in events_in_page:
|
||||
# We have found the event we were looking for
|
||||
return
|
||||
|
||||
self.assertIn(
|
||||
"end",
|
||||
channel.json_body,
|
||||
f"No `end` token received. Did not find {early_event_id} whilst backpaginating ({latest_event_id = }, {event_ids_in_pages = })",
|
||||
)
|
||||
# Use the end_token in the next iteration
|
||||
end_token = channel.json_body["end"]
|
||||
|
||||
self.fail(
|
||||
f"Exhausted backpagination attempts. Did not find {early_event_id} whilst backpaginating ({latest_event_id = }, {event_ids_in_pages = })"
|
||||
)
|
||||
|
||||
|
||||
class RoomMessageFilterTestCase(RoomBase):
|
||||
"""Tests /rooms/$room_id/messages REST events."""
|
||||
|
||||
Reference in New Issue
Block a user