From 0276c52ff4b1697607d3d8974e33bc3ddd8ee067 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 4 Mar 2026 13:46:07 +0000 Subject: [PATCH] WIP --- rust/src/events/mod.rs | 345 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 305 insertions(+), 40 deletions(-) diff --git a/rust/src/events/mod.rs b/rust/src/events/mod.rs index fdc9784201..06665bdd06 100644 --- a/rust/src/events/mod.rs +++ b/rust/src/events/mod.rs @@ -20,12 +20,15 @@ //! Classes for representing Events. -use std::{collections::HashMap, sync::Arc}; +use std::{ + collections::HashMap, + sync::{Arc, Mutex, RwLock}, +}; use pyo3::{ - exceptions::PyKeyError, + exceptions::{PyKeyError, PyTypeError}, pyclass, pymethods, - types::{PyAnyMethods, PyModule, PyModuleMethods}, + types::{PyAnyMethods, PyIterator, PyMapping, PyMappingMethods, PyModule, PyModuleMethods}, wrap_pyfunction, Bound, IntoPyObject, PyAny, PyResult, Python, }; use pythonize::{depythonize, pythonize}; @@ -40,6 +43,8 @@ pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> child_module.add_class::()?; child_module.add_class::()?; child_module.add_class::()?; + child_module.add_class::()?; + child_module.add_class::()?; child_module.add_function(wrap_pyfunction!(filter::event_visible_to_server_py, m)?)?; m.add_submodule(&child_module)?; @@ -53,26 +58,6 @@ pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> Ok(()) } -struct Hashes { - sha256: Option<[u8; 32]>, - others: std::collections::HashMap, Box>, -} - -#[pyclass] -struct EventInner { - #[pyo3(get)] - content: JsonObject, - depth: i64, - hashes: Hashes, - origin_server_ts: i64, - sender: Box, - state_key: Option>, - type_: Box, - - unsigned: JsonObject, - signatures: HashMap, HashMap, Box>>, -} - #[derive(Serialize, Deserialize)] #[pyclass(mapping)] #[derive(Clone)] @@ -104,29 +89,283 @@ impl JsonObject { }; Ok(Some(pythonize(py, value)?)) } + + fn __iter__<'py>(&self, py: Python<'py>) -> PyResult> { + PyIterator::from_object( + &self + .object + .keys() + .map(|k| &**k) + .collect::>() + .into_pyobject(py)?, + ) + } +} + +#[derive(Serialize, Deserialize, Clone)] +#[pyclass(mapping)] +#[serde(transparent)] +struct Signatures { + signatures: Arc, DomainSignatures>>>, +} + +#[pymethods] +impl Signatures { + fn __getitem__(&self, key: &str) -> PyResult> { + let signatures = self.signatures.read().unwrap(); + let Some(value) = signatures.get(key) else { + return Err(PyKeyError::new_err(key.to_string())); + }; + Ok(Some(value.clone())) + } + + fn __len__(&self) -> usize { + let signatures = self.signatures.read().unwrap(); + signatures.len() + } + + fn __iter__<'py>(&self, py: Python<'py>) -> PyResult> { + let signatures = self.signatures.read().unwrap(); + PyIterator::from_object( + &signatures + .keys() + .map(|k| &**k) + .collect::>() + .into_pyobject(py)?, + ) + } + + fn __contains__(&self, key: &str) -> bool { + let signatures = self.signatures.read().unwrap(); + signatures.contains_key(key) + } + + fn __setitem__(&mut self, key: String, value: DomainSignatures) -> PyResult<()> { + let mut signatures = self.signatures.write().unwrap(); + signatures.insert(key.into_boxed_str(), value); + Ok(()) + } + + fn __delitem__(&mut self, key: &str) -> PyResult<()> { + let mut signatures = self.signatures.write().unwrap(); + if signatures.remove(key).is_none() { + return Err(PyKeyError::new_err(key.to_string())); + } + Ok(()) + } + + fn clear(&mut self) -> PyResult<()> { + let mut signatures = self.signatures.write().unwrap(); + signatures.clear(); + Ok(()) + } + + fn pop<'py>( + &mut self, + py: Python<'py>, + key: &str, + default: Option>, + ) -> PyResult> { + let mut signatures = self.signatures.write().unwrap(); + match signatures.remove(key) { + Some(value) => Ok(Some(value).into_pyobject(py)?), + None => Ok(default.unwrap_or_else(|| py.None().into_bound(py))), + } + } + + fn keys<'py>(&self, py: Python<'py>) -> PyResult> { + let signatures = self.signatures.read().unwrap(); + Ok(signatures + .keys() + .map(|k| &**k) + .collect::>() + .into_pyobject(py)?) + } + + fn values<'py>(&self, py: Python<'py>) -> PyResult> { + let signatures = self.signatures.read().unwrap(); + Ok(signatures + .values() + .cloned() + .collect::>() + .into_pyobject(py)?) + } + + fn items<'py>(&self, py: Python<'py>) -> PyResult> { + let signatures = self.signatures.read().unwrap(); + let items: Vec<_> = signatures.iter().map(|(k, v)| (&**k, v.clone())).collect(); + Ok(items.into_pyobject(py)?) + } + + #[pyo3(signature = (key, default=None))] + fn get<'py>( + &self, + py: Python<'py>, + key: &str, + default: Option>, + ) -> PyResult> { + let signatures = self.signatures.read().unwrap(); + match signatures.get(key) { + Some(value) => Ok(Some(value.clone()).into_pyobject(py)?), + None => Ok(default.unwrap_or_else(|| py.None().into_bound(py))), + } + } + + fn update(&mut self, other: &Bound<'_, PyMapping>) -> PyResult<()> { + let mut signatures = self.signatures.write().unwrap(); + for key in other.keys()? { + let key_str = key.extract::()?; + let value: HashMap = other.get_item(&key)?.extract()?; + let value = DomainSignatures { + signatures: Arc::new(RwLock::new( + value + .into_iter() + .map(|(k, v)| (k.into_boxed_str(), v.into_boxed_str())) + .collect(), + )), + }; + signatures.insert(key_str.into_boxed_str(), value); + } + Ok(()) + } +} + +#[derive(Serialize, Deserialize, Clone)] +#[pyclass(mapping)] +#[serde(transparent)] +struct DomainSignatures { + signatures: Arc, Box>>>, +} + +#[pymethods] +impl DomainSignatures { + fn __getitem__<'py>(&self, py: Python<'py>, key: &str) -> PyResult> { + let signatures = self.signatures.read().unwrap(); + let Some(value) = signatures.get(key) else { + return Err(PyKeyError::new_err(key.to_string())); + }; + Ok(Some(&**value).into_pyobject(py)?) + } + + fn __len__(&self) -> usize { + let signatures = self.signatures.read().unwrap(); + signatures.len() + } + + fn __iter__<'py>(&self, py: Python<'py>) -> PyResult> { + // This is a bit inefficient, but it avoids having to implement a custom + // iterator type. + let signatures = self.signatures.read().unwrap(); + + Ok(signatures + .keys() + .map(|k| &**k) + .collect::>() + .into_pyobject(py)?) + } + + fn __contains__(&self, key: &str) -> bool { + let signatures = self.signatures.read().unwrap(); + signatures.contains_key(key) + } + + fn __setitem__(&mut self, key: &str, value: &str) { + let mut signatures = self.signatures.write().unwrap(); + signatures.insert(Box::from(key), Box::from(value)); + } + + fn __delitem__(&mut self, key: &str) -> PyResult<()> { + let mut signatures = self.signatures.write().unwrap(); + if signatures.remove(key).is_none() { + return Err(PyKeyError::new_err(key.to_string())); + } + Ok(()) + } + + fn keys<'py>(&self, py: Python<'py>) -> PyResult> { + let signatures = self.signatures.read().unwrap(); + Ok(signatures + .keys() + .map(|k| &**k) + .collect::>() + .into_pyobject(py)?) + } + + fn values<'py>(&self, py: Python<'py>) -> PyResult> { + let signatures = self.signatures.read().unwrap(); + Ok(signatures + .values() + .map(|v| &**v) + .collect::>() + .into_pyobject(py)?) + } + + fn items<'py>(&self, py: Python<'py>) -> PyResult> { + let signatures = self.signatures.read().unwrap(); + let items: Vec<_> = signatures.iter().map(|(k, v)| (&**k, &**v)).collect(); + Ok(items.into_pyobject(py)?) + } + + #[pyo3(signature = (key, default=None))] + fn get<'py>( + &self, + py: Python<'py>, + key: &str, + default: Option>, + ) -> PyResult> { + let signatures = self.signatures.read().unwrap(); + match signatures.get(key) { + Some(value) => Ok(Some(&**value).into_pyobject(py)?), + None => Ok(default.unwrap_or_else(|| py.None().into_bound(py))), + } + } + + fn clear(&mut self) -> PyResult<()> { + let mut signatures = self.signatures.write().unwrap(); + signatures.clear(); + Ok(()) + } + + fn pop<'py>( + &mut self, + py: Python<'py>, + key: &str, + default: Option>, + ) -> PyResult> { + let mut signatures = self.signatures.write().unwrap(); + match signatures.remove(key) { + Some(value) => Ok(Some(&*value).into_pyobject(py)?), + None => Ok(default.unwrap_or_else(|| py.None().into_bound(py))), + } + } + + fn update(&mut self, other: &Bound<'_, PyMapping>) -> PyResult<()> { + let mut signatures: std::sync::RwLockWriteGuard<'_, HashMap, Box>> = + self.signatures.write().unwrap(); + for key in other.keys()? { + let key_str = key.extract::()?; + let value: String = other.get_item(&key)?.extract()?; + signatures.insert(key_str.into_boxed_str(), value.into_boxed_str()); + } + Ok(()) + } } #[derive(Serialize, Deserialize)] -#[pyclass] struct EventCommonFields { - #[pyo3(get)] content: JsonObject, - #[pyo3(get)] - depth: i64, hashes: HashMap, - #[pyo3(get)] origin_server_ts: i64, - #[pyo3(get)] sender: String, #[serde(skip_serializing_if = "Option::is_none")] - #[pyo3(get)] state_key: Option, #[serde(rename = "type")] - #[pyo3(get, name = "type")] type_: String, + room_id: Option, + unsigned: JsonObject, - signatures: HashMap, HashMap, Box>>, + signatures: Signatures, #[serde(flatten)] other_fields: HashMap, @@ -157,7 +396,15 @@ impl Event { #[getter] fn room_id(&self) -> Option<&str> { match &self.inner { - EventFormatEnum::V3(format) => format.specific_fields.room_id.as_deref(), + EventFormatEnum::V3(format) => format.common_fields.room_id.as_deref(), + // ... + } + } + + #[getter] + fn signatures<'py>(&self, py: Python<'py>) -> PyResult { + match &self.inner { + EventFormatEnum::V3(format) => Ok(format.common_fields.signatures.clone()), // ... } } @@ -179,7 +426,6 @@ enum EventFormatEnum { struct EventFormatV3 { auth_events: Vec, prev_events: Vec, - room_id: Option, } #[derive(Serialize, Deserialize)] @@ -205,15 +451,34 @@ mod tests { assert_eq!(event.common_fields.type_, "m.room.create".to_string()); assert_eq!( - event.specific_fields.room_id, + event.common_fields.room_id, Some("!qVoJSympOqdUQRUfiC:localhost:8800".to_string()) ); - assert_eq!( - event.common_fields.other_fields.get("auth_events").unwrap(), - &serde_json::Value::Array(vec![]) - ); - assert_eq!(event_value, parsed_value); } + + #[test] + fn test_signatures_serde() { + let json = r#"{"localhost:8800":{"ed25519:a_GMSl":"GU7WmvI2Kd5kLrXKrWpRbUfEiVKGgH0sxQNEpBMMvgF3QhHN25AubVMmIClht5r/c+Iihb1xsq1j5Sw+RGfiDg"}}"#; + let signatures: Signatures = serde_json::from_str(json).unwrap(); + + let signatures_inner = signatures.signatures.read().unwrap(); + assert!(signatures_inner.contains_key("localhost:8800")); + let domain_signatures = signatures_inner.get("localhost:8800").unwrap(); + let signatures_map = domain_signatures.signatures.read().unwrap(); + assert!(signatures_map.contains_key("ed25519:a_GMSl")); + assert_eq!( + signatures_map.get("ed25519:a_GMSl").unwrap().as_ref(), + "GU7WmvI2Kd5kLrXKrWpRbUfEiVKGgH0sxQNEpBMMvgF3QhHN25AubVMmIClht5r/c+Iihb1xsq1j5Sw+RGfiDg" + ); + + // Now test serialization + let serialized_json = serde_json::to_string(&signatures).unwrap(); + + assert_eq!( + serde_json::from_str::(&serialized_json).unwrap(), + serde_json::from_str::(json).unwrap() + ); + } }