Update the rest of the code

This commit is contained in:
Erik Johnston
2026-04-01 16:47:16 +01:00
parent b3af1e733e
commit 01d52699dc
28 changed files with 314 additions and 225 deletions
+6 -5
View File
@@ -34,6 +34,7 @@ from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.events import EventBase
from synapse.events.utils import FilteredEvent
from synapse.handlers.admin import ExfiltrationWriter
from synapse.server import HomeServer
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
@@ -150,14 +151,14 @@ class FileExfiltrationWriter(ExfiltrationWriter):
if list(os.listdir(self.base_directory)):
raise Exception("Directory must be empty")
def write_events(self, room_id: str, events: list[EventBase]) -> None:
def write_events(self, room_id: str, filtered_events: list[FilteredEvent]) -> None:
room_directory = os.path.join(self.base_directory, "rooms", room_id)
os.makedirs(room_directory, exist_ok=True)
events_file = os.path.join(room_directory, "events")
with open(events_file, "a") as f:
for event in events:
json.dump(event.get_pdu_json(), fp=f)
for filtered_event in filtered_events:
json.dump(filtered_event.event.get_pdu_json(), fp=f)
def write_state(
self, room_id: str, event_id: str, state: StateMap[EventBase]
@@ -175,7 +176,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
def write_invite(
self, room_id: str, event: EventBase, state: StateMap[EventBase]
) -> None:
self.write_events(room_id, [event])
self.write_events(room_id, [FilteredEvent.state(event)])
# We write the invite state somewhere else as they aren't full events
# and are only a subset of the state at the event.
@@ -191,7 +192,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
def write_knock(
self, room_id: str, event: EventBase, state: StateMap[EventBase]
) -> None:
self.write_events(room_id, [event])
self.write_events(room_id, [FilteredEvent.state(event)])
# We write the knock state somewhere else as they aren't full events
# and are only a subset of the state at the event.
+2 -2
View File
@@ -40,7 +40,7 @@ from synapse.appservice import (
TransactionUnusedFallbackKeys,
)
from synapse.events import EventBase
from synapse.events.utils import SerializeEventConfig
from synapse.events.utils import FilteredEvent, SerializeEventConfig
from synapse.http.client import SimpleHttpClient, is_unknown_endpoint
from synapse.logging import opentracing
from synapse.metrics import SERVER_NAME_LABEL
@@ -545,7 +545,7 @@ class ApplicationServiceApi(SimpleHttpClient):
) -> list[JsonDict]:
time_now = self.clock.time_msec()
return await self._event_serializer.serialize_events(
list(events),
[FilteredEvent(event=e, membership=None) for e in events],
time_now,
config=SerializeEventConfig(
as_client_event=True,
+19 -10
View File
@@ -33,6 +33,7 @@ import attr
from synapse.api.constants import Direction, EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.events.utils import FilteredEvent
from synapse.types import (
JsonMapping,
Requester,
@@ -251,32 +252,40 @@ class AdminHandler:
topological=last_event.depth,
)
events = await filter_and_transform_events_for_client(
filtered_events = await filter_and_transform_events_for_client(
self._storage_controllers,
user_id,
events,
)
writer.write_events(room_id, events)
writer.write_events(room_id, filtered_events)
# Update the extremity tracking dicts
for event in events:
for filtered_event in filtered_events:
# Check if we have any prev events that haven't been
# processed yet, and add those to the appropriate dicts.
unseen_events = set(event.prev_event_ids()) - written_events
unseen_events = (
set(filtered_event.event.prev_event_ids()) - written_events
)
if unseen_events:
event_to_unseen_prevs[event.event_id] = unseen_events
event_to_unseen_prevs[filtered_event.event.event_id] = (
unseen_events
)
for unseen in unseen_events:
unseen_to_child_events.setdefault(unseen, set()).add(
event.event_id
filtered_event.event.event_id
)
# Now check if this event is an unseen prev event, if so
# then we remove this event from the appropriate dicts.
for child_id in unseen_to_child_events.pop(event.event_id, []):
event_to_unseen_prevs[child_id].discard(event.event_id)
for child_id in unseen_to_child_events.pop(
filtered_event.event.event_id, []
):
event_to_unseen_prevs[child_id].discard(
filtered_event.event.event_id
)
written_events.add(event.event_id)
written_events.add(filtered_event.event.event_id)
logger.info(
"Written %d events in room %s", len(written_events), room_id
@@ -511,7 +520,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
"""Interface used to specify how to write exported data."""
@abc.abstractmethod
def write_events(self, room_id: str, events: list[EventBase]) -> None:
def write_events(self, room_id: str, events: list[FilteredEvent]) -> None:
"""Write a batch of events for a room."""
raise NotImplementedError()
+8 -9
View File
@@ -25,8 +25,7 @@ from typing import TYPE_CHECKING, Iterable
from synapse.api.constants import EduTypes, EventTypes, Membership, PresenceState
from synapse.api.errors import AuthError, SynapseError
from synapse.events import EventBase
from synapse.events.utils import SerializeEventConfig
from synapse.events.utils import FilteredEvent, SerializeEventConfig
from synapse.handlers.presence import format_user_presence_state
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.streams.config import PaginationConfig
@@ -102,19 +101,19 @@ class EventStreamHandler:
# joined room, we need to send down presence for those users.
to_add: list[JsonDict] = []
for event in events:
if not isinstance(event, EventBase):
if not isinstance(event, FilteredEvent):
continue
if event.type == EventTypes.Member:
if event.membership != Membership.JOIN:
if event.event.type == EventTypes.Member:
if event.event.membership != Membership.JOIN:
continue
# Send down presence.
if event.state_key == requester.user.to_string():
if event.event.state_key == requester.user.to_string():
# Send down presence for everyone in the room.
users: Iterable[str] = await self.store.get_users_in_room(
event.room_id
event.event.room_id
)
else:
users = [event.state_key]
users = [event.event.state_key]
states = await presence_handler.get_states(users)
to_add.extend(
@@ -155,7 +154,7 @@ class EventHandler:
room_id: str | None,
event_id: str,
show_redacted: bool = False,
) -> EventBase | None:
) -> FilteredEvent | None:
"""Retrieve a single specified event on behalf of a user.
The event will be transformed in a user-specific and time-specific way,
e.g. having unsigned metadata added or being erased depending on who is accessing.
+17 -11
View File
@@ -30,7 +30,7 @@ from synapse.api.constants import (
Membership,
)
from synapse.api.errors import SynapseError
from synapse.events.utils import SerializeEventConfig
from synapse.events.utils import FilteredEvent, SerializeEventConfig
from synapse.events.validator import EventValidator
from synapse.handlers.presence import format_user_presence_state
from synapse.handlers.receipts import ReceiptEventSource
@@ -186,7 +186,7 @@ class InitialSyncHandler:
invite_event = await self.store.get_event(event.event_id)
d["invite"] = await self._event_serializer.serialize_event(
invite_event,
FilteredEvent.state(event=invite_event),
time_now,
config=serializer_options,
)
@@ -225,7 +225,7 @@ class InitialSyncHandler:
)
).addErrback(unwrapFirstError)
messages = await filter_and_transform_events_for_client(
filtered_messages = await filter_and_transform_events_for_client(
self._storage_controllers,
user_id,
messages,
@@ -240,7 +240,7 @@ class InitialSyncHandler:
d["messages"] = {
"chunk": (
await self._event_serializer.serialize_events(
messages,
filtered_messages,
time_now=time_now,
config=serializer_options,
)
@@ -250,7 +250,7 @@ class InitialSyncHandler:
}
d["state"] = await self._event_serializer.serialize_events(
current_state.values(),
[FilteredEvent.state(e) for e in current_state.values()],
time_now=time_now,
config=serializer_options,
)
@@ -382,7 +382,9 @@ class InitialSyncHandler:
room_id, limit=pagin_config.limit, end_token=stream_token
)
messages = await filter_and_transform_events_for_client(
filtered_messages: list[
FilteredEvent
] = await filter_and_transform_events_for_client(
self._storage_controllers,
requester.user.to_string(),
messages,
@@ -402,7 +404,7 @@ class InitialSyncHandler:
"chunk": (
# Don't bundle aggregations as this is a deprecated API.
await self._event_serializer.serialize_events(
messages, time_now, config=serialize_options
filtered_messages, time_now, config=serialize_options
)
),
"start": await start_token.to_string(self.store),
@@ -411,7 +413,9 @@ class InitialSyncHandler:
"state": (
# Don't bundle aggregations as this is a deprecated API.
await self._event_serializer.serialize_events(
room_state.values(), time_now, config=serialize_options
[FilteredEvent.state(e) for e in room_state.values()],
time_now,
config=serialize_options,
)
),
"presence": [],
@@ -435,7 +439,7 @@ class InitialSyncHandler:
serialize_options = SerializeEventConfig(requester=requester)
# Don't bundle aggregations as this is a deprecated API.
state = await self._event_serializer.serialize_events(
current_state.values(),
[FilteredEvent.state(e) for e in current_state.values()],
time_now,
config=serialize_options,
)
@@ -496,7 +500,9 @@ class InitialSyncHandler:
).addErrback(unwrapFirstError)
)
messages = await filter_and_transform_events_for_client(
filtered_messages: list[
FilteredEvent
] = await filter_and_transform_events_for_client(
self._storage_controllers,
requester.user.to_string(),
messages,
@@ -512,7 +518,7 @@ class InitialSyncHandler:
"chunk": (
# Don't bundle aggregations as this is a deprecated API.
await self._event_serializer.serialize_events(
messages, time_now, config=serialize_options
filtered_messages, time_now, config=serialize_options
)
),
"start": await start_token.to_string(self.store),
+6 -2
View File
@@ -61,7 +61,11 @@ from synapse.events.snapshot import (
UnpersistedEventContext,
UnpersistedEventContextBase,
)
from synapse.events.utils import SerializeEventConfig, maybe_upsert_event_field
from synapse.events.utils import (
FilteredEvent,
SerializeEventConfig,
maybe_upsert_event_field,
)
from synapse.events.validator import EventValidator
from synapse.handlers.directory import DirectoryHandler
from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
@@ -261,7 +265,7 @@ class MessageHandler:
room_state = room_state_events[membership_event_id]
events = await self._event_serializer.serialize_events(
room_state.values(),
[FilteredEvent.state(e) for e in room_state.values()],
self.clock.time_msec(),
config=SerializeEventConfig(requester=requester),
)
+11 -8
View File
@@ -29,6 +29,7 @@ from synapse.api.constants import Direction, EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter
from synapse.events import EventBase
from synapse.events.utils import FilteredEvent
from synapse.handlers.relations import BundledAggregations
from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
from synapse.logging.opentracing import trace
@@ -79,7 +80,7 @@ class GetMessagesResult:
Everything needed to serialize a `/messages` response.
"""
messages_chunk: list[EventBase]
messages_chunk: list[FilteredEvent]
"""
A list of room events.
@@ -684,16 +685,18 @@ class PaginationHandler:
events = await event_filter.filter(events)
if not use_admin_priviledge:
events = await filter_and_transform_events_for_client(
filtered_events = await filter_and_transform_events_for_client(
self._storage_controllers,
user_id,
events,
is_peeking=(member_event_id is None),
)
else:
filtered_events = [FilteredEvent.admin_override(e) for e in events]
# if after the filter applied there are no more events
# return immediately - but there might be more in next_token batch
if not events:
if not filtered_events:
return GetMessagesResult(
messages_chunk=[],
bundled_aggregations={},
@@ -703,16 +706,16 @@ class PaginationHandler:
)
state = None
if event_filter and event_filter.lazy_load_members and len(events) > 0:
if event_filter and event_filter.lazy_load_members and len(filtered_events) > 0:
# TODO: remove redundant members
# FIXME: we also care about invite targets etc.
state_filter = StateFilter.from_types(
(EventTypes.Member, event.sender) for event in events
(EventTypes.Member, event.event.sender) for event in filtered_events
)
state_ids = await self._state_storage_controller.get_state_ids_for_event(
events[0].event_id, state_filter=state_filter
filtered_events[0].event.event_id, state_filter=state_filter
)
if state_ids:
@@ -720,11 +723,11 @@ class PaginationHandler:
state = list(state_dict.values())
aggregations = await self._relations_handler.get_bundled_aggregations(
events, user_id
filtered_events, user_id
)
return GetMessagesResult(
messages_chunk=events,
messages_chunk=filtered_events,
bundled_aggregations=aggregations,
state=state,
start_token=from_token,
+19 -13
View File
@@ -33,7 +33,7 @@ import attr
from synapse.api.constants import Direction, EventTypes, RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase, relation_from_event
from synapse.events.utils import SerializeEventConfig
from synapse.events.utils import FilteredEvent, SerializeEventConfig
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import trace
from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent
@@ -139,7 +139,7 @@ class RelationsHandler:
# not passing them in here we should get a better cache hit rate).
related_events, next_token = await self._main_store.get_relations_for_event(
event_id=event_id,
event=event,
event=event.event,
room_id=room_id,
relation_type=relation_type,
event_type=event_type,
@@ -154,7 +154,9 @@ class RelationsHandler:
[e.event_id for e in related_events]
)
events = await filter_and_transform_events_for_client(
filtered_events: list[
FilteredEvent
] = await filter_and_transform_events_for_client(
self._storage_controllers,
user_id,
events,
@@ -164,14 +166,14 @@ class RelationsHandler:
# The relations returned for the requested event do include their
# bundled aggregations.
aggregations = await self.get_bundled_aggregations(
events, requester.user.to_string()
filtered_events, requester.user.to_string()
)
now = self._clock.time_msec()
serialize_options = SerializeEventConfig(requester=requester)
return_value: JsonDict = {
"chunk": await self._event_serializer.serialize_events(
events,
filtered_events,
now,
bundle_aggregations=aggregations,
config=serialize_options,
@@ -389,7 +391,7 @@ class RelationsHandler:
potential_events, _ = await self._main_store.get_relations_for_event(
room_id,
event_id,
event,
event.event,
RelationTypes.THREAD,
direction=Direction.FORWARDS,
)
@@ -417,7 +419,7 @@ class RelationsHandler:
potential_events[-1].event_id,
)
continue
latest_thread_event = event
latest_thread_event = event.event
results[event_id] = _ThreadAggregation(
latest_event=latest_thread_event,
@@ -432,12 +434,12 @@ class RelationsHandler:
@trace
async def get_bundled_aggregations(
self, events: Iterable[EventBase], user_id: str
self, filtered_events: Iterable[FilteredEvent], user_id: str
) -> dict[str, BundledAggregations]:
"""Generate bundled aggregations for events.
Args:
events: The iterable of events to calculate bundled aggregations for.
filtered_events: The iterable of filtered events to calculate bundled aggregations for.
user_id: The user requesting the bundled aggregations.
Returns:
@@ -453,7 +455,9 @@ class RelationsHandler:
events_by_id = {}
# A map of event ID to the relation in that event, if there is one.
relations_by_id: dict[str, str] = {}
for event in events:
for filtered_event in filtered_events:
event = filtered_event.event
# State events do not get bundled aggregations.
if event.is_state():
continue
@@ -599,7 +603,9 @@ class RelationsHandler:
# Limit the returned threads to those the user has participated in.
events = [event for event in events if participated[event.event_id]]
events = await filter_and_transform_events_for_client(
filtered_events: list[
FilteredEvent
] = await filter_and_transform_events_for_client(
self._storage_controllers,
user_id,
events,
@@ -607,12 +613,12 @@ class RelationsHandler:
)
aggregations = await self.get_bundled_aggregations(
events, requester.user.to_string()
filtered_events, requester.user.to_string()
)
now = self._clock.time_msec()
serialized_events = await self._event_serializer.serialize_events(
events, now, bundle_aggregations=aggregations
filtered_events, now, bundle_aggregations=aggregations
)
return_value: JsonDict = {"chunk": serialized_events}
+21 -19
View File
@@ -67,7 +67,7 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase
from synapse.events.snapshot import UnpersistedEventContext
from synapse.events.utils import copy_and_fixup_power_levels_contents
from synapse.events.utils import FilteredEvent, copy_and_fixup_power_levels_contents
from synapse.handlers.relations import BundledAggregations
from synapse.rest.admin._base import assert_user_is_admin
from synapse.streams import EventSource
@@ -109,9 +109,9 @@ FIVE_MINUTES_IN_MS = 5 * 60 * 1000
@attr.s(slots=True, frozen=True, auto_attribs=True)
class EventContext:
events_before: list[EventBase]
event: EventBase
events_after: list[EventBase]
events_before: list[FilteredEvent]
event: FilteredEvent
events_after: list[FilteredEvent]
state: list[EventBase]
aggregations: dict[str, BundledAggregations]
start: str
@@ -1916,9 +1916,9 @@ class RoomContextHandler:
# The user is peeking if they aren't in the room already
is_peeking = not is_user_in_room
async def filter_evts(events: list[EventBase]) -> list[EventBase]:
async def filter_evts(events: list[EventBase]) -> list[FilteredEvent]:
if use_admin_priviledge:
return events
return [FilteredEvent.admin_override(e) for e in events]
return await filter_and_transform_events_for_client(
self._storage_controllers,
user.to_string(),
@@ -1946,31 +1946,33 @@ class RoomContextHandler:
events_before = await event_filter.filter(events_before)
events_after = await event_filter.filter(events_after)
events_before = await filter_evts(events_before)
events_after = await filter_evts(events_after)
filtered_events_before = await filter_evts(events_before)
filtered_events_after = await filter_evts(events_after)
# filter_evts can return a pruned event in case the user is allowed to see that
# there's something there but not see the content, so use the event that's in
# `filtered` rather than the event we retrieved from the datastore.
event = filtered[0]
filtered_event = filtered[0]
# Fetch the aggregations.
aggregations = await self._relations_handler.get_bundled_aggregations(
itertools.chain(events_before, (event,), events_after),
itertools.chain(
filtered_events_before, (filtered_event,), filtered_events_after
),
user.to_string(),
)
if events_after:
last_event_id = events_after[-1].event_id
if filtered_events_after:
last_event_id = filtered_events_after[-1].event.event_id
else:
last_event_id = event_id
if event_filter and event_filter.lazy_load_members:
state_filter = StateFilter.from_lazy_load_member_list(
ev.sender
ev.event.sender
for ev in itertools.chain(
events_before,
(event,),
events_after,
filtered_events_before,
(filtered_event,),
filtered_events_after,
)
)
else:
@@ -1993,9 +1995,9 @@ class RoomContextHandler:
token = StreamToken.START
return EventContext(
events_before=events_before,
event=event,
events_after=events_after,
events_before=filtered_events_before,
event=filtered_event,
events_after=filtered_events_after,
state=state_events,
aggregations=aggregations,
start=await token.copy_and_replace(
+24 -23
View File
@@ -29,8 +29,7 @@ from unpaddedbase64 import decode_base64, encode_base64
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.filtering import Filter
from synapse.events import EventBase
from synapse.events.utils import SerializeEventConfig
from synapse.events.utils import FilteredEvent, SerializeEventConfig
from synapse.types import JsonDict, Requester, StrCollection, StreamKeyType, UserID
from synapse.types.state import StateFilter
from synapse.visibility import filter_and_transform_events_for_client
@@ -48,7 +47,7 @@ class _SearchResult:
# A mapping of event ID to the rank of that event.
rank_map: dict[str, int]
# A list of the resulting events.
allowed_events: list[EventBase]
allowed_events: list[FilteredEvent]
# A map of room ID to results.
room_groups: dict[str, JsonDict]
# A set of event IDs to highlight.
@@ -355,12 +354,12 @@ class SearchHandler:
state_results = {}
if include_state:
for room_id in {e.room_id for e in search_result.allowed_events}:
for room_id in {e.event.room_id for e in search_result.allowed_events}:
state = await self._storage_controllers.state.get_current_state(room_id)
state_results[room_id] = list(state.values())
aggregations = await self._relations_handler.get_bundled_aggregations(
# Generate an iterable of EventBase for all the events that will be
# Generate an iterable of FilteredEvent for all the events that will be
# returned, including contextual events.
itertools.chain(
# The events_before and events_after for each context.
@@ -396,14 +395,14 @@ class SearchHandler:
results = [
{
"rank": search_result.rank_map[e.event_id],
"rank": search_result.rank_map[e.event.event_id],
"result": await self._event_serializer.serialize_event(
e,
time_now,
bundle_aggregations=aggregations,
config=serialize_options,
),
"context": contexts.get(e.event_id, {}),
"context": contexts.get(e.event.event_id, {}),
}
for e in search_result.allowed_events
]
@@ -417,7 +416,9 @@ class SearchHandler:
if state_results:
rooms_cat_res["state"] = {
room_id: await self._event_serializer.serialize_events(
state_events, time_now, config=serialize_options
[FilteredEvent.state(e) for e in state_events],
time_now,
config=serialize_options,
)
for room_id, state_events in state_results.items()
}
@@ -485,19 +486,19 @@ class SearchHandler:
filtered_events,
)
events.sort(key=lambda e: -rank_map[e.event_id])
events.sort(key=lambda e: -rank_map[e.event.event_id])
allowed_events = events[: search_filter.limit]
for e in allowed_events:
rm = room_groups.setdefault(
e.room_id, {"results": [], "order": rank_map[e.event_id]}
e.event.room_id, {"results": [], "order": rank_map[e.event.event_id]}
)
rm["results"].append(e.event_id)
rm["results"].append(e.event.event_id)
s = sender_group.setdefault(
e.sender, {"results": [], "order": rank_map[e.event_id]}
e.event.sender, {"results": [], "order": rank_map[e.event.event_id]}
)
s["results"].append(e.event_id)
s["results"].append(e.event.event_id)
return (
_SearchResult(
@@ -549,7 +550,7 @@ class SearchHandler:
highlights = set()
room_events: list[EventBase] = []
room_events: list[FilteredEvent] = []
i = 0
pagination_token = batch_token
@@ -595,11 +596,11 @@ class SearchHandler:
pagination_token = results[-1]["pagination_token"]
for event in room_events:
group = room_groups.setdefault(event.room_id, {"results": []})
group["results"].append(event.event_id)
group = room_groups.setdefault(event.event.room_id, {"results": []})
group["results"].append(event.event.event_id)
if room_events and len(room_events) >= search_filter.limit:
last_event_id = room_events[-1].event_id
last_event_id = room_events[-1].event.event_id
pagination_token = results_map[last_event_id]["pagination_token"]
# We want to respect the given batch group and group keys so
@@ -632,7 +633,7 @@ class SearchHandler:
async def _calculate_event_contexts(
self,
user: UserID,
allowed_events: list[EventBase],
allowed_events: list[FilteredEvent],
before_limit: int,
after_limit: int,
include_profile: bool,
@@ -658,7 +659,7 @@ class SearchHandler:
contexts = {}
for event in allowed_events:
res = await self.store.get_events_around(
event.room_id, event.event_id, before_limit, after_limit
event.event.room_id, event.event.event_id, before_limit, after_limit
)
logger.info(
@@ -692,14 +693,14 @@ class SearchHandler:
if include_profile:
senders = {
ev.sender
ev.event.sender
for ev in itertools.chain(events_before, [event], events_after)
}
if events_after:
last_event_id = events_after[-1].event_id
last_event_id = events_after[-1].event.event_id
else:
last_event_id = event.event_id
last_event_id = event.event.event_id
state_filter = StateFilter.from_types(
[(EventTypes.Member, sender) for sender in senders]
@@ -718,6 +719,6 @@ class SearchHandler:
if s.type == EventTypes.Member and s.state_key in senders
}
contexts[event.event_id] = context
contexts[event.event.event_id] = context
return contexts
+20 -15
View File
@@ -23,7 +23,7 @@ from typing_extensions import assert_never
from synapse.api.constants import Direction, EventTypes, Membership
from synapse.events import EventBase
from synapse.events.utils import strip_event
from synapse.events.utils import FilteredEvent, strip_event
from synapse.handlers.relations import BundledAggregations
from synapse.handlers.sliding_sync.extensions import SlidingSyncExtensionHandler
from synapse.handlers.sliding_sync.room_lists import (
@@ -679,7 +679,7 @@ class SlidingSyncHandler:
# membership. Currently, we have to make all of these optional because
# `invite`/`knock` rooms only have `stripped_state`. See
# https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1653045932
timeline_events: list[EventBase] = []
timeline_events: list[FilteredEvent] = []
bundled_aggregations: dict[str, BundledAggregations] | None = None
limited: bool | None = None
prev_batch_token: StreamToken | None = None
@@ -739,7 +739,7 @@ class SlidingSyncHandler:
# Use `stream_ordering` for updates
else paginate_room_events_by_stream_ordering
)
timeline_events, new_room_key, limited = await pagination_method(
raw_timeline_events, new_room_key, limited = await pagination_method(
room_id=room_id,
# The bounds are reversed so we can paginate backwards
# (from newer to older events) starting at to_bound.
@@ -752,13 +752,13 @@ class SlidingSyncHandler:
# We want to return the events in ascending order (the last event is the
# most recent).
timeline_events.reverse()
raw_timeline_events.reverse()
# Make sure we don't expose any events that the client shouldn't see
timeline_events = await filter_and_transform_events_for_client(
self.storage_controllers,
user.to_string(),
timeline_events,
raw_timeline_events,
is_peeking=room_membership_for_user_at_to_token.membership
!= Membership.JOIN,
filter_send_to_client=True,
@@ -778,12 +778,17 @@ class SlidingSyncHandler:
if from_token is not None:
for timeline_event in reversed(timeline_events):
# This fields should be present for all persisted events
assert timeline_event.internal_metadata.stream_ordering is not None
assert timeline_event.internal_metadata.instance_name is not None
assert (
timeline_event.event.internal_metadata.stream_ordering
is not None
)
assert (
timeline_event.event.internal_metadata.instance_name is not None
)
persisted_position = PersistedEventPosition(
instance_name=timeline_event.internal_metadata.instance_name,
stream=timeline_event.internal_metadata.stream_ordering,
instance_name=timeline_event.event.internal_metadata.instance_name,
stream=timeline_event.event.internal_metadata.stream_ordering,
)
if persisted_position.persisted_after(
from_token.stream_token.room_key
@@ -1061,13 +1066,13 @@ class SlidingSyncHandler:
if timeline_events is not None:
for timeline_event in timeline_events:
# Anyone who sent a message is relevant
timeline_membership.add(timeline_event.sender)
timeline_membership.add(timeline_event.event.sender)
# We also care about invite, ban, kick, targets,
# etc.
if timeline_event.type == EventTypes.Member:
if timeline_event.event.type == EventTypes.Member:
timeline_membership.add(
timeline_event.state_key
timeline_event.event.state_key
)
# The client needs to know the membership of everyone in
@@ -1480,7 +1485,7 @@ class SlidingSyncHandler:
self,
room_id: str,
to_token: StreamToken,
timeline: list[EventBase],
timeline: list[FilteredEvent],
check_outside_timeline: bool,
) -> int | None:
"""Get a bump stamp for the room, if we have a bump event and it has
@@ -1500,8 +1505,8 @@ class SlidingSyncHandler:
# those matches. We iterate backwards and take the stream ordering
# of the first event that matches the bump event types.
for timeline_event in reversed(timeline):
if timeline_event.type in SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES:
new_bump_stamp = timeline_event.internal_metadata.stream_ordering
if timeline_event.event.type in SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES:
new_bump_stamp = timeline_event.event.internal_metadata.stream_ordering
# All persisted events have a stream ordering
assert new_bump_stamp is not None
+1 -1
View File
@@ -761,7 +761,7 @@ class SlidingSyncExtensionHandler:
# in the timeline to avoid bloating and blowing up the sync response
# as the number of users in the room increases. (this behavior is part of the spec)
initial_rooms_and_event_ids = [
(room_id, event.event_id)
(room_id, event.event.event_id)
for room_id in initial_rooms
if room_id in actual_room_response_map
for event in actual_room_response_map[room_id].timeline_events
+57 -40
View File
@@ -43,6 +43,7 @@ from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.events.utils import FilteredEvent
from synapse.handlers.relations import BundledAggregations
from synapse.logging import issue9533_logger
from synapse.logging.context import current_context
@@ -123,7 +124,7 @@ class SyncConfig:
@attr.s(slots=True, frozen=True, auto_attribs=True)
class TimelineBatch:
prev_batch: StreamToken
events: Sequence[EventBase]
events: Sequence[FilteredEvent]
limited: bool
# A mapping of event ID to the bundled aggregations for the above events.
# This is only calculated if limited is true.
@@ -148,7 +149,7 @@ class JoinedSyncResult:
state: StateMap[EventBase]
ephemeral: list[JsonDict]
account_data: list[JsonDict]
sticky: list[EventBase]
sticky: list[FilteredEvent]
unread_notifications: JsonDict
unread_thread_notifications: JsonDict
summary: JsonDict | None
@@ -699,6 +700,7 @@ class SyncHandler:
log_kv({"limited": limited})
filtered_recents: list[FilteredEvent]
if potential_recents:
recents = await sync_config.filter_collection.filter_room_timeline(
potential_recents
@@ -725,29 +727,32 @@ class SyncHandler:
)
)
recents = await filter_and_transform_events_for_client(
filtered_recents = await filter_and_transform_events_for_client(
self._storage_controllers,
sync_config.user.to_string(),
recents,
always_include_ids=current_state_ids,
)
log_kv({"recents_after_visibility_filtering": len(recents)})
log_kv({"recents_after_visibility_filtering": len(filtered_recents)})
else:
recents = []
filtered_recents = []
if not limited or block_all_timeline:
prev_batch_token = upto_token
if recents:
assert recents[0].internal_metadata.stream_ordering
if filtered_recents:
assert filtered_recents[0].event.internal_metadata.stream_ordering
room_key = RoomStreamToken(
stream=recents[0].internal_metadata.stream_ordering - 1
stream=filtered_recents[
0
].event.internal_metadata.stream_ordering
- 1
)
prev_batch_token = upto_token.copy_and_replace(
StreamKeyType.ROOM, room_key
)
return TimelineBatch(
events=recents, prev_batch=prev_batch_token, limited=False
events=filtered_recents, prev_batch=prev_batch_token, limited=False
)
filtering_factor = 2
@@ -764,7 +769,7 @@ class SyncHandler:
elif since_token and not newly_joined_room:
since_key = since_token.room_key
while limited and len(recents) < timeline_limit and max_repeat:
while limited and len(filtered_recents) < timeline_limit and max_repeat:
# For initial `/sync`, we want to view a historical section of the
# timeline; to fetch events by `topological_ordering` (best
# representation of the room DAG as others were seeing it at the time).
@@ -835,26 +840,35 @@ class SyncHandler:
)
)
loaded_recents = await filter_and_transform_events_for_client(
loaded_filtered_recents: list[
FilteredEvent
] = await filter_and_transform_events_for_client(
self._storage_controllers,
sync_config.user.to_string(),
loaded_recents,
always_include_ids=current_state_ids,
)
log_kv({"loaded_recents_after_client_filtering": len(loaded_recents)})
log_kv(
{
"loaded_recents_after_client_filtering": len(
loaded_filtered_recents
)
}
)
loaded_recents.extend(recents)
recents = loaded_recents
loaded_filtered_recents.extend(filtered_recents)
filtered_recents = loaded_filtered_recents
max_repeat -= 1
if len(recents) > timeline_limit:
if len(filtered_recents) > timeline_limit:
limited = True
recents = recents[-timeline_limit:]
assert recents[0].internal_metadata.stream_ordering
filtered_recents = filtered_recents[-timeline_limit:]
assert filtered_recents[0].event.internal_metadata.stream_ordering
room_key = RoomStreamToken(
stream=recents[0].internal_metadata.stream_ordering - 1
stream=filtered_recents[0].event.internal_metadata.stream_ordering
- 1
)
prev_batch_token = upto_token.copy_and_replace(StreamKeyType.ROOM, room_key)
@@ -865,12 +879,12 @@ class SyncHandler:
if limited or newly_joined_room:
bundled_aggregations = (
await self._relations_handler.get_bundled_aggregations(
recents, sync_config.user.to_string()
filtered_recents, sync_config.user.to_string()
)
)
return TimelineBatch(
events=recents,
events=filtered_recents,
prev_batch=prev_batch_token,
# Also mark as limited if this is a new room or there has been a gap
# (to force client to paginate the gap).
@@ -976,8 +990,8 @@ class SyncHandler:
# ...or ones which are in the timeline...
for ev in batch.events:
if ev.type == EventTypes.Member:
existing_members.add(ev.state_key)
if ev.event.type == EventTypes.Member:
existing_members.add(ev.event.state_key)
# ...and then ensure any missing ones get included in state.
missing_hero_event_ids = [
@@ -1084,32 +1098,34 @@ class SyncHandler:
first_event_by_sender_map = {}
for event in batch.events:
# Build the map from user IDs to the first timeline event they sent.
if event.sender not in first_event_by_sender_map:
first_event_by_sender_map[event.sender] = event
if event.event.sender not in first_event_by_sender_map:
first_event_by_sender_map[event.event.sender] = event.event
# When using `state_after`, there is no special treatment with
# regards to state also being in the `timeline`. Always fetch
# relevant membership regardless of whether the state event is in
# the `timeline`.
if sync_config.use_state_after:
members_to_fetch.add(event.sender)
members_to_fetch.add(event.event.sender)
# For `state`, the client is supposed to do a flawed re-construction
# of state over time by starting with the given `state` and layering
# on state from the `timeline` as you go (flawed because state
# resolution). In this case, we only need their membership in
# `state` when their membership isn't already in the `timeline`.
elif (EventTypes.Member, event.sender) not in timeline_state:
members_to_fetch.add(event.sender)
elif (EventTypes.Member, event.event.sender) not in timeline_state:
members_to_fetch.add(event.event.sender)
# FIXME: we also care about invite targets etc.
if event.is_state():
timeline_state[(event.type, event.state_key)] = event.event_id
if event.event.is_state():
timeline_state[(event.event.type, event.event.state_key)] = (
event.event.event_id
)
else:
timeline_state = {
(event.type, event.state_key): event.event_id
(event.event.type, event.event.state_key): event.event.event_id
for event in batch.events
if event.is_state()
if event.event.is_state()
}
# Now calculate the state to return in the sync response for the room.
@@ -1340,7 +1356,7 @@ class SyncHandler:
# timeline, but that is good enough here.
state_at_timeline_start = (
await self._state_storage_controller.get_state_ids_for_event(
batch.events[0].event_id,
batch.events[0].event.event_id,
state_filter=state_filter,
await_full_state=await_full_state,
)
@@ -1470,10 +1486,10 @@ class SyncHandler:
prev_event_id = last_event_id_prev_batch
for e in batch.events:
if e.prev_event_ids() != [prev_event_id]:
if e.event.prev_event_ids() != [prev_event_id]:
is_linear_timeline = False
break
prev_event_id = e.event_id
prev_event_id = e.event.event_id
if is_linear_timeline and not batch.limited:
state_ids: StateMap[str] = {}
@@ -1487,7 +1503,7 @@ class SyncHandler:
state_ids = (
await self._state_storage_controller.get_state_ids_for_event(
batch.events[0].event_id,
batch.events[0].event.event_id,
# we only want members!
state_filter=StateFilter.from_types(
(EventTypes.Member, member)
@@ -1501,7 +1517,7 @@ class SyncHandler:
if batch:
state_at_timeline_start = (
await self._state_storage_controller.get_state_ids_for_event(
batch.events[0].event_id,
batch.events[0].event.event_id,
state_filter=state_filter,
await_full_state=await_full_state,
)
@@ -2854,7 +2870,7 @@ class SyncHandler:
# if there are membership changes in the timeline, or
# if membership has changed during a gappy sync, or
# if this is an initial sync.
any(ev.type == EventTypes.Member for ev in batch.events)
any(ev.event.type == EventTypes.Member for ev in batch.events)
or (
# XXX: this may include false positives in the form of LL
# members which have snuck into state
@@ -2870,7 +2886,7 @@ class SyncHandler:
if room_builder.rtype == "joined":
unread_notifications: dict[str, int] = {}
sticky_events: list[EventBase] = []
sticky_events: list[FilteredEvent] = []
if sticky_event_ids:
# As per MSC4354:
# Remove sticky events that are already in the timeline, else we will needlessly duplicate
@@ -2880,7 +2896,7 @@ class SyncHandler:
# This is particularly important given the risk of sticky events spam since
# anyone can send sticky events, so halving the bandwidth on average for each sticky
# event is helpful.
timeline_event_id_set = {ev.event_id for ev in batch.events}
timeline_event_id_set = {ev.event.event_id for ev in batch.events}
# Must preserve sticky event stream order
sticky_event_ids = [
e for e in sticky_event_ids if e not in timeline_event_id_set
@@ -3144,7 +3160,8 @@ class SyncResultBuilder:
if self.since_token:
for joined_sync in self.joined:
it = itertools.chain(
joined_sync.state.values(), joined_sync.timeline.events
joined_sync.state.values(),
(e.event for e in joined_sync.timeline.events),
)
for event in it:
if event.type == EventTypes.Member:
+6 -4
View File
@@ -53,7 +53,7 @@ class ThreadSubscriptionsHandler:
raise NotFoundError("No such thread root")
return await self.store.get_subscription_for_thread(
user_id.to_string(), event.room_id, thread_root_event_id
user_id.to_string(), event.event.room_id, thread_root_event_id
)
async def subscribe_user_to_thread(
@@ -103,7 +103,7 @@ class ThreadSubscriptionsHandler:
)
if autosub_cause_event is None:
raise NotFoundError("Automatic subscription event not found")
relation = relation_from_event(autosub_cause_event)
relation = relation_from_event(autosub_cause_event.event)
if (
relation is None
or relation.rel_type != RelationTypes.THREAD
@@ -115,7 +115,9 @@ class ThreadSubscriptionsHandler:
errcode=Codes.MSC4306_NOT_IN_THREAD,
)
automatic_event_orderings = EventOrderings.from_event(autosub_cause_event)
automatic_event_orderings = EventOrderings.from_event(
autosub_cause_event.event
)
else:
automatic_event_orderings = None
@@ -174,7 +176,7 @@ class ThreadSubscriptionsHandler:
outcome = await self.store.unsubscribe_user_from_thread(
user_id.to_string(),
event.room_id,
event.event.room_id,
thread_root_event_id,
)
+3 -2
View File
@@ -41,6 +41,7 @@ from twisted.internet.defer import Deferred
from synapse.api.constants import EduTypes, EventTypes, HistoryVisibility, Membership
from synapse.api.errors import AuthError
from synapse.events import EventBase
from synapse.events.utils import FilteredEvent
from synapse.handlers.presence import format_user_presence_state
from synapse.logging import issue9533_logger
from synapse.logging.context import PreserveLoggingContext
@@ -210,7 +211,7 @@ class _NotifierUserStream:
@attr.s(slots=True, frozen=True, auto_attribs=True)
class EventStreamResult:
events: list[JsonDict | EventBase]
events: list[JsonDict | FilteredEvent]
start_token: StreamToken
end_token: StreamToken
@@ -765,7 +766,7 @@ class Notifier:
# The events fetched from each source are a JsonDict, EventBase, or
# UserPresenceState, but see below for UserPresenceState being
# converted to JsonDict.
events: list[JsonDict | EventBase] = []
events: list[JsonDict | FilteredEvent] = []
end_token = from_token
for keyname, source in self.event_sources.sources.get_sources():
+4 -2
View File
@@ -543,8 +543,10 @@ class Mailer:
results.events_before + [notif_event],
)
for event in the_events:
messagevars = await self._get_message_vars(notif, event, room_state_ids)
for filtered_event in the_events:
messagevars = await self._get_message_vars(
notif, filtered_event.event, room_state_ids
)
if messagevars is not None:
ret["messages"].append(messagevars)
+4 -1
View File
@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING
from synapse.api.errors import NotFoundError
from synapse.events.utils import (
FilteredEvent,
SerializeEventConfig,
format_event_raw,
)
@@ -66,7 +67,9 @@ class EventRestServlet(RestServlet):
)
res = {
"event": await self._event_serializer.serialize_event(
event, self._clock.time_msec(), config=config
FilteredEvent.admin_override(event),
self._clock.time_msec(),
config=config,
)
}
+6 -2
View File
@@ -29,6 +29,7 @@ from synapse.api.constants import Direction, EventTypes, JoinRules, Membership
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.api.filtering import Filter
from synapse.events.utils import (
FilteredEvent,
SerializeEventConfig,
)
from synapse.handlers.pagination import (
@@ -529,7 +530,9 @@ class RoomStateRestServlet(RestServlet):
)
events = await self.store.get_events(event_ids.values())
now = self.clock.time_msec()
room_state = await self._event_serializer.serialize_events(events.values(), now)
room_state = await self._event_serializer.serialize_events(
[FilteredEvent.state(e) for e in events.values()], now
)
ret = {"state": room_state}
return HTTPStatus.OK, ret
@@ -897,7 +900,8 @@ class RoomEventContextServlet(RestServlet):
bundle_aggregations=event_context.aggregations,
),
"state": await self._event_serializer.serialize_events(
event_context.state, time_now
[FilteredEvent.state(e) for e in event_context.state],
time_now,
),
"start": event_context.start,
"end": event_context.end,
+2 -1
View File
@@ -24,6 +24,7 @@ from typing import TYPE_CHECKING
from synapse.api.constants import ReceiptTypes
from synapse.events.utils import (
FilteredEvent,
SerializeEventConfig,
format_event_for_client_v2_without_room_id,
)
@@ -111,7 +112,7 @@ class NotificationsServlet(RestServlet):
"ts": pa.received_ts,
"event": (
await self._event_serializer.serialize_event(
notif_events[pa.event_id],
FilteredEvent(event=notif_events[pa.event_id], membership=None),
now,
config=serialize_options,
)
+6 -3
View File
@@ -53,6 +53,7 @@ from synapse.api.errors import (
from synapse.api.filtering import Filter
from synapse.events.utils import (
EventClientSerializer,
FilteredEvent,
SerializeEventConfig,
format_event_for_client_v2,
)
@@ -286,7 +287,7 @@ class RoomStateEventRestServlet(RestServlet):
if format == "event":
event = await self._event_serializer.serialize_event(
data,
FilteredEvent.state(data),
self.clock.time_msec(),
config=SerializeEventConfig(
event_format=format_event_for_client_v2,
@@ -866,7 +867,9 @@ async def encode_messages_response(
serialized_result[
"state"
] = await serialize_deps.event_serializer.serialize_events(
get_messages_result.state, time_now, config=serialize_options
[FilteredEvent.state(e) for e in get_messages_result.state],
time_now,
config=serialize_options,
)
return serialized_result
@@ -1172,7 +1175,7 @@ class RoomEventContextServlet(RestServlet):
config=serializer_options,
),
"state": await self._event_serializer.serialize_events(
event_context.state,
[FilteredEvent.state(e) for e in event_context.state],
time_now,
config=serializer_options,
),
+22 -6
View File
@@ -18,7 +18,6 @@
# [This file includes modifications made by New Vector Limited]
#
#
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Mapping
@@ -31,6 +30,7 @@ from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState
from synapse.api.ratelimiting import Ratelimiter
from synapse.events.utils import (
FilteredEvent,
SerializeEventConfig,
format_event_for_client_v2_without_room_id,
format_event_raw,
@@ -448,7 +448,9 @@ class SyncRestServlet(RestServlet):
invited = {}
for room in rooms:
invite = await self._event_serializer.serialize_event(
room.invite, time_now, config=serialize_options
FilteredEvent.state(event=room.invite),
time_now,
config=serialize_options,
)
unsigned = dict(invite.get("unsigned", {}))
invite["unsigned"] = unsigned
@@ -484,7 +486,9 @@ class SyncRestServlet(RestServlet):
knocked = {}
for room in rooms:
knock = await self._event_serializer.serialize_event(
room.knock, time_now, config=serialize_options
FilteredEvent.state(event=room.knock),
time_now,
config=serialize_options,
)
# Extract the `unsigned` key from the knock event.
@@ -574,7 +578,7 @@ class SyncRestServlet(RestServlet):
state_events = state_dict.values()
for event in itertools.chain(state_events, timeline_events):
for event in state_events:
# We've had bug reports that events were coming down under the
# wrong room.
if event.room_id != room.room_id:
@@ -584,9 +588,21 @@ class SyncRestServlet(RestServlet):
room.room_id,
event.room_id,
)
for filtered_event in timeline_events:
# We've had bug reports that events were coming down under the
# wrong room.
if filtered_event.event.room_id != room.room_id:
logger.warning(
"Event %r is under room %r instead of %r",
filtered_event.event.event_id,
room.room_id,
filtered_event.event.room_id,
)
serialized_state = await self._event_serializer.serialize_events(
state_events, time_now, config=serialize_options
[FilteredEvent.state(e) for e in state_events],
time_now,
config=serialize_options,
)
serialized_timeline = await self._event_serializer.serialize_events(
timeline_events,
@@ -974,7 +990,7 @@ class SlidingSyncRestServlet(RestServlet):
):
serialized_required_state = (
await self.event_serializer.serialize_events(
room_result.required_state,
[FilteredEvent.state(e) for e in room_result.required_state],
time_now,
config=serialize_options,
)
+2 -1
View File
@@ -34,6 +34,7 @@ from pydantic import ConfigDict
from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.events.utils import FilteredEvent
from synapse.types import (
DeviceListUpdates,
JsonDict,
@@ -185,7 +186,7 @@ class SlidingSyncResult:
# Should be empty for invite/knock rooms with `stripped_state`
required_state: list[EventBase]
# Should be empty for invite/knock rooms with `stripped_state`
timeline_events: list[EventBase]
timeline_events: list[FilteredEvent]
bundled_aggregations: dict[str, "BundledAggregations"] | None
# Optional because it's only relevant to invite/knock rooms
stripped_state: list[JsonDict]
+4 -4
View File
@@ -380,7 +380,7 @@ class AutoAcceptInvitesTestCase(FederatingHomeserverTestCase):
join_updates, _ = sync_join(self, inviting_user_id)
# Assert that the last event in the room was not a member event for the target user.
self.assertEqual(
join_updates[0].timeline.events[-1].content["membership"], "invite"
join_updates[0].timeline.events[-1].event.content["membership"], "invite"
)
@override_config(
@@ -423,7 +423,7 @@ class AutoAcceptInvitesTestCase(FederatingHomeserverTestCase):
join_updates, b = sync_join(self, inviting_user_id)
# Assert that the last event in the room was not a member event for the target user.
self.assertEqual(
join_updates[0].timeline.events[-1].content["membership"], "invite"
join_updates[0].timeline.events[-1].event.content["membership"], "invite"
)
@override_config(
@@ -466,7 +466,7 @@ class AutoAcceptInvitesTestCase(FederatingHomeserverTestCase):
join_updates, b = sync_join(self, inviting_user_id)
# Assert that the last event in the room was not a member event for the target user.
self.assertEqual(
join_updates[0].timeline.events[-1].content["membership"], "invite"
join_updates[0].timeline.events[-1].event.content["membership"], "invite"
)
@override_config(
@@ -509,7 +509,7 @@ class AutoAcceptInvitesTestCase(FederatingHomeserverTestCase):
join_updates, b = sync_join(self, inviting_user_id)
# Assert that the last event in the room was not a member event for the target user.
self.assertEqual(
join_updates[0].timeline.events[-1].content["membership"], "invite"
join_updates[0].timeline.events[-1].event.content["membership"], "invite"
)
+2 -1
View File
@@ -28,6 +28,7 @@ from synapse.api.constants import EventContentFields
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, make_event_from_dict
from synapse.events.utils import (
FilteredEvent,
PowerLevelsContent,
SerializeEventConfig,
_split_field,
@@ -655,7 +656,7 @@ class SerializeEventTestCase(HomeserverTestCase):
) -> JsonDict:
return self.get_success(
self._event_serializer.serialize_event(
ev,
FilteredEvent(event=ev, membership=None),
1479807801915,
config=SerializeEventConfig(
only_event_fields=fields,
+8 -4
View File
@@ -81,7 +81,8 @@ class ExfiltrateData(unittest.HomeserverTestCase):
# Check that the right number of events were written
counter = Counter(
(event.type, getattr(event, "state_key", None)) for event in written_events
(event.event.type, getattr(event.event, "state_key", None))
for event in written_events
)
self.assertEqual(counter[(EventTypes.Message, None)], 2)
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
@@ -119,7 +120,8 @@ class ExfiltrateData(unittest.HomeserverTestCase):
# Check that the right number of events were written
counter = Counter(
(event.type, getattr(event, "state_key", None)) for event in written_events
(event.event.type, getattr(event.event, "state_key", None))
for event in written_events
)
self.assertEqual(counter[(EventTypes.Message, None)], 1)
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
@@ -151,7 +153,8 @@ class ExfiltrateData(unittest.HomeserverTestCase):
# Check that the right number of events were written
counter = Counter(
(event.type, getattr(event, "state_key", None)) for event in written_events
(event.event.type, getattr(event.event, "state_key", None))
for event in written_events
)
self.assertEqual(counter[(EventTypes.Message, None)], 2)
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
@@ -192,7 +195,8 @@ class ExfiltrateData(unittest.HomeserverTestCase):
# Check that the right number of events were written
counter = Counter(
(event.type, getattr(event, "state_key", None)) for event in written_events
(event.event.type, getattr(event.event, "state_key", None))
for event in written_events
)
self.assertEqual(counter[(EventTypes.Message, None)], 2)
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
+14 -14
View File
@@ -307,7 +307,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.assertEqual(len(alice_sync_result.joined), 1)
self.assertEqual(alice_sync_result.joined[0].room_id, room_id)
last_room_creation_event_id = (
alice_sync_result.joined[0].timeline.events[-1].event_id
alice_sync_result.joined[0].timeline.events[-1].event.event_id
)
# Eve, a ne'er-do-well, registers.
@@ -402,7 +402,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
)
)
last_room_creation_event_id = (
initial_sync_result.joined[0].timeline.events[-1].event_id
initial_sync_result.joined[0].timeline.events[-1].event.event_id
)
# Send a state event, and a regular event, both using the same prev ID
@@ -437,7 +437,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.assertEqual(room_sync.room_id, room_id)
self.assertTrue(room_sync.timeline.limited)
self.assertEqual(
[e.event_id for e in room_sync.timeline.events],
[e.event.event_id for e in room_sync.timeline.events],
[e3_event, e4_event],
)
self.assertEqual(
@@ -476,7 +476,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
)
)
last_room_creation_event_id = (
initial_sync_result.joined[0].timeline.events[-1].event_id
initial_sync_result.joined[0].timeline.events[-1].event.event_id
)
# Send a state event, and a regular event, both using the same prev ID
@@ -521,7 +521,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.assertEqual(room_sync.room_id, room_id)
self.assertTrue(room_sync.timeline.limited)
self.assertEqual(
[e.event_id for e in room_sync.timeline.events],
[e.event.event_id for e in room_sync.timeline.events],
[e3_event],
)
self.assertEqual(
@@ -563,7 +563,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
)
)
last_room_creation_event_id = (
initial_sync_result.joined[0].timeline.events[-1].event_id
initial_sync_result.joined[0].timeline.events[-1].event.event_id
)
# Send a state event, and a regular event, both using the same prev ID
@@ -593,7 +593,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.assertEqual(room_sync.room_id, room_id)
self.assertTrue(room_sync.timeline.limited)
self.assertEqual(
[e.event_id for e in room_sync.timeline.events],
[e.event.event_id for e in room_sync.timeline.events],
[e3_event],
)
@@ -632,7 +632,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.assertEqual(room_sync.room_id, room_id)
self.assertFalse(room_sync.timeline.limited)
self.assertEqual(
[e.event_id for e in room_sync.timeline.events],
[e.event.event_id for e in room_sync.timeline.events],
[e4_event],
)
@@ -701,7 +701,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
)
)
last_room_creation_event_id = (
initial_sync_result.joined[0].timeline.events[-1].event_id
initial_sync_result.joined[0].timeline.events[-1].event.event_id
)
# Send a state event, and a regular event, both using the same prev ID
@@ -728,7 +728,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
room_sync = initial_sync_result.joined[0]
self.assertEqual(room_sync.room_id, room_id)
self.assertEqual(
[e.event_id for e in room_sync.timeline.events],
[e.event.event_id for e in room_sync.timeline.events],
[e3_event],
)
if self.use_state_after:
@@ -757,7 +757,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.assertEqual(room_sync.room_id, room_id)
self.assertFalse(room_sync.timeline.limited)
self.assertEqual(
[e.event_id for e in room_sync.timeline.events],
[e.event.event_id for e in room_sync.timeline.events],
[e4_event, e5_event],
)
@@ -855,7 +855,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# The last three events in the timeline should be those leading up to the
# leave
self.assertEqual(
[e.event_id for e in sync_room_result.timeline.events[-3:]],
[e.event.event_id for e in sync_room_result.timeline.events[-3:]],
[before_message_event, before_state_event, leave_event],
)
@@ -947,7 +947,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
)
event_ids = []
for event in sync_result.joined[0].timeline.events:
event_ids.append(event.event_id)
event_ids.append(event.event.event_id)
self.assertNotIn(call_event.event_id, event_ids)
# it will come down in a private room, though
@@ -995,7 +995,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
)
priv_event_ids = []
for event in private_sync_result.joined[0].timeline.events:
priv_event_ids.append(event.event_id)
priv_event_ids.append(event.event.event_id)
self.assertIn(private_call_event.event_id, priv_event_ids)
+9 -2
View File
@@ -23,6 +23,7 @@ from unittest.mock import Mock
from twisted.internet.testing import MemoryReactor
from synapse.api.constants import EventTypes
from synapse.events.utils import FilteredEvent
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
@@ -173,7 +174,9 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# We should only get one event back.
self.assertEqual(len(filtered_events), 1, filtered_events)
# That event should be the second, not outdated event.
self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events)
self.assertEqual(
filtered_events[0].event.event_id, valid_event_id, filtered_events
)
def _test_retention_event_purged(self, room_id: str, increment: float) -> None:
"""Run the following test scenario to test the message retention policy support:
@@ -253,7 +256,11 @@ class RetentionTestCase(unittest.HomeserverTestCase):
assert event is not None
time_now = self.clock.time_msec()
serialized = self.get_success(self.serializer.serialize_event(event, time_now))
serialized = self.get_success(
self.serializer.serialize_event(
FilteredEvent(event=event, membership=None), time_now
)
)
return serialized
+11 -20
View File
@@ -22,7 +22,7 @@ from unittest.mock import patch
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import AccountDataTypes, EventUnsignedContentFields
from synapse.api.constants import AccountDataTypes
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext
@@ -341,7 +341,7 @@ class FilterEventsForServerAdminsTestCase(HomeserverTestCase):
)
self.assertEqual(
[e.event_id for e in [self.regular_event]],
[e.event_id for e in filtered_events],
[e.event.event_id for e in filtered_events],
)
def test_see_soft_failed_events(self) -> None:
@@ -380,7 +380,7 @@ class FilterEventsForServerAdminsTestCase(HomeserverTestCase):
)
self.assertEqual(
[e.event_id for e in [self.regular_event, self.soft_failed_event]],
[e.event_id for e in filtered_events],
[e.event.event_id for e in filtered_events],
)
def test_see_policy_server_spammy_events(self) -> None:
@@ -427,7 +427,7 @@ class FilterEventsForServerAdminsTestCase(HomeserverTestCase):
)
self.assertEqual(
[e.event_id for e in [self.regular_event, self.spammy_event]],
[e.event_id for e in filtered_events],
[e.event.event_id for e in filtered_events],
)
def test_see_soft_failed_and_policy_server_spammy_events(self) -> None:
@@ -477,7 +477,7 @@ class FilterEventsForServerAdminsTestCase(HomeserverTestCase):
e.event_id
for e in [self.regular_event, self.soft_failed_event, self.spammy_event]
],
[e.event_id for e in filtered_events],
[e.event.event_id for e in filtered_events],
)
@@ -559,14 +559,11 @@ class FilterEventsForClientTestCase(HomeserverTestCase):
# and messages sent between the two, but not before or after.
self.assertEqual(
[e.event_id for e in [join_event, during_event, leave_event]],
[e.event_id for e in joiner_filtered_events],
[e.event.event_id for e in joiner_filtered_events],
)
self.assertEqual(
["join", "join", "leave"],
[
e.unsigned[EventUnsignedContentFields.MEMBERSHIP]
for e in joiner_filtered_events
],
[e.membership for e in joiner_filtered_events],
)
# The resident user should see all the events.
@@ -581,14 +578,11 @@ class FilterEventsForClientTestCase(HomeserverTestCase):
after_event,
]
],
[e.event_id for e in resident_filtered_events],
[e.event.event_id for e in resident_filtered_events],
)
self.assertEqual(
["join", "join", "join", "join", "join"],
[
e.unsigned[EventUnsignedContentFields.MEMBERSHIP]
for e in resident_filtered_events
],
[e.membership for e in resident_filtered_events],
)
@@ -651,15 +645,12 @@ class FilterEventsOutOfBandEventsForClientTestCase(
)
)
self.assertEqual(
[e.event_id for e in filtered_events],
[e.event.event_id for e in filtered_events],
[e.event_id for e in [invite_event, reject_event]],
)
self.assertEqual(
["invite", "leave"],
[
e.unsigned[EventUnsignedContentFields.MEMBERSHIP]
for e in filtered_events
],
[e.membership for e in filtered_events],
)
# other users should see neither