This commit is contained in:
Erik Johnston
2026-03-04 09:50:23 +00:00
parent 853734e30f
commit aeb99c65cc
2 changed files with 117 additions and 1 deletions
+1 -1
View File
@@ -43,7 +43,7 @@ pyo3-log = "0.13.1"
pythonize = "0.27.0"
regex = "1.6.0"
sha2 = "0.10.8"
serde = { version = "1.0.144", features = ["derive"] }
serde = { version = "1.0.144", features = ["derive", "rc"] }
serde_json = "1.0.85"
ulid = "1.1.2"
icu_segmenter = "2.0.0"
+116
View File
@@ -29,6 +29,7 @@ use pyo3::{
wrap_pyfunction, Bound, IntoPyObject, PyAny, PyResult, Python,
};
use pythonize::{depythonize, pythonize};
use serde::{Deserialize, Serialize};
pub mod filter;
mod internal_metadata;
@@ -38,6 +39,7 @@ pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()>
let child_module = PyModule::new(py, "events")?;
child_module.add_class::<internal_metadata::EventInternalMetadata>()?;
child_module.add_class::<JsonObject>()?;
child_module.add_class::<Event>()?;
child_module.add_function(wrap_pyfunction!(filter::event_visible_to_server_py, m)?)?;
m.add_submodule(&child_module)?;
@@ -71,8 +73,10 @@ struct EventInner {
signatures: HashMap<Box<str>, HashMap<Box<str>, Box<str>>>,
}
#[derive(Serialize, Deserialize)]
#[pyclass(mapping)]
#[derive(Clone)]
#[serde(transparent)]
struct JsonObject {
object: Arc<HashMap<Box<str>, serde_json::Value>>,
}
@@ -101,3 +105,115 @@ impl JsonObject {
Ok(Some(pythonize(py, value)?))
}
}
#[derive(Serialize, Deserialize)]
#[pyclass]
struct EventCommonFields {
#[pyo3(get)]
content: JsonObject,
#[pyo3(get)]
depth: i64,
hashes: HashMap<String, String>,
#[pyo3(get)]
origin_server_ts: i64,
#[pyo3(get)]
sender: String,
#[serde(skip_serializing_if = "Option::is_none")]
#[pyo3(get)]
state_key: Option<String>,
#[serde(rename = "type")]
#[pyo3(get, name = "type")]
type_: String,
unsigned: JsonObject,
signatures: HashMap<Box<str>, HashMap<Box<str>, Box<str>>>,
#[serde(flatten)]
other_fields: HashMap<String, serde_json::Value>,
}
#[pyclass]
struct Event {
inner: EventFormatEnum,
}
#[pymethods]
impl Event {
#[new]
fn new<'a, 'py>(format: u8, event_dict: &'a Bound<'py, PyAny>) -> PyResult<Self> {
if format != 3 {
return Err(PyKeyError::new_err(format!(
"Unsupported event format version: {}",
format
)));
}
let event_format_v3: EventFormatV3Container = depythonize(event_dict)?;
Ok(Self {
inner: EventFormatEnum::V3(event_format_v3),
})
}
#[getter]
fn room_id(&self) -> Option<&str> {
match &self.inner {
EventFormatEnum::V3(format) => format.specific_fields.room_id.as_deref(),
// ...
}
}
fn get_pdu_json<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
match &self.inner {
EventFormatEnum::V3(format) => Ok(pythonize(py, format)?),
// ...
}
}
}
enum EventFormatEnum {
V3(EventFormatV3Container),
// ...
}
#[derive(Serialize, Deserialize)]
struct EventFormatV3 {
auth_events: Vec<String>,
prev_events: Vec<String>,
room_id: Option<String>,
}
#[derive(Serialize, Deserialize)]
struct EventFormatV3Container {
#[serde(flatten)]
specific_fields: EventFormatV3,
#[serde(flatten)]
common_fields: EventCommonFields,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_v3_roundtrip() {
let json = r#"{"auth_events":[],"prev_events":[],"type":"m.room.create","sender":"@anon-20260225_142731-20:localhost:8800","content":{"room_version":"10","creator":"@anon-20260225_142731-20:localhost:8800"},"depth":1,"room_id":"!qVoJSympOqdUQRUfiC:localhost:8800","state_key":"","origin_server_ts":1772029657149,"hashes":{"sha256":"RIDkn4CrExGMOfRZlHl//1weAro5QC/q2D76YcyAUqk"},"signatures":{"localhost:8800":{"ed25519:a_GMSl":"GU7WmvI2Kd5kLrXKrWpRbUfEiVKGgH0sxQNEpBMMvgF3QhHN25AubVMmIClht5r/c+Iihb1xsq1j5Sw+RGfiDg"}},"unsigned":{"age_ts":1772029657149}}"#;
let event_value: serde_json::Value = serde_json::from_str(json).unwrap();
let event: EventFormatV3Container = serde_json::from_str(json).unwrap();
let parsed_value = serde_json::to_value(&event).unwrap();
assert_eq!(event.common_fields.type_, "m.room.create".to_string());
assert_eq!(
event.specific_fields.room_id,
Some("!qVoJSympOqdUQRUfiC:localhost:8800".to_string())
);
assert_eq!(
event.common_fields.other_fields.get("auth_events").unwrap(),
&serde_json::Value::Array(vec![])
);
assert_eq!(event_value, parsed_value);
}
}