This commit is contained in:
Erik Johnston
2026-03-04 13:46:07 +00:00
parent aeb99c65cc
commit 0276c52ff4
+305 -40
View File
@@ -20,12 +20,15 @@
//! Classes for representing Events.
use std::{collections::HashMap, sync::Arc};
use std::{
collections::HashMap,
sync::{Arc, Mutex, RwLock},
};
use pyo3::{
exceptions::PyKeyError,
exceptions::{PyKeyError, PyTypeError},
pyclass, pymethods,
types::{PyAnyMethods, PyModule, PyModuleMethods},
types::{PyAnyMethods, PyIterator, PyMapping, PyMappingMethods, PyModule, PyModuleMethods},
wrap_pyfunction, Bound, IntoPyObject, PyAny, PyResult, Python,
};
use pythonize::{depythonize, pythonize};
@@ -40,6 +43,8 @@ pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()>
child_module.add_class::<internal_metadata::EventInternalMetadata>()?;
child_module.add_class::<JsonObject>()?;
child_module.add_class::<Event>()?;
child_module.add_class::<Signatures>()?;
child_module.add_class::<DomainSignatures>()?;
child_module.add_function(wrap_pyfunction!(filter::event_visible_to_server_py, m)?)?;
m.add_submodule(&child_module)?;
@@ -53,26 +58,6 @@ pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()>
Ok(())
}
struct Hashes {
sha256: Option<[u8; 32]>,
others: std::collections::HashMap<Box<str>, Box<str>>,
}
#[pyclass]
struct EventInner {
#[pyo3(get)]
content: JsonObject,
depth: i64,
hashes: Hashes,
origin_server_ts: i64,
sender: Box<str>,
state_key: Option<Box<str>>,
type_: Box<str>,
unsigned: JsonObject,
signatures: HashMap<Box<str>, HashMap<Box<str>, Box<str>>>,
}
#[derive(Serialize, Deserialize)]
#[pyclass(mapping)]
#[derive(Clone)]
@@ -104,29 +89,283 @@ impl JsonObject {
};
Ok(Some(pythonize(py, value)?))
}
fn __iter__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyIterator>> {
PyIterator::from_object(
&self
.object
.keys()
.map(|k| &**k)
.collect::<Vec<_>>()
.into_pyobject(py)?,
)
}
}
#[derive(Serialize, Deserialize, Clone)]
#[pyclass(mapping)]
#[serde(transparent)]
struct Signatures {
signatures: Arc<RwLock<HashMap<Box<str>, DomainSignatures>>>,
}
#[pymethods]
impl Signatures {
fn __getitem__(&self, key: &str) -> PyResult<Option<DomainSignatures>> {
let signatures = self.signatures.read().unwrap();
let Some(value) = signatures.get(key) else {
return Err(PyKeyError::new_err(key.to_string()));
};
Ok(Some(value.clone()))
}
fn __len__(&self) -> usize {
let signatures = self.signatures.read().unwrap();
signatures.len()
}
fn __iter__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyIterator>> {
let signatures = self.signatures.read().unwrap();
PyIterator::from_object(
&signatures
.keys()
.map(|k| &**k)
.collect::<Vec<_>>()
.into_pyobject(py)?,
)
}
fn __contains__(&self, key: &str) -> bool {
let signatures = self.signatures.read().unwrap();
signatures.contains_key(key)
}
fn __setitem__(&mut self, key: String, value: DomainSignatures) -> PyResult<()> {
let mut signatures = self.signatures.write().unwrap();
signatures.insert(key.into_boxed_str(), value);
Ok(())
}
fn __delitem__(&mut self, key: &str) -> PyResult<()> {
let mut signatures = self.signatures.write().unwrap();
if signatures.remove(key).is_none() {
return Err(PyKeyError::new_err(key.to_string()));
}
Ok(())
}
fn clear(&mut self) -> PyResult<()> {
let mut signatures = self.signatures.write().unwrap();
signatures.clear();
Ok(())
}
fn pop<'py>(
&mut self,
py: Python<'py>,
key: &str,
default: Option<Bound<'py, PyAny>>,
) -> PyResult<Bound<'py, PyAny>> {
let mut signatures = self.signatures.write().unwrap();
match signatures.remove(key) {
Some(value) => Ok(Some(value).into_pyobject(py)?),
None => Ok(default.unwrap_or_else(|| py.None().into_bound(py))),
}
}
fn keys<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let signatures = self.signatures.read().unwrap();
Ok(signatures
.keys()
.map(|k| &**k)
.collect::<Vec<_>>()
.into_pyobject(py)?)
}
fn values<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let signatures = self.signatures.read().unwrap();
Ok(signatures
.values()
.cloned()
.collect::<Vec<_>>()
.into_pyobject(py)?)
}
fn items<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let signatures = self.signatures.read().unwrap();
let items: Vec<_> = signatures.iter().map(|(k, v)| (&**k, v.clone())).collect();
Ok(items.into_pyobject(py)?)
}
#[pyo3(signature = (key, default=None))]
fn get<'py>(
&self,
py: Python<'py>,
key: &str,
default: Option<Bound<'py, PyAny>>,
) -> PyResult<Bound<'py, PyAny>> {
let signatures = self.signatures.read().unwrap();
match signatures.get(key) {
Some(value) => Ok(Some(value.clone()).into_pyobject(py)?),
None => Ok(default.unwrap_or_else(|| py.None().into_bound(py))),
}
}
fn update(&mut self, other: &Bound<'_, PyMapping>) -> PyResult<()> {
let mut signatures = self.signatures.write().unwrap();
for key in other.keys()? {
let key_str = key.extract::<String>()?;
let value: HashMap<String, String> = other.get_item(&key)?.extract()?;
let value = DomainSignatures {
signatures: Arc::new(RwLock::new(
value
.into_iter()
.map(|(k, v)| (k.into_boxed_str(), v.into_boxed_str()))
.collect(),
)),
};
signatures.insert(key_str.into_boxed_str(), value);
}
Ok(())
}
}
#[derive(Serialize, Deserialize, Clone)]
#[pyclass(mapping)]
#[serde(transparent)]
struct DomainSignatures {
signatures: Arc<RwLock<HashMap<Box<str>, Box<str>>>>,
}
#[pymethods]
impl DomainSignatures {
fn __getitem__<'py>(&self, py: Python<'py>, key: &str) -> PyResult<Bound<'py, PyAny>> {
let signatures = self.signatures.read().unwrap();
let Some(value) = signatures.get(key) else {
return Err(PyKeyError::new_err(key.to_string()));
};
Ok(Some(&**value).into_pyobject(py)?)
}
fn __len__(&self) -> usize {
let signatures = self.signatures.read().unwrap();
signatures.len()
}
fn __iter__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
// This is a bit inefficient, but it avoids having to implement a custom
// iterator type.
let signatures = self.signatures.read().unwrap();
Ok(signatures
.keys()
.map(|k| &**k)
.collect::<Vec<_>>()
.into_pyobject(py)?)
}
fn __contains__(&self, key: &str) -> bool {
let signatures = self.signatures.read().unwrap();
signatures.contains_key(key)
}
fn __setitem__(&mut self, key: &str, value: &str) {
let mut signatures = self.signatures.write().unwrap();
signatures.insert(Box::from(key), Box::from(value));
}
fn __delitem__(&mut self, key: &str) -> PyResult<()> {
let mut signatures = self.signatures.write().unwrap();
if signatures.remove(key).is_none() {
return Err(PyKeyError::new_err(key.to_string()));
}
Ok(())
}
fn keys<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let signatures = self.signatures.read().unwrap();
Ok(signatures
.keys()
.map(|k| &**k)
.collect::<Vec<_>>()
.into_pyobject(py)?)
}
fn values<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let signatures = self.signatures.read().unwrap();
Ok(signatures
.values()
.map(|v| &**v)
.collect::<Vec<_>>()
.into_pyobject(py)?)
}
fn items<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let signatures = self.signatures.read().unwrap();
let items: Vec<_> = signatures.iter().map(|(k, v)| (&**k, &**v)).collect();
Ok(items.into_pyobject(py)?)
}
#[pyo3(signature = (key, default=None))]
fn get<'py>(
&self,
py: Python<'py>,
key: &str,
default: Option<Bound<'py, PyAny>>,
) -> PyResult<Bound<'py, PyAny>> {
let signatures = self.signatures.read().unwrap();
match signatures.get(key) {
Some(value) => Ok(Some(&**value).into_pyobject(py)?),
None => Ok(default.unwrap_or_else(|| py.None().into_bound(py))),
}
}
fn clear(&mut self) -> PyResult<()> {
let mut signatures = self.signatures.write().unwrap();
signatures.clear();
Ok(())
}
fn pop<'py>(
&mut self,
py: Python<'py>,
key: &str,
default: Option<Bound<'py, PyAny>>,
) -> PyResult<Bound<'py, PyAny>> {
let mut signatures = self.signatures.write().unwrap();
match signatures.remove(key) {
Some(value) => Ok(Some(&*value).into_pyobject(py)?),
None => Ok(default.unwrap_or_else(|| py.None().into_bound(py))),
}
}
fn update(&mut self, other: &Bound<'_, PyMapping>) -> PyResult<()> {
let mut signatures: std::sync::RwLockWriteGuard<'_, HashMap<Box<str>, Box<str>>> =
self.signatures.write().unwrap();
for key in other.keys()? {
let key_str = key.extract::<String>()?;
let value: String = other.get_item(&key)?.extract()?;
signatures.insert(key_str.into_boxed_str(), value.into_boxed_str());
}
Ok(())
}
}
#[derive(Serialize, Deserialize)]
#[pyclass]
struct EventCommonFields {
#[pyo3(get)]
content: JsonObject,
#[pyo3(get)]
depth: i64,
hashes: HashMap<String, String>,
#[pyo3(get)]
origin_server_ts: i64,
#[pyo3(get)]
sender: String,
#[serde(skip_serializing_if = "Option::is_none")]
#[pyo3(get)]
state_key: Option<String>,
#[serde(rename = "type")]
#[pyo3(get, name = "type")]
type_: String,
room_id: Option<String>,
unsigned: JsonObject,
signatures: HashMap<Box<str>, HashMap<Box<str>, Box<str>>>,
signatures: Signatures,
#[serde(flatten)]
other_fields: HashMap<String, serde_json::Value>,
@@ -157,7 +396,15 @@ impl Event {
#[getter]
fn room_id(&self) -> Option<&str> {
match &self.inner {
EventFormatEnum::V3(format) => format.specific_fields.room_id.as_deref(),
EventFormatEnum::V3(format) => format.common_fields.room_id.as_deref(),
// ...
}
}
#[getter]
fn signatures<'py>(&self, py: Python<'py>) -> PyResult<Signatures> {
match &self.inner {
EventFormatEnum::V3(format) => Ok(format.common_fields.signatures.clone()),
// ...
}
}
@@ -179,7 +426,6 @@ enum EventFormatEnum {
struct EventFormatV3 {
auth_events: Vec<String>,
prev_events: Vec<String>,
room_id: Option<String>,
}
#[derive(Serialize, Deserialize)]
@@ -205,15 +451,34 @@ mod tests {
assert_eq!(event.common_fields.type_, "m.room.create".to_string());
assert_eq!(
event.specific_fields.room_id,
event.common_fields.room_id,
Some("!qVoJSympOqdUQRUfiC:localhost:8800".to_string())
);
assert_eq!(
event.common_fields.other_fields.get("auth_events").unwrap(),
&serde_json::Value::Array(vec![])
);
assert_eq!(event_value, parsed_value);
}
#[test]
fn test_signatures_serde() {
let json = r#"{"localhost:8800":{"ed25519:a_GMSl":"GU7WmvI2Kd5kLrXKrWpRbUfEiVKGgH0sxQNEpBMMvgF3QhHN25AubVMmIClht5r/c+Iihb1xsq1j5Sw+RGfiDg"}}"#;
let signatures: Signatures = serde_json::from_str(json).unwrap();
let signatures_inner = signatures.signatures.read().unwrap();
assert!(signatures_inner.contains_key("localhost:8800"));
let domain_signatures = signatures_inner.get("localhost:8800").unwrap();
let signatures_map = domain_signatures.signatures.read().unwrap();
assert!(signatures_map.contains_key("ed25519:a_GMSl"));
assert_eq!(
signatures_map.get("ed25519:a_GMSl").unwrap().as_ref(),
"GU7WmvI2Kd5kLrXKrWpRbUfEiVKGgH0sxQNEpBMMvgF3QhHN25AubVMmIClht5r/c+Iihb1xsq1j5Sw+RGfiDg"
);
// Now test serialization
let serialized_json = serde_json::to_string(&signatures).unwrap();
assert_eq!(
serde_json::from_str::<serde_json::Value>(&serialized_json).unwrap(),
serde_json::from_str::<serde_json::Value>(json).unwrap()
);
}
}