mirror of
https://github.com/element-hq/synapse.git
synced 2026-06-06 17:42:10 +00:00
Send sticky events when catching up over federation
This commit is contained in:
@@ -101,6 +101,7 @@ class PerDestinationQueue:
|
||||
self._instance_name = hs.get_instance_name()
|
||||
self._federation_shard_config = hs.config.worker.federation_shard_config
|
||||
self._state = hs.get_state_handler()
|
||||
self.msc4354_enabled = hs.config.experimental.msc4354_enabled
|
||||
|
||||
self._should_send_on_this_instance = True
|
||||
if not self._federation_shard_config.should_handle(
|
||||
@@ -558,6 +559,33 @@ class PerDestinationQueue:
|
||||
# send.
|
||||
extrem_events = await self._store.get_events_as_list(extrems)
|
||||
|
||||
if self.msc4354_enabled:
|
||||
# we also want to send sticky events that are still active in this room
|
||||
sticky_event_ids = (
|
||||
await self._store.get_sticky_event_ids_sent_by_self(
|
||||
pdu.room_id,
|
||||
last_successful_stream_ordering,
|
||||
)
|
||||
)
|
||||
# skip any that are actually the forward extremities we want to send anyway
|
||||
sticky_events = await self._store.get_events_as_list(
|
||||
[
|
||||
event_id
|
||||
for event_id in sticky_event_ids
|
||||
if event_id not in extrems
|
||||
]
|
||||
)
|
||||
if sticky_events:
|
||||
# *prepend* these to the extrem list, so they are processed first.
|
||||
# This ensures they will show up before the forward extrem in stream order
|
||||
extrem_events = sticky_events + extrem_events
|
||||
logger.info(
|
||||
"Sending %d missed sticky events to %s: %r",
|
||||
len(sticky_events),
|
||||
self._destination,
|
||||
pdu.room_id,
|
||||
)
|
||||
|
||||
new_pdus = []
|
||||
for p in extrem_events:
|
||||
# We pulled this from the DB, so it'll be non-null
|
||||
|
||||
@@ -209,6 +209,43 @@ class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStor
|
||||
)
|
||||
return cast(List[Tuple[int, str, str]], txn.fetchall())
|
||||
|
||||
async def get_sticky_event_ids_sent_by_self(
|
||||
self, room_id: str, from_stream_pos: int
|
||||
) -> List[str]:
|
||||
"""Get sticky event IDs which have been sent by users on this homeserver.
|
||||
|
||||
Used when sending sticky events eagerly to newly joined servers, or when catching up over federation.
|
||||
|
||||
Args:
|
||||
room_id: The room to fetch sticky events in.
|
||||
from_stream_pos: The stream position to return events from. May be 0 for newly joined servers.
|
||||
Returns:
|
||||
A list of event IDs, which may be empty.
|
||||
"""
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_sticky_event_ids_sent_by_self",
|
||||
self._get_sticky_event_ids_sent_by_self_txn,
|
||||
room_id,
|
||||
from_stream_pos,
|
||||
)
|
||||
|
||||
def _get_sticky_event_ids_sent_by_self_txn(
|
||||
self, txn: LoggingTransaction, room_id: str, from_stream_pos: int
|
||||
) -> List[str]:
|
||||
now_ms = self._now()
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT sticky_events.event_id, sticky_events.sender, events.stream_ordering FROM sticky_events
|
||||
INNER JOIN events ON events.event_id = sticky_events.event_id
|
||||
WHERE soft_failed=FALSE AND expires_at > ? AND sticky_events.room_id = ?
|
||||
""",
|
||||
(now_ms, room_id),
|
||||
)
|
||||
rows = cast(List[Tuple[str, str, int]], txn.fetchall())
|
||||
return [
|
||||
row[0] for row in rows if row[2] > from_stream_pos and self.hs.is_mine_id(row[1])
|
||||
]
|
||||
|
||||
async def reevaluate_soft_failed_sticky_events(
|
||||
self,
|
||||
room_id: str,
|
||||
|
||||
@@ -380,7 +380,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
||||
) -> List[str]:
|
||||
"""
|
||||
Returns at most 50 event IDs and their corresponding stream_orderings
|
||||
that correspond to the oldest events that have not yet been sent to
|
||||
that correspond to the newest events that have not yet been sent to
|
||||
the destination.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import time
|
||||
from typing import Callable, Collection, List, Optional, Tuple
|
||||
from unittest import mock
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
@@ -19,6 +20,7 @@ from synapse.types import JsonDict
|
||||
from synapse.util import Clock
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
|
||||
from tests import unittest
|
||||
from tests.test_utils import event_injection
|
||||
from tests.unittest import FederatingHomeserverTestCase
|
||||
|
||||
@@ -452,6 +454,60 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
|
||||
# has been successfully sent.
|
||||
self.assertCountEqual(woken, set(server_names[:-1]))
|
||||
|
||||
@unittest.override_config({"experimental_features": {"msc4354_enabled": True}})
|
||||
def test_sends_sticky_events(self) -> None:
|
||||
"""Test that we send sticky events in addition to the latest event in the room when catching up."""
|
||||
# make the clock used when generating origin_server_ts the same as the clock used to check expiry
|
||||
self.reactor.advance(time.time())
|
||||
per_dest_queue, sent_pdus = self.make_fake_destination_queue()
|
||||
|
||||
# Make a room with a local user, and two servers. One will go offline
|
||||
# and one will send some events.
|
||||
self.register_user("u1", "you the one")
|
||||
u1_token = self.login("u1", "you the one")
|
||||
room_1 = self.helper.create_room_as("u1", tok=u1_token)
|
||||
|
||||
self.get_success(
|
||||
event_injection.inject_member_event(self.hs, room_1, "@user:host2", "join")
|
||||
)
|
||||
event_1 = self.get_success(
|
||||
event_injection.inject_member_event(self.hs, room_1, "@user:host3", "join")
|
||||
)
|
||||
|
||||
# now we send a sticky event that we expect to be bundled with the fwd extrem event
|
||||
sticky_event_id = self.helper.send_sticky_event(
|
||||
room_1, "m.room.sticky", 60000, tok=u1_token
|
||||
)["event_id"]
|
||||
# ..and other uninteresting events
|
||||
self.helper.send(room_1, "you hear me!!", tok=u1_token)
|
||||
|
||||
# Now simulate us receiving an event from the still online remote.
|
||||
fwd_extrem_event = self.get_success(
|
||||
event_injection.inject_event(
|
||||
self.hs,
|
||||
type=EventTypes.Message,
|
||||
sender="@user:host3",
|
||||
room_id=room_1,
|
||||
content={"msgtype": "m.text", "body": "Hello"},
|
||||
)
|
||||
)
|
||||
|
||||
assert event_1.internal_metadata.stream_ordering is not None
|
||||
self.get_success(
|
||||
self.hs.get_datastores().main.set_destination_last_successful_stream_ordering(
|
||||
"host2", event_1.internal_metadata.stream_ordering
|
||||
)
|
||||
)
|
||||
|
||||
self.get_success(per_dest_queue._catch_up_transmission_loop())
|
||||
|
||||
# We expect the sticky event and the fwd extrem to be sent
|
||||
self.assertEqual(len(sent_pdus), 2)
|
||||
# We expect the sticky event to appear before the fwd extrem
|
||||
self.assertEqual(sent_pdus[0].event_id, sticky_event_id)
|
||||
self.assertEqual(sent_pdus[1].event_id, fwd_extrem_event.event_id)
|
||||
self.assertFalse(per_dest_queue._catching_up)
|
||||
|
||||
def test_not_latest_event(self) -> None:
|
||||
"""Test that we send the latest event in the room even if its not ours."""
|
||||
|
||||
|
||||
@@ -456,6 +456,46 @@ class RestHelper:
|
||||
|
||||
return channel.json_body
|
||||
|
||||
def send_sticky_event(
|
||||
self,
|
||||
room_id: str,
|
||||
type: str,
|
||||
duration_ms: int,
|
||||
content: Optional[dict] = None,
|
||||
txn_id: Optional[str] = None,
|
||||
tok: Optional[str] = None,
|
||||
expect_code: int = HTTPStatus.OK,
|
||||
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
|
||||
) -> JsonDict:
|
||||
if txn_id is None:
|
||||
txn_id = "m%s" % (str(time.time()))
|
||||
|
||||
path = "/_matrix/client/r0/rooms/%s/send/%s/%s?msc4354_stick_duration_ms=%d" % (
|
||||
room_id,
|
||||
type,
|
||||
txn_id,
|
||||
duration_ms,
|
||||
)
|
||||
if tok:
|
||||
path = path + "&access_token=%s" % tok
|
||||
|
||||
channel = make_request(
|
||||
self.reactor,
|
||||
self.site,
|
||||
"PUT",
|
||||
path,
|
||||
content or {},
|
||||
custom_headers=custom_headers,
|
||||
)
|
||||
|
||||
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
|
||||
expect_code,
|
||||
channel.code,
|
||||
channel.result["body"],
|
||||
)
|
||||
|
||||
return channel.json_body
|
||||
|
||||
def get_event(
|
||||
self,
|
||||
room_id: str,
|
||||
|
||||
Reference in New Issue
Block a user