Send sticky events when catching up over federation

This commit is contained in:
Kegan Dougal
2025-09-26 09:29:52 +01:00
parent ad6a2b9e0c
commit 33d80be69f
5 changed files with 162 additions and 1 deletions
@@ -101,6 +101,7 @@ class PerDestinationQueue:
self._instance_name = hs.get_instance_name()
self._federation_shard_config = hs.config.worker.federation_shard_config
self._state = hs.get_state_handler()
self.msc4354_enabled = hs.config.experimental.msc4354_enabled
self._should_send_on_this_instance = True
if not self._federation_shard_config.should_handle(
@@ -558,6 +559,33 @@ class PerDestinationQueue:
# send.
extrem_events = await self._store.get_events_as_list(extrems)
if self.msc4354_enabled:
# we also want to send sticky events that are still active in this room
sticky_event_ids = (
await self._store.get_sticky_event_ids_sent_by_self(
pdu.room_id,
last_successful_stream_ordering,
)
)
# skip any that are actually the forward extremities we want to send anyway
sticky_events = await self._store.get_events_as_list(
[
event_id
for event_id in sticky_event_ids
if event_id not in extrems
]
)
if sticky_events:
# *prepend* these to the extrem list, so they are processed first.
# This ensures they will show up before the forward extrem in stream order
extrem_events = sticky_events + extrem_events
logger.info(
"Sending %d missed sticky events to %s: %r",
len(sticky_events),
self._destination,
pdu.room_id,
)
new_pdus = []
for p in extrem_events:
# We pulled this from the DB, so it'll be non-null
@@ -209,6 +209,43 @@ class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStor
)
return cast(List[Tuple[int, str, str]], txn.fetchall())
async def get_sticky_event_ids_sent_by_self(
self, room_id: str, from_stream_pos: int
) -> List[str]:
"""Get sticky event IDs which have been sent by users on this homeserver.
Used when sending sticky events eagerly to newly joined servers, or when catching up over federation.
Args:
room_id: The room to fetch sticky events in.
from_stream_pos: The stream position to return events from. May be 0 for newly joined servers.
Returns:
A list of event IDs, which may be empty.
"""
return await self.db_pool.runInteraction(
"get_sticky_event_ids_sent_by_self",
self._get_sticky_event_ids_sent_by_self_txn,
room_id,
from_stream_pos,
)
def _get_sticky_event_ids_sent_by_self_txn(
self, txn: LoggingTransaction, room_id: str, from_stream_pos: int
) -> List[str]:
now_ms = self._now()
txn.execute(
"""
SELECT sticky_events.event_id, sticky_events.sender, events.stream_ordering FROM sticky_events
INNER JOIN events ON events.event_id = sticky_events.event_id
WHERE soft_failed=FALSE AND expires_at > ? AND sticky_events.room_id = ?
""",
(now_ms, room_id),
)
rows = cast(List[Tuple[str, str, int]], txn.fetchall())
return [
row[0] for row in rows if row[2] > from_stream_pos and self.hs.is_mine_id(row[1])
]
async def reevaluate_soft_failed_sticky_events(
self,
room_id: str,
@@ -380,7 +380,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
) -> List[str]:
"""
Returns at most 50 event IDs and their corresponding stream_orderings
that correspond to the oldest events that have not yet been sent to
that correspond to the newest events that have not yet been sent to
the destination.
Args:
@@ -1,3 +1,4 @@
import time
from typing import Callable, Collection, List, Optional, Tuple
from unittest import mock
from unittest.mock import AsyncMock, Mock
@@ -19,6 +20,7 @@ from synapse.types import JsonDict
from synapse.util import Clock
from synapse.util.retryutils import NotRetryingDestination
from tests import unittest
from tests.test_utils import event_injection
from tests.unittest import FederatingHomeserverTestCase
@@ -452,6 +454,60 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
# has been successfully sent.
self.assertCountEqual(woken, set(server_names[:-1]))
@unittest.override_config({"experimental_features": {"msc4354_enabled": True}})
def test_sends_sticky_events(self) -> None:
"""Test that we send sticky events in addition to the latest event in the room when catching up."""
# make the clock used when generating origin_server_ts the same as the clock used to check expiry
self.reactor.advance(time.time())
per_dest_queue, sent_pdus = self.make_fake_destination_queue()
# Make a room with a local user, and two servers. One will go offline
# and one will send some events.
self.register_user("u1", "you the one")
u1_token = self.login("u1", "you the one")
room_1 = self.helper.create_room_as("u1", tok=u1_token)
self.get_success(
event_injection.inject_member_event(self.hs, room_1, "@user:host2", "join")
)
event_1 = self.get_success(
event_injection.inject_member_event(self.hs, room_1, "@user:host3", "join")
)
# now we send a sticky event that we expect to be bundled with the fwd extrem event
sticky_event_id = self.helper.send_sticky_event(
room_1, "m.room.sticky", 60000, tok=u1_token
)["event_id"]
# ..and other uninteresting events
self.helper.send(room_1, "you hear me!!", tok=u1_token)
# Now simulate us receiving an event from the still online remote.
fwd_extrem_event = self.get_success(
event_injection.inject_event(
self.hs,
type=EventTypes.Message,
sender="@user:host3",
room_id=room_1,
content={"msgtype": "m.text", "body": "Hello"},
)
)
assert event_1.internal_metadata.stream_ordering is not None
self.get_success(
self.hs.get_datastores().main.set_destination_last_successful_stream_ordering(
"host2", event_1.internal_metadata.stream_ordering
)
)
self.get_success(per_dest_queue._catch_up_transmission_loop())
# We expect the sticky event and the fwd extrem to be sent
self.assertEqual(len(sent_pdus), 2)
# We expect the sticky event to appear before the fwd extrem
self.assertEqual(sent_pdus[0].event_id, sticky_event_id)
self.assertEqual(sent_pdus[1].event_id, fwd_extrem_event.event_id)
self.assertFalse(per_dest_queue._catching_up)
def test_not_latest_event(self) -> None:
"""Test that we send the latest event in the room even if its not ours."""
+40
View File
@@ -456,6 +456,46 @@ class RestHelper:
return channel.json_body
def send_sticky_event(
self,
room_id: str,
type: str,
duration_ms: int,
content: Optional[dict] = None,
txn_id: Optional[str] = None,
tok: Optional[str] = None,
expect_code: int = HTTPStatus.OK,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
) -> JsonDict:
if txn_id is None:
txn_id = "m%s" % (str(time.time()))
path = "/_matrix/client/r0/rooms/%s/send/%s/%s?msc4354_stick_duration_ms=%d" % (
room_id,
type,
txn_id,
duration_ms,
)
if tok:
path = path + "&access_token=%s" % tok
channel = make_request(
self.reactor,
self.site,
"PUT",
path,
content or {},
custom_headers=custom_headers,
)
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
channel.code,
channel.result["body"],
)
return channel.json_body
def get_event(
self,
room_id: str,