MSC4242: State DAGs (Federation)

Builds off https://github.com/element-hq/synapse/pull/19424

Adds federation compatibility for state DAG rooms. Overview:
 - Adds extra HTTP API fields as per the MSC.
 - Adds methods for walking and extracting the state DAG for a room (for `/get_missing_events` and `/send_join` respectively).
 - Adds impl for processing the federation requests, as well as `/send`.
This commit is contained in:
Kegan Dougal
2026-02-03 14:46:11 +00:00
parent dc3db60d36
commit 53ca01db28
16 changed files with 1684 additions and 280 deletions
+48 -14
View File
@@ -42,7 +42,12 @@ from typing import (
import attr
from prometheus_client import Counter
from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership
from synapse.api.constants import (
Direction,
EventContentFields,
EventTypes,
Membership,
)
from synapse.api.errors import (
CodeMessageException,
Codes,
@@ -119,6 +124,8 @@ class SendJoinResult:
origin: str
state: list[EventBase]
auth_chain: list[EventBase]
# Only valid for state DAG rooms (MSC4242)
state_dag: list[EventBase] | None
# True if 'state' elides non-critical membership events
partial_state: bool
@@ -658,7 +665,10 @@ class FederationClient(FederationBase):
@trace
@tag_args
async def get_room_state_ids(
self, destination: str, room_id: str, event_id: str
self,
destination: str,
room_id: str,
event_id: str,
) -> tuple[list[str], list[str]]:
"""Calls the /state_ids endpoint to fetch the state at a particular point
in the room, and the auth events for the given event
@@ -670,7 +680,9 @@ class FederationClient(FederationBase):
InvalidResponseError: if fields in the response have the wrong type.
"""
result = await self.transport_layer.get_room_state_ids(
destination, room_id, event_id=event_id
destination,
room_id,
event_id=event_id,
)
state_event_ids = result["pdu_ids"]
@@ -1178,7 +1190,6 @@ class FederationClient(FederationBase):
response = await self._do_send_join(
room_version, destination, pdu, omit_members=partial_state
)
# If an event was returned (and expected to be returned):
#
# * Ensure it has the same event ID (note that the event ID is a hash
@@ -1201,14 +1212,22 @@ class FederationClient(FederationBase):
event = pdu
state = response.state
auth_chain = response.auth_events
auth_events = response.auth_events
create_event = None
for e in state:
if (e.type, e.state_key) == (EventTypes.Create, ""):
create_event = e
break
if room_version.msc4242_state_dags and response.state_dag:
# assign to auth_events to reuse the below code which ultimately just does
# sig/hash checks. We'll set the right field in SendJoinResult later.
auth_events = response.state_dag
for e in response.state_dag:
if (e.type, e.state_key) == (EventTypes.Create, ""):
create_event = e
break
if create_event is None:
# If the state doesn't have a create event then the room is
# invalid, and it would fail auth checks anyway.
@@ -1227,7 +1246,7 @@ class FederationClient(FederationBase):
)
logger.info(
"Processing from send_join %d events", len(state) + len(auth_chain)
"Processing from send_join %d events", len(state) + len(auth_events)
)
# We now go and check the signatures and hashes for the event. Note
@@ -1246,7 +1265,7 @@ class FederationClient(FederationBase):
valid_pdus_map[valid_pdu.event_id] = valid_pdu
await concurrently_execute(
_execute, itertools.chain(state, auth_chain), 10000
_execute, itertools.chain(state, auth_events), 10000
)
# NB: We *need* to copy to ensure that we don't have multiple
@@ -1259,27 +1278,28 @@ class FederationClient(FederationBase):
signed_auth = [
valid_pdus_map[p.event_id]
for p in auth_chain
for p in auth_events
if p.event_id in valid_pdus_map
]
# NB: We *need* to copy to ensure that we don't have multiple
# references being passed on, as that causes... issues.
# TODO(kegan): It's unclear why we only need to do this for state and not auth_events
for s in signed_state:
s.internal_metadata = s.internal_metadata.copy()
# double-check that the auth chain doesn't include a different create event
auth_chain_create_events = [
# double-check that the auth events doesn't include a different create event
auth_events_create_events = [
e.event_id
for e in signed_auth
if (e.type, e.state_key) == (EventTypes.Create, "")
]
if auth_chain_create_events and auth_chain_create_events != [
if auth_events_create_events and auth_events_create_events != [
create_event.event_id
]:
raise InvalidResponseError(
"Unexpected create event(s) in auth chain: %s"
% (auth_chain_create_events,)
% (auth_events_create_events,)
)
servers_in_room = None
@@ -1301,10 +1321,21 @@ class FederationClient(FederationBase):
# Fix things up in case the remote homeserver is badly behaved.
servers_in_room.add(destination)
signed_auth_events = signed_auth
signed_state_dag = None
if room_version.msc4242_state_dags:
# We previously set the state dag to auth_events so re-assign it correctly
signed_state_dag = signed_auth
# Ensure the caller cannot accidentally use these values even if the server
# returned them.
signed_state = []
signed_auth_events = []
return SendJoinResult(
event=event,
state=signed_state,
auth_chain=signed_auth,
auth_chain=signed_auth_events,
state_dag=signed_state_dag,
origin=destination,
partial_state=response.members_omitted,
servers_in_room=servers_in_room or frozenset(),
@@ -1608,6 +1639,7 @@ class FederationClient(FederationBase):
limit: int,
min_depth: int,
timeout: int,
state_dag: bool = False,
) -> list[EventBase]:
"""Tries to fetch events we are missing. This is called when we receive
an event without having received all of its ancestors.
@@ -1623,6 +1655,7 @@ class FederationClient(FederationBase):
limit: Maximum number of events to return.
min_depth: Minimum depth of events to return.
timeout: Max time to wait in ms
state_dag: True to walk the state DAG (MSC4242 rooms)
"""
try:
content = await self.transport_layer.get_missing_events(
@@ -1633,6 +1666,7 @@ class FederationClient(FederationBase):
limit=limit,
min_depth=min_depth,
timeout=timeout,
state_dag=state_dag,
)
room_version = await self.store.get_room_version(room_id)
+58 -14
View File
@@ -28,6 +28,8 @@ from typing import (
Callable,
Collection,
Mapping,
Sequence,
cast,
)
from prometheus_client import Counter, Gauge, Histogram
@@ -53,7 +55,7 @@ from synapse.api.errors import (
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.crypto.event_signing import compute_event_signature
from synapse.events import EventBase
from synapse.events import EventBase, FrozenEventVMSC4242
from synapse.events.snapshot import EventPersistencePair
from synapse.federation.federation_base import (
FederationBase,
@@ -596,7 +598,10 @@ class FederationServer(FederationBase):
)
async def on_room_state_request(
self, origin: str, room_id: str, event_id: str
self,
origin: str,
room_id: str,
event_id: str,
) -> tuple[int, JsonDict]:
await self._event_auth_handler.assert_host_in_room(room_id, origin)
origin_host, _ = parse_server_name(origin)
@@ -620,7 +625,10 @@ class FederationServer(FederationBase):
@trace
@tag_args
async def on_state_ids_request(
self, origin: str, room_id: str, event_id: str
self,
origin: str,
room_id: str,
event_id: str,
) -> tuple[int, JsonDict]:
if not event_id:
raise NotImplementedError("Specify an event")
@@ -641,17 +649,27 @@ class FederationServer(FederationBase):
@trace
@tag_args
async def _on_state_ids_request_compute(
self, room_id: str, event_id: str
self,
room_id: str,
event_id: str,
) -> JsonDict:
state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
state_ids = await self.handler.get_state_ids_for_pdu(
room_id,
event_id,
)
auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids)
return {"pdu_ids": state_ids, "auth_chain_ids": list(auth_chain_ids)}
async def _on_context_state_request_compute(
self, room_id: str, event_id: str
self,
room_id: str,
event_id: str,
) -> dict[str, list]:
pdus: Collection[EventBase]
event_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
event_ids = await self.handler.get_state_ids_for_pdu(
room_id,
event_id,
)
pdus = await self.store.get_events_as_list(event_ids)
auth_chain = await self.store.get_auth_chain(
@@ -759,6 +777,10 @@ class FederationServer(FederationBase):
origin, content, Membership.JOIN, room_id
)
if event.room_version.msc4242_state_dags and caller_supports_partial_state:
# TODO(kegan): for now, MSC4242 won't support partial state for ease of prototyping.
caller_supports_partial_state = False
prev_state_ids = await context.get_prev_state_ids()
state_event_ids: Collection[str]
@@ -769,15 +791,31 @@ class FederationServer(FederationBase):
event, prev_state_ids, summary
)
servers_in_room = await self.state.get_hosts_in_room_at_events(
room_id, event_ids=event.prev_event_ids()
room_id,
event_ids=event.prev_state_events
if isinstance(event, FrozenEventVMSC4242)
else event.prev_event_ids(),
)
else:
state_event_ids = prev_state_ids.values()
servers_in_room = None
auth_chain_event_ids = await self.store.get_auth_chain_ids(
room_id, state_event_ids
)
state_dag = None
auth_chain_event_ids = set()
if event.room_version.msc4242_state_dags:
assert isinstance(event, FrozenEventVMSC4242)
# NOTE: we don't return the state dag for forward extremities that aren't part of this
# join event to make it easier for the receiving server to set their own forward
# extremities (they are equal to the join event's prev_state_events). This means we may
# fail to sync concurrent forks not on the path to the join event, but this is an
# outstanding problem in general.
state_dag = await self.store.get_state_dag(
room_id, set(event.prev_state_events)
)
else:
auth_chain_event_ids = await self.store.get_auth_chain_ids(
room_id, state_event_ids
)
# if the caller has opted in, we can omit any auth_chain events which are
# already in state_event_ids
@@ -794,13 +832,18 @@ class FederationServer(FederationBase):
resp = {
"event": event_json,
"state": serialize_and_filter_pdus(state_events, time_now),
"auth_chain": serialize_and_filter_pdus(auth_chain_events, time_now),
"members_omitted": caller_supports_partial_state,
}
if state_dag is None:
resp["auth_chain"] = serialize_and_filter_pdus(auth_chain_events, time_now)
else:
resp["state_dag"] = serialize_and_filter_pdus(
cast(Sequence[EventBase], state_dag.values()), time_now
)
del resp["state"]
if servers_in_room is not None:
resp["servers_in_room"] = list(servers_in_room)
return resp
async def on_make_leave_request(
@@ -1097,6 +1140,7 @@ class FederationServer(FederationBase):
earliest_events: list[str],
latest_events: list[str],
limit: int,
walk_state_dag: bool = False,
) -> dict[str, list]:
async with self._server_linearizer.queue((origin, room_id)):
origin_host, _ = parse_server_name(origin)
@@ -1111,7 +1155,7 @@ class FederationServer(FederationBase):
)
missing_events = await self.handler.on_get_missing_events(
origin, room_id, earliest_events, latest_events, limit
origin, room_id, earliest_events, latest_events, limit, walk_state_dag
)
if len(missing_events) < 5:
+5 -2
View File
@@ -148,7 +148,7 @@ from twisted.internet import defer
import synapse.metrics
from synapse.api.constants import EventTypes, Membership
from synapse.api.presence import UserPresenceState
from synapse.events import EventBase
from synapse.events import EventBase, FrozenEventVMSC4242
from synapse.federation.sender.per_destination_queue import (
CATCHUP_RETRY_INTERVAL,
PerDestinationQueue,
@@ -660,7 +660,10 @@ class FederationSender(AbstractFederationSender):
# banned then it won't receive the event because it won't
# be in the room after the ban.
destinations = await self.state.get_hosts_in_room_at_events(
event.room_id, event_ids=event.prev_event_ids()
event.room_id,
event_ids=event.prev_state_events
if isinstance(event, FrozenEventVMSC4242)
else event.prev_event_ids(),
)
except Exception:
logger.exception(
+36 -14
View File
@@ -69,7 +69,10 @@ class TransportLayerClient:
self.client.shutdown()
async def get_room_state_ids(
self, destination: str, room_id: str, event_id: str
self,
destination: str,
room_id: str,
event_id: str,
) -> JsonDict:
"""Requests the IDs of all state for a given room at the given event.
@@ -78,22 +81,26 @@ class TransportLayerClient:
to get the state from.
room_id: the room we want the state of
event_id: The event we want the context at.
Returns:
Results in a dict received from the remote homeserver.
"""
logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id)
path = _create_v1_path("/state_ids/%s", room_id)
qps = {"event_id": event_id}
return await self.client.get_json(
destination,
path=path,
args={"event_id": event_id},
args=qps,
try_trailing_slash_on_400=True,
)
async def get_room_state(
self, room_version: RoomVersion, destination: str, room_id: str, event_id: str
self,
room_version: RoomVersion,
destination: str,
room_id: str,
event_id: str,
) -> "StateRequestResponse":
"""Requests the full state for a given room at the given event.
@@ -103,15 +110,15 @@ class TransportLayerClient:
to get the state from.
room_id: the room we want the state of
event_id: The event we want the context at.
Returns:
Results in a dict received from the remote homeserver.
"""
path = _create_v1_path("/state/%s", room_id)
qps = {"event_id": event_id}
return await self.client.get_json(
destination,
path=path,
args={"event_id": event_id},
args=qps,
# This can take a looooooong time for large rooms. Give this a generous
# timeout of 10 minutes to avoid the partial state resync timing out early
# and trying a bunch of servers who haven't seen our join yet.
@@ -787,18 +794,21 @@ class TransportLayerClient:
limit: int,
min_depth: int,
timeout: int,
state_dag: bool,
) -> JsonDict:
path = _create_v1_path("/get_missing_events/%s", room_id)
data = {
"limit": int(limit),
"min_depth": int(min_depth),
"earliest_events": earliest_events,
"latest_events": latest_events,
}
if state_dag:
data["org.matrix.msc4242.state_dag"] = True
return await self.client.post_json(
destination=destination,
path=path,
data={
"limit": int(limit),
"min_depth": int(min_depth),
"earliest_events": earliest_events,
"latest_events": latest_events,
},
data=data,
timeout=timeout,
)
@@ -993,6 +1003,9 @@ class SendJoinResponse:
state: list[EventBase]
# The raw join event from the /send_join response.
event_dict: JsonDict
# MSC4242: State DAGs. Always included for state dag rooms, else an empty list.
# Replaces auth_events.
state_dag: list[EventBase]
# The parsed join event from the /send_join response. This will be None if
# "event" is not included in the response.
event: EventBase | None = None
@@ -1079,7 +1092,7 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
MAX_RESPONSE_SIZE = 500 * 1024 * 1024
def __init__(self, room_version: RoomVersion, v1_api: bool):
self._response = SendJoinResponse([], [], event_dict={})
self._response = SendJoinResponse([], [], event_dict={}, state_dag=[])
self._room_version = room_version
self._coros: list[Generator[None, bytes, None]] = []
@@ -1123,6 +1136,15 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
)
)
if room_version.msc4242_state_dags:
self._coros.append(
ijson.items_coro(
_event_list_parser(room_version, self._response.state_dag),
prefix + "state_dag.item",
use_float=True,
)
)
def write(self, data: bytes) -> int:
for c in self._coros:
c.send(data)
@@ -621,6 +621,7 @@ class FederationGetMissingEventsServlet(BaseFederationServerServlet):
limit = int(content.get("limit", 10))
earliest_events = content.get("earliest_events", [])
latest_events = content.get("latest_events", [])
walk_state_dag = content.get("org.matrix.msc4242.state_dag", False)
result = await self.handler.on_get_missing_events(
origin,
@@ -628,6 +629,7 @@ class FederationGetMissingEventsServlet(BaseFederationServerServlet):
earliest_events=earliest_events,
latest_events=latest_events,
limit=limit,
walk_state_dag=walk_state_dag,
)
return 200, result
+49 -1
View File
@@ -669,6 +669,13 @@ class FederationHandler:
logger.debug("do_invite_join auth_chain: %s", auth_chain)
logger.debug("do_invite_join state: %s", state)
# If the 'state_dag' field is set, everything will be derived from it.
if ret.state_dag:
logger.debug("do_invite_join state_dag: %s", ret.state_dag)
state = ret.state_dag
# TODO(kegan): is this actually important for processing? Shouldn't we be using
# the actual DAG here?
state.sort(key=lambda e: e.depth)
logger.debug("do_invite_join event: %s", event)
@@ -965,6 +972,7 @@ class FederationHandler:
# Note that this requires the /send_join request to come back to the
# same server.
prev_event_ids = None
prev_state_events = None
if room_version.restricted_join_rule:
# Note that the room's state can change out from under us and render our
# nice join rules-conformant event non-conformant by the time we build the
@@ -981,6 +989,9 @@ class FederationHandler:
state_ids = await self._state_storage_controller.get_current_state_ids(
room_id
)
prev_state_events = list(
await self.store.get_state_dag_extremities(room_id)
)
if await self._event_auth_handler.has_restricted_join_rules(
state_ids, room_version
):
@@ -1021,6 +1032,7 @@ class FederationHandler:
) = await self.event_creation_handler.create_new_client_event(
builder=builder,
prev_event_ids=prev_event_ids,
prev_state_events=prev_state_events,
)
except SynapseError as e:
logger.warning("Failed to create join to %s because %s", room_id, e)
@@ -1301,7 +1313,11 @@ class FederationHandler:
@trace
@tag_args
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> list[str]:
async def get_state_ids_for_pdu(
self,
room_id: str,
event_id: str,
) -> list[str]:
"""Returns the state at the event. i.e. not including said event."""
event = await self.store.get_event(event_id, check_room_id=room_id)
if event.internal_metadata.outlier:
@@ -1412,11 +1428,17 @@ class FederationHandler:
earliest_events: list[str],
latest_events: list[str],
limit: int,
walk_state_dag: bool,
) -> list[EventBase]:
# We allow partially joined rooms since in this case we are filtering out
# non-local events in `filter_events_for_server`.
await self._event_auth_handler.assert_host_in_room(room_id, origin, True)
if walk_state_dag:
return await self.on_get_missing_events_state_dag(
room_id, earliest_events, latest_events, limit
)
# Only allow up to 20 events to be retrieved per request.
limit = min(limit, 20)
@@ -1439,6 +1461,32 @@ class FederationHandler:
return missing_events
async def on_get_missing_events_state_dag(
self,
room_id: str,
earliest_events: list[str],
latest_events: list[str],
limit: int,
) -> list[EventBase]:
"""Processes a /get_missing_events request for the state DAG.
This is similar to processing the normal DAG with a few notable exceptions:
* There is no max 20 limit applied. As the entire state DAG needs to be filled in, we
cannot arbitrarily set a low limit. If the state DAG delta is 1000s of events, we need
to rely on the sender to set sensible limits depending on the bandwidth/round trip
tradeoff. We rely on existing server rate limits here to prevent abuse.
* We do not filter any events in the state DAG. History visibility does not filter out
delivery of auth chain events, so neither should this. All of the returned events will
be treated as outliers and as such will not be delivered to clients.
"""
missing_events = await self.store.get_missing_events_state_dag(
room_id=room_id,
earliest_events=earliest_events,
latest_events=latest_events,
limit=limit,
)
return missing_events
async def exchange_third_party_invite(
self, sender_user_id: str, target_user_id: str, room_id: str, signed: JsonDict
) -> None:
File diff suppressed because it is too large Load Diff
-1
View File
@@ -1713,7 +1713,6 @@ class DatabasePool:
", ".join(key_names),
latter,
)
txn.execute_values(sql, args, fetch=False)
else:
@@ -37,7 +37,7 @@ from prometheus_client import Counter, Gauge
from synapse.api.constants import MAX_DEPTH
from synapse.api.errors import StoreError
from synapse.api.room_versions import EventFormatVersions, RoomVersion
from synapse.events import EventBase, make_event_from_dict
from synapse.events import EventBase, FrozenEventVMSC4242, make_event_from_dict
from synapse.logging.opentracing import tag_args, trace
from synapse.metrics import SERVER_NAME_LABEL
from synapse.metrics.background_process_metrics import wrap_as_background_process
@@ -258,6 +258,59 @@ class EventFederationWorkerStore(
include_given,
)
async def get_state_dag(
self, room_id: str, forward_extrems: set[str]
) -> dict[str, FrozenEventVMSC4242]:
"""Get the current state DAG for the given room.
This function is called when calculating a /send_join response.
This does not check that the room is an state DAG room, so check this
before calling this function!
This functions guarantees that the returned state DAG is connected.
Args:
room_id: The room to get the state dag for
Returns:
A map of event_id => event
"""
def _get_state_events_txn(txn: LoggingTransaction, room_id: str) -> list[str]:
sql = """
SELECT event_id FROM msc4242_state_dag_edges WHERE room_id = ?
"""
txn.execute(sql, (room_id,))
event_ids = [ev_id for (ev_id,) in txn]
return event_ids
# Pull out all auth events for this room using the events_by_room_and_type index.
# This is going to pull in erroneous events e.g m.room.members with no state keys but that's
# okay as we'll filter them out next.
event_ids = await self.db_pool.runInteraction(
"_get_state_events_txn",
_get_state_events_txn,
room_id,
)
event_map = await self.get_events(event_ids)
# Filter the returned state events to only include ones on the paths back from the forward
# extremities.
result: dict[str, FrozenEventVMSC4242] = {}
next = forward_extrems
seen: set[str] = set()
while len(next) > 0:
# Pull the event and add the prev_state_events.
# We must have the event.
event_id = next.pop()
if event_id in seen:
continue
seen.add(event_id)
ev = event_map[event_id]
assert isinstance(ev, FrozenEventVMSC4242)
result[event_id] = ev
for pae in ev.prev_state_events:
next.add(pae)
return result
def _get_auth_chain_ids_using_cover_index_txn(
self,
txn: LoggingTransaction,
@@ -1987,6 +2040,57 @@ class EventFederationWorkerStore(
event_results.reverse()
return event_results
async def get_missing_events_state_dag(
self,
room_id: str,
earliest_events: list[str],
latest_events: list[str],
limit: int,
) -> list[EventBase]:
ids = await self.db_pool.runInteraction(
"get_missing_events_state_dag",
self._get_missing_events_state_dag,
room_id,
earliest_events,
latest_events,
limit,
)
return await self.get_events_as_list(ids)
def _get_missing_events_state_dag(
self,
txn: LoggingTransaction,
room_id: str,
earliest_events: list[str],
latest_events: list[str],
limit: int,
) -> list[str]:
seen_events = set(earliest_events)
# ascii sort
front_queue = sorted(set(latest_events) - seen_events)
event_results: list[str] = []
# TODO(kegan): use a recursive CTE?
query = (
"SELECT prev_state_event_id FROM msc4242_state_dag_edges "
"WHERE room_id = ? AND event_id = ? "
"ORDER BY prev_state_event_id ASC "
"LIMIT ?"
)
while front_queue and len(event_results) < limit:
event_id = front_queue.pop(0)
txn.execute(query, (room_id, event_id, limit - len(event_results)))
# None check because the m.room.create event has NULL prev_state_events
new_results = [
t[0] for t in txn if t[0] is not None and t[0] not in seen_events
]
for next in new_results:
front_queue.append(next)
seen_events |= set(new_results)
event_results.extend(new_results)
return event_results
@trace
@tag_args
async def get_successor_events(self, event_id: str) -> list[str]:
@@ -395,6 +395,7 @@ class OutOfBandMembershipTests(unittest.FederatingHomeserverTestCase):
user1_invite_membership_event,
],
event_dict=user1_join_membership_event_signed.get_pdu_json(),
state_dag=[],
event=user1_join_membership_event_signed,
members_omitted=False,
servers_in_room=[
+47 -35
View File
@@ -399,36 +399,54 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
# we should get complete room state back
returned_state = [
(ev["type"], ev["state_key"]) for ev in channel.json_body["state"]
]
self.assertCountEqual(
returned_state,
[
("m.room.create", ""),
("m.room.power_levels", ""),
("m.room.join_rules", ""),
("m.room.history_visibility", ""),
("m.room.member", f"@kermit_v{room_version}:test"),
("m.room.member", f"@fozzie_v{room_version}:test"),
# nb: *not* the joining user
],
)
if KNOWN_ROOM_VERSIONS[room_version].msc4242_state_dags:
# we should get complete state dag back
returned_state_dag = [
(ev["type"], ev["state_key"]) for ev in channel.json_body["state_dag"]
]
self.assertCountEqual(
returned_state_dag,
[
("m.room.create", ""),
("m.room.power_levels", ""),
("m.room.join_rules", ""),
("m.room.history_visibility", ""),
("m.room.member", f"@kermit_v{room_version}:test"),
("m.room.member", f"@fozzie_v{room_version}:test"),
# nb: *not* the joining user
],
)
else:
# we should get complete room state back
returned_state = [
(ev["type"], ev["state_key"]) for ev in channel.json_body["state"]
]
self.assertCountEqual(
returned_state,
[
("m.room.create", ""),
("m.room.power_levels", ""),
("m.room.join_rules", ""),
("m.room.history_visibility", ""),
("m.room.member", f"@kermit_v{room_version}:test"),
("m.room.member", f"@fozzie_v{room_version}:test"),
# nb: *not* the joining user
],
)
# also check the auth chain
returned_auth_chain_events = [
(ev["type"], ev["state_key"]) for ev in channel.json_body["auth_chain"]
]
self.assertCountEqual(
returned_auth_chain_events,
[
("m.room.create", ""),
("m.room.member", f"@kermit_v{room_version}:test"),
("m.room.power_levels", ""),
("m.room.join_rules", ""),
],
)
# also check the auth chain
returned_auth_chain_events = [
(ev["type"], ev["state_key"]) for ev in channel.json_body["auth_chain"]
]
self.assertCountEqual(
returned_auth_chain_events,
[
("m.room.create", ""),
("m.room.member", f"@kermit_v{room_version}:test"),
("m.room.power_levels", ""),
("m.room.join_rules", ""),
],
)
# the room should show that the new user is a member
r = self.get_success(self._storage_controllers.state.get_current_state(room_id))
@@ -438,18 +456,12 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
@override_config({"use_frozen_dicts": True})
def test_send_join_with_frozen_dicts(self, room_version: str) -> None:
"""Test send_join with USE_FROZEN_DICTS=True"""
if room_version == RoomVersions.MSC4242v12.identifier:
# TODO: This room version doesn't work over federation in this PR.
return
self._test_send_join_common(room_version)
@parameterized.expand([(k,) for k in KNOWN_ROOM_VERSIONS.keys()])
@override_config({"use_frozen_dicts": False})
def test_send_join_without_frozen_dicts(self, room_version: str) -> None:
"""Test send_join with USE_FROZEN_DICTS=False"""
if room_version == RoomVersions.MSC4242v12.identifier:
# TODO: This room version doesn't work over federation in this PR.
return
self._test_send_join_common(room_version)
def test_send_join_partial_state(self) -> None:
+1
View File
@@ -652,6 +652,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
],
partial_state=True,
servers_in_room={"example.com"},
state_dag=None,
)
)
+301 -2
View File
@@ -18,19 +18,23 @@
# [This file includes modifications made by New Vector Limited]
#
#
from typing import Iterable
from unittest import mock
from twisted.internet.testing import MemoryReactor
from synapse.api.errors import AuthError, StoreError
from synapse.api.room_versions import RoomVersion
from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.event_auth import (
check_state_dependent_auth_rules,
check_state_independent_auth_rules,
)
from synapse.events import make_event_from_dict
from synapse.events import EventBase, FrozenEventVMSC4242, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.federation.transport.client import StateRequestResponse
from synapse.handlers.federation_event import (
is_state_dag_connected,
)
from synapse.logging.context import LoggingContext
from synapse.rest import admin
from synapse.rest.client import login, room
@@ -1184,3 +1188,298 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
bert_member_event.event_id,
"Rejected kick event unexpectedly became part of room state.",
)
MSC4242_ROOM_ID = "!msc4242:example.com"
counter = 1
class FederationEventMSC4242AuthDAGTests(unittest.FederatingHomeserverTestCase):
def test_is_state_dag_connected(self) -> None:
linear_1 = msc4242_event([])
linear_2 = msc4242_event([linear_1.event_id])
linear_3 = msc4242_event([linear_2.event_id])
self.assertEqual(is_state_dag_connected([linear_3, linear_1, linear_2]), True)
fork_1 = msc4242_event([])
fork_2 = msc4242_event([fork_1.event_id])
fork_3 = msc4242_event([fork_1.event_id])
fork_4 = msc4242_event([fork_1.event_id])
fork_5 = msc4242_event([fork_2.event_id, fork_4.event_id])
self.assertEqual(
is_state_dag_connected([fork_1, fork_2, fork_3, fork_4, fork_5]), True
)
unconnected_1 = msc4242_event([])
unconnected_2 = msc4242_event([unconnected_1.event_id])
unconnected_3 = msc4242_event(["$unknown"])
self.assertEqual(
is_state_dag_connected([unconnected_1, unconnected_2, unconnected_3]), False
)
def test_fetch_missing_state_dag_events_linear(self) -> None:
linear = self.make_state_dag(
{
"A": [],
"B": ["A"],
"C": ["B"],
}
)
seen_events: set[str] = set()
get_missing_events_req_resps = {
("C",): ["A", "B"],
}
self.prepare_handler(seen_events, linear, get_missing_events_req_resps)
self.assert_fetch_missing_state_dag_events(linear, "C", ["A", "B"])
def test_fetch_missing_state_dag_events_linear_seen(self) -> None:
linear = self.make_state_dag(
{
"A": [],
"B": ["A"],
"C": ["B"],
"D": ["C"],
}
)
seen_events: set[str] = {"A", "B"}
get_missing_events_req_resps = {
("D",): ["C"],
("C",): ["B"],
}
self.prepare_handler(seen_events, linear, get_missing_events_req_resps)
self.assert_fetch_missing_state_dag_events(linear, "D", ["C"])
def test_fetch_missing_state_dag_events_fork_merge(self) -> None:
fork_merge = self.make_state_dag(
{
"A": [],
"B": ["A"],
"C": ["A"],
"D": ["C"],
"E": ["B"],
"F": ["D", "E"],
}
)
seen_events: set[str] = set()
get_missing_events_req_resps = {
("F",): ["D", "E"],
(
"D",
"E",
): ["C", "A"],
("E",): ["B", "A"],
}
self.prepare_handler(seen_events, fork_merge, get_missing_events_req_resps)
self.assert_fetch_missing_state_dag_events(
fork_merge, "F", ["A", "B", "C", "D", "E"]
)
def test_fetch_missing_state_dag_events_give_up_no_forward_progress(self) -> None:
fork_merge = self.make_state_dag(
{
"A": [],
"B": ["A"],
"C": ["A"],
"D": ["C", "B"],
}
)
seen_events: set[str] = set()
get_missing_events_req_resps = {
("D",): ["C"], # never provide B
}
self.prepare_handler(seen_events, fork_merge, get_missing_events_req_resps)
self.assert_fetch_missing_state_dag_events(fork_merge, "D", [])
def test_fetch_missing_state_dag_events_seen(self) -> None:
fork_merge = self.make_state_dag(
{
"A": [],
"B": ["A"],
"C": ["A"],
"D": ["C"],
"E": ["B"],
"F": ["D", "E"],
}
)
seen_events: set[str] = {"A", "B"}
get_missing_events_req_resps = {
("F",): ["D", "E"],
("D",): ["C", "A"], # NB: not D,E as we've seen E's prev_state_events => B.
("E",): ["B", "A"],
}
self.prepare_handler(seen_events, fork_merge, get_missing_events_req_resps)
self.assert_fetch_missing_state_dag_events(fork_merge, "F", ["C", "D", "E"])
# test that we maintain a visited set so we don't needlessly make expensive /get_missing_events
# calls in cases like:
# A <- B <- C <- D <------------------------ I
# `----- E <-- F <-- G <-- H <--`
# In this scenario, there's two paths to A: via D and via EFGH. Both paths converge at C and then
# go to A. The first path to reach A should be remembered so we don't make additional requests.
def test_fetch_missing_state_dag_events_memoise(self) -> None:
memoise_concurrent = self.make_state_dag(
{
"A": [],
"B": ["A"],
"C": ["B"],
"D": ["C"],
"E": ["C"],
"F": ["E"],
"G": ["F"],
"H": ["G"],
"I": ["D", "H"],
}
)
get_missing_events_req_resps = {
("I",): ["D", "H"],
(
"D",
"H",
): ["C", "G"],
("C", "G"): ["B", "F"],
("B", "F"): ["A", "E"],
# we should never request C on its own as we should remember we have visited C already.
}
self.prepare_handler(set(), memoise_concurrent, get_missing_events_req_resps)
self.assert_fetch_missing_state_dag_events(
memoise_concurrent,
"I",
[
"A",
"B",
"C",
"D",
"E",
"F",
"G",
"H",
],
)
def test_fetch_missing_state_dag_seen_all(self) -> None:
seen_all = self.make_state_dag(
{
"A": [],
"B": ["A"],
"C": ["A"],
"D": ["C"],
}
)
seen_events: set[str] = {"A", "B", "C", "D"}
get_missing_events_req_resps: dict[tuple, list[str]] = {}
self.prepare_handler(seen_events, seen_all, get_missing_events_req_resps)
self.assert_fetch_missing_state_dag_events(seen_all, "D", [])
def assert_fetch_missing_state_dag_events(
self,
graph: dict[str, FrozenEventVMSC4242],
start_event_id: str,
want_event_ids: list[str],
) -> None:
"""Run _fetch_missing_state_dag_events with the param start_event_id on the graph provided
and ensure the result matches want_event_ids."""
got = self.get_success(
self.hs.get_federation_event_handler()._fetch_missing_state_dag_events(
"unknown", graph[start_event_id]
)
)
self.assertEqual(
{ev.event_id for ev in got}, {graph[x].event_id for x in want_event_ids}
)
def prepare_handler(
self,
seen_fake_events: set[str],
graph: dict[str, FrozenEventVMSC4242],
gme_req_resps: dict[tuple, list[str]],
) -> None:
"""Setup mocks on federation event handler to return the right data at the right time.
Args:
seen_fake_events: The events seen by this homeserver already (in the database)
graph: The prepared state DAG graph mapping from fake event ID to real event.
gme_req_resps: The /get_missing_events responses to return. The keys are the events
provided as 'latest_events' and the values will be the events returned."""
h = self.hs.get_federation_event_handler()
h._federation_client = mock.Mock(
spec=[
"get_missing_events",
]
)
for x in graph:
print(f"{x} => {graph[x].event_id}")
async def get_missing_events(
destination: str,
room_id: str,
earliest_events_ids: Iterable[str],
latest_events: Iterable[EventBase],
limit: int,
min_depth: int,
timeout: int,
state_dag: bool = False,
) -> list[EventBase]:
assert state_dag
assert room_id == MSC4242_ROOM_ID
target_key = [ev.event_id for ev in latest_events]
target_key.sort()
assert target_key
# test which response to return
for req, resp in gme_req_resps.items():
key = [graph[x].event_id for x in req]
key.sort()
if key == target_key:
return [graph[x] for x in resp]
raise AssertionError(
f"get_missing_events with latest={target_key} but no matches found (tested {len(gme_req_resps)})"
)
h._federation_client.get_missing_events.side_effect = get_missing_events
seen_events = {
graph[fake_event_id].event_id for fake_event_id in seen_fake_events
}
async def have_seen_events(room_id: str, event_ids: Iterable[str]) -> set[str]:
return seen_events.intersection(set(event_ids))
main_store = self.hs.get_datastores().main
main_store.have_seen_events = mock.AsyncMock() # type: ignore[method-assign]
main_store.have_seen_events.side_effect = have_seen_events
def make_state_dag(
self, graph: dict[str, list[str]]
) -> dict[str, FrozenEventVMSC4242]:
"""Create a state dag from a graph of event_id -> prev_state_events.
Returns a map of fake ID to real event.
The map must be sorted topologically else this raises an exception."""
fake_to_real_id: dict[str, str] = {}
result: dict[str, FrozenEventVMSC4242] = {}
for fake_id, fake_prev_state_events in graph.items():
# lookup fails if graph not sorted topologically
paes = [fake_to_real_id[fpae] for fpae in fake_prev_state_events]
ev = msc4242_event(paes)
fake_to_real_id[fake_id] = ev.event_id
result[fake_id] = ev
return result
def msc4242_event(prev_state_events: list[str]) -> FrozenEventVMSC4242:
global counter
counter += 1
ev = make_event_from_dict(
{
"type": "m.room.member",
"state_key": "@alice:example.com",
"content": {
"membership": "join",
},
"sender": "@alice:example.com",
"origin_server_ts": counter, # ensure hashed event changes
"room_id": MSC4242_ROOM_ID,
"prev_state_events": prev_state_events,
"prev_events": [], # we shouldn't look at this field
},
RoomVersions.MSC4242v12,
)
assert isinstance(ev, FrozenEventVMSC4242)
return ev
+1
View File
@@ -172,6 +172,7 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
auth_chain=[create_event],
partial_state=False,
servers_in_room=frozenset(),
state_dag=None,
)
)
+187 -1
View File
@@ -19,6 +19,7 @@
#
import datetime
from collections import namedtuple
from typing import (
Collection,
Iterable,
@@ -38,8 +39,9 @@ from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
RoomVersion,
RoomVersions,
)
from synapse.events import EventBase
from synapse.events import EventBase, FrozenEventVMSC4242
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
@@ -1414,6 +1416,174 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# elapsed past the backoff range so there is no events to backoff from.
self.assertEqual(event_ids_with_backoff, {})
def test_get_state_dag(self) -> None:
"""
Test that MSC4242 state dag rooms can return the complete state dag on request.
"""
# Create the room
user_id = self.register_user("alice", "test")
tok = self.login("alice", "test")
room_id = self.helper.create_room_as(
room_creator=user_id,
tok=tok,
room_version=RoomVersions.MSC4242v12.identifier,
)
resp = self.helper.send_state(
room_id,
"m.room.join_rules",
{"join_rule": "knock"},
tok=tok,
)
latest = resp["event_id"]
state_dag = self.get_success(
self.store.get_state_dag(room_id, {latest}),
)
# create <- member <- pl <- join_rules <- his vis <- join_rules
self.assertEquals(len(state_dag), 6)
want_types = [
EventTypes.Create,
EventTypes.Member,
EventTypes.PowerLevels,
EventTypes.JoinRules,
EventTypes.RoomHistoryVisibility,
EventTypes.JoinRules,
]
curr = {latest}
while len(curr) > 0:
event_id = curr.pop()
ev = state_dag[event_id]
want_type = want_types.pop()
self.assertEqual(ev.type, want_type)
curr.update(ev.prev_state_events)
def test_get_missing_events_state_dag(self) -> None:
# Primarily testing to make sure that we sort events
# correctly when there are multiple prev_state_events
# .- C -- D ---.
# A <- B E
# `- R -- W --`
# `-- T -`
graph = {
"A": [],
"B": ["A"],
"C": ["B"],
"R": ["B"],
"D": ["C"],
"W": ["R"],
"T": ["R"],
"E": ["W", "D", "T"],
}
room_id = "@state_dag:local"
events = [
cast(
FrozenEventVMSC4242,
StateDAGFakeEvent(
event_id,
room_id,
EventTypes.Create if len(graph[event_id]) == 0 else "foo",
graph[event_id],
),
)
for event_id in graph
]
def insert(txn: LoggingTransaction) -> None:
# store these to satisfy fk constraints
self.persist_events.db_pool.simple_insert_many_txn(
txn,
table="events",
keys=(
"instance_name",
"stream_ordering",
"topological_ordering",
"depth",
"event_id",
"room_id",
"type",
"processed",
"outlier",
"origin_server_ts",
"received_ts",
"sender",
"contains_url",
"state_key",
"rejection_reason",
),
values=[
(
"test",
event.internal_metadata.stream_ordering,
event.depth, # topological_ordering
event.depth, # depth
event.event_id,
event.room_id,
event.type,
True, # processed
False, # outlier
1337,
1741622420,
event.sender,
False, # contains url
event.state_key,
False, # rejected
)
for event in events
],
)
for ev in events:
self.persist_events._store_state_dag_edges(
txn,
ev,
)
# satisfy fk constraints
self.get_success(
self.store.store_room(room_id, "foo", False, RoomVersions.MSC4242v12)
)
self.get_success(
self.store.db_pool.runInteraction(
"_store_state_dag_edges",
insert,
)
)
# .- C -- D ---.
# A <- B E
# `- R -- W --`
# `-- T -`
TestCase = namedtuple("TestCase", "latest want limit")
test_cases = [
TestCase(latest=["E"], want=["D", "T", "W"], limit=3),
TestCase(latest=["W", "T", "D"], want=["C", "R"], limit=2),
# breadth first and new entries are added to the end, sorted lexicographically
TestCase(latest=["E"], want=["D", "T", "W", "C", "R", "B", "A"], limit=100),
# we should sort the latest values initially
TestCase(latest=["E", "C"], want=["B", "D", "T", "W"], limit=4),
TestCase(latest=["C", "E"], want=["B", "D", "T", "W"], limit=4),
# dupes are ignored
TestCase(
latest=["E", "E", "C", "C", "C"], want=["B", "D", "T", "W"], limit=4
),
# include latest events in response
TestCase(latest=["W", "E"], want=["D", "T", "W", "R"], limit=4),
]
for test_case in test_cases:
def do_test(txn: LoggingTransaction) -> None:
got = self.store._get_missing_events_state_dag(
txn,
room_id,
[],
test_case.latest,
test_case.limit,
)
self.assertEquals(got, test_case.want)
self.get_success(
self.store.db_pool.runInteraction("test_case", do_test),
)
@attr.s(auto_attribs=True)
class FakeEvent:
@@ -1431,3 +1601,19 @@ class FakeEvent:
def is_state(self) -> bool:
return True
@attr.s(auto_attribs=True)
class StateDAGFakeEvent:
event_id: str
room_id: str
type: str
prev_state_events: list[str]
state_key = "foo"
depth = 1
sender = "foo"
internal_metadata = EventInternalMetadata({})
def is_state(self) -> bool:
return True
+1
View File
@@ -1455,6 +1455,7 @@ class GetCurrentStateDeltaMembershipChangesForUserFederationTestCase(
auth_chain=[create_event, creator_join_event],
partial_state=False,
servers_in_room=frozenset(),
state_dag=None,
)
)