mirror of
https://github.com/element-hq/synapse.git
synced 2026-05-24 23:55:21 +00:00
MSC4242: State DAGs (Federation)
Builds off https://github.com/element-hq/synapse/pull/19424 Adds federation compatibility for state DAG rooms. Overview: - Adds extra HTTP API fields as per the MSC. - Adds methods for walking and extracting the state DAG for a room (for `/get_missing_events` and `/send_join` respectively). - Adds impl for processing the federation requests, as well as `/send`.
This commit is contained in:
@@ -42,7 +42,12 @@ from typing import (
|
||||
import attr
|
||||
from prometheus_client import Counter
|
||||
|
||||
from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership
|
||||
from synapse.api.constants import (
|
||||
Direction,
|
||||
EventContentFields,
|
||||
EventTypes,
|
||||
Membership,
|
||||
)
|
||||
from synapse.api.errors import (
|
||||
CodeMessageException,
|
||||
Codes,
|
||||
@@ -119,6 +124,8 @@ class SendJoinResult:
|
||||
origin: str
|
||||
state: list[EventBase]
|
||||
auth_chain: list[EventBase]
|
||||
# Only valid for state DAG rooms (MSC4242)
|
||||
state_dag: list[EventBase] | None
|
||||
|
||||
# True if 'state' elides non-critical membership events
|
||||
partial_state: bool
|
||||
@@ -658,7 +665,10 @@ class FederationClient(FederationBase):
|
||||
@trace
|
||||
@tag_args
|
||||
async def get_room_state_ids(
|
||||
self, destination: str, room_id: str, event_id: str
|
||||
self,
|
||||
destination: str,
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""Calls the /state_ids endpoint to fetch the state at a particular point
|
||||
in the room, and the auth events for the given event
|
||||
@@ -670,7 +680,9 @@ class FederationClient(FederationBase):
|
||||
InvalidResponseError: if fields in the response have the wrong type.
|
||||
"""
|
||||
result = await self.transport_layer.get_room_state_ids(
|
||||
destination, room_id, event_id=event_id
|
||||
destination,
|
||||
room_id,
|
||||
event_id=event_id,
|
||||
)
|
||||
|
||||
state_event_ids = result["pdu_ids"]
|
||||
@@ -1178,7 +1190,6 @@ class FederationClient(FederationBase):
|
||||
response = await self._do_send_join(
|
||||
room_version, destination, pdu, omit_members=partial_state
|
||||
)
|
||||
|
||||
# If an event was returned (and expected to be returned):
|
||||
#
|
||||
# * Ensure it has the same event ID (note that the event ID is a hash
|
||||
@@ -1201,14 +1212,22 @@ class FederationClient(FederationBase):
|
||||
event = pdu
|
||||
|
||||
state = response.state
|
||||
auth_chain = response.auth_events
|
||||
|
||||
auth_events = response.auth_events
|
||||
create_event = None
|
||||
for e in state:
|
||||
if (e.type, e.state_key) == (EventTypes.Create, ""):
|
||||
create_event = e
|
||||
break
|
||||
|
||||
if room_version.msc4242_state_dags and response.state_dag:
|
||||
# assign to auth_events to reuse the below code which ultimately just does
|
||||
# sig/hash checks. We'll set the right field in SendJoinResult later.
|
||||
auth_events = response.state_dag
|
||||
for e in response.state_dag:
|
||||
if (e.type, e.state_key) == (EventTypes.Create, ""):
|
||||
create_event = e
|
||||
break
|
||||
|
||||
if create_event is None:
|
||||
# If the state doesn't have a create event then the room is
|
||||
# invalid, and it would fail auth checks anyway.
|
||||
@@ -1227,7 +1246,7 @@ class FederationClient(FederationBase):
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Processing from send_join %d events", len(state) + len(auth_chain)
|
||||
"Processing from send_join %d events", len(state) + len(auth_events)
|
||||
)
|
||||
|
||||
# We now go and check the signatures and hashes for the event. Note
|
||||
@@ -1246,7 +1265,7 @@ class FederationClient(FederationBase):
|
||||
valid_pdus_map[valid_pdu.event_id] = valid_pdu
|
||||
|
||||
await concurrently_execute(
|
||||
_execute, itertools.chain(state, auth_chain), 10000
|
||||
_execute, itertools.chain(state, auth_events), 10000
|
||||
)
|
||||
|
||||
# NB: We *need* to copy to ensure that we don't have multiple
|
||||
@@ -1259,27 +1278,28 @@ class FederationClient(FederationBase):
|
||||
|
||||
signed_auth = [
|
||||
valid_pdus_map[p.event_id]
|
||||
for p in auth_chain
|
||||
for p in auth_events
|
||||
if p.event_id in valid_pdus_map
|
||||
]
|
||||
|
||||
# NB: We *need* to copy to ensure that we don't have multiple
|
||||
# references being passed on, as that causes... issues.
|
||||
# TODO(kegan): It's unclear why we only need to do this for state and not auth_events
|
||||
for s in signed_state:
|
||||
s.internal_metadata = s.internal_metadata.copy()
|
||||
|
||||
# double-check that the auth chain doesn't include a different create event
|
||||
auth_chain_create_events = [
|
||||
# double-check that the auth events doesn't include a different create event
|
||||
auth_events_create_events = [
|
||||
e.event_id
|
||||
for e in signed_auth
|
||||
if (e.type, e.state_key) == (EventTypes.Create, "")
|
||||
]
|
||||
if auth_chain_create_events and auth_chain_create_events != [
|
||||
if auth_events_create_events and auth_events_create_events != [
|
||||
create_event.event_id
|
||||
]:
|
||||
raise InvalidResponseError(
|
||||
"Unexpected create event(s) in auth chain: %s"
|
||||
% (auth_chain_create_events,)
|
||||
% (auth_events_create_events,)
|
||||
)
|
||||
|
||||
servers_in_room = None
|
||||
@@ -1301,10 +1321,21 @@ class FederationClient(FederationBase):
|
||||
# Fix things up in case the remote homeserver is badly behaved.
|
||||
servers_in_room.add(destination)
|
||||
|
||||
signed_auth_events = signed_auth
|
||||
signed_state_dag = None
|
||||
if room_version.msc4242_state_dags:
|
||||
# We previously set the state dag to auth_events so re-assign it correctly
|
||||
signed_state_dag = signed_auth
|
||||
# Ensure the caller cannot accidentally use these values even if the server
|
||||
# returned them.
|
||||
signed_state = []
|
||||
signed_auth_events = []
|
||||
|
||||
return SendJoinResult(
|
||||
event=event,
|
||||
state=signed_state,
|
||||
auth_chain=signed_auth,
|
||||
auth_chain=signed_auth_events,
|
||||
state_dag=signed_state_dag,
|
||||
origin=destination,
|
||||
partial_state=response.members_omitted,
|
||||
servers_in_room=servers_in_room or frozenset(),
|
||||
@@ -1608,6 +1639,7 @@ class FederationClient(FederationBase):
|
||||
limit: int,
|
||||
min_depth: int,
|
||||
timeout: int,
|
||||
state_dag: bool = False,
|
||||
) -> list[EventBase]:
|
||||
"""Tries to fetch events we are missing. This is called when we receive
|
||||
an event without having received all of its ancestors.
|
||||
@@ -1623,6 +1655,7 @@ class FederationClient(FederationBase):
|
||||
limit: Maximum number of events to return.
|
||||
min_depth: Minimum depth of events to return.
|
||||
timeout: Max time to wait in ms
|
||||
state_dag: True to walk the state DAG (MSC4242 rooms)
|
||||
"""
|
||||
try:
|
||||
content = await self.transport_layer.get_missing_events(
|
||||
@@ -1633,6 +1666,7 @@ class FederationClient(FederationBase):
|
||||
limit=limit,
|
||||
min_depth=min_depth,
|
||||
timeout=timeout,
|
||||
state_dag=state_dag,
|
||||
)
|
||||
|
||||
room_version = await self.store.get_room_version(room_id)
|
||||
|
||||
@@ -28,6 +28,8 @@ from typing import (
|
||||
Callable,
|
||||
Collection,
|
||||
Mapping,
|
||||
Sequence,
|
||||
cast,
|
||||
)
|
||||
|
||||
from prometheus_client import Counter, Gauge, Histogram
|
||||
@@ -53,7 +55,7 @@ from synapse.api.errors import (
|
||||
)
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
||||
from synapse.crypto.event_signing import compute_event_signature
|
||||
from synapse.events import EventBase
|
||||
from synapse.events import EventBase, FrozenEventVMSC4242
|
||||
from synapse.events.snapshot import EventPersistencePair
|
||||
from synapse.federation.federation_base import (
|
||||
FederationBase,
|
||||
@@ -596,7 +598,10 @@ class FederationServer(FederationBase):
|
||||
)
|
||||
|
||||
async def on_room_state_request(
|
||||
self, origin: str, room_id: str, event_id: str
|
||||
self,
|
||||
origin: str,
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
) -> tuple[int, JsonDict]:
|
||||
await self._event_auth_handler.assert_host_in_room(room_id, origin)
|
||||
origin_host, _ = parse_server_name(origin)
|
||||
@@ -620,7 +625,10 @@ class FederationServer(FederationBase):
|
||||
@trace
|
||||
@tag_args
|
||||
async def on_state_ids_request(
|
||||
self, origin: str, room_id: str, event_id: str
|
||||
self,
|
||||
origin: str,
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
) -> tuple[int, JsonDict]:
|
||||
if not event_id:
|
||||
raise NotImplementedError("Specify an event")
|
||||
@@ -641,17 +649,27 @@ class FederationServer(FederationBase):
|
||||
@trace
|
||||
@tag_args
|
||||
async def _on_state_ids_request_compute(
|
||||
self, room_id: str, event_id: str
|
||||
self,
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
) -> JsonDict:
|
||||
state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
|
||||
state_ids = await self.handler.get_state_ids_for_pdu(
|
||||
room_id,
|
||||
event_id,
|
||||
)
|
||||
auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids)
|
||||
return {"pdu_ids": state_ids, "auth_chain_ids": list(auth_chain_ids)}
|
||||
|
||||
async def _on_context_state_request_compute(
|
||||
self, room_id: str, event_id: str
|
||||
self,
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
) -> dict[str, list]:
|
||||
pdus: Collection[EventBase]
|
||||
event_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
|
||||
event_ids = await self.handler.get_state_ids_for_pdu(
|
||||
room_id,
|
||||
event_id,
|
||||
)
|
||||
pdus = await self.store.get_events_as_list(event_ids)
|
||||
|
||||
auth_chain = await self.store.get_auth_chain(
|
||||
@@ -759,6 +777,10 @@ class FederationServer(FederationBase):
|
||||
origin, content, Membership.JOIN, room_id
|
||||
)
|
||||
|
||||
if event.room_version.msc4242_state_dags and caller_supports_partial_state:
|
||||
# TODO(kegan): for now, MSC4242 won't support partial state for ease of prototyping.
|
||||
caller_supports_partial_state = False
|
||||
|
||||
prev_state_ids = await context.get_prev_state_ids()
|
||||
|
||||
state_event_ids: Collection[str]
|
||||
@@ -769,15 +791,31 @@ class FederationServer(FederationBase):
|
||||
event, prev_state_ids, summary
|
||||
)
|
||||
servers_in_room = await self.state.get_hosts_in_room_at_events(
|
||||
room_id, event_ids=event.prev_event_ids()
|
||||
room_id,
|
||||
event_ids=event.prev_state_events
|
||||
if isinstance(event, FrozenEventVMSC4242)
|
||||
else event.prev_event_ids(),
|
||||
)
|
||||
else:
|
||||
state_event_ids = prev_state_ids.values()
|
||||
servers_in_room = None
|
||||
|
||||
auth_chain_event_ids = await self.store.get_auth_chain_ids(
|
||||
room_id, state_event_ids
|
||||
)
|
||||
state_dag = None
|
||||
auth_chain_event_ids = set()
|
||||
if event.room_version.msc4242_state_dags:
|
||||
assert isinstance(event, FrozenEventVMSC4242)
|
||||
# NOTE: we don't return the state dag for forward extremities that aren't part of this
|
||||
# join event to make it easier for the receiving server to set their own forward
|
||||
# extremities (they are equal to the join event's prev_state_events). This means we may
|
||||
# fail to sync concurrent forks not on the path to the join event, but this is an
|
||||
# outstanding problem in general.
|
||||
state_dag = await self.store.get_state_dag(
|
||||
room_id, set(event.prev_state_events)
|
||||
)
|
||||
else:
|
||||
auth_chain_event_ids = await self.store.get_auth_chain_ids(
|
||||
room_id, state_event_ids
|
||||
)
|
||||
|
||||
# if the caller has opted in, we can omit any auth_chain events which are
|
||||
# already in state_event_ids
|
||||
@@ -794,13 +832,18 @@ class FederationServer(FederationBase):
|
||||
resp = {
|
||||
"event": event_json,
|
||||
"state": serialize_and_filter_pdus(state_events, time_now),
|
||||
"auth_chain": serialize_and_filter_pdus(auth_chain_events, time_now),
|
||||
"members_omitted": caller_supports_partial_state,
|
||||
}
|
||||
if state_dag is None:
|
||||
resp["auth_chain"] = serialize_and_filter_pdus(auth_chain_events, time_now)
|
||||
else:
|
||||
resp["state_dag"] = serialize_and_filter_pdus(
|
||||
cast(Sequence[EventBase], state_dag.values()), time_now
|
||||
)
|
||||
del resp["state"]
|
||||
|
||||
if servers_in_room is not None:
|
||||
resp["servers_in_room"] = list(servers_in_room)
|
||||
|
||||
return resp
|
||||
|
||||
async def on_make_leave_request(
|
||||
@@ -1097,6 +1140,7 @@ class FederationServer(FederationBase):
|
||||
earliest_events: list[str],
|
||||
latest_events: list[str],
|
||||
limit: int,
|
||||
walk_state_dag: bool = False,
|
||||
) -> dict[str, list]:
|
||||
async with self._server_linearizer.queue((origin, room_id)):
|
||||
origin_host, _ = parse_server_name(origin)
|
||||
@@ -1111,7 +1155,7 @@ class FederationServer(FederationBase):
|
||||
)
|
||||
|
||||
missing_events = await self.handler.on_get_missing_events(
|
||||
origin, room_id, earliest_events, latest_events, limit
|
||||
origin, room_id, earliest_events, latest_events, limit, walk_state_dag
|
||||
)
|
||||
|
||||
if len(missing_events) < 5:
|
||||
|
||||
@@ -148,7 +148,7 @@ from twisted.internet import defer
|
||||
import synapse.metrics
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.presence import UserPresenceState
|
||||
from synapse.events import EventBase
|
||||
from synapse.events import EventBase, FrozenEventVMSC4242
|
||||
from synapse.federation.sender.per_destination_queue import (
|
||||
CATCHUP_RETRY_INTERVAL,
|
||||
PerDestinationQueue,
|
||||
@@ -660,7 +660,10 @@ class FederationSender(AbstractFederationSender):
|
||||
# banned then it won't receive the event because it won't
|
||||
# be in the room after the ban.
|
||||
destinations = await self.state.get_hosts_in_room_at_events(
|
||||
event.room_id, event_ids=event.prev_event_ids()
|
||||
event.room_id,
|
||||
event_ids=event.prev_state_events
|
||||
if isinstance(event, FrozenEventVMSC4242)
|
||||
else event.prev_event_ids(),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
|
||||
@@ -69,7 +69,10 @@ class TransportLayerClient:
|
||||
self.client.shutdown()
|
||||
|
||||
async def get_room_state_ids(
|
||||
self, destination: str, room_id: str, event_id: str
|
||||
self,
|
||||
destination: str,
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
) -> JsonDict:
|
||||
"""Requests the IDs of all state for a given room at the given event.
|
||||
|
||||
@@ -78,22 +81,26 @@ class TransportLayerClient:
|
||||
to get the state from.
|
||||
room_id: the room we want the state of
|
||||
event_id: The event we want the context at.
|
||||
|
||||
Returns:
|
||||
Results in a dict received from the remote homeserver.
|
||||
"""
|
||||
logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id)
|
||||
|
||||
path = _create_v1_path("/state_ids/%s", room_id)
|
||||
qps = {"event_id": event_id}
|
||||
return await self.client.get_json(
|
||||
destination,
|
||||
path=path,
|
||||
args={"event_id": event_id},
|
||||
args=qps,
|
||||
try_trailing_slash_on_400=True,
|
||||
)
|
||||
|
||||
async def get_room_state(
|
||||
self, room_version: RoomVersion, destination: str, room_id: str, event_id: str
|
||||
self,
|
||||
room_version: RoomVersion,
|
||||
destination: str,
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
) -> "StateRequestResponse":
|
||||
"""Requests the full state for a given room at the given event.
|
||||
|
||||
@@ -103,15 +110,15 @@ class TransportLayerClient:
|
||||
to get the state from.
|
||||
room_id: the room we want the state of
|
||||
event_id: The event we want the context at.
|
||||
|
||||
Returns:
|
||||
Results in a dict received from the remote homeserver.
|
||||
"""
|
||||
path = _create_v1_path("/state/%s", room_id)
|
||||
qps = {"event_id": event_id}
|
||||
return await self.client.get_json(
|
||||
destination,
|
||||
path=path,
|
||||
args={"event_id": event_id},
|
||||
args=qps,
|
||||
# This can take a looooooong time for large rooms. Give this a generous
|
||||
# timeout of 10 minutes to avoid the partial state resync timing out early
|
||||
# and trying a bunch of servers who haven't seen our join yet.
|
||||
@@ -787,18 +794,21 @@ class TransportLayerClient:
|
||||
limit: int,
|
||||
min_depth: int,
|
||||
timeout: int,
|
||||
state_dag: bool,
|
||||
) -> JsonDict:
|
||||
path = _create_v1_path("/get_missing_events/%s", room_id)
|
||||
|
||||
data = {
|
||||
"limit": int(limit),
|
||||
"min_depth": int(min_depth),
|
||||
"earliest_events": earliest_events,
|
||||
"latest_events": latest_events,
|
||||
}
|
||||
if state_dag:
|
||||
data["org.matrix.msc4242.state_dag"] = True
|
||||
return await self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data={
|
||||
"limit": int(limit),
|
||||
"min_depth": int(min_depth),
|
||||
"earliest_events": earliest_events,
|
||||
"latest_events": latest_events,
|
||||
},
|
||||
data=data,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
@@ -993,6 +1003,9 @@ class SendJoinResponse:
|
||||
state: list[EventBase]
|
||||
# The raw join event from the /send_join response.
|
||||
event_dict: JsonDict
|
||||
# MSC4242: State DAGs. Always included for state dag rooms, else an empty list.
|
||||
# Replaces auth_events.
|
||||
state_dag: list[EventBase]
|
||||
# The parsed join event from the /send_join response. This will be None if
|
||||
# "event" is not included in the response.
|
||||
event: EventBase | None = None
|
||||
@@ -1079,7 +1092,7 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
|
||||
MAX_RESPONSE_SIZE = 500 * 1024 * 1024
|
||||
|
||||
def __init__(self, room_version: RoomVersion, v1_api: bool):
|
||||
self._response = SendJoinResponse([], [], event_dict={})
|
||||
self._response = SendJoinResponse([], [], event_dict={}, state_dag=[])
|
||||
self._room_version = room_version
|
||||
self._coros: list[Generator[None, bytes, None]] = []
|
||||
|
||||
@@ -1123,6 +1136,15 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
|
||||
)
|
||||
)
|
||||
|
||||
if room_version.msc4242_state_dags:
|
||||
self._coros.append(
|
||||
ijson.items_coro(
|
||||
_event_list_parser(room_version, self._response.state_dag),
|
||||
prefix + "state_dag.item",
|
||||
use_float=True,
|
||||
)
|
||||
)
|
||||
|
||||
def write(self, data: bytes) -> int:
|
||||
for c in self._coros:
|
||||
c.send(data)
|
||||
|
||||
@@ -621,6 +621,7 @@ class FederationGetMissingEventsServlet(BaseFederationServerServlet):
|
||||
limit = int(content.get("limit", 10))
|
||||
earliest_events = content.get("earliest_events", [])
|
||||
latest_events = content.get("latest_events", [])
|
||||
walk_state_dag = content.get("org.matrix.msc4242.state_dag", False)
|
||||
|
||||
result = await self.handler.on_get_missing_events(
|
||||
origin,
|
||||
@@ -628,6 +629,7 @@ class FederationGetMissingEventsServlet(BaseFederationServerServlet):
|
||||
earliest_events=earliest_events,
|
||||
latest_events=latest_events,
|
||||
limit=limit,
|
||||
walk_state_dag=walk_state_dag,
|
||||
)
|
||||
|
||||
return 200, result
|
||||
|
||||
@@ -669,6 +669,13 @@ class FederationHandler:
|
||||
|
||||
logger.debug("do_invite_join auth_chain: %s", auth_chain)
|
||||
logger.debug("do_invite_join state: %s", state)
|
||||
# If the 'state_dag' field is set, everything will be derived from it.
|
||||
if ret.state_dag:
|
||||
logger.debug("do_invite_join state_dag: %s", ret.state_dag)
|
||||
state = ret.state_dag
|
||||
# TODO(kegan): is this actually important for processing? Shouldn't we be using
|
||||
# the actual DAG here?
|
||||
state.sort(key=lambda e: e.depth)
|
||||
|
||||
logger.debug("do_invite_join event: %s", event)
|
||||
|
||||
@@ -965,6 +972,7 @@ class FederationHandler:
|
||||
# Note that this requires the /send_join request to come back to the
|
||||
# same server.
|
||||
prev_event_ids = None
|
||||
prev_state_events = None
|
||||
if room_version.restricted_join_rule:
|
||||
# Note that the room's state can change out from under us and render our
|
||||
# nice join rules-conformant event non-conformant by the time we build the
|
||||
@@ -981,6 +989,9 @@ class FederationHandler:
|
||||
state_ids = await self._state_storage_controller.get_current_state_ids(
|
||||
room_id
|
||||
)
|
||||
prev_state_events = list(
|
||||
await self.store.get_state_dag_extremities(room_id)
|
||||
)
|
||||
if await self._event_auth_handler.has_restricted_join_rules(
|
||||
state_ids, room_version
|
||||
):
|
||||
@@ -1021,6 +1032,7 @@ class FederationHandler:
|
||||
) = await self.event_creation_handler.create_new_client_event(
|
||||
builder=builder,
|
||||
prev_event_ids=prev_event_ids,
|
||||
prev_state_events=prev_state_events,
|
||||
)
|
||||
except SynapseError as e:
|
||||
logger.warning("Failed to create join to %s because %s", room_id, e)
|
||||
@@ -1301,7 +1313,11 @@ class FederationHandler:
|
||||
|
||||
@trace
|
||||
@tag_args
|
||||
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> list[str]:
|
||||
async def get_state_ids_for_pdu(
|
||||
self,
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
) -> list[str]:
|
||||
"""Returns the state at the event. i.e. not including said event."""
|
||||
event = await self.store.get_event(event_id, check_room_id=room_id)
|
||||
if event.internal_metadata.outlier:
|
||||
@@ -1412,11 +1428,17 @@ class FederationHandler:
|
||||
earliest_events: list[str],
|
||||
latest_events: list[str],
|
||||
limit: int,
|
||||
walk_state_dag: bool,
|
||||
) -> list[EventBase]:
|
||||
# We allow partially joined rooms since in this case we are filtering out
|
||||
# non-local events in `filter_events_for_server`.
|
||||
await self._event_auth_handler.assert_host_in_room(room_id, origin, True)
|
||||
|
||||
if walk_state_dag:
|
||||
return await self.on_get_missing_events_state_dag(
|
||||
room_id, earliest_events, latest_events, limit
|
||||
)
|
||||
|
||||
# Only allow up to 20 events to be retrieved per request.
|
||||
limit = min(limit, 20)
|
||||
|
||||
@@ -1439,6 +1461,32 @@ class FederationHandler:
|
||||
|
||||
return missing_events
|
||||
|
||||
async def on_get_missing_events_state_dag(
|
||||
self,
|
||||
room_id: str,
|
||||
earliest_events: list[str],
|
||||
latest_events: list[str],
|
||||
limit: int,
|
||||
) -> list[EventBase]:
|
||||
"""Processes a /get_missing_events request for the state DAG.
|
||||
|
||||
This is similar to processing the normal DAG with a few notable exceptions:
|
||||
* There is no max 20 limit applied. As the entire state DAG needs to be filled in, we
|
||||
cannot arbitrarily set a low limit. If the state DAG delta is 1000s of events, we need
|
||||
to rely on the sender to set sensible limits depending on the bandwidth/round trip
|
||||
tradeoff. We rely on existing server rate limits here to prevent abuse.
|
||||
* We do not filter any events in the state DAG. History visibility does not filter out
|
||||
delivery of auth chain events, so neither should this. All of the returned events will
|
||||
be treated as outliers and as such will not be delivered to clients.
|
||||
"""
|
||||
missing_events = await self.store.get_missing_events_state_dag(
|
||||
room_id=room_id,
|
||||
earliest_events=earliest_events,
|
||||
latest_events=latest_events,
|
||||
limit=limit,
|
||||
)
|
||||
return missing_events
|
||||
|
||||
async def exchange_third_party_invite(
|
||||
self, sender_user_id: str, target_user_id: str, room_id: str, signed: JsonDict
|
||||
) -> None:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1713,7 +1713,6 @@ class DatabasePool:
|
||||
", ".join(key_names),
|
||||
latter,
|
||||
)
|
||||
|
||||
txn.execute_values(sql, args, fetch=False)
|
||||
|
||||
else:
|
||||
|
||||
@@ -37,7 +37,7 @@ from prometheus_client import Counter, Gauge
|
||||
from synapse.api.constants import MAX_DEPTH
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.api.room_versions import EventFormatVersions, RoomVersion
|
||||
from synapse.events import EventBase, make_event_from_dict
|
||||
from synapse.events import EventBase, FrozenEventVMSC4242, make_event_from_dict
|
||||
from synapse.logging.opentracing import tag_args, trace
|
||||
from synapse.metrics import SERVER_NAME_LABEL
|
||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||
@@ -258,6 +258,59 @@ class EventFederationWorkerStore(
|
||||
include_given,
|
||||
)
|
||||
|
||||
async def get_state_dag(
|
||||
self, room_id: str, forward_extrems: set[str]
|
||||
) -> dict[str, FrozenEventVMSC4242]:
|
||||
"""Get the current state DAG for the given room.
|
||||
|
||||
This function is called when calculating a /send_join response.
|
||||
This does not check that the room is an state DAG room, so check this
|
||||
before calling this function!
|
||||
|
||||
This functions guarantees that the returned state DAG is connected.
|
||||
|
||||
Args:
|
||||
room_id: The room to get the state dag for
|
||||
Returns:
|
||||
A map of event_id => event
|
||||
"""
|
||||
|
||||
def _get_state_events_txn(txn: LoggingTransaction, room_id: str) -> list[str]:
|
||||
sql = """
|
||||
SELECT event_id FROM msc4242_state_dag_edges WHERE room_id = ?
|
||||
"""
|
||||
txn.execute(sql, (room_id,))
|
||||
event_ids = [ev_id for (ev_id,) in txn]
|
||||
return event_ids
|
||||
|
||||
# Pull out all auth events for this room using the events_by_room_and_type index.
|
||||
# This is going to pull in erroneous events e.g m.room.members with no state keys but that's
|
||||
# okay as we'll filter them out next.
|
||||
event_ids = await self.db_pool.runInteraction(
|
||||
"_get_state_events_txn",
|
||||
_get_state_events_txn,
|
||||
room_id,
|
||||
)
|
||||
event_map = await self.get_events(event_ids)
|
||||
# Filter the returned state events to only include ones on the paths back from the forward
|
||||
# extremities.
|
||||
result: dict[str, FrozenEventVMSC4242] = {}
|
||||
next = forward_extrems
|
||||
seen: set[str] = set()
|
||||
while len(next) > 0:
|
||||
# Pull the event and add the prev_state_events.
|
||||
# We must have the event.
|
||||
event_id = next.pop()
|
||||
if event_id in seen:
|
||||
continue
|
||||
seen.add(event_id)
|
||||
ev = event_map[event_id]
|
||||
assert isinstance(ev, FrozenEventVMSC4242)
|
||||
result[event_id] = ev
|
||||
for pae in ev.prev_state_events:
|
||||
next.add(pae)
|
||||
return result
|
||||
|
||||
def _get_auth_chain_ids_using_cover_index_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
@@ -1987,6 +2040,57 @@ class EventFederationWorkerStore(
|
||||
event_results.reverse()
|
||||
return event_results
|
||||
|
||||
async def get_missing_events_state_dag(
|
||||
self,
|
||||
room_id: str,
|
||||
earliest_events: list[str],
|
||||
latest_events: list[str],
|
||||
limit: int,
|
||||
) -> list[EventBase]:
|
||||
ids = await self.db_pool.runInteraction(
|
||||
"get_missing_events_state_dag",
|
||||
self._get_missing_events_state_dag,
|
||||
room_id,
|
||||
earliest_events,
|
||||
latest_events,
|
||||
limit,
|
||||
)
|
||||
return await self.get_events_as_list(ids)
|
||||
|
||||
def _get_missing_events_state_dag(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
room_id: str,
|
||||
earliest_events: list[str],
|
||||
latest_events: list[str],
|
||||
limit: int,
|
||||
) -> list[str]:
|
||||
seen_events = set(earliest_events)
|
||||
# ascii sort
|
||||
front_queue = sorted(set(latest_events) - seen_events)
|
||||
event_results: list[str] = []
|
||||
# TODO(kegan): use a recursive CTE?
|
||||
query = (
|
||||
"SELECT prev_state_event_id FROM msc4242_state_dag_edges "
|
||||
"WHERE room_id = ? AND event_id = ? "
|
||||
"ORDER BY prev_state_event_id ASC "
|
||||
"LIMIT ?"
|
||||
)
|
||||
|
||||
while front_queue and len(event_results) < limit:
|
||||
event_id = front_queue.pop(0)
|
||||
txn.execute(query, (room_id, event_id, limit - len(event_results)))
|
||||
# None check because the m.room.create event has NULL prev_state_events
|
||||
new_results = [
|
||||
t[0] for t in txn if t[0] is not None and t[0] not in seen_events
|
||||
]
|
||||
for next in new_results:
|
||||
front_queue.append(next)
|
||||
seen_events |= set(new_results)
|
||||
event_results.extend(new_results)
|
||||
|
||||
return event_results
|
||||
|
||||
@trace
|
||||
@tag_args
|
||||
async def get_successor_events(self, event_id: str) -> list[str]:
|
||||
|
||||
@@ -395,6 +395,7 @@ class OutOfBandMembershipTests(unittest.FederatingHomeserverTestCase):
|
||||
user1_invite_membership_event,
|
||||
],
|
||||
event_dict=user1_join_membership_event_signed.get_pdu_json(),
|
||||
state_dag=[],
|
||||
event=user1_join_membership_event_signed,
|
||||
members_omitted=False,
|
||||
servers_in_room=[
|
||||
|
||||
@@ -399,36 +399,54 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
|
||||
|
||||
# we should get complete room state back
|
||||
returned_state = [
|
||||
(ev["type"], ev["state_key"]) for ev in channel.json_body["state"]
|
||||
]
|
||||
self.assertCountEqual(
|
||||
returned_state,
|
||||
[
|
||||
("m.room.create", ""),
|
||||
("m.room.power_levels", ""),
|
||||
("m.room.join_rules", ""),
|
||||
("m.room.history_visibility", ""),
|
||||
("m.room.member", f"@kermit_v{room_version}:test"),
|
||||
("m.room.member", f"@fozzie_v{room_version}:test"),
|
||||
# nb: *not* the joining user
|
||||
],
|
||||
)
|
||||
if KNOWN_ROOM_VERSIONS[room_version].msc4242_state_dags:
|
||||
# we should get complete state dag back
|
||||
returned_state_dag = [
|
||||
(ev["type"], ev["state_key"]) for ev in channel.json_body["state_dag"]
|
||||
]
|
||||
self.assertCountEqual(
|
||||
returned_state_dag,
|
||||
[
|
||||
("m.room.create", ""),
|
||||
("m.room.power_levels", ""),
|
||||
("m.room.join_rules", ""),
|
||||
("m.room.history_visibility", ""),
|
||||
("m.room.member", f"@kermit_v{room_version}:test"),
|
||||
("m.room.member", f"@fozzie_v{room_version}:test"),
|
||||
# nb: *not* the joining user
|
||||
],
|
||||
)
|
||||
else:
|
||||
# we should get complete room state back
|
||||
returned_state = [
|
||||
(ev["type"], ev["state_key"]) for ev in channel.json_body["state"]
|
||||
]
|
||||
self.assertCountEqual(
|
||||
returned_state,
|
||||
[
|
||||
("m.room.create", ""),
|
||||
("m.room.power_levels", ""),
|
||||
("m.room.join_rules", ""),
|
||||
("m.room.history_visibility", ""),
|
||||
("m.room.member", f"@kermit_v{room_version}:test"),
|
||||
("m.room.member", f"@fozzie_v{room_version}:test"),
|
||||
# nb: *not* the joining user
|
||||
],
|
||||
)
|
||||
|
||||
# also check the auth chain
|
||||
returned_auth_chain_events = [
|
||||
(ev["type"], ev["state_key"]) for ev in channel.json_body["auth_chain"]
|
||||
]
|
||||
self.assertCountEqual(
|
||||
returned_auth_chain_events,
|
||||
[
|
||||
("m.room.create", ""),
|
||||
("m.room.member", f"@kermit_v{room_version}:test"),
|
||||
("m.room.power_levels", ""),
|
||||
("m.room.join_rules", ""),
|
||||
],
|
||||
)
|
||||
# also check the auth chain
|
||||
returned_auth_chain_events = [
|
||||
(ev["type"], ev["state_key"]) for ev in channel.json_body["auth_chain"]
|
||||
]
|
||||
self.assertCountEqual(
|
||||
returned_auth_chain_events,
|
||||
[
|
||||
("m.room.create", ""),
|
||||
("m.room.member", f"@kermit_v{room_version}:test"),
|
||||
("m.room.power_levels", ""),
|
||||
("m.room.join_rules", ""),
|
||||
],
|
||||
)
|
||||
|
||||
# the room should show that the new user is a member
|
||||
r = self.get_success(self._storage_controllers.state.get_current_state(room_id))
|
||||
@@ -438,18 +456,12 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
|
||||
@override_config({"use_frozen_dicts": True})
|
||||
def test_send_join_with_frozen_dicts(self, room_version: str) -> None:
|
||||
"""Test send_join with USE_FROZEN_DICTS=True"""
|
||||
if room_version == RoomVersions.MSC4242v12.identifier:
|
||||
# TODO: This room version doesn't work over federation in this PR.
|
||||
return
|
||||
self._test_send_join_common(room_version)
|
||||
|
||||
@parameterized.expand([(k,) for k in KNOWN_ROOM_VERSIONS.keys()])
|
||||
@override_config({"use_frozen_dicts": False})
|
||||
def test_send_join_without_frozen_dicts(self, room_version: str) -> None:
|
||||
"""Test send_join with USE_FROZEN_DICTS=False"""
|
||||
if room_version == RoomVersions.MSC4242v12.identifier:
|
||||
# TODO: This room version doesn't work over federation in this PR.
|
||||
return
|
||||
self._test_send_join_common(room_version)
|
||||
|
||||
def test_send_join_partial_state(self) -> None:
|
||||
|
||||
@@ -652,6 +652,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
|
||||
],
|
||||
partial_state=True,
|
||||
servers_in_room={"example.com"},
|
||||
state_dag=None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -18,19 +18,23 @@
|
||||
# [This file includes modifications made by New Vector Limited]
|
||||
#
|
||||
#
|
||||
from typing import Iterable
|
||||
from unittest import mock
|
||||
|
||||
from twisted.internet.testing import MemoryReactor
|
||||
|
||||
from synapse.api.errors import AuthError, StoreError
|
||||
from synapse.api.room_versions import RoomVersion
|
||||
from synapse.api.room_versions import RoomVersion, RoomVersions
|
||||
from synapse.event_auth import (
|
||||
check_state_dependent_auth_rules,
|
||||
check_state_independent_auth_rules,
|
||||
)
|
||||
from synapse.events import make_event_from_dict
|
||||
from synapse.events import EventBase, FrozenEventVMSC4242, make_event_from_dict
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.federation.transport.client import StateRequestResponse
|
||||
from synapse.handlers.federation_event import (
|
||||
is_state_dag_connected,
|
||||
)
|
||||
from synapse.logging.context import LoggingContext
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import login, room
|
||||
@@ -1184,3 +1188,298 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
|
||||
bert_member_event.event_id,
|
||||
"Rejected kick event unexpectedly became part of room state.",
|
||||
)
|
||||
|
||||
|
||||
MSC4242_ROOM_ID = "!msc4242:example.com"
|
||||
counter = 1
|
||||
|
||||
|
||||
class FederationEventMSC4242AuthDAGTests(unittest.FederatingHomeserverTestCase):
|
||||
def test_is_state_dag_connected(self) -> None:
|
||||
linear_1 = msc4242_event([])
|
||||
linear_2 = msc4242_event([linear_1.event_id])
|
||||
linear_3 = msc4242_event([linear_2.event_id])
|
||||
self.assertEqual(is_state_dag_connected([linear_3, linear_1, linear_2]), True)
|
||||
|
||||
fork_1 = msc4242_event([])
|
||||
fork_2 = msc4242_event([fork_1.event_id])
|
||||
fork_3 = msc4242_event([fork_1.event_id])
|
||||
fork_4 = msc4242_event([fork_1.event_id])
|
||||
fork_5 = msc4242_event([fork_2.event_id, fork_4.event_id])
|
||||
self.assertEqual(
|
||||
is_state_dag_connected([fork_1, fork_2, fork_3, fork_4, fork_5]), True
|
||||
)
|
||||
|
||||
unconnected_1 = msc4242_event([])
|
||||
unconnected_2 = msc4242_event([unconnected_1.event_id])
|
||||
unconnected_3 = msc4242_event(["$unknown"])
|
||||
self.assertEqual(
|
||||
is_state_dag_connected([unconnected_1, unconnected_2, unconnected_3]), False
|
||||
)
|
||||
|
||||
def test_fetch_missing_state_dag_events_linear(self) -> None:
|
||||
linear = self.make_state_dag(
|
||||
{
|
||||
"A": [],
|
||||
"B": ["A"],
|
||||
"C": ["B"],
|
||||
}
|
||||
)
|
||||
seen_events: set[str] = set()
|
||||
get_missing_events_req_resps = {
|
||||
("C",): ["A", "B"],
|
||||
}
|
||||
self.prepare_handler(seen_events, linear, get_missing_events_req_resps)
|
||||
self.assert_fetch_missing_state_dag_events(linear, "C", ["A", "B"])
|
||||
|
||||
def test_fetch_missing_state_dag_events_linear_seen(self) -> None:
|
||||
linear = self.make_state_dag(
|
||||
{
|
||||
"A": [],
|
||||
"B": ["A"],
|
||||
"C": ["B"],
|
||||
"D": ["C"],
|
||||
}
|
||||
)
|
||||
seen_events: set[str] = {"A", "B"}
|
||||
get_missing_events_req_resps = {
|
||||
("D",): ["C"],
|
||||
("C",): ["B"],
|
||||
}
|
||||
self.prepare_handler(seen_events, linear, get_missing_events_req_resps)
|
||||
self.assert_fetch_missing_state_dag_events(linear, "D", ["C"])
|
||||
|
||||
def test_fetch_missing_state_dag_events_fork_merge(self) -> None:
|
||||
fork_merge = self.make_state_dag(
|
||||
{
|
||||
"A": [],
|
||||
"B": ["A"],
|
||||
"C": ["A"],
|
||||
"D": ["C"],
|
||||
"E": ["B"],
|
||||
"F": ["D", "E"],
|
||||
}
|
||||
)
|
||||
seen_events: set[str] = set()
|
||||
get_missing_events_req_resps = {
|
||||
("F",): ["D", "E"],
|
||||
(
|
||||
"D",
|
||||
"E",
|
||||
): ["C", "A"],
|
||||
("E",): ["B", "A"],
|
||||
}
|
||||
self.prepare_handler(seen_events, fork_merge, get_missing_events_req_resps)
|
||||
self.assert_fetch_missing_state_dag_events(
|
||||
fork_merge, "F", ["A", "B", "C", "D", "E"]
|
||||
)
|
||||
|
||||
def test_fetch_missing_state_dag_events_give_up_no_forward_progress(self) -> None:
|
||||
fork_merge = self.make_state_dag(
|
||||
{
|
||||
"A": [],
|
||||
"B": ["A"],
|
||||
"C": ["A"],
|
||||
"D": ["C", "B"],
|
||||
}
|
||||
)
|
||||
seen_events: set[str] = set()
|
||||
get_missing_events_req_resps = {
|
||||
("D",): ["C"], # never provide B
|
||||
}
|
||||
self.prepare_handler(seen_events, fork_merge, get_missing_events_req_resps)
|
||||
self.assert_fetch_missing_state_dag_events(fork_merge, "D", [])
|
||||
|
||||
def test_fetch_missing_state_dag_events_seen(self) -> None:
|
||||
fork_merge = self.make_state_dag(
|
||||
{
|
||||
"A": [],
|
||||
"B": ["A"],
|
||||
"C": ["A"],
|
||||
"D": ["C"],
|
||||
"E": ["B"],
|
||||
"F": ["D", "E"],
|
||||
}
|
||||
)
|
||||
seen_events: set[str] = {"A", "B"}
|
||||
get_missing_events_req_resps = {
|
||||
("F",): ["D", "E"],
|
||||
("D",): ["C", "A"], # NB: not D,E as we've seen E's prev_state_events => B.
|
||||
("E",): ["B", "A"],
|
||||
}
|
||||
self.prepare_handler(seen_events, fork_merge, get_missing_events_req_resps)
|
||||
self.assert_fetch_missing_state_dag_events(fork_merge, "F", ["C", "D", "E"])
|
||||
|
||||
# test that we maintain a visited set so we don't needlessly make expensive /get_missing_events
|
||||
# calls in cases like:
|
||||
# A <- B <- C <- D <------------------------ I
|
||||
# `----- E <-- F <-- G <-- H <--`
|
||||
# In this scenario, there's two paths to A: via D and via EFGH. Both paths converge at C and then
|
||||
# go to A. The first path to reach A should be remembered so we don't make additional requests.
|
||||
def test_fetch_missing_state_dag_events_memoise(self) -> None:
|
||||
memoise_concurrent = self.make_state_dag(
|
||||
{
|
||||
"A": [],
|
||||
"B": ["A"],
|
||||
"C": ["B"],
|
||||
"D": ["C"],
|
||||
"E": ["C"],
|
||||
"F": ["E"],
|
||||
"G": ["F"],
|
||||
"H": ["G"],
|
||||
"I": ["D", "H"],
|
||||
}
|
||||
)
|
||||
get_missing_events_req_resps = {
|
||||
("I",): ["D", "H"],
|
||||
(
|
||||
"D",
|
||||
"H",
|
||||
): ["C", "G"],
|
||||
("C", "G"): ["B", "F"],
|
||||
("B", "F"): ["A", "E"],
|
||||
# we should never request C on its own as we should remember we have visited C already.
|
||||
}
|
||||
self.prepare_handler(set(), memoise_concurrent, get_missing_events_req_resps)
|
||||
self.assert_fetch_missing_state_dag_events(
|
||||
memoise_concurrent,
|
||||
"I",
|
||||
[
|
||||
"A",
|
||||
"B",
|
||||
"C",
|
||||
"D",
|
||||
"E",
|
||||
"F",
|
||||
"G",
|
||||
"H",
|
||||
],
|
||||
)
|
||||
|
||||
def test_fetch_missing_state_dag_seen_all(self) -> None:
|
||||
seen_all = self.make_state_dag(
|
||||
{
|
||||
"A": [],
|
||||
"B": ["A"],
|
||||
"C": ["A"],
|
||||
"D": ["C"],
|
||||
}
|
||||
)
|
||||
seen_events: set[str] = {"A", "B", "C", "D"}
|
||||
get_missing_events_req_resps: dict[tuple, list[str]] = {}
|
||||
self.prepare_handler(seen_events, seen_all, get_missing_events_req_resps)
|
||||
self.assert_fetch_missing_state_dag_events(seen_all, "D", [])
|
||||
|
||||
def assert_fetch_missing_state_dag_events(
|
||||
self,
|
||||
graph: dict[str, FrozenEventVMSC4242],
|
||||
start_event_id: str,
|
||||
want_event_ids: list[str],
|
||||
) -> None:
|
||||
"""Run _fetch_missing_state_dag_events with the param start_event_id on the graph provided
|
||||
and ensure the result matches want_event_ids."""
|
||||
got = self.get_success(
|
||||
self.hs.get_federation_event_handler()._fetch_missing_state_dag_events(
|
||||
"unknown", graph[start_event_id]
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
{ev.event_id for ev in got}, {graph[x].event_id for x in want_event_ids}
|
||||
)
|
||||
|
||||
def prepare_handler(
|
||||
self,
|
||||
seen_fake_events: set[str],
|
||||
graph: dict[str, FrozenEventVMSC4242],
|
||||
gme_req_resps: dict[tuple, list[str]],
|
||||
) -> None:
|
||||
"""Setup mocks on federation event handler to return the right data at the right time.
|
||||
Args:
|
||||
seen_fake_events: The events seen by this homeserver already (in the database)
|
||||
graph: The prepared state DAG graph mapping from fake event ID to real event.
|
||||
gme_req_resps: The /get_missing_events responses to return. The keys are the events
|
||||
provided as 'latest_events' and the values will be the events returned."""
|
||||
h = self.hs.get_federation_event_handler()
|
||||
h._federation_client = mock.Mock(
|
||||
spec=[
|
||||
"get_missing_events",
|
||||
]
|
||||
)
|
||||
for x in graph:
|
||||
print(f"{x} => {graph[x].event_id}")
|
||||
|
||||
async def get_missing_events(
|
||||
destination: str,
|
||||
room_id: str,
|
||||
earliest_events_ids: Iterable[str],
|
||||
latest_events: Iterable[EventBase],
|
||||
limit: int,
|
||||
min_depth: int,
|
||||
timeout: int,
|
||||
state_dag: bool = False,
|
||||
) -> list[EventBase]:
|
||||
assert state_dag
|
||||
assert room_id == MSC4242_ROOM_ID
|
||||
target_key = [ev.event_id for ev in latest_events]
|
||||
target_key.sort()
|
||||
assert target_key
|
||||
# test which response to return
|
||||
for req, resp in gme_req_resps.items():
|
||||
key = [graph[x].event_id for x in req]
|
||||
key.sort()
|
||||
if key == target_key:
|
||||
return [graph[x] for x in resp]
|
||||
raise AssertionError(
|
||||
f"get_missing_events with latest={target_key} but no matches found (tested {len(gme_req_resps)})"
|
||||
)
|
||||
|
||||
h._federation_client.get_missing_events.side_effect = get_missing_events
|
||||
|
||||
seen_events = {
|
||||
graph[fake_event_id].event_id for fake_event_id in seen_fake_events
|
||||
}
|
||||
|
||||
async def have_seen_events(room_id: str, event_ids: Iterable[str]) -> set[str]:
|
||||
return seen_events.intersection(set(event_ids))
|
||||
|
||||
main_store = self.hs.get_datastores().main
|
||||
main_store.have_seen_events = mock.AsyncMock() # type: ignore[method-assign]
|
||||
main_store.have_seen_events.side_effect = have_seen_events
|
||||
|
||||
def make_state_dag(
|
||||
self, graph: dict[str, list[str]]
|
||||
) -> dict[str, FrozenEventVMSC4242]:
|
||||
"""Create a state dag from a graph of event_id -> prev_state_events.
|
||||
Returns a map of fake ID to real event.
|
||||
The map must be sorted topologically else this raises an exception."""
|
||||
fake_to_real_id: dict[str, str] = {}
|
||||
result: dict[str, FrozenEventVMSC4242] = {}
|
||||
for fake_id, fake_prev_state_events in graph.items():
|
||||
# lookup fails if graph not sorted topologically
|
||||
paes = [fake_to_real_id[fpae] for fpae in fake_prev_state_events]
|
||||
ev = msc4242_event(paes)
|
||||
fake_to_real_id[fake_id] = ev.event_id
|
||||
result[fake_id] = ev
|
||||
return result
|
||||
|
||||
|
||||
def msc4242_event(prev_state_events: list[str]) -> FrozenEventVMSC4242:
|
||||
global counter
|
||||
counter += 1
|
||||
ev = make_event_from_dict(
|
||||
{
|
||||
"type": "m.room.member",
|
||||
"state_key": "@alice:example.com",
|
||||
"content": {
|
||||
"membership": "join",
|
||||
},
|
||||
"sender": "@alice:example.com",
|
||||
"origin_server_ts": counter, # ensure hashed event changes
|
||||
"room_id": MSC4242_ROOM_ID,
|
||||
"prev_state_events": prev_state_events,
|
||||
"prev_events": [], # we shouldn't look at this field
|
||||
},
|
||||
RoomVersions.MSC4242v12,
|
||||
)
|
||||
assert isinstance(ev, FrozenEventVMSC4242)
|
||||
return ev
|
||||
|
||||
@@ -172,6 +172,7 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
|
||||
auth_chain=[create_event],
|
||||
partial_state=False,
|
||||
servers_in_room=frozenset(),
|
||||
state_dag=None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#
|
||||
|
||||
import datetime
|
||||
from collections import namedtuple
|
||||
from typing import (
|
||||
Collection,
|
||||
Iterable,
|
||||
@@ -38,8 +39,9 @@ from synapse.api.room_versions import (
|
||||
KNOWN_ROOM_VERSIONS,
|
||||
EventFormatVersions,
|
||||
RoomVersion,
|
||||
RoomVersions,
|
||||
)
|
||||
from synapse.events import EventBase
|
||||
from synapse.events import EventBase, FrozenEventVMSC4242
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import login, room
|
||||
from synapse.server import HomeServer
|
||||
@@ -1414,6 +1416,174 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||
# elapsed past the backoff range so there is no events to backoff from.
|
||||
self.assertEqual(event_ids_with_backoff, {})
|
||||
|
||||
def test_get_state_dag(self) -> None:
|
||||
"""
|
||||
Test that MSC4242 state dag rooms can return the complete state dag on request.
|
||||
"""
|
||||
# Create the room
|
||||
user_id = self.register_user("alice", "test")
|
||||
tok = self.login("alice", "test")
|
||||
room_id = self.helper.create_room_as(
|
||||
room_creator=user_id,
|
||||
tok=tok,
|
||||
room_version=RoomVersions.MSC4242v12.identifier,
|
||||
)
|
||||
resp = self.helper.send_state(
|
||||
room_id,
|
||||
"m.room.join_rules",
|
||||
{"join_rule": "knock"},
|
||||
tok=tok,
|
||||
)
|
||||
latest = resp["event_id"]
|
||||
state_dag = self.get_success(
|
||||
self.store.get_state_dag(room_id, {latest}),
|
||||
)
|
||||
# create <- member <- pl <- join_rules <- his vis <- join_rules
|
||||
self.assertEquals(len(state_dag), 6)
|
||||
want_types = [
|
||||
EventTypes.Create,
|
||||
EventTypes.Member,
|
||||
EventTypes.PowerLevels,
|
||||
EventTypes.JoinRules,
|
||||
EventTypes.RoomHistoryVisibility,
|
||||
EventTypes.JoinRules,
|
||||
]
|
||||
curr = {latest}
|
||||
while len(curr) > 0:
|
||||
event_id = curr.pop()
|
||||
ev = state_dag[event_id]
|
||||
want_type = want_types.pop()
|
||||
self.assertEqual(ev.type, want_type)
|
||||
curr.update(ev.prev_state_events)
|
||||
|
||||
def test_get_missing_events_state_dag(self) -> None:
|
||||
# Primarily testing to make sure that we sort events
|
||||
# correctly when there are multiple prev_state_events
|
||||
# .- C -- D ---.
|
||||
# A <- B E
|
||||
# `- R -- W --`
|
||||
# `-- T -`
|
||||
graph = {
|
||||
"A": [],
|
||||
"B": ["A"],
|
||||
"C": ["B"],
|
||||
"R": ["B"],
|
||||
"D": ["C"],
|
||||
"W": ["R"],
|
||||
"T": ["R"],
|
||||
"E": ["W", "D", "T"],
|
||||
}
|
||||
room_id = "@state_dag:local"
|
||||
events = [
|
||||
cast(
|
||||
FrozenEventVMSC4242,
|
||||
StateDAGFakeEvent(
|
||||
event_id,
|
||||
room_id,
|
||||
EventTypes.Create if len(graph[event_id]) == 0 else "foo",
|
||||
graph[event_id],
|
||||
),
|
||||
)
|
||||
for event_id in graph
|
||||
]
|
||||
|
||||
def insert(txn: LoggingTransaction) -> None:
|
||||
# store these to satisfy fk constraints
|
||||
self.persist_events.db_pool.simple_insert_many_txn(
|
||||
txn,
|
||||
table="events",
|
||||
keys=(
|
||||
"instance_name",
|
||||
"stream_ordering",
|
||||
"topological_ordering",
|
||||
"depth",
|
||||
"event_id",
|
||||
"room_id",
|
||||
"type",
|
||||
"processed",
|
||||
"outlier",
|
||||
"origin_server_ts",
|
||||
"received_ts",
|
||||
"sender",
|
||||
"contains_url",
|
||||
"state_key",
|
||||
"rejection_reason",
|
||||
),
|
||||
values=[
|
||||
(
|
||||
"test",
|
||||
event.internal_metadata.stream_ordering,
|
||||
event.depth, # topological_ordering
|
||||
event.depth, # depth
|
||||
event.event_id,
|
||||
event.room_id,
|
||||
event.type,
|
||||
True, # processed
|
||||
False, # outlier
|
||||
1337,
|
||||
1741622420,
|
||||
event.sender,
|
||||
False, # contains url
|
||||
event.state_key,
|
||||
False, # rejected
|
||||
)
|
||||
for event in events
|
||||
],
|
||||
)
|
||||
for ev in events:
|
||||
self.persist_events._store_state_dag_edges(
|
||||
txn,
|
||||
ev,
|
||||
)
|
||||
|
||||
# satisfy fk constraints
|
||||
self.get_success(
|
||||
self.store.store_room(room_id, "foo", False, RoomVersions.MSC4242v12)
|
||||
)
|
||||
self.get_success(
|
||||
self.store.db_pool.runInteraction(
|
||||
"_store_state_dag_edges",
|
||||
insert,
|
||||
)
|
||||
)
|
||||
|
||||
# .- C -- D ---.
|
||||
# A <- B E
|
||||
# `- R -- W --`
|
||||
# `-- T -`
|
||||
TestCase = namedtuple("TestCase", "latest want limit")
|
||||
test_cases = [
|
||||
TestCase(latest=["E"], want=["D", "T", "W"], limit=3),
|
||||
TestCase(latest=["W", "T", "D"], want=["C", "R"], limit=2),
|
||||
# breadth first and new entries are added to the end, sorted lexicographically
|
||||
TestCase(latest=["E"], want=["D", "T", "W", "C", "R", "B", "A"], limit=100),
|
||||
# we should sort the latest values initially
|
||||
TestCase(latest=["E", "C"], want=["B", "D", "T", "W"], limit=4),
|
||||
TestCase(latest=["C", "E"], want=["B", "D", "T", "W"], limit=4),
|
||||
# dupes are ignored
|
||||
TestCase(
|
||||
latest=["E", "E", "C", "C", "C"], want=["B", "D", "T", "W"], limit=4
|
||||
),
|
||||
# include latest events in response
|
||||
TestCase(latest=["W", "E"], want=["D", "T", "W", "R"], limit=4),
|
||||
]
|
||||
|
||||
for test_case in test_cases:
|
||||
|
||||
def do_test(txn: LoggingTransaction) -> None:
|
||||
got = self.store._get_missing_events_state_dag(
|
||||
txn,
|
||||
room_id,
|
||||
[],
|
||||
test_case.latest,
|
||||
test_case.limit,
|
||||
)
|
||||
self.assertEquals(got, test_case.want)
|
||||
|
||||
self.get_success(
|
||||
self.store.db_pool.runInteraction("test_case", do_test),
|
||||
)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class FakeEvent:
|
||||
@@ -1431,3 +1601,19 @@ class FakeEvent:
|
||||
|
||||
def is_state(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class StateDAGFakeEvent:
|
||||
event_id: str
|
||||
room_id: str
|
||||
type: str
|
||||
prev_state_events: list[str]
|
||||
state_key = "foo"
|
||||
depth = 1
|
||||
sender = "foo"
|
||||
|
||||
internal_metadata = EventInternalMetadata({})
|
||||
|
||||
def is_state(self) -> bool:
|
||||
return True
|
||||
|
||||
@@ -1455,6 +1455,7 @@ class GetCurrentStateDeltaMembershipChangesForUserFederationTestCase(
|
||||
auth_chain=[create_event, creator_join_event],
|
||||
partial_state=False,
|
||||
servers_in_room=frozenset(),
|
||||
state_dag=None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user