From c0924fbbd8ba6c3d9c0984a05a3160494abc7fbd Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Mon, 16 Mar 2026 12:29:42 -0400 Subject: [PATCH] MSC4140: put delay_id in unsigned data for sender (#19479) Implements https://github.com/matrix-org/matrix-spec-proposals/pull/4140/changes/49b200dcc11de286974925177b1e184cd905e6fa --- changelog.d/19479.feature | 1 + rust/src/events/internal_metadata.rs | 22 +++++ synapse/events/utils.py | 79 +++++++++-------- synapse/handlers/delayed_events.py | 2 + synapse/handlers/message.py | 13 ++- synapse/handlers/room_member.py | 10 +++ synapse/synapse_rust/events.pyi | 2 + tests/rest/client/test_delayed_events.py | 108 ++++++++++++++++++++++- 8 files changed, 198 insertions(+), 39 deletions(-) create mode 100644 changelog.d/19479.feature diff --git a/changelog.d/19479.feature b/changelog.d/19479.feature new file mode 100644 index 0000000000..3e7e8bd6ff --- /dev/null +++ b/changelog.d/19479.feature @@ -0,0 +1 @@ +[MSC4140: Cancellable delayed events](https://github.com/matrix-org/matrix-spec-proposals/pull/4140): When persisting a delayed event to the timeline, include its `delay_id` in the event's `unsigned` section in `/sync` responses to the event sender. diff --git a/rust/src/events/internal_metadata.rs b/rust/src/events/internal_metadata.rs index fa40fdcfad..595f9cf7eb 100644 --- a/rust/src/events/internal_metadata.rs +++ b/rust/src/events/internal_metadata.rs @@ -57,6 +57,7 @@ enum EventInternalMetadataData { PolicyServerSpammy(bool), Redacted(bool), TxnId(Box), + DelayId(Box), TokenId(i64), DeviceId(Box), } @@ -115,6 +116,10 @@ impl EventInternalMetadataData { pyo3::intern!(py, "txn_id"), o.into_pyobject(py).unwrap_infallible().into_any(), ), + EventInternalMetadataData::DelayId(o) => ( + pyo3::intern!(py, "delay_id"), + o.into_pyobject(py).unwrap_infallible().into_any(), + ), EventInternalMetadataData::TokenId(o) => ( pyo3::intern!(py, "token_id"), o.into_pyobject(py).unwrap_infallible().into_any(), @@ -179,6 +184,12 @@ impl EventInternalMetadataData { .map(String::into_boxed_str) .with_context(|| format!("'{key_str}' has invalid type"))?, ), + "delay_id" => EventInternalMetadataData::DelayId( + value + .extract() + .map(String::into_boxed_str) + .with_context(|| format!("'{key_str}' has invalid type"))?, + ), "token_id" => EventInternalMetadataData::TokenId( value .extract() @@ -472,6 +483,17 @@ impl EventInternalMetadata { set_property!(self, TxnId, obj.into_boxed_str()); } + /// The delay ID, set only if the event was a delayed event. + #[getter] + fn get_delay_id(&self) -> PyResult<&str> { + let s = get_property!(self, DelayId)?; + Ok(s) + } + #[setter] + fn set_delay_id(&mut self, obj: String) { + set_property!(self, DelayId, obj.into_boxed_str()); + } + /// The access token ID of the user who sent this event, if any. #[getter] fn get_token_id(&self) -> PyResult { diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 1bf4d632c0..89eb2182af 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -420,7 +420,7 @@ class SerializeEventConfig: # Function to convert from federation format to client format event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1 # The entity that requested the event. This is used to determine whether to include - # the transaction_id in the unsigned section of the event. + # the transaction_id and delay_id in the unsigned section of the event. requester: Requester | None = None # List of event fields to include. If empty, all fields will be returned. only_event_fields: list[str] | None = None @@ -483,44 +483,49 @@ def serialize_event( config=config, ) - # If we have a txn_id saved in the internal_metadata, we should include it in the - # unsigned section of the event if it was sent by the same session as the one - # requesting the event. - txn_id: str | None = getattr(e.internal_metadata, "txn_id", None) - if ( - txn_id is not None - and config.requester is not None - and config.requester.user.to_string() == e.sender - ): - # Some events do not have the device ID stored in the internal metadata, - # this includes old events as well as those created by appservice, guests, - # or with tokens minted with the admin API. For those events, fallback - # to using the access token instead. - event_device_id: str | None = getattr(e.internal_metadata, "device_id", None) - if event_device_id is not None: - if event_device_id == config.requester.device_id: - d["unsigned"]["transaction_id"] = txn_id + # If we have applicable fields saved in the internal_metadata, include them in the + # unsigned section of the event if the event was sent by the same session (or when + # appropriate, just the same sender) as the one requesting the event. + if config.requester is not None and config.requester.user.to_string() == e.sender: + txn_id: str | None = getattr(e.internal_metadata, "txn_id", None) + if txn_id is not None: + # Some events do not have the device ID stored in the internal metadata, + # this includes old events as well as those created by appservice, guests, + # or with tokens minted with the admin API. For those events, fallback + # to using the access token instead. + event_device_id: str | None = getattr( + e.internal_metadata, "device_id", None + ) + if event_device_id is not None: + if event_device_id == config.requester.device_id: + d["unsigned"]["transaction_id"] = txn_id - else: - # Fallback behaviour: only include the transaction ID if the event - # was sent from the same access token. - # - # For regular users, the access token ID can be used to determine this. - # This includes access tokens minted with the admin API. - # - # For guests and appservice users, we can't check the access token ID - # so assume it is the same session. - event_token_id: int | None = getattr(e.internal_metadata, "token_id", None) - if ( - ( - event_token_id is not None - and config.requester.access_token_id is not None - and event_token_id == config.requester.access_token_id + else: + # Fallback behaviour: only include the transaction ID if the event + # was sent from the same access token. + # + # For regular users, the access token ID can be used to determine this. + # This includes access tokens minted with the admin API. + # + # For guests and appservice users, we can't check the access token ID + # so assume it is the same session. + event_token_id: int | None = getattr( + e.internal_metadata, "token_id", None ) - or config.requester.is_guest - or config.requester.app_service - ): - d["unsigned"]["transaction_id"] = txn_id + if ( + ( + event_token_id is not None + and config.requester.access_token_id is not None + and event_token_id == config.requester.access_token_id + ) + or config.requester.is_guest + or config.requester.app_service + ): + d["unsigned"]["transaction_id"] = txn_id + + delay_id: str | None = getattr(e.internal_metadata, "delay_id", None) + if delay_id is not None: + d["unsigned"]["org.matrix.msc4140.delay_id"] = delay_id # invite_room_state and knock_room_state are a list of stripped room state events # that are meant to provide metadata about a room to an invitee/knocker. They are diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index 7e41716f1e..4a9f646d4d 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -560,6 +560,7 @@ class DelayedEventsHandler: action=membership, content=event.content, origin_server_ts=event.origin_server_ts, + delay_id=event.delay_id, ) else: event_dict: JsonDict = { @@ -585,6 +586,7 @@ class DelayedEventsHandler: requester, event_dict, txn_id=txn_id, + delay_id=event.delay_id, ) event_id = sent_event.event_id except ShadowBanError: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 99ce120736..eb01622515 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -585,6 +585,7 @@ class EventCreationHandler: state_map: StateMap[str] | None = None, for_batch: bool = False, current_state_group: int | None = None, + delay_id: str | None = None, ) -> tuple[EventBase, UnpersistedEventContextBase]: """ Given a dict from a client, create a new event. If bool for_batch is true, will @@ -600,7 +601,7 @@ class EventCreationHandler: Args: requester event_dict: An entire event - txn_id + txn_id: The transaction ID. prev_event_ids: the forward extremities to use as the prev_events for the new event. @@ -639,6 +640,8 @@ class EventCreationHandler: current_state_group: the current state group, used only for creating events for batch persisting + delay_id: The delay ID of this event, if it was a delayed event. + Raises: ResourceLimitError if server is blocked to some resource being exceeded @@ -726,6 +729,9 @@ class EventCreationHandler: if txn_id is not None: builder.internal_metadata.txn_id = txn_id + if delay_id is not None: + builder.internal_metadata.delay_id = delay_id + builder.internal_metadata.outlier = outlier event, unpersisted_context = await self.create_new_client_event( @@ -966,6 +972,7 @@ class EventCreationHandler: ignore_shadow_ban: bool = False, outlier: bool = False, depth: int | None = None, + delay_id: str | None = None, ) -> tuple[EventBase, int]: """ Creates an event, then sends it. @@ -994,6 +1001,7 @@ class EventCreationHandler: depth: Override the depth used to order the event in the DAG. Should normally be set to None, which will cause the depth to be calculated based on the prev_events. + delay_id: The delay ID of this event, if it was a delayed event. Returns: The event, and its stream ordering (if deduplication happened, @@ -1090,6 +1098,7 @@ class EventCreationHandler: ignore_shadow_ban=ignore_shadow_ban, outlier=outlier, depth=depth, + delay_id=delay_id, ) async def _create_and_send_nonmember_event_locked( @@ -1103,6 +1112,7 @@ class EventCreationHandler: ignore_shadow_ban: bool = False, outlier: bool = False, depth: int | None = None, + delay_id: str | None = None, ) -> tuple[EventBase, int]: room_id = event_dict["room_id"] @@ -1131,6 +1141,7 @@ class EventCreationHandler: state_event_ids=state_event_ids, outlier=outlier, depth=depth, + delay_id=delay_id, ) context = await unpersisted_context.persist(event) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 0c6be72716..b2e678e90e 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -408,6 +408,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): require_consent: bool = True, outlier: bool = False, origin_server_ts: int | None = None, + delay_id: str | None = None, ) -> tuple[str, int]: """ Internal membership update function to get an existing event or create @@ -440,6 +441,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): opposed to being inline with the current DAG. origin_server_ts: The origin_server_ts to use if a new event is created. Uses the current timestamp if set to None. + delay_id: The delay ID of this event, if it was a delayed event. Returns: Tuple of event ID and stream ordering position @@ -492,6 +494,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): depth=depth, require_consent=require_consent, outlier=outlier, + delay_id=delay_id, ) context = await unpersisted_context.persist(event) prev_state_ids = await context.get_prev_state_ids( @@ -587,6 +590,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): state_event_ids: list[str] | None = None, depth: int | None = None, origin_server_ts: int | None = None, + delay_id: str | None = None, ) -> tuple[str, int]: """Update a user's membership in a room. @@ -617,6 +621,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): based on the prev_events. origin_server_ts: The origin_server_ts to use if a new event is created. Uses the current timestamp if set to None. + delay_id: The delay ID of this event, if it was a delayed event. Returns: A tuple of the new event ID and stream ID. @@ -679,6 +684,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): state_event_ids=state_event_ids, depth=depth, origin_server_ts=origin_server_ts, + delay_id=delay_id, ) return result @@ -701,6 +707,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): state_event_ids: list[str] | None = None, depth: int | None = None, origin_server_ts: int | None = None, + delay_id: str | None = None, ) -> tuple[str, int]: """Helper for update_membership. @@ -733,6 +740,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): based on the prev_events. origin_server_ts: The origin_server_ts to use if a new event is created. Uses the current timestamp if set to None. + delay_id: The delay ID of this event, if it was a delayed event. Returns: A tuple of the new event ID and stream ID. @@ -943,6 +951,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): require_consent=require_consent, outlier=outlier, origin_server_ts=origin_server_ts, + delay_id=delay_id, ) latest_event_ids = await self.store.get_prev_events_for_room(room_id) @@ -1201,6 +1210,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): require_consent=require_consent, outlier=outlier, origin_server_ts=origin_server_ts, + delay_id=delay_id, ) async def check_for_any_membership_in_room( diff --git a/synapse/synapse_rust/events.pyi b/synapse/synapse_rust/events.pyi index 0add391c65..185f29694b 100644 --- a/synapse/synapse_rust/events.pyi +++ b/synapse/synapse_rust/events.pyi @@ -38,6 +38,8 @@ class EventInternalMetadata: txn_id: str """The transaction ID, if it was set when the event was created.""" + delay_id: str + """The delay ID, set only if the event was a delayed event.""" token_id: int """The access token ID of the user who sent this event, if any.""" device_id: str diff --git a/tests/rest/client/test_delayed_events.py b/tests/rest/client/test_delayed_events.py index efa69a393a..da904ce1f5 100644 --- a/tests/rest/client/test_delayed_events.py +++ b/tests/rest/client/test_delayed_events.py @@ -22,7 +22,7 @@ from twisted.internet.testing import MemoryReactor from synapse.api.errors import Codes from synapse.rest import admin -from synapse.rest.client import delayed_events, login, room, versions +from synapse.rest.client import delayed_events, login, room, sync, versions from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util.clock import Clock @@ -59,6 +59,7 @@ class DelayedEventsTestCase(HomeserverTestCase): delayed_events.register_servlets, login.register_servlets, room.register_servlets, + sync.register_servlets, ] def default_config(self) -> JsonDict: @@ -106,6 +107,9 @@ class DelayedEventsTestCase(HomeserverTestCase): self.user1_access_token, ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + delay_id = channel.json_body.get("delay_id") + assert delay_id is not None + events = self._get_delayed_events() self.assertEqual(1, len(events), events) content = self._get_delayed_event_content(events[0]) @@ -128,6 +132,56 @@ class DelayedEventsTestCase(HomeserverTestCase): ) self.assertEqual(setter_expected, content.get(setter_key), content) + self._find_sent_delayed_event(self.user1_access_token, delay_id, True) + self._find_sent_delayed_event(self.user2_access_token, delay_id, False) + + def test_delayed_member_events_are_sent_on_timeout(self) -> None: + channel = self.make_request( + "PUT", + _get_path_for_delayed_state( + self.room_id, + "m.room.member", + self.user2_user_id, + 900, + ), + { + "membership": "leave", + "reason": "Delayed kick", + }, + self.user1_access_token, + ) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + delay_id = channel.json_body.get("delay_id") + assert delay_id is not None + + events = self._get_delayed_events() + self.assertEqual(1, len(events), events) + content = self._get_delayed_event_content(events[0]) + self.assertEqual("leave", content.get("membership"), content) + self.assertEqual("Delayed kick", content.get("reason"), content) + + content = self.helper.get_state( + self.room_id, + "m.room.member", + self.user1_access_token, + state_key=self.user2_user_id, + ) + self.assertEqual("join", content.get("membership"), content) + + self.reactor.advance(1) + self.assertListEqual([], self._get_delayed_events()) + content = self.helper.get_state( + self.room_id, + "m.room.member", + self.user1_access_token, + state_key=self.user2_user_id, + ) + self.assertEqual("leave", content.get("membership"), content) + self.assertEqual("Delayed kick", content.get("reason"), content) + + self._find_sent_delayed_event(self.user1_access_token, delay_id, True) + self._find_sent_delayed_event(self.user2_access_token, delay_id, False) + def test_get_delayed_events_auth(self) -> None: channel = self.make_request("GET", PATH_PREFIX) self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, channel.result) @@ -254,6 +308,9 @@ class DelayedEventsTestCase(HomeserverTestCase): expect_code=HTTPStatus.NOT_FOUND, ) + self._find_sent_delayed_event(self.user1_access_token, delay_id, False) + self._find_sent_delayed_event(self.user2_access_token, delay_id, False) + @parameterized.expand((True, False)) @unittest.override_config( {"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}} @@ -327,6 +384,9 @@ class DelayedEventsTestCase(HomeserverTestCase): ) self.assertEqual(content_value, content.get(content_property_name), content) + self._find_sent_delayed_event(self.user1_access_token, delay_id, True) + self._find_sent_delayed_event(self.user2_access_token, delay_id, False) + @parameterized.expand((True, False)) @unittest.override_config({"rc_message": {"per_second": 2.5, "burst_count": 3}}) def test_send_delayed_event_ratelimit(self, action_in_path: bool) -> None: @@ -406,6 +466,9 @@ class DelayedEventsTestCase(HomeserverTestCase): ) self.assertEqual(setter_expected, content.get(setter_key), content) + self._find_sent_delayed_event(self.user1_access_token, delay_id, True) + self._find_sent_delayed_event(self.user2_access_token, delay_id, False) + @parameterized.expand((True, False)) @unittest.override_config( {"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}} @@ -450,6 +513,8 @@ class DelayedEventsTestCase(HomeserverTestCase): self.user1_access_token, ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + delay_id = channel.json_body.get("delay_id") + assert delay_id is not None events = self._get_delayed_events() self.assertEqual(1, len(events), events) @@ -474,6 +539,9 @@ class DelayedEventsTestCase(HomeserverTestCase): ) self.assertEqual(setter_expected, content.get(setter_key), content) + self._find_sent_delayed_event(self.user1_access_token, delay_id, True) + self._find_sent_delayed_event(self.user2_access_token, delay_id, False) + def test_delayed_state_is_cancelled_by_new_state_from_other_user( self, ) -> None: @@ -489,6 +557,8 @@ class DelayedEventsTestCase(HomeserverTestCase): self.user1_access_token, ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + delay_id = channel.json_body.get("delay_id") + assert delay_id is not None events = self._get_delayed_events() self.assertEqual(1, len(events), events) @@ -513,6 +583,9 @@ class DelayedEventsTestCase(HomeserverTestCase): ) self.assertEqual(setter_expected, content.get(setter_key), content) + self._find_sent_delayed_event(self.user1_access_token, delay_id, False) + self._find_sent_delayed_event(self.user2_access_token, delay_id, False) + def _get_delayed_events(self) -> list[JsonDict]: channel = self.make_request( "GET", @@ -549,6 +622,39 @@ class DelayedEventsTestCase(HomeserverTestCase): body["action"] = action return self.make_request("POST", path, body) + def _find_sent_delayed_event( + self, access_token: str, delay_id: str, should_find: bool + ) -> None: + """Call /sync and look for a synced event with a specified delay_id. + At most one event will ever have a matching delay_id. + + Args: + access_token: The access token of the user to call /sync for. + delay_id: The delay_id to search for in synced events. + should_find: Whether /sync should include an event with a matching delay_id. + """ + channel = self.make_request("GET", "/sync", access_token=access_token) + self.assertEqual(HTTPStatus.OK, channel.code) + + rooms = channel.json_body["rooms"] + events = [] + for membership in "join", "leave": + if membership in rooms: + events += rooms[membership][self.room_id]["timeline"]["events"] + + found = False + for event in events: + if event["unsigned"].get("org.matrix.msc4140.delay_id") == delay_id: + if not should_find: + self.fail( + "Found event with matching delay_id, but expected to not find one" + ) + if found: + self.fail("Found multiple events with matching delay_id") + found = True + if should_find and not found: + self.fail("Did not find event with matching delay_id") + def _get_path_for_delayed_state( room_id: str, event_type: str, state_key: str, delay_ms: int