From 53ca01db28a61f89a9e6a4e38df3fa256812464a Mon Sep 17 00:00:00 2001 From: Kegan Dougal <7190048+kegsay@users.noreply.github.com> Date: Tue, 3 Feb 2026 14:46:11 +0000 Subject: [PATCH] MSC4242: State DAGs (Federation) Builds off https://github.com/element-hq/synapse/pull/19424 Adds federation compatibility for state DAG rooms. Overview: - Adds extra HTTP API fields as per the MSC. - Adds methods for walking and extracting the state DAG for a room (for `/get_missing_events` and `/send_join` respectively). - Adds impl for processing the federation requests, as well as `/send`. --- synapse/federation/federation_client.py | 62 +- synapse/federation/federation_server.py | 72 +- synapse/federation/sender/__init__.py | 7 +- synapse/federation/transport/client.py | 50 +- .../federation/transport/server/federation.py | 2 + synapse/handlers/federation.py | 50 +- synapse/handlers/federation_event.py | 1037 +++++++++++++---- synapse/storage/database.py | 1 - .../databases/main/event_federation.py | 106 +- .../test_federation_out_of_band_membership.py | 1 + tests/federation/test_federation_server.py | 82 +- tests/handlers/test_federation.py | 1 + tests/handlers/test_federation_event.py | 303 ++++- tests/handlers/test_room_member.py | 1 + tests/storage/test_event_federation.py | 188 ++- tests/storage/test_stream.py | 1 + 16 files changed, 1684 insertions(+), 280 deletions(-) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index ba738ad65e..bc77b0bac6 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -42,7 +42,12 @@ from typing import ( import attr from prometheus_client import Counter -from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership +from synapse.api.constants import ( + Direction, + EventContentFields, + EventTypes, + Membership, +) from synapse.api.errors import ( CodeMessageException, Codes, @@ -119,6 +124,8 @@ class SendJoinResult: origin: str state: list[EventBase] auth_chain: list[EventBase] + # Only valid for state DAG rooms (MSC4242) + state_dag: list[EventBase] | None # True if 'state' elides non-critical membership events partial_state: bool @@ -658,7 +665,10 @@ class FederationClient(FederationBase): @trace @tag_args async def get_room_state_ids( - self, destination: str, room_id: str, event_id: str + self, + destination: str, + room_id: str, + event_id: str, ) -> tuple[list[str], list[str]]: """Calls the /state_ids endpoint to fetch the state at a particular point in the room, and the auth events for the given event @@ -670,7 +680,9 @@ class FederationClient(FederationBase): InvalidResponseError: if fields in the response have the wrong type. """ result = await self.transport_layer.get_room_state_ids( - destination, room_id, event_id=event_id + destination, + room_id, + event_id=event_id, ) state_event_ids = result["pdu_ids"] @@ -1178,7 +1190,6 @@ class FederationClient(FederationBase): response = await self._do_send_join( room_version, destination, pdu, omit_members=partial_state ) - # If an event was returned (and expected to be returned): # # * Ensure it has the same event ID (note that the event ID is a hash @@ -1201,14 +1212,22 @@ class FederationClient(FederationBase): event = pdu state = response.state - auth_chain = response.auth_events - + auth_events = response.auth_events create_event = None for e in state: if (e.type, e.state_key) == (EventTypes.Create, ""): create_event = e break + if room_version.msc4242_state_dags and response.state_dag: + # assign to auth_events to reuse the below code which ultimately just does + # sig/hash checks. We'll set the right field in SendJoinResult later. + auth_events = response.state_dag + for e in response.state_dag: + if (e.type, e.state_key) == (EventTypes.Create, ""): + create_event = e + break + if create_event is None: # If the state doesn't have a create event then the room is # invalid, and it would fail auth checks anyway. @@ -1227,7 +1246,7 @@ class FederationClient(FederationBase): ) logger.info( - "Processing from send_join %d events", len(state) + len(auth_chain) + "Processing from send_join %d events", len(state) + len(auth_events) ) # We now go and check the signatures and hashes for the event. Note @@ -1246,7 +1265,7 @@ class FederationClient(FederationBase): valid_pdus_map[valid_pdu.event_id] = valid_pdu await concurrently_execute( - _execute, itertools.chain(state, auth_chain), 10000 + _execute, itertools.chain(state, auth_events), 10000 ) # NB: We *need* to copy to ensure that we don't have multiple @@ -1259,27 +1278,28 @@ class FederationClient(FederationBase): signed_auth = [ valid_pdus_map[p.event_id] - for p in auth_chain + for p in auth_events if p.event_id in valid_pdus_map ] # NB: We *need* to copy to ensure that we don't have multiple # references being passed on, as that causes... issues. + # TODO(kegan): It's unclear why we only need to do this for state and not auth_events for s in signed_state: s.internal_metadata = s.internal_metadata.copy() - # double-check that the auth chain doesn't include a different create event - auth_chain_create_events = [ + # double-check that the auth events doesn't include a different create event + auth_events_create_events = [ e.event_id for e in signed_auth if (e.type, e.state_key) == (EventTypes.Create, "") ] - if auth_chain_create_events and auth_chain_create_events != [ + if auth_events_create_events and auth_events_create_events != [ create_event.event_id ]: raise InvalidResponseError( "Unexpected create event(s) in auth chain: %s" - % (auth_chain_create_events,) + % (auth_events_create_events,) ) servers_in_room = None @@ -1301,10 +1321,21 @@ class FederationClient(FederationBase): # Fix things up in case the remote homeserver is badly behaved. servers_in_room.add(destination) + signed_auth_events = signed_auth + signed_state_dag = None + if room_version.msc4242_state_dags: + # We previously set the state dag to auth_events so re-assign it correctly + signed_state_dag = signed_auth + # Ensure the caller cannot accidentally use these values even if the server + # returned them. + signed_state = [] + signed_auth_events = [] + return SendJoinResult( event=event, state=signed_state, - auth_chain=signed_auth, + auth_chain=signed_auth_events, + state_dag=signed_state_dag, origin=destination, partial_state=response.members_omitted, servers_in_room=servers_in_room or frozenset(), @@ -1608,6 +1639,7 @@ class FederationClient(FederationBase): limit: int, min_depth: int, timeout: int, + state_dag: bool = False, ) -> list[EventBase]: """Tries to fetch events we are missing. This is called when we receive an event without having received all of its ancestors. @@ -1623,6 +1655,7 @@ class FederationClient(FederationBase): limit: Maximum number of events to return. min_depth: Minimum depth of events to return. timeout: Max time to wait in ms + state_dag: True to walk the state DAG (MSC4242 rooms) """ try: content = await self.transport_layer.get_missing_events( @@ -1633,6 +1666,7 @@ class FederationClient(FederationBase): limit=limit, min_depth=min_depth, timeout=timeout, + state_dag=state_dag, ) room_version = await self.store.get_room_version(room_id) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index b909f1e595..1a515a32be 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -28,6 +28,8 @@ from typing import ( Callable, Collection, Mapping, + Sequence, + cast, ) from prometheus_client import Counter, Gauge, Histogram @@ -53,7 +55,7 @@ from synapse.api.errors import ( ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.crypto.event_signing import compute_event_signature -from synapse.events import EventBase +from synapse.events import EventBase, FrozenEventVMSC4242 from synapse.events.snapshot import EventPersistencePair from synapse.federation.federation_base import ( FederationBase, @@ -596,7 +598,10 @@ class FederationServer(FederationBase): ) async def on_room_state_request( - self, origin: str, room_id: str, event_id: str + self, + origin: str, + room_id: str, + event_id: str, ) -> tuple[int, JsonDict]: await self._event_auth_handler.assert_host_in_room(room_id, origin) origin_host, _ = parse_server_name(origin) @@ -620,7 +625,10 @@ class FederationServer(FederationBase): @trace @tag_args async def on_state_ids_request( - self, origin: str, room_id: str, event_id: str + self, + origin: str, + room_id: str, + event_id: str, ) -> tuple[int, JsonDict]: if not event_id: raise NotImplementedError("Specify an event") @@ -641,17 +649,27 @@ class FederationServer(FederationBase): @trace @tag_args async def _on_state_ids_request_compute( - self, room_id: str, event_id: str + self, + room_id: str, + event_id: str, ) -> JsonDict: - state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id) + state_ids = await self.handler.get_state_ids_for_pdu( + room_id, + event_id, + ) auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids) return {"pdu_ids": state_ids, "auth_chain_ids": list(auth_chain_ids)} async def _on_context_state_request_compute( - self, room_id: str, event_id: str + self, + room_id: str, + event_id: str, ) -> dict[str, list]: pdus: Collection[EventBase] - event_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id) + event_ids = await self.handler.get_state_ids_for_pdu( + room_id, + event_id, + ) pdus = await self.store.get_events_as_list(event_ids) auth_chain = await self.store.get_auth_chain( @@ -759,6 +777,10 @@ class FederationServer(FederationBase): origin, content, Membership.JOIN, room_id ) + if event.room_version.msc4242_state_dags and caller_supports_partial_state: + # TODO(kegan): for now, MSC4242 won't support partial state for ease of prototyping. + caller_supports_partial_state = False + prev_state_ids = await context.get_prev_state_ids() state_event_ids: Collection[str] @@ -769,15 +791,31 @@ class FederationServer(FederationBase): event, prev_state_ids, summary ) servers_in_room = await self.state.get_hosts_in_room_at_events( - room_id, event_ids=event.prev_event_ids() + room_id, + event_ids=event.prev_state_events + if isinstance(event, FrozenEventVMSC4242) + else event.prev_event_ids(), ) else: state_event_ids = prev_state_ids.values() servers_in_room = None - auth_chain_event_ids = await self.store.get_auth_chain_ids( - room_id, state_event_ids - ) + state_dag = None + auth_chain_event_ids = set() + if event.room_version.msc4242_state_dags: + assert isinstance(event, FrozenEventVMSC4242) + # NOTE: we don't return the state dag for forward extremities that aren't part of this + # join event to make it easier for the receiving server to set their own forward + # extremities (they are equal to the join event's prev_state_events). This means we may + # fail to sync concurrent forks not on the path to the join event, but this is an + # outstanding problem in general. + state_dag = await self.store.get_state_dag( + room_id, set(event.prev_state_events) + ) + else: + auth_chain_event_ids = await self.store.get_auth_chain_ids( + room_id, state_event_ids + ) # if the caller has opted in, we can omit any auth_chain events which are # already in state_event_ids @@ -794,13 +832,18 @@ class FederationServer(FederationBase): resp = { "event": event_json, "state": serialize_and_filter_pdus(state_events, time_now), - "auth_chain": serialize_and_filter_pdus(auth_chain_events, time_now), "members_omitted": caller_supports_partial_state, } + if state_dag is None: + resp["auth_chain"] = serialize_and_filter_pdus(auth_chain_events, time_now) + else: + resp["state_dag"] = serialize_and_filter_pdus( + cast(Sequence[EventBase], state_dag.values()), time_now + ) + del resp["state"] if servers_in_room is not None: resp["servers_in_room"] = list(servers_in_room) - return resp async def on_make_leave_request( @@ -1097,6 +1140,7 @@ class FederationServer(FederationBase): earliest_events: list[str], latest_events: list[str], limit: int, + walk_state_dag: bool = False, ) -> dict[str, list]: async with self._server_linearizer.queue((origin, room_id)): origin_host, _ = parse_server_name(origin) @@ -1111,7 +1155,7 @@ class FederationServer(FederationBase): ) missing_events = await self.handler.on_get_missing_events( - origin, room_id, earliest_events, latest_events, limit + origin, room_id, earliest_events, latest_events, limit, walk_state_dag ) if len(missing_events) < 5: diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index f7240c2f7f..ea5c63114c 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -148,7 +148,7 @@ from twisted.internet import defer import synapse.metrics from synapse.api.constants import EventTypes, Membership from synapse.api.presence import UserPresenceState -from synapse.events import EventBase +from synapse.events import EventBase, FrozenEventVMSC4242 from synapse.federation.sender.per_destination_queue import ( CATCHUP_RETRY_INTERVAL, PerDestinationQueue, @@ -660,7 +660,10 @@ class FederationSender(AbstractFederationSender): # banned then it won't receive the event because it won't # be in the room after the ban. destinations = await self.state.get_hosts_in_room_at_events( - event.room_id, event_ids=event.prev_event_ids() + event.room_id, + event_ids=event.prev_state_events + if isinstance(event, FrozenEventVMSC4242) + else event.prev_event_ids(), ) except Exception: logger.exception( diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 35d3c30c69..62d09114a2 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -69,7 +69,10 @@ class TransportLayerClient: self.client.shutdown() async def get_room_state_ids( - self, destination: str, room_id: str, event_id: str + self, + destination: str, + room_id: str, + event_id: str, ) -> JsonDict: """Requests the IDs of all state for a given room at the given event. @@ -78,22 +81,26 @@ class TransportLayerClient: to get the state from. room_id: the room we want the state of event_id: The event we want the context at. - Returns: Results in a dict received from the remote homeserver. """ logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id) path = _create_v1_path("/state_ids/%s", room_id) + qps = {"event_id": event_id} return await self.client.get_json( destination, path=path, - args={"event_id": event_id}, + args=qps, try_trailing_slash_on_400=True, ) async def get_room_state( - self, room_version: RoomVersion, destination: str, room_id: str, event_id: str + self, + room_version: RoomVersion, + destination: str, + room_id: str, + event_id: str, ) -> "StateRequestResponse": """Requests the full state for a given room at the given event. @@ -103,15 +110,15 @@ class TransportLayerClient: to get the state from. room_id: the room we want the state of event_id: The event we want the context at. - Returns: Results in a dict received from the remote homeserver. """ path = _create_v1_path("/state/%s", room_id) + qps = {"event_id": event_id} return await self.client.get_json( destination, path=path, - args={"event_id": event_id}, + args=qps, # This can take a looooooong time for large rooms. Give this a generous # timeout of 10 minutes to avoid the partial state resync timing out early # and trying a bunch of servers who haven't seen our join yet. @@ -787,18 +794,21 @@ class TransportLayerClient: limit: int, min_depth: int, timeout: int, + state_dag: bool, ) -> JsonDict: path = _create_v1_path("/get_missing_events/%s", room_id) - + data = { + "limit": int(limit), + "min_depth": int(min_depth), + "earliest_events": earliest_events, + "latest_events": latest_events, + } + if state_dag: + data["org.matrix.msc4242.state_dag"] = True return await self.client.post_json( destination=destination, path=path, - data={ - "limit": int(limit), - "min_depth": int(min_depth), - "earliest_events": earliest_events, - "latest_events": latest_events, - }, + data=data, timeout=timeout, ) @@ -993,6 +1003,9 @@ class SendJoinResponse: state: list[EventBase] # The raw join event from the /send_join response. event_dict: JsonDict + # MSC4242: State DAGs. Always included for state dag rooms, else an empty list. + # Replaces auth_events. + state_dag: list[EventBase] # The parsed join event from the /send_join response. This will be None if # "event" is not included in the response. event: EventBase | None = None @@ -1079,7 +1092,7 @@ class SendJoinParser(ByteParser[SendJoinResponse]): MAX_RESPONSE_SIZE = 500 * 1024 * 1024 def __init__(self, room_version: RoomVersion, v1_api: bool): - self._response = SendJoinResponse([], [], event_dict={}) + self._response = SendJoinResponse([], [], event_dict={}, state_dag=[]) self._room_version = room_version self._coros: list[Generator[None, bytes, None]] = [] @@ -1123,6 +1136,15 @@ class SendJoinParser(ByteParser[SendJoinResponse]): ) ) + if room_version.msc4242_state_dags: + self._coros.append( + ijson.items_coro( + _event_list_parser(room_version, self._response.state_dag), + prefix + "state_dag.item", + use_float=True, + ) + ) + def write(self, data: bytes) -> int: for c in self._coros: c.send(data) diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index a7c297c0b7..da759b3c39 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -621,6 +621,7 @@ class FederationGetMissingEventsServlet(BaseFederationServerServlet): limit = int(content.get("limit", 10)) earliest_events = content.get("earliest_events", []) latest_events = content.get("latest_events", []) + walk_state_dag = content.get("org.matrix.msc4242.state_dag", False) result = await self.handler.on_get_missing_events( origin, @@ -628,6 +629,7 @@ class FederationGetMissingEventsServlet(BaseFederationServerServlet): earliest_events=earliest_events, latest_events=latest_events, limit=limit, + walk_state_dag=walk_state_dag, ) return 200, result diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 14805ac80f..88c1abe9e7 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -669,6 +669,13 @@ class FederationHandler: logger.debug("do_invite_join auth_chain: %s", auth_chain) logger.debug("do_invite_join state: %s", state) + # If the 'state_dag' field is set, everything will be derived from it. + if ret.state_dag: + logger.debug("do_invite_join state_dag: %s", ret.state_dag) + state = ret.state_dag + # TODO(kegan): is this actually important for processing? Shouldn't we be using + # the actual DAG here? + state.sort(key=lambda e: e.depth) logger.debug("do_invite_join event: %s", event) @@ -965,6 +972,7 @@ class FederationHandler: # Note that this requires the /send_join request to come back to the # same server. prev_event_ids = None + prev_state_events = None if room_version.restricted_join_rule: # Note that the room's state can change out from under us and render our # nice join rules-conformant event non-conformant by the time we build the @@ -981,6 +989,9 @@ class FederationHandler: state_ids = await self._state_storage_controller.get_current_state_ids( room_id ) + prev_state_events = list( + await self.store.get_state_dag_extremities(room_id) + ) if await self._event_auth_handler.has_restricted_join_rules( state_ids, room_version ): @@ -1021,6 +1032,7 @@ class FederationHandler: ) = await self.event_creation_handler.create_new_client_event( builder=builder, prev_event_ids=prev_event_ids, + prev_state_events=prev_state_events, ) except SynapseError as e: logger.warning("Failed to create join to %s because %s", room_id, e) @@ -1301,7 +1313,11 @@ class FederationHandler: @trace @tag_args - async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> list[str]: + async def get_state_ids_for_pdu( + self, + room_id: str, + event_id: str, + ) -> list[str]: """Returns the state at the event. i.e. not including said event.""" event = await self.store.get_event(event_id, check_room_id=room_id) if event.internal_metadata.outlier: @@ -1412,11 +1428,17 @@ class FederationHandler: earliest_events: list[str], latest_events: list[str], limit: int, + walk_state_dag: bool, ) -> list[EventBase]: # We allow partially joined rooms since in this case we are filtering out # non-local events in `filter_events_for_server`. await self._event_auth_handler.assert_host_in_room(room_id, origin, True) + if walk_state_dag: + return await self.on_get_missing_events_state_dag( + room_id, earliest_events, latest_events, limit + ) + # Only allow up to 20 events to be retrieved per request. limit = min(limit, 20) @@ -1439,6 +1461,32 @@ class FederationHandler: return missing_events + async def on_get_missing_events_state_dag( + self, + room_id: str, + earliest_events: list[str], + latest_events: list[str], + limit: int, + ) -> list[EventBase]: + """Processes a /get_missing_events request for the state DAG. + + This is similar to processing the normal DAG with a few notable exceptions: + * There is no max 20 limit applied. As the entire state DAG needs to be filled in, we + cannot arbitrarily set a low limit. If the state DAG delta is 1000s of events, we need + to rely on the sender to set sensible limits depending on the bandwidth/round trip + tradeoff. We rely on existing server rate limits here to prevent abuse. + * We do not filter any events in the state DAG. History visibility does not filter out + delivery of auth chain events, so neither should this. All of the returned events will + be treated as outliers and as such will not be delivered to clients. + """ + missing_events = await self.store.get_missing_events_state_dag( + room_id=room_id, + earliest_events=earliest_events, + latest_events=latest_events, + limit=limit, + ) + return missing_events + async def exchange_third_party_invite( self, sender_user_id: str, target_user_id: str, room_id: str, signed: JsonDict ) -> None: diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index e314180e12..701933aeb0 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -60,7 +60,7 @@ from synapse.event_auth import ( check_state_independent_auth_rules, validate_event_for_room_version, ) -from synapse.events import EventBase +from synapse.events import EventBase, FrozenEventVMSC4242, event_exists_in_state_dag from synapse.events.snapshot import ( EventContext, EventPersistencePair, @@ -522,8 +522,10 @@ class FederationEventHandler: Args: origin: Where the events came from room_id: - auth_events - state + auth_events: The auth chain from the send_join response. For MSC4242 State DAG rooms, + this is ignored. + state The current state of the room. For MSC4242 rooms, this is the state DAG and so + includes the auth chains for state. event room_version: The room version we expect this room to have, and will raise if it doesn't match the version in the create event. @@ -555,6 +557,14 @@ class FederationEventHandler: if room_version.identifier != room_version_id: raise SynapseError(400, "Room version mismatch") + if room_version.msc4242_state_dags: + # We should be provided with a connected state DAG, so check it before persisting them. + # If we find a gap, the response is invalid so refuse the join. + if not is_state_dag_connected( + [ev for ev in state if isinstance(ev, FrozenEventVMSC4242)] + ): + raise SynapseError(502, "State DAG is not connected") + # persist the auth chain and state events. # # any invalid events here will be marked as rejected, and we'll carry on. @@ -566,9 +576,14 @@ class FederationEventHandler: # signatures right now doesn't mean that we will *never* be able to, so it # is premature to reject them. # - await self._auth_and_persist_outliers( - room_id, itertools.chain(auth_events, state) + has_rejected_events = await self._auth_and_persist_outliers( + room_id, + # for state DAG rooms, auth_events = [] so this is fine + itertools.chain(auth_events, state), + from_send_join=True, ) + if room_version.msc4242_state_dags and has_rejected_events: + raise SynapseError(502, "State DAG included rejected events") # and now persist the join event itself. logger.info( @@ -620,6 +635,19 @@ class FederationEventHandler: ), ) ) + elif room_version.msc4242_state_dags: + assert isinstance(event, FrozenEventVMSC4242) + # we don't blindly trust the 'current state', and instead do state res to calculate it, + # which we can do because we just did _auth_and_persist_outliers on the state DAG. + state_before = ( + await self._state_handler.resolve_state_groups_for_events( + event.room_id, + event.prev_state_events, + ) + ) + state_ids_before_event = await state_before.get_state( + self._state_storage_controller + ) else: state_ids_before_event = { (e.type, e.state_key): e.event_id for e in state @@ -1119,6 +1147,10 @@ class FederationEventHandler: In other words: we should only call this method if `event` has been *pulled* as part of a batch of missing prev events, or similar. + For state DAG rooms, we simply fill in the state DAG instead of asking the remote server + for the state after each missing `prev_event`, thereby avoiding security issues that can + occur if the remote server lies about the room state. + Params: dest: the remote server to ask for state at the missing prevs. Typically, this will be the server we got `event` from. @@ -1136,6 +1168,13 @@ class FederationEventHandler: """ room_id = event.room_id event_id = event.event_id + room_version = await self._store.get_room_version_id(room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + + if room_version_obj.msc4242_state_dags: + return await self._compute_event_context_with_maybe_missing_prevs_state_dag( + dest, event + ) prevs = set(event.prev_event_ids()) seen = await self._store.have_events_in_timeline(prevs) @@ -1195,7 +1234,10 @@ class FederationEventHandler: # by the get_pdu_cache in federation_client. remote_state_map = ( await self._get_state_ids_after_missing_prev_event( - dest, room_id, p + dest, + room_id, + p, + event.room_version, ) ) @@ -1212,7 +1254,6 @@ class FederationEventHandler: # we don't need this any more, let's delete it. del ours - room_version = await self._store.get_room_version_id(room_id) state_map = await self._state_resolution_handler.resolve_events_with_store( room_id, room_version, @@ -1237,6 +1278,24 @@ class FederationEventHandler: event, state_ids_before_event=state_map, partial_state=partial_state ) + async def _compute_event_context_with_maybe_missing_prevs_state_dag( + self, dest: str, event: EventBase + ) -> EventContext: + # before we can compute the event context we need to calculate the auth event IDs + # for this event. Before we can do that, we need to calculate the auth state before + # this event. Before we can do that, we need to fill in the state DAG for the + # prev_state_events in this event. + assert isinstance(event, FrozenEventVMSC4242) + missed_events = await self._fetch_missing_state_dag_events(dest, event) + await self._auth_and_persist_outliers(event.room_id, missed_events) + ( + ctx, + calculated_auth_event_ids, + ) = await self._calculate_state_dag_context(event) + event.internal_metadata.calculated_auth_event_ids = calculated_auth_event_ids + + return ctx + @trace @tag_args async def _get_state_ids_after_missing_prev_event( @@ -1244,6 +1303,7 @@ class FederationEventHandler: destination: str, room_id: str, event_id: str, + room_version: RoomVersion, ) -> StateMap[str]: """Requests all of the room state at a given event from a remote homeserver. @@ -1251,7 +1311,7 @@ class FederationEventHandler: destination: The remote homeserver to query for the state. room_id: The id of the room we're interested in. event_id: The id of the event we want the state at. - + room_version: The version of the room Returns: The event ids of the state *after* the given event. @@ -1271,7 +1331,9 @@ class FederationEventHandler: state_event_ids, auth_event_ids, ) = await self._federation_client.get_room_state_ids( - destination, room_id, event_id=event_id + destination, + room_id, + event_id=event_id, ) logger.debug( @@ -1280,6 +1342,92 @@ class FederationEventHandler: len(auth_event_ids), ) + # ensure all these events are in the DB before we continue further. + await self._persist_state_response( + destination, room_id, event_id, state_event_ids, auth_event_ids + ) + + # if we couldn't get the prev event in question, that's a problem. + remote_event = await self._store.get_event( + event_id, + allow_none=True, + allow_rejected=True, + redact_behaviour=EventRedactBehaviour.as_is, + ) + if not remote_event: + raise Exception("Unable to get missing prev_event %s" % (event_id,)) + + # We now need to fill out the state map, which involves fetching the + # type and state key for each event ID in the state. + state_map = {} + + event_metadata = await self._store.get_metadata_for_events(state_event_ids) + for state_event_id, metadata in event_metadata.items(): + if metadata.room_id != room_id: + # This is a bogus situation, but since we may only discover it a long time + # after it happened, we try our best to carry on, by just omitting the + # bad events from the returned state set. + # + # This can happen if a remote server claims that the state or + # auth_events at an event in room A are actually events in room B + logger.warning( + "Remote server %s claims event %s in room %s is an auth/state " + "event in room %s", + destination, + state_event_id, + metadata.room_id, + room_id, + ) + continue + + if metadata.state_key is None: + logger.warning( + "Remote server gave us non-state event in state: %s", state_event_id + ) + continue + + state_map[(metadata.event_type, metadata.state_key)] = state_event_id + + # missing state at that event is a warning, not a blocker + # XXX: this doesn't sound right? it means that we'll end up with incomplete + # state. + failed_to_fetch = set(state_event_ids) - event_metadata.keys() + # `event_id` could be missing from `event_metadata` because it's not necessarily + # a state event. We've already checked that we've fetched it above. + failed_to_fetch.discard(event_id) + if failed_to_fetch: + logger.warning( + "Failed to fetch missing state events for %s %s", + event_id, + failed_to_fetch, + ) + set_tag( + SynapseTags.RESULT_PREFIX + "failed_to_fetch", + str(failed_to_fetch), + ) + set_tag( + SynapseTags.RESULT_PREFIX + "failed_to_fetch.length", + str(len(failed_to_fetch)), + ) + + if remote_event.is_state() and remote_event.rejected_reason is None: + state_map[(remote_event.type, remote_event.state_key)] = ( + remote_event.event_id + ) + + return state_map + + async def _persist_state_response( + self, + destination: str, + room_id: str, + event_id: str, + state_event_ids: list[str], + auth_event_ids: list[str], + ) -> None: + """Ensure that all events from a /state{_ids} response have been seen and are in the database. + This function side-effects by inserting missed events into the database. + """ # Start by checking events we already have in the DB desired_events = set(state_event_ids) desired_events.add(event_id) @@ -1344,76 +1492,6 @@ class FederationEventHandler: destination=destination, room_id=room_id, event_ids=missing_event_ids ) - # We now need to fill out the state map, which involves fetching the - # type and state key for each event ID in the state. - state_map = {} - - event_metadata = await self._store.get_metadata_for_events(state_event_ids) - for state_event_id, metadata in event_metadata.items(): - if metadata.room_id != room_id: - # This is a bogus situation, but since we may only discover it a long time - # after it happened, we try our best to carry on, by just omitting the - # bad events from the returned state set. - # - # This can happen if a remote server claims that the state or - # auth_events at an event in room A are actually events in room B - logger.warning( - "Remote server %s claims event %s in room %s is an auth/state " - "event in room %s", - destination, - state_event_id, - metadata.room_id, - room_id, - ) - continue - - if metadata.state_key is None: - logger.warning( - "Remote server gave us non-state event in state: %s", state_event_id - ) - continue - - state_map[(metadata.event_type, metadata.state_key)] = state_event_id - - # if we couldn't get the prev event in question, that's a problem. - remote_event = await self._store.get_event( - event_id, - allow_none=True, - allow_rejected=True, - redact_behaviour=EventRedactBehaviour.as_is, - ) - if not remote_event: - raise Exception("Unable to get missing prev_event %s" % (event_id,)) - - # missing state at that event is a warning, not a blocker - # XXX: this doesn't sound right? it means that we'll end up with incomplete - # state. - failed_to_fetch = desired_events - event_metadata.keys() - # `event_id` could be missing from `event_metadata` because it's not necessarily - # a state event. We've already checked that we've fetched it above. - failed_to_fetch.discard(event_id) - if failed_to_fetch: - logger.warning( - "Failed to fetch missing state events for %s %s", - event_id, - failed_to_fetch, - ) - set_tag( - SynapseTags.RESULT_PREFIX + "failed_to_fetch", - str(failed_to_fetch), - ) - set_tag( - SynapseTags.RESULT_PREFIX + "failed_to_fetch.length", - str(len(failed_to_fetch)), - ) - - if remote_event.is_state() and remote_event.rejected_reason is None: - state_map[(remote_event.type, remote_event.state_key)] = ( - remote_event.event_id - ) - - return state_map - @trace @tag_args async def _get_state_and_persist( @@ -1422,7 +1500,7 @@ class FederationEventHandler: """Get the complete room state at a given event, and persist any new events as outliers""" room_version = await self._store.get_room_version(room_id) - auth_events, state_events = await self._federation_client.get_room_state( + state_events, auth_events = await self._federation_client.get_room_state( destination, room_id, event_id=event_id, room_version=room_version ) logger.info("/state returned %i events", len(auth_events) + len(state_events)) @@ -1629,10 +1707,10 @@ class FederationEventHandler: @trace @tag_args - async def _get_events_and_persist( + async def _get_events_from_remote( self, destination: str, room_id: str, event_ids: StrCollection - ) -> None: - """Fetch the given events from a server, and persist them as outliers. + ) -> list[EventBase]: + """Fetch the given events from a server. This function *does not* recursively get missing auth events of the newly fetched events. Callers must include in the `event_ids` argument @@ -1671,13 +1749,32 @@ class FederationEventHandler: ) await concurrently_execute(get_event, event_ids, 5) + return events + + @trace + @tag_args + async def _get_events_and_persist( + self, destination: str, room_id: str, event_ids: StrCollection + ) -> None: + """Fetch the given events from a server, and persist them as outliers. + + This function *does not* recursively get missing auth events of the + newly fetched events. Callers must include in the `event_ids` argument + any missing events from the auth chain. + + Logs a warning if we can't find the given event. + """ + events = await self._get_events_from_remote(destination, room_id, event_ids) logger.info("Fetched %i events of %i requested", len(events), len(event_ids)) await self._auth_and_persist_outliers(room_id, events) @trace async def _auth_and_persist_outliers( - self, room_id: str, events: Iterable[EventBase] - ) -> None: + self, + room_id: str, + events: Iterable[EventBase], + from_send_join: bool = False, + ) -> bool: """Persist a batch of outlier events fetched from remote servers. We first sort the events to make sure that we process each event's auth_events @@ -1690,8 +1787,15 @@ class FederationEventHandler: room_id: the room that the events are meant to be in (though this has not yet been checked) events: the events that have been fetched + from_send_join: If True, the events in `events` are from a /send_join response. This + allows some optimisations to be performed as we know we don't need to query the database + for events. State DAG rooms use this. + Returns: + True if some outlier events were rejected. """ event_map = {event.event_id: event for event in events} + if len(event_map) == 0: + return False # nothing to do event_ids = event_map.keys() set_tag( @@ -1702,6 +1806,9 @@ class FederationEventHandler: SynapseTags.FUNC_ARG_PREFIX + "event_ids.length", str(len(event_ids)), ) + is_state_dag_room = event_map[ + list(event_ids)[0] + ].room_version.msc4242_state_dags # filter out any events we have already seen. This might happen because # the events were eagerly pushed to us (eg, during a room join), or because @@ -1718,9 +1825,18 @@ class FederationEventHandler: # We need to persist an event's auth events before the event. auth_graph = { - ev.event_id: [e_id for e_id in ev.auth_event_ids() if e_id in event_map] + ev.event_id: [ + e_id + for e_id in ( + ev.auth_event_ids() + if not isinstance(ev, FrozenEventVMSC4242) + else ev.prev_state_events + ) + if e_id in event_map + ] for ev in event_map.values() } + # XXX: confusing name: this isn't _only_ auth events! sorted_auth_event_ids = sorted_topologically(event_map.keys(), auth_graph) sorted_auth_events = [event_map[e_id] for e_id in sorted_auth_event_ids] logger.info( @@ -1729,92 +1845,234 @@ class FederationEventHandler: shortstr(e.event_id for e in sorted_auth_events), ) - # get all the auth events for all the events in this batch. By now, they should - # have been persisted. - auth_event_ids = { - aid for event in sorted_auth_events for aid in event.auth_event_ids() - } - auth_map = { - ev.event_id: ev - for ev in sorted_auth_events - if ev.event_id in auth_event_ids - } - - missing_events = auth_event_ids.difference(auth_map) - if missing_events: - persisted_events = await self._store.get_events( - missing_events, - allow_rejected=True, - redact_behaviour=EventRedactBehaviour.as_is, - ) - auth_map.update(persisted_events) + has_rejected_events = False events_and_contexts_to_persist: list[EventPersistencePair] = [] - async def prep(event: EventBase) -> None: - with nested_logging_context(suffix=event.event_id): - auth = [] - for auth_event_id in event.auth_event_ids(): - ae = auth_map.get(auth_event_id) - if not ae: - # the fact we can't find the auth event doesn't mean it doesn't - # exist, which means it is premature to reject `event`. Instead we - # just ignore it for now. - logger.warning( - "Dropping event %s, which relies on auth_event %s, which could not be found", - event, - auth_event_id, - ) - # Drop the event from the auth_map too, else we may incorrectly persist - # events which depend on this dropped event. - auth_map.pop(event.event_id, None) - return - auth.append(ae) + if is_state_dag_room: + event_id_to_state_group: dict[str, int] = {} + state_group_to_state_map: dict[int, StateMap[str]] = {} + processed_event_map: dict[str, EventBase] = {} - # we're not bothering about room state, so flag the event as an outlier. - event.internal_metadata.outlier = True - - context = EventContext.for_outlier(self._storage_controllers) - try: - validate_event_for_room_version(event) - await check_state_independent_auth_rules( - self._store, event, batched_auth_events=auth_map + # other room versions now fetch missed auth_events and check that the event is allowed + # according to event-supplied auth_events. MSC4242 rooms instead _calculate_ the auth + # state and then determine if the event is allowed. + async def process(event: EventBase) -> tuple[EventBase, EventContext]: + assert isinstance(event, FrozenEventVMSC4242) + with nested_logging_context(suffix=event.event_id): + known_prev_state_maps = None + if from_send_join: + state_groups = [ + event_id_to_state_group[event_id] + for event_id in event.prev_state_events + ] + known_prev_state_maps = { + sg: state_group_to_state_map[sg] for sg in state_groups + } + # The following is very similar to EventBuilder.build which also calculates + # what the auth events should be. + ( + context, + calculated_auth_event_ids, + ) = await self._calculate_state_dag_context( + event, + known_prev_state_maps=known_prev_state_maps, + event_map=processed_event_map, ) - check_state_dependent_auth_rules(event, auth) - except AuthError as e: - logger.warning("Rejecting %r because %s", event, e) - context.rejected = RejectedReason.AUTH_ERROR - except EventSizeError as e: - if e.unpersistable: - # This event is completely unpersistable. - raise e - # Otherwise, we are somewhat lenient and just persist the event - # as rejected, for moderate compatibility with older Synapse - # versions. - logger.warning("While validating received event %r: %s", event, e) - context.rejected = RejectedReason.OVERSIZED_EVENT + # set metadata fields + event.internal_metadata.calculated_auth_event_ids = ( + calculated_auth_event_ids + ) + event.internal_metadata.outlier = True + # fetch the calculated auth events for checking auth rules + batched_auth_events = None + if from_send_join: + # we already have the auth events as we're processing a /send_join response. + # So let's pull them out now. + calculated_auth_events = { + event_id: event_map[event_id] + for event_id in calculated_auth_event_ids + } + batched_auth_events = { + event_id: event_map[event_id] + for event_id in ( + calculated_auth_event_ids + event.prev_state_events + ) + } + if context._state_group is not None: + event_id_to_state_group[event.event_id] = ( + context._state_group + ) + state_group_to_state_map[ + context._state_group + ] = await context.get_current_state_ids() - events_and_contexts_to_persist.append((event, context)) + else: + calculated_auth_events = await self._store.get_events( + calculated_auth_event_ids, + allow_rejected=True, + redact_behaviour=EventRedactBehaviour.as_is, + ) + try: + validate_event_for_room_version(event) + await check_state_independent_auth_rules( + self._store, event, batched_auth_events + ) + check_state_dependent_auth_rules( + event, calculated_auth_events.values() + ) + except AuthError as e: + logger.warning("Rejecting %r because %s", event, e) + context.rejected = RejectedReason.AUTH_ERROR + except EventSizeError as e: + if e.unpersistable: + # This event is completely unpersistable. + raise e + # TODO(kegan): check before merging if we can be stricter here. + # Otherwise, we are somewhat lenient and just persist the event + # as rejected, for moderate compatibility with older Synapse + # versions. + logger.warning( + "While validating received event %r: %s", event, e + ) + context.rejected = RejectedReason.OVERSIZED_EVENT - for i, event in enumerate(sorted_auth_events): - await prep(event) + processed_event_map[event.event_id] = event + return (event, context) - # The above function is typically not async, and so won't yield to - # the reactor. For large rooms let's yield to the reactor - # occasionally to ensure we don't block other work. - if (i + 1) % 1000 == 0: - await self._clock.sleep(Duration(seconds=0)) + # we cannot prep in a batch then persist in a batch like with other room versions + # because we need to have persisted the state before the prev_state_events _before_ + # we can work out what the calculated auth_events are for the next event. If we try + # to do what other room versions do, _calculate_state_dag_context will fail due + # to not knowing the state groups at prev_state_events. + # The exception to this is when processing a /send_join response as we have all the + # events in-memory. + if from_send_join: + for i, event in enumerate(sorted_auth_events): + event, context = await process(event) + if context.rejected is not None: + has_rejected_events = True + events_and_contexts_to_persist.append((event, context)) + # The above function is typically not async, and so won't yield to + # the reactor. For large rooms let's yield to the reactor + # occasionally to ensure we don't block other work. + if (i + 1) % 1000 == 0: + await self._clock.sleep(Duration(seconds=0)) - # Also persist the new event in batches for similar reasons as above. - for batch in batch_iter(events_and_contexts_to_persist, 1000): - await self.persist_events_and_notify( - room_id, - batch, - # Mark these events as backfilled as they're historic events that will - # eventually be backfilled. For example, missing events we fetch - # during backfill should be marked as backfilled as well. - backfilled=True, - ) + for batch in batch_iter(events_and_contexts_to_persist, 1000): + await self.persist_events_and_notify( + room_id, + batch, + # Mark these events as backfilled as they're historic events that will + # eventually be backfilled. For example, missing events we fetch + # during backfill should be marked as backfilled as well. + backfilled=True, + ) + else: + for event in sorted_auth_events: + event_and_context = await process(event) + await self.persist_events_and_notify( + room_id, + [event_and_context], + # Mark these events as backfilled as they're historic events that will + # eventually be backfilled. For example, missing events we fetch + # during backfill should be marked as backfilled as well. + backfilled=True, + ) + if event_and_context[1].rejected is not None: + has_rejected_events = True + else: + # get all the auth events for all the events in this batch. By now, they should + # have been persisted. + auth_event_ids = { + aid for event in sorted_auth_events for aid in event.auth_event_ids() + } + auth_map = { + ev.event_id: ev + for ev in sorted_auth_events + if ev.event_id in auth_event_ids + } + + missing_events = auth_event_ids.difference(auth_map) + if missing_events: + persisted_events = await self._store.get_events( + missing_events, + allow_rejected=True, + redact_behaviour=EventRedactBehaviour.as_is, + ) + auth_map.update(persisted_events) + + async def prep(event: EventBase) -> bool: + with nested_logging_context(suffix=event.event_id): + auth = [] + for auth_event_id in event.auth_event_ids(): + ae = auth_map.get(auth_event_id) + if not ae: + # the fact we can't find the auth event doesn't mean it doesn't + # exist, which means it is premature to reject `event`. Instead we + # just ignore it for now. + logger.warning( + "Dropping event %s, which relies on auth_event %s, which could not be found", + event, + auth_event_id, + ) + # Drop the event from the auth_map too, else we may incorrectly persist + # events which depend on this dropped event. + auth_map.pop(event.event_id, None) + return False + auth.append(ae) + + # we're not bothering about room state, so flag the event as an outlier. + event.internal_metadata.outlier = True + + context = EventContext.for_outlier(self._storage_controllers) + try: + validate_event_for_room_version(event) + await check_state_independent_auth_rules( + self._store, event, batched_auth_events=auth_map + ) + check_state_dependent_auth_rules(event, auth) + except AuthError as e: + logger.warning("Rejecting %r because %s", event, e) + context.rejected = RejectedReason.AUTH_ERROR + except EventSizeError as e: + if e.unpersistable: + # This event is completely unpersistable. + raise e + # Otherwise, we are somewhat lenient and just persist the event + # as rejected, for moderate compatibility with older Synapse + # versions. + logger.warning( + "While validating received event %r: %s", event, e + ) + context.rejected = RejectedReason.OVERSIZED_EVENT + + events_and_contexts_to_persist.append((event, context)) + return context.rejected is not None + + for i, event in enumerate(sorted_auth_events): + is_rejected = await prep(event) + if is_rejected: + has_rejected_events = True + + # The above function is typically not async, and so won't yield to + # the reactor. For large rooms let's yield to the reactor + # occasionally to ensure we don't block other work. + if (i + 1) % 1000 == 0: + await self._clock.sleep(Duration(seconds=0)) + + # Also persist the new event in batches for similar reasons as above. + for batch in batch_iter(events_and_contexts_to_persist, 1000): + await self.persist_events_and_notify( + room_id, + batch, + # Mark these events as backfilled as they're historic events that will + # eventually be backfilled. For example, missing events we fetch + # during backfill should be marked as backfilled as well. + backfilled=True, + ) + + return has_rejected_events @trace async def _check_event_auth( @@ -1863,7 +2121,9 @@ class FederationEventHandler: # caller rather than swallow with `context.rejected` (since we cannot be # certain that there is a permanent problem with the event). claimed_auth_events = await self._load_or_fetch_auth_events_for_event( - origin, event + origin, + event, + context, ) set_tag( SynapseTags.RESULT_PREFIX + "claimed_auth_events", @@ -1888,6 +2148,11 @@ class FederationEventHandler: context.rejected = RejectedReason.AUTH_ERROR return + if event.room_version.msc4242_state_dags: + # Step 4 also did Step 5 as we calculated the auth events from the state before the event + # so we're done. + return + # now check the auth rules pass against the room state before the event # https://spec.matrix.org/v1.3/server-server-api/#checks-performed-on-receipt-of-a-pdu: # 5. Passes authorization rules based on the state before the event, @@ -1987,6 +2252,78 @@ class FederationEventHandler: current_state_list = list(current_state.values()) await self._get_room_member_handler().kick_guest_users(current_state_list) + async def _calculate_state_dag_context( + self, + event: FrozenEventVMSC4242, + existing_context: EventContext | None = None, + known_prev_state_maps: dict[int, StateMap[str]] | None = None, + event_map: dict[str, EventBase] | None = None, + ) -> tuple[EventContext, list[str]]: + """Calculates the state events for an event based on any provided existing context. + If no context is provided, calculates the state before the prev state events and + returns a new one. + + Args: + event: The event to calculate the state at. + existing_context: The precalculated state if known. + known_prev_state_maps: The state maps for all the prev_state_events, if known. + This avoids an event ID to state group lookup which allow batch processing if the events + aren't yet persisted. + Returns: + The persisted EventContext and the calculated auth event IDs. + """ + if len(event.prev_state_events) == 0 and event.type != EventTypes.Create: + raise SynapseError(502, f"event {event.event_id} has no prev_state_events") + + if existing_context: + state_ids = await existing_context.get_prev_state_ids() + return existing_context, self._event_auth_handler.compute_auth_events( + event, state_ids + ) + # load the room state at the prev_state_events. This is the expensive bit but + # it's a one-time cost as we remember the calculated auth events. + # TODO(kegan): we could be smarter here and cache what the calculated auth events are + # so we don't need to do state res, keyed off ([prev_state_events], sender), since the + # auth events for a given event is determined by its position in the state DAG and who is + # performing the operation. Without this cache, the same user sending 5 events incurs + # 5 state res operations. The cache must also be accessible to EventBuilder.build for local + # event creation, so somewhere in state handler perhaps? + if known_prev_state_maps is not None and len(known_prev_state_maps) > 0: + if len(known_prev_state_maps) == 1: + state_ids = list(known_prev_state_maps.values())[0] + else: + res = await self._state_resolution_handler.resolve_state_groups( + event.room_id, + event.room_version.identifier, + known_prev_state_maps, + event_map=event_map, + state_res_store=StateResolutionStore( + self._store, self._state_deletion_store + ), + ) + state_ids = await res.get_state(self._state_storage_controller) + else: + state_ids = await self._state_handler.compute_state_after_events( + event.room_id, + event.prev_state_events, + state_filter=None, # can't apply this as we need to persist the state group afterwards + await_full_state=False, + ) + # we should always have some kind of resolved state after the prev_state_events, except for + # the create event. If we don't, yell loudly. + if event.type != EventTypes.Create: + assert len(state_ids) > 0 + context = await self._state_handler.compute_event_context( + event, + state_ids, + # Ideally this would always be False but code asserts this is None if there are no + # state_ids, which can happen normally if we reach the create event, so appease it. + partial_state=None + if len(event.prev_state_events) == 0 and event.type == EventTypes.Create + else False, + ) + return context, self._event_auth_handler.compute_auth_events(event, state_ids) + async def _check_for_soft_fail( self, event: EventBase, @@ -2010,9 +2347,15 @@ class FederationEventHandler: # current state, it may have been derived from state resolution between # partial and full state and may not be accurate. return + is_state_dag_room = event.room_version.msc4242_state_dags - extrem_ids = await self._store.get_latest_event_ids_in_room(event.room_id) - prev_event_ids = set(event.prev_event_ids()) + if is_state_dag_room: + extrem_ids = await self._store.get_state_dag_extremities(event.room_id) + assert isinstance(event, FrozenEventVMSC4242) + prev_event_ids = set(event.prev_state_events) + else: + extrem_ids = await self._store.get_latest_event_ids_in_room(event.room_id) + prev_event_ids = set(event.prev_event_ids()) if extrem_ids == prev_event_ids: # If they're the same then the current state is the same as the @@ -2026,27 +2369,17 @@ class FederationEventHandler: auth_types = auth_types_for_event(room_version_obj, event) # Calculate the "current state". - seen_event_ids = await self._store.have_events_in_timeline(prev_event_ids) - has_missing_prevs = bool(prev_event_ids - seen_event_ids) - if has_missing_prevs: - # We don't have all the prev_events of this event, which means we have a - # gap in the graph, and the new event is going to become a new backwards - # extremity. - # - # In this case we want to be a little careful as we might have been - # down for a while and have an incorrect view of the current state, - # however we still want to do checks as gaps are easy to - # maliciously manufacture. - # - # So we use a "current state" that is actually a state - # resolution across the current forward extremities and the - # given state at the event. This should correctly handle cases - # like bans, especially with state res v2. - - state_sets_d = await self._state_storage_controller.get_state_groups_ids( + if is_state_dag_room: + # Because we may have just come back online after a long time, we don't know + # which is newer: our forward extremities or the event's state. As such, + # we do state resolution across those state sets to try to ensure we are seeing + # the 'current' state, particularly for catching bans. The block comment below + # says much the same thing but conditionally applies it based on missing prev_events, + # but in a state DAG world we always have prev_state_events. + state_sets_fwd = await self._state_storage_controller.get_state_groups_ids( event.room_id, extrem_ids ) - state_sets: list[StateMap[str]] = list(state_sets_d.values()) + state_sets: list[StateMap[str]] = list(state_sets_fwd.values()) state_ids = await context.get_prev_state_ids() state_sets.append(state_ids) current_state_ids = ( @@ -2061,11 +2394,48 @@ class FederationEventHandler: ) ) else: - current_state_ids = ( - await self._state_storage_controller.get_current_state_ids( - event.room_id, StateFilter.from_types(auth_types) + seen_event_ids = await self._store.have_events_in_timeline(prev_event_ids) + has_missing_prevs = bool(prev_event_ids - seen_event_ids) + if has_missing_prevs: + # We don't have all the prev_events of this event, which means we have a + # gap in the graph, and the new event is going to become a new backwards + # extremity. + # + # In this case we want to be a little careful as we might have been + # down for a while and have an incorrect view of the current state, + # however we still want to do checks as gaps are easy to + # maliciously manufacture. + # + # So we use a "current state" that is actually a state + # resolution across the current forward extremities and the + # given state at the event. This should correctly handle cases + # like bans, especially with state res v2. + + state_sets_d = ( + await self._state_storage_controller.get_state_groups_ids( + event.room_id, extrem_ids + ) + ) + state_sets = list(state_sets_d.values()) + state_ids = await context.get_prev_state_ids() + state_sets.append(state_ids) + current_state_ids = ( + await self._state_resolution_handler.resolve_events_with_store( + event.room_id, + room_version, + state_sets, + event_map=None, + state_res_store=StateResolutionStore( + self._store, self._state_deletion_store + ), + ) + ) + else: + current_state_ids = ( + await self._state_storage_controller.get_current_state_ids( + event.room_id, StateFilter.from_types(auth_types) + ) ) - ) logger.debug( "Doing soft-fail check for %s: state %s", @@ -2101,7 +2471,10 @@ class FederationEventHandler: event.internal_metadata.soft_failed = True async def _load_or_fetch_auth_events_for_event( - self, destination: str | None, event: EventBase + self, + destination: str | None, + event: EventBase, + context: EventContext, ) -> Collection[EventBase]: """Fetch this event's auth_events, from database or remote @@ -2123,6 +2496,8 @@ class FederationEventHandler: event: the event whose auth_events we want + context: The state before the event, if known. + Returns: all of the events listed in `event.auth_events_ids`, after deduplication @@ -2132,7 +2507,11 @@ class FederationEventHandler: AuthError if we were unable to fetch the auth_events for any reason. """ - event_auth_event_ids = set(event.auth_event_ids()) + event_auth_event_ids = set( + event.auth_event_ids() + if not isinstance(event, FrozenEventVMSC4242) + else event.prev_state_events + ) event_auth_events = await self._store.get_events( event_auth_event_ids, allow_rejected=True ) @@ -2140,6 +2519,21 @@ class FederationEventHandler: event_auth_events.keys() ) if not missing_auth_event_ids: + if isinstance(event, FrozenEventVMSC4242): + # we still need to calculate the auth_events before returning, as we only made + # sure we know of the prev_state_events. + _, calculated_auth_event_ids = await self._calculate_state_dag_context( + event, + context, + ) + event.internal_metadata.calculated_auth_event_ids = ( + calculated_auth_event_ids + ) + calculated_events = await self._store.get_events( + calculated_auth_event_ids, + allow_rejected=True, # TODO(kegan): probably False? + ) + return calculated_events.values() return event_auth_events.values() if destination is None: # this shouldn't happen: destination must be set unless we know we have already @@ -2155,9 +2549,16 @@ class FederationEventHandler: missing_auth_event_ids, ) try: - await self._get_remote_auth_chain_for_event( - destination, event.room_id, event.event_id - ) + if event.room_version.msc4242_state_dags: + assert isinstance(event, FrozenEventVMSC4242) + missed_events = await self._fetch_missing_state_dag_events( + destination, event + ) + await self._auth_and_persist_outliers(event.room_id, missed_events) + else: + await self._get_remote_auth_chain_for_event( + destination, event.room_id, event.event_id + ) except Exception as e: logger.warning("Failed to get auth chain for %s: %s", event, e) # in this case, it's very likely we still won't have all the auth @@ -2170,6 +2571,20 @@ class FederationEventHandler: missing_auth_event_ids.difference_update(extra_auth_events.keys()) event_auth_events.update(extra_auth_events) if not missing_auth_event_ids: + if isinstance(event, FrozenEventVMSC4242): + # we still need to calculate the auth_events before returning. + _, calculated_auth_event_ids = await self._calculate_state_dag_context( + event, + context, + ) + event.internal_metadata.calculated_auth_event_ids = ( + calculated_auth_event_ids + ) + calculated_events = await self._store.get_events( + calculated_auth_event_ids, + allow_rejected=True, # TODO(kegan): probably False? + ) + return calculated_events.values() return event_auth_events.values() # we still don't have all the auth events. @@ -2213,6 +2628,225 @@ class FederationEventHandler: await self._auth_and_persist_outliers(room_id, remote_auth_events) + @trace + @tag_args + async def _fetch_missing_state_dag_events( + self, + destination: str, + event: FrozenEventVMSC4242, + ) -> Iterable[EventBase]: + """If we are missing some of an event's prev state events, request them until we fill in + the complete state DAG. + + Args: + destination: where to fetch the state dag from + event: The event we're missing prev_state_events for. + Returns: + The missed state DAG events. + """ + # The logic for this is gnarly but the general idea is to hit /get_missing_events with + # the event ID (hence known as the "back set") to walk back up the state DAG + # breadth first. We then need to see if we have seen any of these events, in which case they + # can be removed from the back set. + # When the back set size reaches 0, we have filled in the entire DAG. It's gnarly because + # we don't have a clear idea when we have filled in the state DAG without checking with the + # database for potentially a lot of events. Consider the following graph: + # .-- E <- F <- G + # A <- B <- C <- D + # `-- H <- I <- J + # Assume this server knows A-G and receives J. Knowing the forwards extremities in the room + # (G) does nothing to help us know when we have connected up the state DAG at D. + # Thus, this function queries the database for all returned events to see if we have seen + # them, and then filters them out. + # + # This function terminates when either: + # - the back set size is 0 (we filled in the gap) + # - the back set entries do not change after a round of /get_missing_events + # (we're not making forward progress), in which case this indicates the remote server is + # lying to us and not sending us events we need. + # + # There are tradeoffs here between # round trips, amount of memory consumed and # DB hits. + # We could reduce memory consumed by persisting intermediate events in a staging area + # on disk. We could reduce DB hits by only querying a subset of events (and using the + # topological ordering) which may mean we try to process events we've already seen. We try + # to reduce the # round trips by exponentially increasing the limit in each request. + # Better algorithms exist here (search for "set reconciliation") but as of today, we don't + # do any of them. + + # this function is expensive. See if we need to do it at all. + seen = await self._store.have_seen_events( + event.room_id, event.prev_state_events + ) + if seen == set(event.prev_state_events): + return [] + + room_id = event.room_id + # we allow this amount of time for each event we're going to receive. + # This dynamically adjusts the timeout to account for very large responses. + timeout_ms_per_event = 100 + iteration = 0 + limit = 8 + # we maintain 3 sets: the back set is what the next /gme request will be, and the + # /gme response events get bucketed into one of these 3 sets (seen, missed, back) sets. + missed_events: dict[str, FrozenEventVMSC4242] = {} + seen_event_ids: set[str] = set() + # we operate on state events, but we may have originally hit the backwards extremity with + # a message event. If we do, we need to grab the prev_state_events first to seed the + # back set. /get_missing_events will not return the event we provide to it in latest_events. + back_set: dict[str, FrozenEventVMSC4242] = {} + if event_exists_in_state_dag(event): + back_set = {event.event_id: event} + else: + prev_state_events = await self._get_events_from_remote( + destination, room_id, event.prev_state_events + ) + back_set = { + e.event_id: e + for e in prev_state_events + if e.event_id not in seen and isinstance(e, FrozenEventVMSC4242) + } + # the back set now consists of state events we have not seen, so ensure we return them + # to the caller + missed_events = {e.event_id: e for e in back_set.values()} + + while len(back_set) > 0: + logger.info( + "missed=%s seen=%s back_set=%s", + missed_events.keys(), + seen_event_ids, + back_set.keys(), + ) + # remember which events we're querying for. If we don't make forward progress we'll bail + before_back_set = set(back_set) + max_events_per_req = limit * pow(2, iteration) # 8x1, 8x2, 8x4, ... + try: + # 10s base then +(100ms x # events) on top e.g 64 events = +6400ms = 16.4s + timeout = 10000 + (limit * timeout_ms_per_event) + remote_events = await self._federation_client.get_missing_events( + destination, + room_id, + earliest_events_ids=[], + latest_events=back_set.values(), + limit=max_events_per_req, + min_depth=0, + timeout=timeout, + state_dag=True, + ) + remote_events_map = { + ev.event_id: ev + for ev in remote_events + if isinstance(ev, FrozenEventVMSC4242) + } + except RequestSendFailed as e1: + logger.warning( + "Failed to get missing state dag events from remote: %s", e1 + ) + # by returning nothing we all but guarantee that the processing of the event + # received over federation will fail. We'll try doing this again the next time + # this server sends an event to us. + # TODO(kegan): Having a staging area of auth events we have got but not yet authed + # would help us stop doing repeat work. + return [] + + # bucket the remote events into seen / unseen. We include each event's prev_state_events + # here because that way we might be able to skip another request i.e we know we have + # seen the prev_state_events so don't bother fetching them again. + remote_event_ids = { + event_id + for event in remote_events_map.values() + for event_id in event.prev_state_events + } + remote_event_ids.update(remote_events_map.keys()) + seen_remotes = await self._store.have_seen_events( + room_id, + remote_event_ids, + ) + seen_event_ids.update(seen_remotes) + unseen_remotes = set(remote_events_map).difference(seen_remotes) + + # all unseen events must be returned + missed_events.update( + {k: v for (k, v) in remote_events_map.items() if k in unseen_remotes} + ) + + # now figure out what the new back set is. In the common case, remote events will have + # a long chain of new events e.g A <- B <- C <- D so we want to walk up this graph + # if all the events are unseen. If there are seen events (e.g A) then when we reach A + # we terminate that branch as we have filled in the gap. + # In order to avoid mutating the dict whilst iterating, we iterate over the back set + # snapshot we took earlier, and try to exhaust it (i.e it maps an event ID to a new + # earlier event ID(s) or None if we filled in the gap.) + back_queue = list(before_back_set) + new_back_set: set[str] = set() + while back_queue: + # If the prevs are: + # - All seen: we've filled in the gap, don't add this event to the back set. + # - All fetched: add all the prevs to the back set. + # - Mixed seen/fetched: add all the fetched prevs to the back set + # - Any unseen: keep this event in the back set. + back_event_id = back_queue.pop(0) + back_event = missed_events.get( + back_event_id, back_set.get(back_event_id) + ) + if back_event is None: + continue + seen_all_prevs = all( + pae in seen_event_ids for pae in back_event.prev_state_events + ) + if seen_all_prevs: + continue + has_any_unseen_prev = any( + pae not in seen_event_ids and pae not in missed_events + for pae in back_event.prev_state_events + ) + if has_any_unseen_prev: + new_back_set.add(back_event_id) + continue + + # if we reach here then we have a mixture of seen/fetched prevs. Add the fetched + # prevs to the queue + back_queue.extend( + pae + for pae in back_event.prev_state_events + if pae not in seen_event_ids + ) + + # if there is an event with lots of prev_state_events, so many that our limit won't + # pull them all in, then we have a problem if we give up trying to walk backwards. + # Ensure the limit is at least that large before giving up. + max_prev_state_events_on_single_event = max( + [len(ev.prev_state_events) for (_, ev) in back_set.items()] + ) + if ( + new_back_set == before_back_set + and max_events_per_req > max_prev_state_events_on_single_event + ): + # we didn't make forward progress, give up. + logger.warning( + "Failed to make forward progress when walking back through state dag, stuck at back set %s", + before_back_set, + ) + return [] + iteration += 1 # let the limit exponentially increase + + # edge case: the initial event we put as latest_events has so many prev_state_events that + # we did not make forward progress yet. In this case, missed_events does NOT have the + # initial event, as it was not returned from /get_missing_events. Therefore, we just + # special case this scenario and set the back set accordingly. + if new_back_set == {event.event_id}: + back_set = {event.event_id: event} + else: + back_set = { + event_id: missed_events[event_id] for event_id in new_back_set + } + + logger.info( + "fetch_missing_state_dag_events returning %s events from %s", + len(missed_events), + event.event_id, + ) + return missed_events.values() + @trace async def _run_push_actions_and_persist_event( self, event: EventBase, context: EventContext, backfilled: bool = False @@ -2411,10 +3045,23 @@ class FederationEventHandler: ) raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many prev_events") - if len(ev.auth_event_ids()) > 10: + if not isinstance(ev, FrozenEventVMSC4242) and len(ev.auth_event_ids()) > 10: logger.warning( "Rejecting event %s which has %i auth_events", ev.event_id, len(ev.auth_event_ids()), ) raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events") + + +def is_state_dag_connected(state_dag: list[FrozenEventVMSC4242]) -> bool: + assert len(state_dag) > 0 + have_event_ids = {ev.event_id for ev in state_dag} + all_prev_state_events = { + ev_id + for state_events in [ev.prev_state_events for ev in state_dag] + for ev_id in state_events + } + # have_event_ids should consist of all events including forward extremities, + # making all_prev_state_events a strict subset. + return all_prev_state_events.issubset(have_event_ids) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 6e38b55686..06c5bd23c5 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -1713,7 +1713,6 @@ class DatabasePool: ", ".join(key_names), latter, ) - txn.execute_values(sql, args, fetch=False) else: diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 415926eb0a..ea2bb93380 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -37,7 +37,7 @@ from prometheus_client import Counter, Gauge from synapse.api.constants import MAX_DEPTH from synapse.api.errors import StoreError from synapse.api.room_versions import EventFormatVersions, RoomVersion -from synapse.events import EventBase, make_event_from_dict +from synapse.events import EventBase, FrozenEventVMSC4242, make_event_from_dict from synapse.logging.opentracing import tag_args, trace from synapse.metrics import SERVER_NAME_LABEL from synapse.metrics.background_process_metrics import wrap_as_background_process @@ -258,6 +258,59 @@ class EventFederationWorkerStore( include_given, ) + async def get_state_dag( + self, room_id: str, forward_extrems: set[str] + ) -> dict[str, FrozenEventVMSC4242]: + """Get the current state DAG for the given room. + + This function is called when calculating a /send_join response. + This does not check that the room is an state DAG room, so check this + before calling this function! + + This functions guarantees that the returned state DAG is connected. + + Args: + room_id: The room to get the state dag for + Returns: + A map of event_id => event + """ + + def _get_state_events_txn(txn: LoggingTransaction, room_id: str) -> list[str]: + sql = """ + SELECT event_id FROM msc4242_state_dag_edges WHERE room_id = ? + """ + txn.execute(sql, (room_id,)) + event_ids = [ev_id for (ev_id,) in txn] + return event_ids + + # Pull out all auth events for this room using the events_by_room_and_type index. + # This is going to pull in erroneous events e.g m.room.members with no state keys but that's + # okay as we'll filter them out next. + event_ids = await self.db_pool.runInteraction( + "_get_state_events_txn", + _get_state_events_txn, + room_id, + ) + event_map = await self.get_events(event_ids) + # Filter the returned state events to only include ones on the paths back from the forward + # extremities. + result: dict[str, FrozenEventVMSC4242] = {} + next = forward_extrems + seen: set[str] = set() + while len(next) > 0: + # Pull the event and add the prev_state_events. + # We must have the event. + event_id = next.pop() + if event_id in seen: + continue + seen.add(event_id) + ev = event_map[event_id] + assert isinstance(ev, FrozenEventVMSC4242) + result[event_id] = ev + for pae in ev.prev_state_events: + next.add(pae) + return result + def _get_auth_chain_ids_using_cover_index_txn( self, txn: LoggingTransaction, @@ -1987,6 +2040,57 @@ class EventFederationWorkerStore( event_results.reverse() return event_results + async def get_missing_events_state_dag( + self, + room_id: str, + earliest_events: list[str], + latest_events: list[str], + limit: int, + ) -> list[EventBase]: + ids = await self.db_pool.runInteraction( + "get_missing_events_state_dag", + self._get_missing_events_state_dag, + room_id, + earliest_events, + latest_events, + limit, + ) + return await self.get_events_as_list(ids) + + def _get_missing_events_state_dag( + self, + txn: LoggingTransaction, + room_id: str, + earliest_events: list[str], + latest_events: list[str], + limit: int, + ) -> list[str]: + seen_events = set(earliest_events) + # ascii sort + front_queue = sorted(set(latest_events) - seen_events) + event_results: list[str] = [] + # TODO(kegan): use a recursive CTE? + query = ( + "SELECT prev_state_event_id FROM msc4242_state_dag_edges " + "WHERE room_id = ? AND event_id = ? " + "ORDER BY prev_state_event_id ASC " + "LIMIT ?" + ) + + while front_queue and len(event_results) < limit: + event_id = front_queue.pop(0) + txn.execute(query, (room_id, event_id, limit - len(event_results))) + # None check because the m.room.create event has NULL prev_state_events + new_results = [ + t[0] for t in txn if t[0] is not None and t[0] not in seen_events + ] + for next in new_results: + front_queue.append(next) + seen_events |= set(new_results) + event_results.extend(new_results) + + return event_results + @trace @tag_args async def get_successor_events(self, event_id: str) -> list[str]: diff --git a/tests/federation/test_federation_out_of_band_membership.py b/tests/federation/test_federation_out_of_band_membership.py index a1ab72b7a1..f556460af4 100644 --- a/tests/federation/test_federation_out_of_band_membership.py +++ b/tests/federation/test_federation_out_of_band_membership.py @@ -395,6 +395,7 @@ class OutOfBandMembershipTests(unittest.FederatingHomeserverTestCase): user1_invite_membership_event, ], event_dict=user1_join_membership_event_signed.get_pdu_json(), + state_dag=[], event=user1_join_membership_event_signed, members_omitted=False, servers_in_room=[ diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index 5cd4e154cf..3ca3477217 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -399,36 +399,54 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): ) self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) - # we should get complete room state back - returned_state = [ - (ev["type"], ev["state_key"]) for ev in channel.json_body["state"] - ] - self.assertCountEqual( - returned_state, - [ - ("m.room.create", ""), - ("m.room.power_levels", ""), - ("m.room.join_rules", ""), - ("m.room.history_visibility", ""), - ("m.room.member", f"@kermit_v{room_version}:test"), - ("m.room.member", f"@fozzie_v{room_version}:test"), - # nb: *not* the joining user - ], - ) + if KNOWN_ROOM_VERSIONS[room_version].msc4242_state_dags: + # we should get complete state dag back + returned_state_dag = [ + (ev["type"], ev["state_key"]) for ev in channel.json_body["state_dag"] + ] + self.assertCountEqual( + returned_state_dag, + [ + ("m.room.create", ""), + ("m.room.power_levels", ""), + ("m.room.join_rules", ""), + ("m.room.history_visibility", ""), + ("m.room.member", f"@kermit_v{room_version}:test"), + ("m.room.member", f"@fozzie_v{room_version}:test"), + # nb: *not* the joining user + ], + ) + else: + # we should get complete room state back + returned_state = [ + (ev["type"], ev["state_key"]) for ev in channel.json_body["state"] + ] + self.assertCountEqual( + returned_state, + [ + ("m.room.create", ""), + ("m.room.power_levels", ""), + ("m.room.join_rules", ""), + ("m.room.history_visibility", ""), + ("m.room.member", f"@kermit_v{room_version}:test"), + ("m.room.member", f"@fozzie_v{room_version}:test"), + # nb: *not* the joining user + ], + ) - # also check the auth chain - returned_auth_chain_events = [ - (ev["type"], ev["state_key"]) for ev in channel.json_body["auth_chain"] - ] - self.assertCountEqual( - returned_auth_chain_events, - [ - ("m.room.create", ""), - ("m.room.member", f"@kermit_v{room_version}:test"), - ("m.room.power_levels", ""), - ("m.room.join_rules", ""), - ], - ) + # also check the auth chain + returned_auth_chain_events = [ + (ev["type"], ev["state_key"]) for ev in channel.json_body["auth_chain"] + ] + self.assertCountEqual( + returned_auth_chain_events, + [ + ("m.room.create", ""), + ("m.room.member", f"@kermit_v{room_version}:test"), + ("m.room.power_levels", ""), + ("m.room.join_rules", ""), + ], + ) # the room should show that the new user is a member r = self.get_success(self._storage_controllers.state.get_current_state(room_id)) @@ -438,18 +456,12 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): @override_config({"use_frozen_dicts": True}) def test_send_join_with_frozen_dicts(self, room_version: str) -> None: """Test send_join with USE_FROZEN_DICTS=True""" - if room_version == RoomVersions.MSC4242v12.identifier: - # TODO: This room version doesn't work over federation in this PR. - return self._test_send_join_common(room_version) @parameterized.expand([(k,) for k in KNOWN_ROOM_VERSIONS.keys()]) @override_config({"use_frozen_dicts": False}) def test_send_join_without_frozen_dicts(self, room_version: str) -> None: """Test send_join with USE_FROZEN_DICTS=False""" - if room_version == RoomVersions.MSC4242v12.identifier: - # TODO: This room version doesn't work over federation in this PR. - return self._test_send_join_common(room_version) def test_send_join_partial_state(self) -> None: diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 7085531548..25f0522887 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -652,6 +652,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): ], partial_state=True, servers_in_room={"example.com"}, + state_dag=None, ) ) diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index 1aaa86e2e8..745a02a807 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -18,19 +18,23 @@ # [This file includes modifications made by New Vector Limited] # # +from typing import Iterable from unittest import mock from twisted.internet.testing import MemoryReactor from synapse.api.errors import AuthError, StoreError -from synapse.api.room_versions import RoomVersion +from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.event_auth import ( check_state_dependent_auth_rules, check_state_independent_auth_rules, ) -from synapse.events import make_event_from_dict +from synapse.events import EventBase, FrozenEventVMSC4242, make_event_from_dict from synapse.events.snapshot import EventContext from synapse.federation.transport.client import StateRequestResponse +from synapse.handlers.federation_event import ( + is_state_dag_connected, +) from synapse.logging.context import LoggingContext from synapse.rest import admin from synapse.rest.client import login, room @@ -1184,3 +1188,298 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): bert_member_event.event_id, "Rejected kick event unexpectedly became part of room state.", ) + + +MSC4242_ROOM_ID = "!msc4242:example.com" +counter = 1 + + +class FederationEventMSC4242AuthDAGTests(unittest.FederatingHomeserverTestCase): + def test_is_state_dag_connected(self) -> None: + linear_1 = msc4242_event([]) + linear_2 = msc4242_event([linear_1.event_id]) + linear_3 = msc4242_event([linear_2.event_id]) + self.assertEqual(is_state_dag_connected([linear_3, linear_1, linear_2]), True) + + fork_1 = msc4242_event([]) + fork_2 = msc4242_event([fork_1.event_id]) + fork_3 = msc4242_event([fork_1.event_id]) + fork_4 = msc4242_event([fork_1.event_id]) + fork_5 = msc4242_event([fork_2.event_id, fork_4.event_id]) + self.assertEqual( + is_state_dag_connected([fork_1, fork_2, fork_3, fork_4, fork_5]), True + ) + + unconnected_1 = msc4242_event([]) + unconnected_2 = msc4242_event([unconnected_1.event_id]) + unconnected_3 = msc4242_event(["$unknown"]) + self.assertEqual( + is_state_dag_connected([unconnected_1, unconnected_2, unconnected_3]), False + ) + + def test_fetch_missing_state_dag_events_linear(self) -> None: + linear = self.make_state_dag( + { + "A": [], + "B": ["A"], + "C": ["B"], + } + ) + seen_events: set[str] = set() + get_missing_events_req_resps = { + ("C",): ["A", "B"], + } + self.prepare_handler(seen_events, linear, get_missing_events_req_resps) + self.assert_fetch_missing_state_dag_events(linear, "C", ["A", "B"]) + + def test_fetch_missing_state_dag_events_linear_seen(self) -> None: + linear = self.make_state_dag( + { + "A": [], + "B": ["A"], + "C": ["B"], + "D": ["C"], + } + ) + seen_events: set[str] = {"A", "B"} + get_missing_events_req_resps = { + ("D",): ["C"], + ("C",): ["B"], + } + self.prepare_handler(seen_events, linear, get_missing_events_req_resps) + self.assert_fetch_missing_state_dag_events(linear, "D", ["C"]) + + def test_fetch_missing_state_dag_events_fork_merge(self) -> None: + fork_merge = self.make_state_dag( + { + "A": [], + "B": ["A"], + "C": ["A"], + "D": ["C"], + "E": ["B"], + "F": ["D", "E"], + } + ) + seen_events: set[str] = set() + get_missing_events_req_resps = { + ("F",): ["D", "E"], + ( + "D", + "E", + ): ["C", "A"], + ("E",): ["B", "A"], + } + self.prepare_handler(seen_events, fork_merge, get_missing_events_req_resps) + self.assert_fetch_missing_state_dag_events( + fork_merge, "F", ["A", "B", "C", "D", "E"] + ) + + def test_fetch_missing_state_dag_events_give_up_no_forward_progress(self) -> None: + fork_merge = self.make_state_dag( + { + "A": [], + "B": ["A"], + "C": ["A"], + "D": ["C", "B"], + } + ) + seen_events: set[str] = set() + get_missing_events_req_resps = { + ("D",): ["C"], # never provide B + } + self.prepare_handler(seen_events, fork_merge, get_missing_events_req_resps) + self.assert_fetch_missing_state_dag_events(fork_merge, "D", []) + + def test_fetch_missing_state_dag_events_seen(self) -> None: + fork_merge = self.make_state_dag( + { + "A": [], + "B": ["A"], + "C": ["A"], + "D": ["C"], + "E": ["B"], + "F": ["D", "E"], + } + ) + seen_events: set[str] = {"A", "B"} + get_missing_events_req_resps = { + ("F",): ["D", "E"], + ("D",): ["C", "A"], # NB: not D,E as we've seen E's prev_state_events => B. + ("E",): ["B", "A"], + } + self.prepare_handler(seen_events, fork_merge, get_missing_events_req_resps) + self.assert_fetch_missing_state_dag_events(fork_merge, "F", ["C", "D", "E"]) + + # test that we maintain a visited set so we don't needlessly make expensive /get_missing_events + # calls in cases like: + # A <- B <- C <- D <------------------------ I + # `----- E <-- F <-- G <-- H <--` + # In this scenario, there's two paths to A: via D and via EFGH. Both paths converge at C and then + # go to A. The first path to reach A should be remembered so we don't make additional requests. + def test_fetch_missing_state_dag_events_memoise(self) -> None: + memoise_concurrent = self.make_state_dag( + { + "A": [], + "B": ["A"], + "C": ["B"], + "D": ["C"], + "E": ["C"], + "F": ["E"], + "G": ["F"], + "H": ["G"], + "I": ["D", "H"], + } + ) + get_missing_events_req_resps = { + ("I",): ["D", "H"], + ( + "D", + "H", + ): ["C", "G"], + ("C", "G"): ["B", "F"], + ("B", "F"): ["A", "E"], + # we should never request C on its own as we should remember we have visited C already. + } + self.prepare_handler(set(), memoise_concurrent, get_missing_events_req_resps) + self.assert_fetch_missing_state_dag_events( + memoise_concurrent, + "I", + [ + "A", + "B", + "C", + "D", + "E", + "F", + "G", + "H", + ], + ) + + def test_fetch_missing_state_dag_seen_all(self) -> None: + seen_all = self.make_state_dag( + { + "A": [], + "B": ["A"], + "C": ["A"], + "D": ["C"], + } + ) + seen_events: set[str] = {"A", "B", "C", "D"} + get_missing_events_req_resps: dict[tuple, list[str]] = {} + self.prepare_handler(seen_events, seen_all, get_missing_events_req_resps) + self.assert_fetch_missing_state_dag_events(seen_all, "D", []) + + def assert_fetch_missing_state_dag_events( + self, + graph: dict[str, FrozenEventVMSC4242], + start_event_id: str, + want_event_ids: list[str], + ) -> None: + """Run _fetch_missing_state_dag_events with the param start_event_id on the graph provided + and ensure the result matches want_event_ids.""" + got = self.get_success( + self.hs.get_federation_event_handler()._fetch_missing_state_dag_events( + "unknown", graph[start_event_id] + ) + ) + self.assertEqual( + {ev.event_id for ev in got}, {graph[x].event_id for x in want_event_ids} + ) + + def prepare_handler( + self, + seen_fake_events: set[str], + graph: dict[str, FrozenEventVMSC4242], + gme_req_resps: dict[tuple, list[str]], + ) -> None: + """Setup mocks on federation event handler to return the right data at the right time. + Args: + seen_fake_events: The events seen by this homeserver already (in the database) + graph: The prepared state DAG graph mapping from fake event ID to real event. + gme_req_resps: The /get_missing_events responses to return. The keys are the events + provided as 'latest_events' and the values will be the events returned.""" + h = self.hs.get_federation_event_handler() + h._federation_client = mock.Mock( + spec=[ + "get_missing_events", + ] + ) + for x in graph: + print(f"{x} => {graph[x].event_id}") + + async def get_missing_events( + destination: str, + room_id: str, + earliest_events_ids: Iterable[str], + latest_events: Iterable[EventBase], + limit: int, + min_depth: int, + timeout: int, + state_dag: bool = False, + ) -> list[EventBase]: + assert state_dag + assert room_id == MSC4242_ROOM_ID + target_key = [ev.event_id for ev in latest_events] + target_key.sort() + assert target_key + # test which response to return + for req, resp in gme_req_resps.items(): + key = [graph[x].event_id for x in req] + key.sort() + if key == target_key: + return [graph[x] for x in resp] + raise AssertionError( + f"get_missing_events with latest={target_key} but no matches found (tested {len(gme_req_resps)})" + ) + + h._federation_client.get_missing_events.side_effect = get_missing_events + + seen_events = { + graph[fake_event_id].event_id for fake_event_id in seen_fake_events + } + + async def have_seen_events(room_id: str, event_ids: Iterable[str]) -> set[str]: + return seen_events.intersection(set(event_ids)) + + main_store = self.hs.get_datastores().main + main_store.have_seen_events = mock.AsyncMock() # type: ignore[method-assign] + main_store.have_seen_events.side_effect = have_seen_events + + def make_state_dag( + self, graph: dict[str, list[str]] + ) -> dict[str, FrozenEventVMSC4242]: + """Create a state dag from a graph of event_id -> prev_state_events. + Returns a map of fake ID to real event. + The map must be sorted topologically else this raises an exception.""" + fake_to_real_id: dict[str, str] = {} + result: dict[str, FrozenEventVMSC4242] = {} + for fake_id, fake_prev_state_events in graph.items(): + # lookup fails if graph not sorted topologically + paes = [fake_to_real_id[fpae] for fpae in fake_prev_state_events] + ev = msc4242_event(paes) + fake_to_real_id[fake_id] = ev.event_id + result[fake_id] = ev + return result + + +def msc4242_event(prev_state_events: list[str]) -> FrozenEventVMSC4242: + global counter + counter += 1 + ev = make_event_from_dict( + { + "type": "m.room.member", + "state_key": "@alice:example.com", + "content": { + "membership": "join", + }, + "sender": "@alice:example.com", + "origin_server_ts": counter, # ensure hashed event changes + "room_id": MSC4242_ROOM_ID, + "prev_state_events": prev_state_events, + "prev_events": [], # we shouldn't look at this field + }, + RoomVersions.MSC4242v12, + ) + assert isinstance(ev, FrozenEventVMSC4242) + return ev diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py index d8d7caaf1b..a1412d4cf3 100644 --- a/tests/handlers/test_room_member.py +++ b/tests/handlers/test_room_member.py @@ -172,6 +172,7 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase): auth_chain=[create_event], partial_state=False, servers_in_room=frozenset(), + state_dag=None, ) ) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 59e914ca8b..e4e8cd4569 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -19,6 +19,7 @@ # import datetime +from collections import namedtuple from typing import ( Collection, Iterable, @@ -38,8 +39,9 @@ from synapse.api.room_versions import ( KNOWN_ROOM_VERSIONS, EventFormatVersions, RoomVersion, + RoomVersions, ) -from synapse.events import EventBase +from synapse.events import EventBase, FrozenEventVMSC4242 from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer @@ -1414,6 +1416,174 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): # elapsed past the backoff range so there is no events to backoff from. self.assertEqual(event_ids_with_backoff, {}) + def test_get_state_dag(self) -> None: + """ + Test that MSC4242 state dag rooms can return the complete state dag on request. + """ + # Create the room + user_id = self.register_user("alice", "test") + tok = self.login("alice", "test") + room_id = self.helper.create_room_as( + room_creator=user_id, + tok=tok, + room_version=RoomVersions.MSC4242v12.identifier, + ) + resp = self.helper.send_state( + room_id, + "m.room.join_rules", + {"join_rule": "knock"}, + tok=tok, + ) + latest = resp["event_id"] + state_dag = self.get_success( + self.store.get_state_dag(room_id, {latest}), + ) + # create <- member <- pl <- join_rules <- his vis <- join_rules + self.assertEquals(len(state_dag), 6) + want_types = [ + EventTypes.Create, + EventTypes.Member, + EventTypes.PowerLevels, + EventTypes.JoinRules, + EventTypes.RoomHistoryVisibility, + EventTypes.JoinRules, + ] + curr = {latest} + while len(curr) > 0: + event_id = curr.pop() + ev = state_dag[event_id] + want_type = want_types.pop() + self.assertEqual(ev.type, want_type) + curr.update(ev.prev_state_events) + + def test_get_missing_events_state_dag(self) -> None: + # Primarily testing to make sure that we sort events + # correctly when there are multiple prev_state_events + # .- C -- D ---. + # A <- B E + # `- R -- W --` + # `-- T -` + graph = { + "A": [], + "B": ["A"], + "C": ["B"], + "R": ["B"], + "D": ["C"], + "W": ["R"], + "T": ["R"], + "E": ["W", "D", "T"], + } + room_id = "@state_dag:local" + events = [ + cast( + FrozenEventVMSC4242, + StateDAGFakeEvent( + event_id, + room_id, + EventTypes.Create if len(graph[event_id]) == 0 else "foo", + graph[event_id], + ), + ) + for event_id in graph + ] + + def insert(txn: LoggingTransaction) -> None: + # store these to satisfy fk constraints + self.persist_events.db_pool.simple_insert_many_txn( + txn, + table="events", + keys=( + "instance_name", + "stream_ordering", + "topological_ordering", + "depth", + "event_id", + "room_id", + "type", + "processed", + "outlier", + "origin_server_ts", + "received_ts", + "sender", + "contains_url", + "state_key", + "rejection_reason", + ), + values=[ + ( + "test", + event.internal_metadata.stream_ordering, + event.depth, # topological_ordering + event.depth, # depth + event.event_id, + event.room_id, + event.type, + True, # processed + False, # outlier + 1337, + 1741622420, + event.sender, + False, # contains url + event.state_key, + False, # rejected + ) + for event in events + ], + ) + for ev in events: + self.persist_events._store_state_dag_edges( + txn, + ev, + ) + + # satisfy fk constraints + self.get_success( + self.store.store_room(room_id, "foo", False, RoomVersions.MSC4242v12) + ) + self.get_success( + self.store.db_pool.runInteraction( + "_store_state_dag_edges", + insert, + ) + ) + + # .- C -- D ---. + # A <- B E + # `- R -- W --` + # `-- T -` + TestCase = namedtuple("TestCase", "latest want limit") + test_cases = [ + TestCase(latest=["E"], want=["D", "T", "W"], limit=3), + TestCase(latest=["W", "T", "D"], want=["C", "R"], limit=2), + # breadth first and new entries are added to the end, sorted lexicographically + TestCase(latest=["E"], want=["D", "T", "W", "C", "R", "B", "A"], limit=100), + # we should sort the latest values initially + TestCase(latest=["E", "C"], want=["B", "D", "T", "W"], limit=4), + TestCase(latest=["C", "E"], want=["B", "D", "T", "W"], limit=4), + # dupes are ignored + TestCase( + latest=["E", "E", "C", "C", "C"], want=["B", "D", "T", "W"], limit=4 + ), + # include latest events in response + TestCase(latest=["W", "E"], want=["D", "T", "W", "R"], limit=4), + ] + + for test_case in test_cases: + + def do_test(txn: LoggingTransaction) -> None: + got = self.store._get_missing_events_state_dag( + txn, + room_id, + [], + test_case.latest, + test_case.limit, + ) + self.assertEquals(got, test_case.want) + + self.get_success( + self.store.db_pool.runInteraction("test_case", do_test), + ) + @attr.s(auto_attribs=True) class FakeEvent: @@ -1431,3 +1601,19 @@ class FakeEvent: def is_state(self) -> bool: return True + + +@attr.s(auto_attribs=True) +class StateDAGFakeEvent: + event_id: str + room_id: str + type: str + prev_state_events: list[str] + state_key = "foo" + depth = 1 + sender = "foo" + + internal_metadata = EventInternalMetadata({}) + + def is_state(self) -> bool: + return True diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index d51fa1f8ba..8fce7c03e9 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -1455,6 +1455,7 @@ class GetCurrentStateDeltaMembershipChangesForUserFederationTestCase( auth_chain=[create_event, creator_join_event], partial_state=False, servers_in_room=frozenset(), + state_dag=None, ) )