Use get_stripped_room_state_ids_from_event_context

This commit is contained in:
Eric Eastwood
2026-05-01 19:53:45 -05:00
parent e0eb224cfa
commit 3464ec8894
5 changed files with 52 additions and 36 deletions
+12 -26
View File
@@ -60,6 +60,7 @@ from synapse.api.room_versions import (
RoomVersions,
)
from synapse.events import EventBase, builder, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.events.utils import parse_stripped_state_event
from synapse.federation.federation_base import (
FederationBase,
@@ -74,10 +75,7 @@ from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, tag_args,
from synapse.metrics import SERVER_NAME_LABEL
from synapse.types import (
JsonDict,
PersistedEventPosition,
StrCollection,
StreamKeyType,
StreamToken,
UserID,
get_domain_from_id,
)
@@ -145,7 +143,6 @@ class FederationClient(FederationBase):
self._clock.looping_call(self._clear_tried_cache, Duration(minutes=1))
self.state = hs.get_state_handler()
self.transport_layer = hs.get_federation_transport_client()
self.storage_controllers = hs.get_storage_controllers()
self.server_name = hs.hostname
self.signing_key = hs.signing_key
@@ -1320,12 +1317,12 @@ class FederationClient(FederationBase):
self,
destination: str,
room_id: str,
event_id: str,
pdu: EventBase,
context: EventContext,
) -> EventBase:
room_version = await self.store.get_room_version(room_id)
content = await self._do_send_invite(destination, pdu, room_version)
content = await self._do_send_invite(destination, pdu, context, room_version)
pdu_dict = content["event"]
@@ -1346,7 +1343,11 @@ class FederationClient(FederationBase):
return pdu
async def _do_send_invite(
self, destination: str, pdu: EventBase, room_version: RoomVersion
self,
destination: str,
pdu: EventBase,
context: EventContext,
room_version: RoomVersion,
) -> JsonDict:
"""Actually sends the invite, first trying v2 API and falling back to
v1 API if necessary.
@@ -1361,8 +1362,6 @@ class FederationClient(FederationBase):
"""
time_now = self._clock.time_msec()
# TODO: Adapt and use `get_stripped_room_state_from_event_context` instead
# MSC4311: For the federation API, format events in `invite_room_state` as full
# PDU's
#
@@ -1381,26 +1380,13 @@ class FederationClient(FederationBase):
(stripped_state_event.type, stripped_state_event.state_key)
)
# assert (
# pdu.internal_metadata.stream_ordering is not None
# and pdu.internal_metadata.instance_name is not None
# ), "Invite should be persisted by this point"
# Find the full events based on the state at the time of the invite
state_filter = StateFilter.from_types(stripped_state_types)
# XXX: Ideally, we'd use `get_state_ids_at(...)` but the invite event isn't
# persisted yet so there is no persisted position to look at specfically.
state_ids = await self.storage_controllers.state.get_current_state_ids(
pdu.room_id,
state_filter=state_filter,
# Partially-stated rooms should have all state events except for remote
# membership events. Since an invite will only possibly include the
# `m.room.membership` of the local sender, we're good to use partial state
# here.
await_full_state=False,
state_ids = await self.store.get_stripped_room_state_ids_from_event_context(
context, state_filter
)
state_events = await self.store.get_events(list(state_ids.values()))
assert set(state_ids.values()) == set(state_events.keys()), (
state_events = await self.store.get_events(state_ids)
assert set(state_ids) == set(state_events.keys()), (
"We should have all events available that were set as stripped state."
)
+1
View File
@@ -965,6 +965,7 @@ class FederationServer(FederationBase):
# server. This will allow the remote server's clients to display information
# related to the room while the knock request is pending.
stripped_room_state = (
# TODO: Implement MSC4311 and use full PDUs here
await self.store.get_stripped_room_state_from_event_context(
context, self._room_prejoin_state_types
)
+4 -2
View File
@@ -554,7 +554,9 @@ class FederationHandler:
return False
async def send_invite(self, target_host: str, event: EventBase) -> EventBase:
async def send_invite(
self, target_host: str, event: EventBase, context: EventContext
) -> EventBase:
"""Sends the invite to the remote server for signing.
Invites must be signed by the invitee's server before distribution.
@@ -563,8 +565,8 @@ class FederationHandler:
pdu = await self.federation_client.send_invite(
destination=target_host,
room_id=event.room_id,
event_id=event.event_id,
pdu=event,
context=context,
)
except RequestSendFailed:
raise SynapseError(502, f"Can't connect to server {target_host}")
+1 -1
View File
@@ -2087,7 +2087,7 @@ class EventCreationHandler:
# to get them to sign the event.
returned_invite = await federation_handler.send_invite(
invitee.domain, event
invitee.domain, event, context
)
event.unsigned.pop("room_state", None)
@@ -1137,19 +1137,46 @@ class EventsWorkerStore(SQLBaseStore):
filter = StateFilter.from_types(types)
else:
filter = state_keys_to_include
selected_state_ids = await context.get_current_state_ids(filter)
selected_state_ids = await self.get_stripped_room_state_ids_from_event_context(
context, filter
)
state_to_include = await self.get_events(selected_state_ids)
return [strip_event(e) for e in state_to_include.values()]
async def get_stripped_room_state_ids_from_event_context(
self,
context: EventContext,
state_keys_to_include: StateFilter,
) -> list[str]:
"""
Retrieve the stripped state IDs for an event, given an event context to retrieve state
from as well as the state types to include. Optionally, include the membership
events from a specific user.
"Stripped" state means that only the `type`, `state_key`, `content` and `sender` keys
are included from each state event.
Args:
context: The event context to retrieve state of the room from.
state_keys_to_include: The state events to include, for each event type.
Returns:
A list of event_ids, each representing the stripped state event to include for this event
"""
selected_state_ids = await context.get_current_state_ids(state_keys_to_include)
# We know this event is not an outlier, so this must be
# non-None.
assert selected_state_ids is not None
# Confusingly, get_current_state_events may return events that are discarded by
# the filter, if they're in context._state_delta_due_to_event. Strip these away.
selected_state_ids = filter.filter_state(selected_state_ids)
# Confusingly, `get_current_state_ids` may return events that are discarded by
# the filter, if they're in `context._state_delta_due_to_event`. Strip these away.
selected_state_ids = state_keys_to_include.filter_state(selected_state_ids)
state_to_include = await self.get_events(selected_state_ids.values())
return [strip_event(e) for e in state_to_include.values()]
return list(selected_state_ids.values())
def _maybe_start_fetch_thread(self) -> None:
"""Starts an event fetch thread if we are not yet at the maximum number."""