From ee733ae2f6b82da48a2e63e5bb4156698ea8b008 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 27 May 2026 15:43:05 +0100 Subject: [PATCH] Correctly handle room_id/state_key being null These should either be absent or strings. We cannot use `Option` as that does not differentiate between `null` and an absent field. --- rust/src/events/formats/mod.rs | 15 +- rust/src/events/formats/v4.rs | 19 ++- rust/src/events/formats/vmsc4242.rs | 13 +- rust/src/events/mod.rs | 14 +- rust/src/json.rs | 210 ++++++++++++++++++++++++++++ rust/src/lib.rs | 1 + 6 files changed, 251 insertions(+), 21 deletions(-) create mode 100644 rust/src/json.rs diff --git a/rust/src/events/formats/mod.rs b/rust/src/events/formats/mod.rs index 9215bf51c5..ee37d9a3a6 100644 --- a/rust/src/events/formats/mod.rs +++ b/rust/src/events/formats/mod.rs @@ -71,7 +71,10 @@ use std::{collections::HashMap, sync::Arc}; use anyhow::Error; use serde::{Deserialize, Serialize}; -use crate::events::{json_object::JsonObject, signatures::Signatures, unsigned::Unsigned}; +use crate::{ + events::{json_object::JsonObject, signatures::Signatures, unsigned::Unsigned}, + json::AllowMissing, +}; mod v1; mod v2v3; @@ -179,8 +182,12 @@ pub struct EventCommonFields { pub hashes: HashMap, Box>, pub origin_server_ts: i64, pub sender: Box, - #[serde(skip_serializing_if = "Option::is_none")] - pub state_key: Option>, + #[serde( + default, + with = "crate::json::allow_missing", + skip_serializing_if = "AllowMissing::is_absent" + )] + pub state_key: AllowMissing>, /// The `type` field of the event (we use `type_` in Rust to avoid the /// reserved keyword). @@ -195,7 +202,7 @@ impl EventCommonFields { /// Helper method to check if the event is a state event and return the /// tuple of `(type, state_key)` if so. fn type_state_key_tuple(&self) -> Option<(&str, &str)> { - if let Some(state_key) = &self.state_key { + if let AllowMissing::Some(state_key) = &self.state_key { Some((&self.type_, state_key)) } else { None diff --git a/rust/src/events/formats/v4.rs b/rust/src/events/formats/v4.rs index 94c367a01a..f1db6a36b8 100644 --- a/rust/src/events/formats/v4.rs +++ b/rust/src/events/formats/v4.rs @@ -32,20 +32,27 @@ use std::borrow::Cow; use anyhow::{bail, ensure, Error}; use serde::{Deserialize, Serialize}; -use crate::events::{constants::event_type::M_ROOM_CREATE, formats::EventCommonFields}; +use crate::{ + events::{constants::event_type::M_ROOM_CREATE, formats::EventCommonFields}, + json::AllowMissing, +}; /// Version-specific fields for room version 11. #[derive(Serialize, Deserialize)] pub struct EventFormatV4 { - #[serde(skip_serializing_if = "Option::is_none")] - pub room_id: Option>, + #[serde( + default, + with = "crate::json::allow_missing", + skip_serializing_if = "AllowMissing::is_absent" + )] + pub room_id: AllowMissing>, pub auth_events: Vec, pub prev_events: Vec, } impl EventFormatV4 { pub fn validate(&self, common_fields: &EventCommonFields) -> Result<(), Error> { - validate_optional_room_id(self.room_id.as_deref(), common_fields)?; + validate_optional_room_id(self.room_id.as_deref_opt(), common_fields)?; // Ensure that we don't have an event_id set. if common_fields.other_fields.contains_key("event_id") { @@ -60,7 +67,7 @@ impl EventFormatV4 { event_id: &str, common_fields: &EventCommonFields, ) -> Result, Error> { - get_room_id_for_optional_room_id(self.room_id.as_deref(), event_id, common_fields) + get_room_id_for_optional_room_id(self.room_id.as_deref_opt(), event_id, common_fields) } pub fn auth_event_ids(&self, common_fields: &EventCommonFields) -> Result, Error> { @@ -76,7 +83,7 @@ impl EventFormatV4 { // replacing the leading '!' with '$'. let room_id = self .room_id - .as_deref() + .as_deref_opt() .ok_or_else(|| anyhow::anyhow!("non-create event has no room_id"))?; let mut create_event_id = String::with_capacity(room_id.len()); diff --git a/rust/src/events/formats/vmsc4242.rs b/rust/src/events/formats/vmsc4242.rs index 7aa7c76f8d..7b54b06a2e 100644 --- a/rust/src/events/formats/vmsc4242.rs +++ b/rust/src/events/formats/vmsc4242.rs @@ -37,19 +37,24 @@ use crate::events::formats::v4::get_room_id_for_optional_room_id; use crate::events::formats::v4::validate_optional_room_id; use crate::events::formats::EventCommonFields; use crate::events::Event; +use crate::json::AllowMissing; /// Version-specific fields for the MSC4242 event format. #[derive(Serialize, Deserialize)] pub struct EventFormatVMSC4242 { pub prev_state_events: Vec, pub prev_events: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub room_id: Option>, + #[serde( + default, + with = "crate::json::allow_missing", + skip_serializing_if = "AllowMissing::is_absent" + )] + pub room_id: AllowMissing>, } impl EventFormatVMSC4242 { pub fn validate(&self, common_fields: &EventCommonFields) -> Result<(), Error> { - validate_optional_room_id(self.room_id.as_deref(), common_fields)?; + validate_optional_room_id(self.room_id.as_deref_opt(), common_fields)?; // Ensure that we don't have any `auth_events` or `event_id` fields // set. @@ -68,7 +73,7 @@ impl EventFormatVMSC4242 { event_id: &str, common_fields: &EventCommonFields, ) -> Result, Error> { - get_room_id_for_optional_room_id(self.room_id.as_deref(), event_id, common_fields) + get_room_id_for_optional_room_id(self.room_id.as_deref_opt(), event_id, common_fields) } pub fn auth_event_ids(&self, event: &Event) -> PyResult> { diff --git a/rust/src/events/mod.rs b/rust/src/events/mod.rs index d567c7424d..0a1cc8314b 100644 --- a/rust/src/events/mod.rs +++ b/rust/src/events/mod.rs @@ -313,7 +313,7 @@ impl Event { } fn get_state_key(&self) -> Option<&str> { - self.parsed_event.common_fields.state_key.as_deref() + self.parsed_event.common_fields.state_key.as_deref_opt() } #[getter] @@ -482,7 +482,7 @@ impl Event { // We can't call this `state_key` because that would generate a // `get_state_key` method which already exists. fn state_key_attr(&self) -> PyResult<&str> { - let Some(state_key) = self.parsed_event.common_fields.state_key.as_deref() else { + let Some(state_key) = self.parsed_event.common_fields.state_key.as_deref_opt() else { return Err(PyAttributeError::new_err("state_key")); }; Ok(state_key) @@ -614,7 +614,7 @@ mod tests { let parsed_value = serde_json::to_value(&event).unwrap(); assert_eq!(&*event.common_fields.type_, "m.room.message"); - assert_eq!(event.common_fields.state_key, None); + assert!(event.common_fields.state_key.is_absent()); assert_eq!(&*event.specific_fields.room_id, "!room:localhost"); assert_eq!(&*event.specific_fields.event_id, "$event1:localhost"); @@ -639,7 +639,7 @@ mod tests { assert_eq!(&*event.common_fields.type_, "m.room.message"); assert_eq!( - event.specific_fields.room_id.as_deref(), + event.specific_fields.room_id.as_deref_opt(), Some("!room:localhost") ); assert_eq!( @@ -663,7 +663,7 @@ mod tests { let event: FormattedEvent = serde_json::from_str(json).unwrap(); let parsed_value = serde_json::to_value(&event).unwrap(); - assert!(event.specific_fields.room_id.is_none()); + assert!(event.specific_fields.room_id.is_absent()); assert_eq!(&*event.common_fields.type_, M_ROOM_CREATE); // Create events have no implicit auth events. @@ -732,11 +732,11 @@ mod tests { vec!["$pstate1".to_string(), "$pstate2".to_string()] ); assert_eq!( - event.specific_fields.room_id.as_deref(), + event.specific_fields.room_id.as_deref_opt(), Some("!room:localhost") ); assert_eq!( - event.common_fields.state_key.as_deref(), + event.common_fields.state_key.as_deref_opt(), Some("@user:localhost") ); diff --git a/rust/src/json.rs b/rust/src/json.rs new file mode 100644 index 0000000000..696be79a0b --- /dev/null +++ b/rust/src/json.rs @@ -0,0 +1,210 @@ +/* + * This file is licensed under the Affero General Public License (AGPL) version 3. + * + * Copyright (C) 2026 Element Creations Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * See the GNU Affero General Public License for more details: + * . + * + */ + +/// A wrapper type that represents a value that may be missing. +/// +/// We can't necessarily use `Option` for this, as we want to distinguish +/// between a missing value and a value that is present but null (e.g. +/// `{"field": null}` vs `{}`). Serde by default treats missing fields as +/// `None`, so we need a custom type to capture this distinction. +/// +/// A plain `AllowMissing` is used for fields that are either present and of +/// type `T`, or absent. An `AllowMissing>` is used for fields that +/// are of type `T`, null, or absent. +/// +/// Note, to use this type correctly, the field **MUST** be annotated with: +/// +/// ```rust +/// #[serde( +/// default, +/// with = "crate::json::allow_missing", +/// skip_serializing_if = "AllowMissing::is_absent" +/// )] +/// ``` +/// +#[derive(Default, Debug, Clone)] +pub enum AllowMissing { + Some(T), + #[default] + Absent, +} + +impl AllowMissing { + /// Returns `true` if the value is present, even if it is null. + pub fn is_some(&self) -> bool { + matches!(self, AllowMissing::Some(_)) + } + + /// Returns `true` if the value is absent. + pub fn is_absent(&self) -> bool { + matches!(self, AllowMissing::Absent) + } + + /// Converts to `Option`. + /// + /// Useful for converting e.g. `AllowMissing` to `Option<&str>`. + pub fn as_deref_opt(&self) -> Option<&T::Target> + where + T: std::ops::Deref, + { + match self { + AllowMissing::Some(inner) => Some(inner.deref()), + AllowMissing::Absent => None, + } + } +} + +/// A module that provides the serialization and deserialization logic for +/// `AllowMissing`. +pub mod allow_missing { + use serde::ser::Error as _; + + use super::AllowMissing; + + pub fn deserialize<'de, T, D>(deserializer: D) -> Result, D::Error> + where + T: serde::Deserialize<'de>, + D: serde::Deserializer<'de>, + { + Ok(AllowMissing::Some(T::deserialize(deserializer)?)) + } + + pub fn serialize(value: &AllowMissing, serializer: S) -> Result + where + T: serde::Serialize, + S: serde::Serializer, + { + match value { + AllowMissing::Some(inner) => inner.serialize(serializer), + // We should never attempt to serialize an `AllowMissing::Absent`, as we + // should have skipped it with `skip_serializing_if`. + AllowMissing::Absent => Err(S::Error::custom("cannot serialize AllowMissing::Absent")), + } + } +} + +#[cfg(test)] +mod tests { + use std::assert_matches; + + use serde::{Deserialize, Serialize}; + + use super::*; + + #[derive(Serialize, Deserialize)] + struct TestStruct { + #[serde( + default, + with = "crate::json::allow_missing", + skip_serializing_if = "AllowMissing::is_absent" + )] + value: AllowMissing, + } + + #[test] + fn test_deserialize() { + let json = r#"{"value":42}"#; + let deserialized: TestStruct = serde_json::from_str(json).unwrap(); + assert!(deserialized.value.is_some()); + assert_matches!(deserialized.value, AllowMissing::Some(42)); + + let json = r#"{}"#; + let deserialized: TestStruct = serde_json::from_str(json).unwrap(); + assert!(deserialized.value.is_absent()); + assert_matches!(deserialized.value, AllowMissing::Absent); + } + + #[test] + fn test_serialize() { + let value = TestStruct { + value: AllowMissing::Some(42), + }; + let serialized = serde_json::to_string(&value).unwrap(); + assert_eq!(serialized, r#"{"value":42}"#); + + let value = TestStruct { + value: AllowMissing::Absent, + }; + let serialized = serde_json::to_string(&value).unwrap(); + assert_eq!(serialized, r#"{}"#); + } + + /// Test that we get an error if we attempt to serialize an + /// `AllowMissing::Absent` without the skip_serializing_if annotation. + #[test] + fn test_serialize_absent_error() { + #[derive(Serialize)] + struct TestStructWithoutSkip { + #[serde(default, with = "crate::json::allow_missing")] + value: AllowMissing, + } + + let value = TestStructWithoutSkip { + value: AllowMissing::Absent, + }; + + let err = serde_json::to_string(&value).unwrap_err(); + assert_eq!(err.to_string(), "cannot serialize AllowMissing::Absent"); + } + + #[derive(Serialize, Deserialize)] + struct TestStructOption { + #[serde( + default, + with = "crate::json::allow_missing", + skip_serializing_if = "AllowMissing::is_absent" + )] + value: AllowMissing>, + } + + #[test] + fn test_serialize_option() { + let value = TestStructOption { + value: AllowMissing::Some(Some(42)), + }; + let serialized = serde_json::to_string(&value).unwrap(); + assert_eq!(serialized, r#"{"value":42}"#); + + let value = TestStructOption { + value: AllowMissing::Some(None), + }; + let serialized = serde_json::to_string(&value).unwrap(); + assert_eq!(serialized, r#"{"value":null}"#); + + let value = TestStructOption { + value: AllowMissing::Absent, + }; + let serialized = serde_json::to_string(&value).unwrap(); + assert_eq!(serialized, r#"{}"#); + } + + #[test] + fn test_deserialize_option() { + let json = r#"{"value":42}"#; + let deserialized: TestStructOption = serde_json::from_str(json).unwrap(); + assert!(deserialized.value.is_some()); + assert_matches!(deserialized.value, AllowMissing::Some(Some(42))); + + let json = r#"{"value":null}"#; + let deserialized: TestStructOption = serde_json::from_str(json).unwrap(); + assert!(deserialized.value.is_some()); + assert_matches!(deserialized.value, AllowMissing::Some(None)); + + let json = r#"{}"#; + let deserialized: TestStructOption = serde_json::from_str(json).unwrap(); + assert!(deserialized.value.is_absent()); + assert_matches!(deserialized.value, AllowMissing::Absent); + } +} diff --git a/rust/src/lib.rs b/rust/src/lib.rs index cc89862e4e..8ed4e24b81 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -12,6 +12,7 @@ pub mod events; pub mod http; pub mod http_client; pub mod identifier; +pub mod json; pub mod matrix_const; pub mod msc4388_rendezvous; pub mod push;