diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index ad2945e82f..e8347324d1 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -60,6 +60,7 @@ from synapse.api.room_versions import ( RoomVersions, ) from synapse.events import EventBase, builder, make_event_from_dict +from synapse.events.snapshot import EventContext from synapse.events.utils import parse_stripped_state_event from synapse.federation.federation_base import ( FederationBase, @@ -74,10 +75,7 @@ from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, tag_args, from synapse.metrics import SERVER_NAME_LABEL from synapse.types import ( JsonDict, - PersistedEventPosition, StrCollection, - StreamKeyType, - StreamToken, UserID, get_domain_from_id, ) @@ -145,7 +143,6 @@ class FederationClient(FederationBase): self._clock.looping_call(self._clear_tried_cache, Duration(minutes=1)) self.state = hs.get_state_handler() self.transport_layer = hs.get_federation_transport_client() - self.storage_controllers = hs.get_storage_controllers() self.server_name = hs.hostname self.signing_key = hs.signing_key @@ -1320,12 +1317,12 @@ class FederationClient(FederationBase): self, destination: str, room_id: str, - event_id: str, pdu: EventBase, + context: EventContext, ) -> EventBase: room_version = await self.store.get_room_version(room_id) - content = await self._do_send_invite(destination, pdu, room_version) + content = await self._do_send_invite(destination, pdu, context, room_version) pdu_dict = content["event"] @@ -1346,7 +1343,11 @@ class FederationClient(FederationBase): return pdu async def _do_send_invite( - self, destination: str, pdu: EventBase, room_version: RoomVersion + self, + destination: str, + pdu: EventBase, + context: EventContext, + room_version: RoomVersion, ) -> JsonDict: """Actually sends the invite, first trying v2 API and falling back to v1 API if necessary. @@ -1361,8 +1362,6 @@ class FederationClient(FederationBase): """ time_now = self._clock.time_msec() - # TODO: Adapt and use `get_stripped_room_state_from_event_context` instead - # MSC4311: For the federation API, format events in `invite_room_state` as full # PDU's # @@ -1381,26 +1380,13 @@ class FederationClient(FederationBase): (stripped_state_event.type, stripped_state_event.state_key) ) - # assert ( - # pdu.internal_metadata.stream_ordering is not None - # and pdu.internal_metadata.instance_name is not None - # ), "Invite should be persisted by this point" - # Find the full events based on the state at the time of the invite state_filter = StateFilter.from_types(stripped_state_types) - # XXX: Ideally, we'd use `get_state_ids_at(...)` but the invite event isn't - # persisted yet so there is no persisted position to look at specfically. - state_ids = await self.storage_controllers.state.get_current_state_ids( - pdu.room_id, - state_filter=state_filter, - # Partially-stated rooms should have all state events except for remote - # membership events. Since an invite will only possibly include the - # `m.room.membership` of the local sender, we're good to use partial state - # here. - await_full_state=False, + state_ids = await self.store.get_stripped_room_state_ids_from_event_context( + context, state_filter ) - state_events = await self.store.get_events(list(state_ids.values())) - assert set(state_ids.values()) == set(state_events.keys()), ( + state_events = await self.store.get_events(state_ids) + assert set(state_ids) == set(state_events.keys()), ( "We should have all events available that were set as stripped state." ) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 1bbe144422..7018eea1ad 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -965,6 +965,7 @@ class FederationServer(FederationBase): # server. This will allow the remote server's clients to display information # related to the room while the knock request is pending. stripped_room_state = ( + # TODO: Implement MSC4311 and use full PDUs here await self.store.get_stripped_room_state_from_event_context( context, self._room_prejoin_state_types ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 166a02d7c7..530b3f33e2 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -554,7 +554,9 @@ class FederationHandler: return False - async def send_invite(self, target_host: str, event: EventBase) -> EventBase: + async def send_invite( + self, target_host: str, event: EventBase, context: EventContext + ) -> EventBase: """Sends the invite to the remote server for signing. Invites must be signed by the invitee's server before distribution. @@ -563,8 +565,8 @@ class FederationHandler: pdu = await self.federation_client.send_invite( destination=target_host, room_id=event.room_id, - event_id=event.event_id, pdu=event, + context=context, ) except RequestSendFailed: raise SynapseError(502, f"Can't connect to server {target_host}") diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 4032c7eca9..7eaa8e8532 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -2087,7 +2087,7 @@ class EventCreationHandler: # to get them to sign the event. returned_invite = await federation_handler.send_invite( - invitee.domain, event + invitee.domain, event, context ) event.unsigned.pop("room_state", None) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index cc79b8042b..f6ba0f27b0 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1137,19 +1137,46 @@ class EventsWorkerStore(SQLBaseStore): filter = StateFilter.from_types(types) else: filter = state_keys_to_include - selected_state_ids = await context.get_current_state_ids(filter) + + selected_state_ids = await self.get_stripped_room_state_ids_from_event_context( + context, filter + ) + + state_to_include = await self.get_events(selected_state_ids) + + return [strip_event(e) for e in state_to_include.values()] + + async def get_stripped_room_state_ids_from_event_context( + self, + context: EventContext, + state_keys_to_include: StateFilter, + ) -> list[str]: + """ + Retrieve the stripped state IDs for an event, given an event context to retrieve state + from as well as the state types to include. Optionally, include the membership + events from a specific user. + + "Stripped" state means that only the `type`, `state_key`, `content` and `sender` keys + are included from each state event. + + Args: + context: The event context to retrieve state of the room from. + state_keys_to_include: The state events to include, for each event type. + + Returns: + A list of event_ids, each representing the stripped state event to include for this event + """ + selected_state_ids = await context.get_current_state_ids(state_keys_to_include) # We know this event is not an outlier, so this must be # non-None. assert selected_state_ids is not None - # Confusingly, get_current_state_events may return events that are discarded by - # the filter, if they're in context._state_delta_due_to_event. Strip these away. - selected_state_ids = filter.filter_state(selected_state_ids) + # Confusingly, `get_current_state_ids` may return events that are discarded by + # the filter, if they're in `context._state_delta_due_to_event`. Strip these away. + selected_state_ids = state_keys_to_include.filter_state(selected_state_ids) - state_to_include = await self.get_events(selected_state_ids.values()) - - return [strip_event(e) for e in state_to_include.values()] + return list(selected_state_ids.values()) def _maybe_start_fetch_thread(self) -> None: """Starts an event fetch thread if we are not yet at the maximum number."""