Use a CTE in _get_missing_events_state_dag_txn which is more efficient

Works in both sqlite and postgres.
This commit is contained in:
Kegan Dougal
2026-04-24 11:42:18 +01:00
parent 1abc489849
commit afdb068984
2 changed files with 77 additions and 28 deletions
@@ -1295,34 +1295,81 @@ class EventFederationWorkerStore(
latest_event_ids: list[str],
limit: int,
) -> list[str]:
seen_event_ids = set(earliest_event_ids)
# lexicographical sort to ensure that responses are deterministic (for caching/tests)
front_queue = sorted(set(latest_event_ids) - seen_event_ids)
event_id_results: list[str] = []
# TODO(kegan): use a recursive CTE?
# The limit is usually pretty low, so it's cheaper to select the events we need via querying
# rather than selecting all events and filtering python-side.
query = (
"SELECT prev_state_event_id FROM msc4242_state_dag_edges "
"WHERE room_id = ? AND event_id = ? "
"ORDER BY prev_state_event_id ASC "
"LIMIT ?"
)
while front_queue and len(event_id_results) < limit:
front_event_id = front_queue.pop(0)
txn.execute(query, (room_id, front_event_id, limit - len(event_id_results)))
# None check because the m.room.create event has NULL prev_state_events
prev_state_event_ids_for_front_event = [
t[0] for t in txn if t[0] is not None and t[0] not in seen_event_ids
]
for next in (
prev_state_event_ids_for_front_event
): # Sort lexicographically for determinism
front_queue.append(next)
seen_event_ids |= set(prev_state_event_ids_for_front_event)
event_id_results.extend(prev_state_event_ids_for_front_event)
"""Walk the state DAG backward from `latest_event_ids`, stopping at
`limit` results or the beginning of the DAG, whichever comes first.
return event_id_results
Earliest events are treated as already-visited: they are not emitted,
and their predecessors are not traversed via them.
Results are deterministic and ordered by BFS, with lexicographic tie-breaking among
siblings at the same # hops away.
Executes as a single recursive CTE in both SQLite and Postgres.
"""
earliest_set = set(earliest_event_ids)
seed_ids = sorted(set(latest_event_ids) - earliest_set)
if not seed_ids or limit <= 0:
return []
seed_clause, seed_args = make_in_list_sql_clause(
self.database_engine, "e.event_id", seed_ids
)
# `make_in_list_sql_clause` doesn't handle empty iterables, so guard
# the no-earliest case explicitly with an always-true clause.
if earliest_event_ids:
earliest_clause, earliest_args = make_in_list_sql_clause(
self.database_engine,
"e.prev_state_event_id",
earliest_event_ids,
negative=True,
)
else:
# No-op clause that is TRUE for every row, in both dialects.
earliest_clause = "1=1"
earliest_args = []
query = f"""
WITH RECURSIVE walk(event_id, hops) AS (
SELECT
e.prev_state_event_id,
1
FROM msc4242_state_dag_edges e
WHERE e.room_id = ?
AND {seed_clause}
AND e.prev_state_event_id IS NOT NULL
AND {earliest_clause}
UNION
SELECT
e.prev_state_event_id,
w.hops + 1
FROM walk w
JOIN msc4242_state_dag_edges e
ON e.room_id = ?
AND e.event_id = w.event_id
WHERE e.prev_state_event_id IS NOT NULL
AND {earliest_clause}
AND w.hops < ?
)
SELECT event_id, MIN(hops) AS hops
FROM walk
GROUP BY event_id
ORDER BY hops, event_id
LIMIT ?
"""
params: list = [room_id]
params.extend(seed_args)
params.extend(earliest_args)
params.append(room_id)
params.extend(earliest_args)
params.append(limit)
params.append(limit)
txn.execute(query, params)
return [row[0] for row in txn]
@trace
@tag_args
+3 -1
View File
@@ -1542,7 +1542,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
latest=["E", "E", "C", "C", "C"], want=["B", "D", "T", "W"], limit=4
),
# include latest events in response. W included because reachable from E.
TestCase(latest=["W", "E"], want=["D", "T", "W", "R"], limit=4),
# sort order is based on # hops not processing order of parents
# (which would produce D,T,W,R as E is processed first, then W).
TestCase(latest=["W", "E"], want=["D", "R", "T", "W"], limit=4),
]
for test_case in test_cases:
got = self.get_success(