From 33d80be69f8f6161eee93925c29d5e73b2b63206 Mon Sep 17 00:00:00 2001 From: Kegan Dougal <7190048+kegsay@users.noreply.github.com> Date: Fri, 26 Sep 2025 09:29:52 +0100 Subject: [PATCH] Send sticky events when catching up over federation --- .../sender/per_destination_queue.py | 28 ++++++++++ .../storage/databases/main/sticky_events.py | 37 ++++++++++++ .../storage/databases/main/transactions.py | 2 +- tests/federation/test_federation_catch_up.py | 56 +++++++++++++++++++ tests/rest/client/utils.py | 40 +++++++++++++ 5 files changed, 162 insertions(+), 1 deletion(-) diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 4c844d403a..524ad8cb03 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -101,6 +101,7 @@ class PerDestinationQueue: self._instance_name = hs.get_instance_name() self._federation_shard_config = hs.config.worker.federation_shard_config self._state = hs.get_state_handler() + self.msc4354_enabled = hs.config.experimental.msc4354_enabled self._should_send_on_this_instance = True if not self._federation_shard_config.should_handle( @@ -558,6 +559,33 @@ class PerDestinationQueue: # send. extrem_events = await self._store.get_events_as_list(extrems) + if self.msc4354_enabled: + # we also want to send sticky events that are still active in this room + sticky_event_ids = ( + await self._store.get_sticky_event_ids_sent_by_self( + pdu.room_id, + last_successful_stream_ordering, + ) + ) + # skip any that are actually the forward extremities we want to send anyway + sticky_events = await self._store.get_events_as_list( + [ + event_id + for event_id in sticky_event_ids + if event_id not in extrems + ] + ) + if sticky_events: + # *prepend* these to the extrem list, so they are processed first. + # This ensures they will show up before the forward extrem in stream order + extrem_events = sticky_events + extrem_events + logger.info( + "Sending %d missed sticky events to %s: %r", + len(sticky_events), + self._destination, + pdu.room_id, + ) + new_pdus = [] for p in extrem_events: # We pulled this from the DB, so it'll be non-null diff --git a/synapse/storage/databases/main/sticky_events.py b/synapse/storage/databases/main/sticky_events.py index 2f04623a86..5b4c18ab38 100644 --- a/synapse/storage/databases/main/sticky_events.py +++ b/synapse/storage/databases/main/sticky_events.py @@ -209,6 +209,43 @@ class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStor ) return cast(List[Tuple[int, str, str]], txn.fetchall()) + async def get_sticky_event_ids_sent_by_self( + self, room_id: str, from_stream_pos: int + ) -> List[str]: + """Get sticky event IDs which have been sent by users on this homeserver. + + Used when sending sticky events eagerly to newly joined servers, or when catching up over federation. + + Args: + room_id: The room to fetch sticky events in. + from_stream_pos: The stream position to return events from. May be 0 for newly joined servers. + Returns: + A list of event IDs, which may be empty. + """ + return await self.db_pool.runInteraction( + "get_sticky_event_ids_sent_by_self", + self._get_sticky_event_ids_sent_by_self_txn, + room_id, + from_stream_pos, + ) + + def _get_sticky_event_ids_sent_by_self_txn( + self, txn: LoggingTransaction, room_id: str, from_stream_pos: int + ) -> List[str]: + now_ms = self._now() + txn.execute( + """ + SELECT sticky_events.event_id, sticky_events.sender, events.stream_ordering FROM sticky_events + INNER JOIN events ON events.event_id = sticky_events.event_id + WHERE soft_failed=FALSE AND expires_at > ? AND sticky_events.room_id = ? + """, + (now_ms, room_id), + ) + rows = cast(List[Tuple[str, str, int]], txn.fetchall()) + return [ + row[0] for row in rows if row[2] > from_stream_pos and self.hs.is_mine_id(row[1]) + ] + async def reevaluate_soft_failed_sticky_events( self, room_id: str, diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index bfc324b80d..a439c41aab 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -380,7 +380,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): ) -> List[str]: """ Returns at most 50 event IDs and their corresponding stream_orderings - that correspond to the oldest events that have not yet been sent to + that correspond to the newest events that have not yet been sent to the destination. Args: diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py index f99911b102..a2bc7e59b5 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py @@ -1,3 +1,4 @@ +import time from typing import Callable, Collection, List, Optional, Tuple from unittest import mock from unittest.mock import AsyncMock, Mock @@ -19,6 +20,7 @@ from synapse.types import JsonDict from synapse.util import Clock from synapse.util.retryutils import NotRetryingDestination +from tests import unittest from tests.test_utils import event_injection from tests.unittest import FederatingHomeserverTestCase @@ -452,6 +454,60 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): # has been successfully sent. self.assertCountEqual(woken, set(server_names[:-1])) + @unittest.override_config({"experimental_features": {"msc4354_enabled": True}}) + def test_sends_sticky_events(self) -> None: + """Test that we send sticky events in addition to the latest event in the room when catching up.""" + # make the clock used when generating origin_server_ts the same as the clock used to check expiry + self.reactor.advance(time.time()) + per_dest_queue, sent_pdus = self.make_fake_destination_queue() + + # Make a room with a local user, and two servers. One will go offline + # and one will send some events. + self.register_user("u1", "you the one") + u1_token = self.login("u1", "you the one") + room_1 = self.helper.create_room_as("u1", tok=u1_token) + + self.get_success( + event_injection.inject_member_event(self.hs, room_1, "@user:host2", "join") + ) + event_1 = self.get_success( + event_injection.inject_member_event(self.hs, room_1, "@user:host3", "join") + ) + + # now we send a sticky event that we expect to be bundled with the fwd extrem event + sticky_event_id = self.helper.send_sticky_event( + room_1, "m.room.sticky", 60000, tok=u1_token + )["event_id"] + # ..and other uninteresting events + self.helper.send(room_1, "you hear me!!", tok=u1_token) + + # Now simulate us receiving an event from the still online remote. + fwd_extrem_event = self.get_success( + event_injection.inject_event( + self.hs, + type=EventTypes.Message, + sender="@user:host3", + room_id=room_1, + content={"msgtype": "m.text", "body": "Hello"}, + ) + ) + + assert event_1.internal_metadata.stream_ordering is not None + self.get_success( + self.hs.get_datastores().main.set_destination_last_successful_stream_ordering( + "host2", event_1.internal_metadata.stream_ordering + ) + ) + + self.get_success(per_dest_queue._catch_up_transmission_loop()) + + # We expect the sticky event and the fwd extrem to be sent + self.assertEqual(len(sent_pdus), 2) + # We expect the sticky event to appear before the fwd extrem + self.assertEqual(sent_pdus[0].event_id, sticky_event_id) + self.assertEqual(sent_pdus[1].event_id, fwd_extrem_event.event_id) + self.assertFalse(per_dest_queue._catching_up) + def test_not_latest_event(self) -> None: """Test that we send the latest event in the room even if its not ours.""" diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index bb214759d9..c34b1c4973 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -456,6 +456,46 @@ class RestHelper: return channel.json_body + def send_sticky_event( + self, + room_id: str, + type: str, + duration_ms: int, + content: Optional[dict] = None, + txn_id: Optional[str] = None, + tok: Optional[str] = None, + expect_code: int = HTTPStatus.OK, + custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, + ) -> JsonDict: + if txn_id is None: + txn_id = "m%s" % (str(time.time())) + + path = "/_matrix/client/r0/rooms/%s/send/%s/%s?msc4354_stick_duration_ms=%d" % ( + room_id, + type, + txn_id, + duration_ms, + ) + if tok: + path = path + "&access_token=%s" % tok + + channel = make_request( + self.reactor, + self.site, + "PUT", + path, + content or {}, + custom_headers=custom_headers, + ) + + assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( + expect_code, + channel.code, + channel.result["body"], + ) + + return channel.json_body + def get_event( self, room_id: str,