diff --git a/changelog.d/19706.misc b/changelog.d/19706.misc new file mode 100644 index 0000000000..205abd09d4 --- /dev/null +++ b/changelog.d/19706.misc @@ -0,0 +1 @@ +Port `Event.signatures` field to Rust. diff --git a/rust/Cargo.toml b/rust/Cargo.toml index e6b378a092..5bdd194707 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -43,7 +43,7 @@ pyo3-log = "0.13.1" pythonize = "0.27.0" regex = "1.6.0" sha2 = "0.10.8" -serde = { version = "1.0.144", features = ["derive"] } +serde = { version = "1.0.144", features = ["derive", "rc"] } serde_json = { version = "1.0.85", features = ["raw_value"] } ulid = "1.1.2" icu_segmenter = "2.0.0" @@ -58,10 +58,6 @@ tokio = { version = "1.44.2", features = ["rt", "rt-multi-thread"] } once_cell = "1.18.0" itertools = "0.14.0" -[features] -extension-module = ["pyo3/extension-module"] -default = ["extension-module"] - [build-dependencies] blake2 = "0.10.4" hex = "0.4.3" diff --git a/rust/src/events/mod.rs b/rust/src/events/mod.rs index 209efb917b..e42eb97739 100644 --- a/rust/src/events/mod.rs +++ b/rust/src/events/mod.rs @@ -27,11 +27,13 @@ use pyo3::{ pub mod filter; mod internal_metadata; +pub mod signatures; /// Called when registering modules with python. pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { let child_module = PyModule::new(py, "events")?; child_module.add_class::()?; + child_module.add_class::()?; child_module.add_function(wrap_pyfunction!(filter::event_visible_to_server_py, m)?)?; m.add_submodule(&child_module)?; diff --git a/rust/src/events/signatures.rs b/rust/src/events/signatures.rs new file mode 100644 index 0000000000..0f2acd5c9b --- /dev/null +++ b/rust/src/events/signatures.rs @@ -0,0 +1,348 @@ +/* + * This file is licensed under the Affero General Public License (AGPL) version 3. + * + * Copyright (C) 2026 Element Creations Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * See the GNU Affero General Public License for more details: + * . + * + */ + +//! Class for representing event signatures + +use std::{ + collections::HashMap, + sync::{Arc, RwLock}, +}; + +use pyo3::{ + exceptions::{PyKeyError, PyRuntimeError}, + pyclass, pymethods, + types::{PyAnyMethods, PyDict, PyMapping, PyMappingMethods}, + Bound, IntoPyObject, PyAny, PyResult, Python, +}; +use serde::{Deserialize, Serialize}; + +/// A class representing the signatures on an event. +#[pyclass(frozen, skip_from_py_object)] +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(transparent)] +pub struct Signatures { + inner: Arc>>>, +} + +#[pymethods] +impl Signatures { + #[new] + #[pyo3(signature = (signatures = None))] + fn py_new(signatures: Option>>) -> Self { + let mut signatures = signatures.unwrap_or_default(); + + // Prune any entries that have no signatures. + signatures.retain(|_, server_sigs| !server_sigs.is_empty()); + + Self { + inner: Arc::new(RwLock::new(signatures)), + } + } + + /// Check if the signatures contain a signature for the given server name. + fn __contains__(&self, key: Bound<'_, PyAny>) -> PyResult { + let Ok(key) = key.extract::<&str>() else { + return Ok(false); + }; + + let signatures = self + .inner + .read() + .map_err(|_| PyRuntimeError::new_err("Failed to acquire lock"))?; + Ok(signatures.contains_key(key)) + } + + /// Get the number of servers that have signatures. + fn __len__(&self) -> PyResult { + let signatures = self + .inner + .read() + .map_err(|_| PyRuntimeError::new_err("Failed to acquire lock"))?; + Ok(signatures.len()) + } + + /// Get the signature for the given server name and key ID, if it exists. + fn get_signature(&self, server_name: &str, key_id: &str) -> PyResult> { + let signatures = self + .inner + .read() + .map_err(|_| PyRuntimeError::new_err("Failed to acquire lock"))?; + + Ok(signatures + .get(server_name) + .and_then(|server_sigs| server_sigs.get(key_id).cloned())) + } + + /// Get the signatures for the given server name. + fn __getitem__(&self, key: Bound<'_, PyAny>) -> PyResult> { + let Some(server_name) = key.extract::<&str>().ok() else { + return Err(PyKeyError::new_err(key.to_string())); + }; + + let signatures = self + .inner + .read() + .map_err(|_| PyRuntimeError::new_err("Failed to acquire lock"))?; + + if let Some(server_sigs) = signatures.get(server_name) { + Ok(server_sigs.clone()) + } else { + Err(PyKeyError::new_err(server_name.to_string())) + } + } + + /// Add a signature for the given server name and key ID. + fn add_signature( + &self, + server_name: String, + key_id: String, + signature: String, + ) -> PyResult<()> { + let mut signatures = self + .inner + .write() + .map_err(|_| PyRuntimeError::new_err("Failed to acquire lock"))?; + + signatures + .entry(server_name) + .or_default() + .insert(key_id, signature); + + Ok(()) + } + + /// Update the signatures with the given signatures. + /// + /// Will overwrite all existing signatures for the server names provided. + fn update(&self, other: &Bound<'_, PyMapping>) -> PyResult<()> { + let mut signatures = self + .inner + .write() + .map_err(|_| PyRuntimeError::new_err("Failed to acquire lock"))?; + + for list_entry in other.items()? { + let (server_name, server_sigs) = list_entry.extract::<(String, Bound)>()?; + + let mut entry = HashMap::new(); + for list_entry in server_sigs.items()? { + let (key, value) = list_entry.extract::<(String, String)>()?; + entry.insert(key, value); + } + + // Only insert the entry if it has at least one signature. + if !entry.is_empty() { + signatures.insert(server_name, entry); + } else { + signatures.remove(&server_name); + } + } + + Ok(()) + } + + /// Return a copy of the signatures as a dictionary. + fn as_dict<'py>(&self, py: Python<'py>) -> PyResult> { + let signatures = self + .inner + .read() + .map_err(|_| PyRuntimeError::new_err("Failed to acquire lock"))?; + + (&*signatures).into_pyobject(py) + } + + fn __repr__(&self) -> PyResult { + let signatures = self + .inner + .read() + .map_err(|_| PyRuntimeError::new_err("Failed to acquire lock"))?; + + Ok(format!("Signatures({signatures:?})")) + } +} + +#[cfg(test)] +mod tests { + use pythonize::pythonize; + + use super::*; + + /// Helper that reads the inner map directly. + fn read_inner(sigs: &Signatures) -> HashMap> { + sigs.inner.read().expect("lock poisoned").clone() + } + + /// Helper to create a server signatures map from a list of (key_id, sig) + /// pairs. + fn make_server_sigs(data: &[(&str, &str)]) -> HashMap { + let mut map = HashMap::new(); + for (key_id, sig) in data { + map.insert((*key_id).to_owned(), (*sig).to_owned()); + } + map + } + + /// Helper to create a `Signatures` object from a list of (server_name, + /// key_id, sig) tuples. + fn create_signatures(data: &[(&str, &str, &str)]) -> Signatures { + let mut map: HashMap> = HashMap::new(); + for (server_name, key_id, sig) in data { + map.entry((*server_name).to_owned()) + .or_default() + .insert((*key_id).to_owned(), (*sig).to_owned()); + } + Signatures::py_new(Some(map)) + } + + #[test] + fn test_new_empty() { + let sigs = Signatures::py_new(None); + assert!(read_inner(&sigs).is_empty()); + assert_eq!(sigs.__len__().unwrap(), 0); + } + + #[test] + fn test_new_with_data() { + let sigs = create_signatures(&[("example.com", "ed25519:key1", "sig1")]); + assert_eq!(sigs.__len__().unwrap(), 1); + assert_eq!( + sigs.get_signature("example.com", "ed25519:key1").unwrap(), + Some("sig1".to_string()) + ); + } + + #[test] + fn test_new_prunes_servers_with_no_signatures() { + let mut data = HashMap::new(); + data.insert("empty.example.com".to_string(), HashMap::new()); + data.insert( + "example.com".to_string(), + make_server_sigs(&[("ed25519:key1", "sig1")]), + ); + + let sigs = Signatures::py_new(Some(data)); + + let inner = read_inner(&sigs); + assert_eq!(inner.len(), 1); + assert!(inner.contains_key("example.com")); + assert!(!inner.contains_key("empty.example.com")); + } + + #[test] + fn test_add_signature() { + let sigs = Signatures::py_new(None); + sigs.add_signature( + "example.com".to_string(), + "ed25519:key1".to_string(), + "sig1".to_string(), + ) + .unwrap(); + + let inner = read_inner(&sigs); + assert_eq!(inner.len(), 1); + assert_eq!( + inner.get("example.com").and_then(|m| m.get("ed25519:key1")), + Some(&"sig1".to_string()) + ); + } + + #[test] + fn test_add_signature_to_existing_server() { + let sigs = create_signatures(&[("example.com", "ed25519:key1", "sig1")]); + sigs.add_signature( + "example.com".to_string(), + "ed25519:key2".to_string(), + "sig2".to_string(), + ) + .unwrap(); + + let inner = read_inner(&sigs); + assert_eq!(inner.len(), 1); + assert_eq!( + inner.get("example.com").and_then(|m| m.get("ed25519:key1")), + Some(&"sig1".to_string()) + ); + assert_eq!( + inner.get("example.com").and_then(|m| m.get("ed25519:key2")), + Some(&"sig2".to_string()) + ); + } + + #[test] + fn test_update_signatures_clobbers_existing() { + let sigs = create_signatures(&[("example.com", "ed25519:key1", "sig1")]); + + // Create a new signatures map with a different signature for the same + // server. + let mut other = HashMap::new(); + other.insert( + "example.com".to_string(), + make_server_sigs(&[("ed25519:key2", "sig2")]), + ); + + // Update the signatures with the new map. + Python::initialize(); + Python::attach(|py| { + let value = pythonize(py, &other).unwrap(); + let value = value.cast::().unwrap(); + + sigs.update(value).unwrap(); + }); + + // Check that the old signature has been replaced with the new one. + let inner = read_inner(&sigs); + assert_eq!(inner.len(), 1); + assert_eq!(inner["example.com"].len(), 1); + assert_eq!(inner["example.com"]["ed25519:key2"], "sig2"); + } + + #[test] + fn test_serialize() { + let mut data = HashMap::new(); + data.insert( + "example.com".to_string(), + make_server_sigs(&[("ed25519:key1", "sig1")]), + ); + let sigs = Signatures::py_new(Some(data)); + + let json = serde_json::to_string(&sigs).unwrap(); + assert_eq!(json, r#"{"example.com":{"ed25519:key1":"sig1"}}"#); + } + + #[test] + fn test_serialize_empty() { + let sigs = Signatures::py_new(None); + let json = serde_json::to_string(&sigs).unwrap(); + assert_eq!(json, "{}"); + } + + #[test] + fn test_deserialize() { + let json = r#"{"example.com":{"ed25519:key1":"sig1"}}"#; + let sigs: Signatures = serde_json::from_str(json).unwrap(); + + let inner = read_inner(&sigs); + assert_eq!(inner.len(), 1); + assert_eq!( + inner.get("example.com").and_then(|m| m.get("ed25519:key1")), + Some(&"sig1".to_string()) + ); + } + + #[test] + fn test_deserialize_empty() { + let sigs: Signatures = serde_json::from_str("{}").unwrap(); + assert!(read_inner(&sigs).is_empty()); + } +} diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py index d789c06a9c..823b6288e8 100644 --- a/synapse/crypto/event_signing.py +++ b/synapse/crypto/event_signing.py @@ -236,9 +236,7 @@ def event_needs_resigning( if sender.domain != server_name: return False want_key_id = verify_key.alg + ":" + verify_key.version - signed_with_current_key_id = ev.signatures.get(server_name, {}).get( - want_key_id, None - ) + signed_with_current_key_id = ev.signatures.get_signature(server_name, want_key_id) if signed_with_current_key_id: return False diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 0d4d5e0e17..36736b4559 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -120,8 +120,18 @@ class VerifyJsonRequest: ) -> "VerifyJsonRequest": """Create a VerifyJsonRequest to verify all signatures on an event object for the given server. + + Raises immediately if the event doesn't have any signatures from the + given server. """ - key_ids = list(event.signatures.get(server_name, [])) + if server_name not in event.signatures: + raise SynapseError( + 400, + f"Not signed by {server_name}", + Codes.UNAUTHORIZED, + ) + + key_ids = list(event.signatures[server_name]) return VerifyJsonRequest( server_name, # We defer creating the redacted json object, as it uses a lot more diff --git a/synapse/event_auth.py b/synapse/event_auth.py index ca528ae235..fd35da8ba0 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -128,7 +128,7 @@ def validate_event_for_room_version(event: "EventBase") -> None: ) # Check the sender's domain has signed the event - if not event.signatures.get(sender_domain): + if sender_domain not in event.signatures: # We allow invites via 3pid to have a sender from a different # HS, as the sender must match the sender of the original # 3pid invite. This is checked further down with the @@ -141,7 +141,7 @@ def validate_event_for_room_version(event: "EventBase") -> None: event_id_domain = get_domain_from_id(event.event_id) # Check the origin domain has signed the event - if not event.signatures.get(event_id_domain): + if event_id_domain not in event.signatures: raise AuthError(403, "Event not signed by sending server") is_invite_via_allow_rule = ( @@ -154,7 +154,7 @@ def validate_event_for_room_version(event: "EventBase") -> None: authoriser_domain = get_domain_from_id( event.content[EventContentFields.AUTHORISING_USER] ) - if not event.signatures.get(authoriser_domain): + if authoriser_domain not in event.signatures: raise AuthError(403, "Event not signed by authorising server") diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index ac9b31cfcf..fc0f6aadbd 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -44,8 +44,12 @@ from synapse.api.constants import ( StickyEvent, ) from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions -from synapse.synapse_rust.events import EventInternalMetadata -from synapse.types import JsonDict, StateKey, StrCollection +from synapse.synapse_rust.events import EventInternalMetadata, Signatures +from synapse.types import ( + JsonDict, + StateKey, + StrCollection, +) from synapse.util.caches import intern_dict from synapse.util.duration import Duration from synapse.util.frozenutils import freeze @@ -203,7 +207,7 @@ class EventBase(metaclass=abc.ABCMeta): assert room_version.event_format == self.format_version self.room_version = room_version - self.signatures = signatures + self.signatures = Signatures(signatures) self.unsigned = unsigned self.rejected_reason = rejected_reason @@ -255,7 +259,9 @@ class EventBase(metaclass=abc.ABCMeta): def get_dict(self) -> JsonDict: d = dict(self._dict) - d.update({"signatures": self.signatures, "unsigned": dict(self.unsigned)}) + d.update( + {"signatures": self.signatures.as_dict(), "unsigned": dict(self.unsigned)} + ) return d diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 4032c7eca9..ed3bce69ab 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -2092,7 +2092,7 @@ class EventCreationHandler: event.unsigned.pop("room_state", None) # TODO: Make sure the signatures actually are correct. - event.signatures.update(returned_invite.signatures) + event.signatures.update(returned_invite.signatures.as_dict()) if event.content["membership"] == Membership.KNOCK: maybe_upsert_event_field( diff --git a/synapse/handlers/room_policy.py b/synapse/handlers/room_policy.py index 01943e1991..e46e6dc2ef 100644 --- a/synapse/handlers/room_policy.py +++ b/synapse/handlers/room_policy.py @@ -181,9 +181,10 @@ class RoomPolicyHandler: async def _verify_policy_server_signature( self, event: EventBase, policy_server: str, public_key: str ) -> bool: - # check the event is signed with this (via, public_key). - verify_json_req = VerifyJsonRequest.from_event(policy_server, event, 0) try: + # check the event is signed with this (via, public_key). + verify_json_req = VerifyJsonRequest.from_event(policy_server, event, 0) + key_bytes = decode_base64(public_key) verify_key = decode_verify_key_bytes(POLICY_SERVER_KEY_ID, key_bytes) # We would normally use KeyRing.verify_event_for_server but we can't here as we don't @@ -260,9 +261,7 @@ class RoomPolicyHandler: # servers need to manually fetch signatures for. This is the code that allows # those events to continue working (because they're legally sent, even if missing # the policy server signature). - event.signatures.setdefault(policy_server.server_name, {}).update( - signature.get(policy_server.server_name, {}) - ) + event.signatures.update(signature) except HttpResponseException as ex: # re-wrap HTTP errors as `SynapseError` so they can be proxied to clients directly raise ex.to_synapse_error() from ex diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index d2623f0760..c0d218398d 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -2824,8 +2824,8 @@ class EventsBackgroundUpdatesStore( # with the provided old key. if old_verify_key is not None: old_key_id = f"{old_verify_key.alg}:{old_verify_key.version}" - server_sigs = event.signatures.get(self.hs.hostname, {}) - if old_key_id not in server_sigs: + old_sig = event.signatures.get_signature(self.hs.hostname, old_key_id) + if old_sig is None: # Event wasn't signed with this key ID at all, skip. continue diff --git a/synapse/synapse_rust/events.pyi b/synapse/synapse_rust/events.pyi index fe0ca04420..40cf3f59f2 100644 --- a/synapse/synapse_rust/events.pyi +++ b/synapse/synapse_rust/events.pyi @@ -10,7 +10,7 @@ # See the GNU Affero General Public License for more details: # . -from typing import Mapping +from typing import Any, Mapping from synapse.types import JsonDict @@ -154,3 +154,32 @@ def event_visible_to_server( Returns: Whether the server is allowed to see the unredacted event. """ + +class Signatures: + """A class representing the signatures on an event.""" + + def __init__(self, signatures: Mapping[str, Mapping[str, str]] | None = None): ... + def get_signature(self, server_name: str, key_id: str) -> str | None: ... + """Get the signature for the given server name and key ID, if it exists.""" + + def __getitem__(self, server_name: str) -> Mapping[str, str]: ... + """Get the signatures for the given server name. Raises KeyError if there + are no signatures for that server.""" + + def __contains__(self, server_name: Any) -> bool: ... + """Check if there are signatures for the given server name.""" + + def __len__(self) -> int: ... + """Return the number of servers that have signatures.""" + + def add_signature(self, server_name: str, key_id: str, signature: str) -> None: ... + """Add a signature for the given server name and key ID.""" + + def update(self, signatures: Mapping[str, Mapping[str, str]]) -> None: ... + """Update the signatures with the given signatures. + + Will overwrite all existing signatures for the server names provided. + """ + + def as_dict(self) -> dict[str, dict[str, str]]: ... + """Return a copy of the signatures as a dictionary.""" diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index dde1785854..20ffed68f4 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -368,7 +368,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): ) # the auth code requires that a signature exists, but doesn't check that # signature... go figure. - join_event.signatures[other_server] = {"x": "y"} + join_event.signatures.update({other_server: {"x": "y"}}) self.get_success( self.hs.get_federation_event_handler().on_send_membership_event(