From afdb0689840a6edc242cdf1b58b6723f7bbf1c04 Mon Sep 17 00:00:00 2001 From: Kegan Dougal <7190048+kegsay@users.noreply.github.com> Date: Fri, 24 Apr 2026 11:42:18 +0100 Subject: [PATCH] Use a CTE in `_get_missing_events_state_dag_txn` which is more efficient Works in both sqlite and postgres. --- .../databases/main/event_federation.py | 101 +++++++++++++----- tests/storage/test_event_federation.py | 4 +- 2 files changed, 77 insertions(+), 28 deletions(-) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 65bbd6000e..fbf64f4964 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1295,34 +1295,81 @@ class EventFederationWorkerStore( latest_event_ids: list[str], limit: int, ) -> list[str]: - seen_event_ids = set(earliest_event_ids) - # lexicographical sort to ensure that responses are deterministic (for caching/tests) - front_queue = sorted(set(latest_event_ids) - seen_event_ids) - event_id_results: list[str] = [] - # TODO(kegan): use a recursive CTE? - # The limit is usually pretty low, so it's cheaper to select the events we need via querying - # rather than selecting all events and filtering python-side. - query = ( - "SELECT prev_state_event_id FROM msc4242_state_dag_edges " - "WHERE room_id = ? AND event_id = ? " - "ORDER BY prev_state_event_id ASC " - "LIMIT ?" - ) - while front_queue and len(event_id_results) < limit: - front_event_id = front_queue.pop(0) - txn.execute(query, (room_id, front_event_id, limit - len(event_id_results))) - # None check because the m.room.create event has NULL prev_state_events - prev_state_event_ids_for_front_event = [ - t[0] for t in txn if t[0] is not None and t[0] not in seen_event_ids - ] - for next in ( - prev_state_event_ids_for_front_event - ): # Sort lexicographically for determinism - front_queue.append(next) - seen_event_ids |= set(prev_state_event_ids_for_front_event) - event_id_results.extend(prev_state_event_ids_for_front_event) + """Walk the state DAG backward from `latest_event_ids`, stopping at + `limit` results or the beginning of the DAG, whichever comes first. - return event_id_results + Earliest events are treated as already-visited: they are not emitted, + and their predecessors are not traversed via them. + + Results are deterministic and ordered by BFS, with lexicographic tie-breaking among + siblings at the same # hops away. + + Executes as a single recursive CTE in both SQLite and Postgres. + """ + earliest_set = set(earliest_event_ids) + seed_ids = sorted(set(latest_event_ids) - earliest_set) + if not seed_ids or limit <= 0: + return [] + + seed_clause, seed_args = make_in_list_sql_clause( + self.database_engine, "e.event_id", seed_ids + ) + + # `make_in_list_sql_clause` doesn't handle empty iterables, so guard + # the no-earliest case explicitly with an always-true clause. + if earliest_event_ids: + earliest_clause, earliest_args = make_in_list_sql_clause( + self.database_engine, + "e.prev_state_event_id", + earliest_event_ids, + negative=True, + ) + else: + # No-op clause that is TRUE for every row, in both dialects. + earliest_clause = "1=1" + earliest_args = [] + + query = f""" + WITH RECURSIVE walk(event_id, hops) AS ( + SELECT + e.prev_state_event_id, + 1 + FROM msc4242_state_dag_edges e + WHERE e.room_id = ? + AND {seed_clause} + AND e.prev_state_event_id IS NOT NULL + AND {earliest_clause} + + UNION + + SELECT + e.prev_state_event_id, + w.hops + 1 + FROM walk w + JOIN msc4242_state_dag_edges e + ON e.room_id = ? + AND e.event_id = w.event_id + WHERE e.prev_state_event_id IS NOT NULL + AND {earliest_clause} + AND w.hops < ? + ) + SELECT event_id, MIN(hops) AS hops + FROM walk + GROUP BY event_id + ORDER BY hops, event_id + LIMIT ? + """ + + params: list = [room_id] + params.extend(seed_args) + params.extend(earliest_args) + params.append(room_id) + params.extend(earliest_args) + params.append(limit) + params.append(limit) + + txn.execute(query, params) + return [row[0] for row in txn] @trace @tag_args diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 48a27f77cc..c8cba83511 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -1542,7 +1542,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): latest=["E", "E", "C", "C", "C"], want=["B", "D", "T", "W"], limit=4 ), # include latest events in response. W included because reachable from E. - TestCase(latest=["W", "E"], want=["D", "T", "W", "R"], limit=4), + # sort order is based on # hops not processing order of parents + # (which would produce D,T,W,R as E is processed first, then W). + TestCase(latest=["W", "E"], want=["D", "R", "T", "W"], limit=4), ] for test_case in test_cases: got = self.get_success(