diff --git a/changelog.d/17545.bugfix b/changelog.d/17545.bugfix new file mode 100644 index 0000000000..31e22d873e --- /dev/null +++ b/changelog.d/17545.bugfix @@ -0,0 +1 @@ +Handle lower-case http headers in `_Mulitpart_Parser_Protocol`. \ No newline at end of file diff --git a/changelog.d/17548.misc b/changelog.d/17548.misc new file mode 100644 index 0000000000..861b241dcd --- /dev/null +++ b/changelog.d/17548.misc @@ -0,0 +1 @@ +Fix performance of device lists in `/key/changes` and sliding sync. diff --git a/changelog.d/17562.misc b/changelog.d/17562.misc new file mode 100644 index 0000000000..a267df8b83 --- /dev/null +++ b/changelog.d/17562.misc @@ -0,0 +1 @@ +Test github token before running release script steps. diff --git a/changelog.d/17566.misc b/changelog.d/17566.misc new file mode 100644 index 0000000000..7210753fa3 --- /dev/null +++ b/changelog.d/17566.misc @@ -0,0 +1 @@ +Speed up responding to media requests. \ No newline at end of file diff --git a/changelog.d/17567.misc b/changelog.d/17567.misc new file mode 100644 index 0000000000..cfa8089a81 --- /dev/null +++ b/changelog.d/17567.misc @@ -0,0 +1 @@ +Speed up responding to media requests. diff --git a/changelog.d/17568.bugfix b/changelog.d/17568.bugfix new file mode 100644 index 0000000000..71a1f12915 --- /dev/null +++ b/changelog.d/17568.bugfix @@ -0,0 +1 @@ +Fix fetching federation signing keys from servers that omit `old_verify_keys`. Contributed by @tulir @ Beeper. diff --git a/changelog.d/17569.misc b/changelog.d/17569.misc new file mode 100644 index 0000000000..cfa8089a81 --- /dev/null +++ b/changelog.d/17569.misc @@ -0,0 +1 @@ +Speed up responding to media requests. diff --git a/changelog.d/17570.bugfix b/changelog.d/17570.bugfix new file mode 100644 index 0000000000..e2964168b1 --- /dev/null +++ b/changelog.d/17570.bugfix @@ -0,0 +1 @@ +Fix bug where we would respond with an error when a remote server asked for media that had a length of 0, using the new multipart federation media endpoint. diff --git a/changelog.d/17571.misc b/changelog.d/17571.misc new file mode 100644 index 0000000000..67182a4fcd --- /dev/null +++ b/changelog.d/17571.misc @@ -0,0 +1 @@ +Add a flag to `/versions`, `org.matrix.simplified_msc3575`, to indicate whether experimental sliding sync support has been enabled. diff --git a/changelog.d/17574.misc b/changelog.d/17574.misc new file mode 100644 index 0000000000..71020abec4 --- /dev/null +++ b/changelog.d/17574.misc @@ -0,0 +1 @@ +Refactor per-connection state in experimental sliding sync handler. diff --git a/changelog.d/17575.misc b/changelog.d/17575.misc new file mode 100644 index 0000000000..1b4a53ee17 --- /dev/null +++ b/changelog.d/17575.misc @@ -0,0 +1 @@ +Correctly track read receipts that should be sent down in experimental sliding sync. diff --git a/scripts-dev/release.py b/scripts-dev/release.py index 5e519bb758..1ace804682 100755 --- a/scripts-dev/release.py +++ b/scripts-dev/release.py @@ -324,6 +324,11 @@ def tag(gh_token: Optional[str]) -> None: def _tag(gh_token: Optional[str]) -> None: """Tags the release and generates a draft GitHub release""" + if gh_token: + # Test that the GH Token is valid before continuing. + gh = Github(gh_token) + gh.get_user() + # Make sure we're in a git repo. repo = get_repo_and_check_clean_checkout() @@ -418,6 +423,11 @@ def publish(gh_token: str) -> None: def _publish(gh_token: str) -> None: """Publish release on GitHub.""" + if gh_token: + # Test that the GH Token is valid before continuing. + gh = Github(gh_token) + gh.get_user() + # Make sure we're in a git repo. get_repo_and_check_clean_checkout() @@ -460,6 +470,11 @@ def upload(gh_token: Optional[str]) -> None: def _upload(gh_token: Optional[str]) -> None: """Upload release to pypi.""" + if gh_token: + # Test that the GH Token is valid before continuing. + gh = Github(gh_token) + gh.get_user() + current_version = get_package_version() tag_name = f"v{current_version}" @@ -555,6 +570,11 @@ def wait_for_actions(gh_token: Optional[str]) -> None: def _wait_for_actions(gh_token: Optional[str]) -> None: + if gh_token: + # Test that the GH Token is valid before continuing. + gh = Github(gh_token) + gh.get_user() + # Find out the version and tag name. current_version = get_package_version() tag_name = f"v{current_version}" @@ -711,6 +731,11 @@ Ask the designated people to do the blog and tweets.""" @cli.command() @click.option("--gh-token", envvar=["GH_TOKEN", "GITHUB_TOKEN"], required=True) def full(gh_token: str) -> None: + if gh_token: + # Test that the GH Token is valid before continuing. + gh = Github(gh_token) + gh.get_user() + click.echo("1. If this is a security release, read the security wiki page.") click.echo("2. Check for any release blockers before proceeding.") click.echo(" https://github.com/element-hq/synapse/labels/X-Release-Blocker") diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 8c301e077c..643d2d4e66 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -589,7 +589,7 @@ class BaseV2KeyFetcher(KeyFetcher): % (server_name,) ) - for key_id, key_data in response_json["old_verify_keys"].items(): + for key_id, key_data in response_json.get("old_verify_keys", {}).items(): if is_signing_algorithm_supported(key_id): key_base64 = key_data["key"] key_bytes = decode_base64(key_base64) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index ce26c91a7b..4f2a9f3a5b 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -267,31 +267,27 @@ class DeviceWorkerHandler: newly_left_rooms.add(change.room_id) # We now work out if any other users have since joined or left the rooms - # the user is currently in. First we filter out rooms that we know - # haven't changed recently. - rooms_changed = self.store.get_rooms_that_changed( - joined_room_ids, from_token.room_key - ) + # the user is currently in. # List of membership changes per room room_to_deltas: Dict[str, List[StateDelta]] = {} # The set of event IDs of membership events (so we can fetch their # associated membership). memberships_to_fetch: Set[str] = set() - for room_id in rooms_changed: - # TODO: Only pull out membership events? - state_changes = await self.store.get_current_state_deltas_for_room( - room_id, from_token=from_token.room_key, to_token=now_token.room_key - ) - for delta in state_changes: - if delta.event_type != EventTypes.Member: - continue - room_to_deltas.setdefault(room_id, []).append(delta) - if delta.event_id: - memberships_to_fetch.add(delta.event_id) - if delta.prev_event_id: - memberships_to_fetch.add(delta.prev_event_id) + # TODO: Only pull out membership events? + state_changes = await self.store.get_current_state_deltas_for_rooms( + joined_room_ids, from_token=from_token.room_key, to_token=now_token.room_key + ) + for delta in state_changes: + if delta.event_type != EventTypes.Member: + continue + + room_to_deltas.setdefault(delta.room_id, []).append(delta) + if delta.event_id: + memberships_to_fetch.add(delta.event_id) + if delta.prev_event_id: + memberships_to_fetch.add(delta.prev_event_id) # Fetch all the memberships for the membership events event_id_to_memberships = await self.store.get_membership_from_event_ids( diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index 0859cd1463..0b0a040477 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -19,6 +19,8 @@ # import enum import logging +import typing +from collections import ChainMap from enum import Enum from itertools import chain from typing import ( @@ -27,14 +29,18 @@ from typing import ( Callable, Dict, Final, + Generic, List, Literal, Mapping, + MutableMapping, Optional, Sequence, Set, Tuple, + TypeVar, Union, + cast, ) import attr @@ -51,6 +57,7 @@ from synapse.api.constants import ( from synapse.api.errors import SlidingSyncUnknownPosition from synapse.events import EventBase, StrippedStateEvent from synapse.events.utils import parse_stripped_state_event, strip_event +from synapse.handlers.receipts import ReceiptEventSource from synapse.handlers.relations import BundledAggregations from synapse.logging.opentracing import ( SynapseTags, @@ -564,21 +571,21 @@ class SlidingSyncHandler: # See https://github.com/matrix-org/matrix-doc/issues/1144 raise NotImplementedError() - if from_token: - # Check that we recognize the connection position, if not tell the - # clients that they need to start again. - # - # If we don't do this and the client asks for the full range of - # rooms, we end up sending down all rooms and their state from - # scratch (which can be very slow). By expiring the connection we - # allow the client a chance to do an initial request with a smaller - # range of rooms to get them some results sooner but will end up - # taking the same amount of time (more with round-trips and - # re-processing) in the end to get everything again. - if not await self.connection_store.is_valid_token( - sync_config, from_token.connection_position - ): - raise SlidingSyncUnknownPosition() + # Get the per-connection state (if any). + # + # Raises an exception if there is a `connection_position` that we don't + # recognize. If we don't do this and the client asks for the full range + # of rooms, we end up sending down all rooms and their state from + # scratch (which can be very slow). By expiring the connection we allow + # the client a chance to do an initial request with a smaller range of + # rooms to get them some results sooner but will end up taking the same + # amount of time (more with round-trips and re-processing) in the end to + # get everything again. + previous_connection_state = ( + await self.connection_store.get_per_connection_state( + sync_config, from_token + ) + ) await self.connection_store.mark_token_seen( sync_config=sync_config, @@ -774,11 +781,7 @@ class SlidingSyncHandler: # we haven't sent the room down, or we have but there are missing # updates). for room_id in relevant_room_map: - status = await self.connection_store.have_sent_room( - sync_config, - from_token.connection_position, - room_id, - ) + status = previous_connection_state.rooms.have_sent_room(room_id) if ( # The room was never sent down before so the client needs to know # about it regardless of any updates. @@ -814,6 +817,7 @@ class SlidingSyncHandler: async def handle_room(room_id: str) -> None: room_sync_result = await self.get_room_sync_data( sync_config=sync_config, + previous_connection_state=previous_connection_state, room_id=room_id, room_sync_config=relevant_rooms_to_send_map[room_id], room_membership_for_user_at_to_token=room_membership_for_user_map[ @@ -831,9 +835,13 @@ class SlidingSyncHandler: with start_active_span("sliding_sync.generate_room_entries"): await concurrently_execute(handle_room, relevant_rooms_to_send_map, 10) + new_connection_state = previous_connection_state.get_mutable() + extensions = await self.get_extensions_response( sync_config=sync_config, actual_lists=lists, + previous_connection_state=previous_connection_state, + new_connection_state=new_connection_state, # We're purposely using `relevant_room_map` instead of # `relevant_rooms_to_send_map` here. This needs to be all room_ids we could # send regardless of whether they have an event update or not. The @@ -875,11 +883,18 @@ class SlidingSyncHandler: ) unsent_room_ids = list(missing_event_map_by_room) - connection_position = await self.connection_store.record_rooms( + new_connection_state.rooms.record_unsent_rooms( + unsent_room_ids, from_token.stream_token.room_key + ) + + new_connection_state.rooms.record_sent_rooms( + relevant_rooms_to_send_map.keys() + ) + + connection_position = await self.connection_store.record_new_state( sync_config=sync_config, from_token=from_token, - sent_room_ids=relevant_rooms_to_send_map.keys(), - unsent_room_ids=unsent_room_ids, + new_connection_state=new_connection_state, ) elif from_token: connection_position = from_token.connection_position @@ -1932,6 +1947,7 @@ class SlidingSyncHandler: async def get_room_sync_data( self, sync_config: SlidingSyncConfig, + previous_connection_state: "PerConnectionState", room_id: str, room_sync_config: RoomSyncConfig, room_membership_for_user_at_to_token: _RoomMembershipForUser, @@ -1979,11 +1995,7 @@ class SlidingSyncHandler: from_bound = None initial = True if from_token and not room_membership_for_user_at_to_token.newly_joined: - room_status = await self.connection_store.have_sent_room( - sync_config=sync_config, - connection_token=from_token.connection_position, - room_id=room_id, - ) + room_status = previous_connection_state.rooms.have_sent_room(room_id) if room_status.status == HaveSentRoomFlag.LIVE: from_bound = from_token.stream_token.room_key initial = False @@ -2464,6 +2476,8 @@ class SlidingSyncHandler: async def get_extensions_response( self, sync_config: SlidingSyncConfig, + previous_connection_state: "PerConnectionState", + new_connection_state: "MutablePerConnectionState", actual_lists: Dict[str, SlidingSyncResult.SlidingWindowList], actual_room_ids: Set[str], actual_room_response_map: Dict[str, SlidingSyncResult.RoomResult], @@ -2474,6 +2488,9 @@ class SlidingSyncHandler: Args: sync_config: Sync configuration + new_connection_state: Snapshot of the current per-connection state + new_per_connection_state: A mutable copy of the per-connection + state, used to record updates to the state during this request. actual_lists: Sliding window API. A map of list key to list results in the Sliding Sync response. actual_room_ids: The actual room IDs in the the Sliding Sync response. @@ -2518,6 +2535,8 @@ class SlidingSyncHandler: if sync_config.extensions.receipts is not None: receipts_response = await self.get_receipts_extension_response( sync_config=sync_config, + previous_connection_state=previous_connection_state, + new_connection_state=new_connection_state, actual_lists=actual_lists, actual_room_ids=actual_room_ids, actual_room_response_map=actual_room_response_map, @@ -2837,6 +2856,8 @@ class SlidingSyncHandler: async def get_receipts_extension_response( self, sync_config: SlidingSyncConfig, + previous_connection_state: "PerConnectionState", + new_connection_state: "MutablePerConnectionState", actual_lists: Dict[str, SlidingSyncResult.SlidingWindowList], actual_room_ids: Set[str], actual_room_response_map: Dict[str, SlidingSyncResult.RoomResult], @@ -2848,6 +2869,9 @@ class SlidingSyncHandler: Args: sync_config: Sync configuration + previous_connection_state: The current per-connection state + new_connection_state: A mutable copy of the per-connection + state, used to record updates to the state. actual_lists: Sliding window API. A map of list key to list results in the Sliding Sync response. actual_room_ids: The actual room IDs in the the Sliding Sync response. @@ -2870,50 +2894,145 @@ class SlidingSyncHandler: room_id_to_receipt_map: Dict[str, JsonMapping] = {} if len(relevant_room_ids) > 0: - # TODO: Take connection tracking into account so that when a room comes back - # into range we can send the receipts that were missed. - receipt_source = self.event_sources.sources.receipt - receipts, _ = await receipt_source.get_new_events( - user=sync_config.user, - from_key=( - from_token.stream_token.receipt_key - if from_token - else MultiWriterStreamToken(stream=0) - ), - to_key=to_token.receipt_key, - # This is a dummy value and isn't used in the function - limit=0, - room_ids=relevant_room_ids, - is_guest=False, + # We need to handle the different cases depending on if we have sent + # down receipts previously or not, so we split the relevant rooms + # up into different collections based on status. + live_rooms = set() + previously_rooms: Dict[str, MultiWriterStreamToken] = {} + initial_rooms = set() + + for room_id in relevant_room_ids: + if not from_token: + initial_rooms.add(room_id) + continue + + # If we're sending down the room from scratch again for some reason, we + # should always resend the receipts as well (regardless of if + # we've sent them down before). This is to mimic the behaviour + # of what happens on initial sync, where you get a chunk of + # timeline with all of the corresponding receipts for the events in the timeline. + room_result = actual_room_response_map.get(room_id) + if room_result is not None and room_result.initial: + initial_rooms.add(room_id) + continue + + room_status = previous_connection_state.receipts.have_sent_room(room_id) + if room_status.status == HaveSentRoomFlag.LIVE: + live_rooms.add(room_id) + elif room_status.status == HaveSentRoomFlag.PREVIOUSLY: + assert room_status.last_token is not None + previously_rooms[room_id] = room_status.last_token + elif room_status.status == HaveSentRoomFlag.NEVER: + initial_rooms.add(room_id) + else: + assert_never(room_status.status) + + # The set of receipts that we fetched. Private receipts need to be + # filtered out before returning. + fetched_receipts = [] + + # For live rooms we just fetch all receipts in those rooms since the + # `since` token. + if live_rooms: + assert from_token is not None + receipts = await self.store.get_linearized_receipts_for_rooms( + room_ids=live_rooms, + from_key=from_token.stream_token.receipt_key, + to_key=to_token.receipt_key, + ) + fetched_receipts.extend(receipts) + + # For rooms we've previously sent down, but aren't up to date, we + # need to use the from token from the room status. + if previously_rooms: + for room_id, receipt_token in previously_rooms.items(): + # TODO: Limit the number of receipts we're about to send down + # for the room, if its too many we should TODO + previously_receipts = ( + await self.store.get_linearized_receipts_for_room( + room_id=room_id, + from_key=receipt_token, + to_key=to_token.receipt_key, + ) + ) + fetched_receipts.extend(previously_receipts) + + # For rooms we haven't previously sent down, we could send all receipts + # from that room but we only want to include receipts for events + # in the timeline to avoid bloating and blowing up the sync response + # as the number of users in the room increases. (this behavior is part of the spec) + for room_id in initial_rooms: + room_result = actual_room_response_map.get(room_id) + if room_result is None: + continue + + relevant_event_ids = [ + event.event_id for event in room_result.timeline_events + ] + + # TODO: In the future, it would be good to fetch less receipts + # out of the database in the first place but we would need to + # add a new `event_id` index to `receipts_linearized`. + initial_receipts = await self.store.get_linearized_receipts_for_room( + room_id=room_id, + to_key=to_token.receipt_key, + ) + + for receipt in initial_receipts: + content = { + event_id: content_value + for event_id, content_value in receipt["content"].items() + if event_id in relevant_event_ids + } + if content: + fetched_receipts.append( + { + "type": receipt["type"], + "room_id": receipt["room_id"], + "content": content, + } + ) + + fetched_receipts = ReceiptEventSource.filter_out_private_receipts( + fetched_receipts, sync_config.user.to_string() ) - for receipt in receipts: + for receipt in fetched_receipts: # These fields should exist for every receipt room_id = receipt["room_id"] type = receipt["type"] content = receipt["content"] - # For `inital: True` rooms, we only want to include receipts for events - # in the timeline. - room_result = actual_room_response_map.get(room_id) - if room_result is not None: - if room_result.initial: - # TODO: In the future, it would be good to fetch less receipts - # out of the database in the first place but we would need to - # add a new `event_id` index to `receipts_linearized`. - relevant_event_ids = [ - event.event_id for event in room_result.timeline_events - ] - - assert isinstance(content, dict) - content = { - event_id: content_value - for event_id, content_value in content.items() - if event_id in relevant_event_ids - } - room_id_to_receipt_map[room_id] = {"type": type, "content": content} + # Now we update the per-connection state to track which receipts we have + # and haven't sent down. + new_connection_state.receipts.record_sent_rooms(relevant_room_ids) + + if from_token: + # Now find the set of rooms that may have receipts that we're not sending + # down. We only need to check rooms that we have previously returned + # receipts for (in `previous_connection_state`) because we only care about + # updating `LIVE` rooms to `PREVIOUSLY`. The `PREVIOUSLY` rooms will just + # stay pointing at their previous position so we don't need to waste time + # checking those and since we default to `NEVER`, rooms that were `NEVER` + # sent before don't need to be recorded as we'll handle them correctly when + # they come into range for the first time. + rooms_no_receipts = [ + room_id + for room_id, room_status in previous_connection_state.receipts._statuses.items() + if room_status.status == HaveSentRoomFlag.LIVE + and room_id not in relevant_room_ids + ] + changed_rooms = await self.store.get_rooms_with_receipts_between( + rooms_no_receipts, + from_key=from_token.stream_token.receipt_key, + to_key=to_token.receipt_key, + ) + new_connection_state.receipts.record_unsent_rooms( + changed_rooms, from_token.stream_token.receipt_key + ) + return SlidingSyncResult.Extensions.ReceiptsExtension( room_id_to_receipt_map=room_id_to_receipt_map, ) @@ -3004,9 +3123,15 @@ class HaveSentRoomFlag(Enum): LIVE = 3 +T = TypeVar("T") + + @attr.s(auto_attribs=True, slots=True, frozen=True) -class HaveSentRoom: - """Whether we have sent the room down a sliding sync connection. +class HaveSentRoom(Generic[T]): + """Whether we have sent the room data down a sliding sync connection. + + We are generic over the type of token used, e.g. `RoomStreamToken` or + `MultiWriterStreamToken`. Attributes: status: Flag of if we have or haven't sent down the room @@ -3017,16 +3142,143 @@ class HaveSentRoom: """ status: HaveSentRoomFlag - last_token: Optional[RoomStreamToken] + last_token: Optional[T] @staticmethod - def previously(last_token: RoomStreamToken) -> "HaveSentRoom": + def live() -> "HaveSentRoom[T]": + return HaveSentRoom(HaveSentRoomFlag.LIVE, None) + + @staticmethod + def previously(last_token: T) -> "HaveSentRoom[T]": """Constructor for `PREVIOUSLY` flag.""" return HaveSentRoom(HaveSentRoomFlag.PREVIOUSLY, last_token) + @staticmethod + def never() -> "HaveSentRoom[T]": + return HaveSentRoom(HaveSentRoomFlag.NEVER, None) -HAVE_SENT_ROOM_NEVER = HaveSentRoom(HaveSentRoomFlag.NEVER, None) -HAVE_SENT_ROOM_LIVE = HaveSentRoom(HaveSentRoomFlag.LIVE, None) + +@attr.s(auto_attribs=True, slots=True, frozen=True) +class RoomStatusMap(Generic[T]): + """For a given stream, e.g. events, records what we have or have not sent + down for that stream in a given room.""" + + # `room_id` -> `HaveSentRoom` + _statuses: Mapping[str, HaveSentRoom[T]] = attr.Factory(dict) + + def have_sent_room(self, room_id: str) -> HaveSentRoom[T]: + """Return whether we have previously sent the room down""" + return self._statuses.get(room_id, HaveSentRoom.never()) + + def get_mutable(self) -> "MutableRoomStatusMap[T]": + """Get a mutable copy of this state.""" + return MutableRoomStatusMap( + statuses=self._statuses, + ) + + def copy(self) -> "RoomStatusMap[T]": + """Make a copy of the class. Useful for converting from a mutable to + immutable version.""" + + return RoomStatusMap(statuses=dict(self._statuses)) + + +class MutableRoomStatusMap(RoomStatusMap[T]): + """A mutable version of `RoomStatusMap`""" + + # We use a ChainMap here so that we can easily track what has been updated + # and what hasn't. Note that when we persist the per connection state this + # will get flattened to a normal dict (via calling `.copy()`) + _statuses: typing.ChainMap[str, HaveSentRoom[T]] + + def __init__( + self, + statuses: Mapping[str, HaveSentRoom[T]], + ) -> None: + # ChainMap requires a mutable mapping, but we're not actually going to + # mutate it. + statuses = cast(MutableMapping, statuses) + + super().__init__( + statuses=ChainMap({}, statuses), + ) + + def get_updates(self) -> Mapping[str, HaveSentRoom[T]]: + """Return only the changes that were made""" + return self._statuses.maps[0] + + def record_sent_rooms(self, room_ids: StrCollection) -> None: + """Record that we have sent these rooms in the response""" + for room_id in room_ids: + current_status = self._statuses.get(room_id, HaveSentRoom.never()) + if current_status.status == HaveSentRoomFlag.LIVE: + continue + + self._statuses[room_id] = HaveSentRoom.live() + + def record_unsent_rooms(self, room_ids: StrCollection, from_token: T) -> None: + """Record that we have not sent these rooms in the response, but there + have been updates. + """ + # Whether we add/update the entries for unsent rooms depends on the + # existing entry: + # - LIVE: We have previously sent down everything up to + # `last_room_token, so we update the entry to be `PREVIOUSLY` with + # `last_room_token`. + # - PREVIOUSLY: We have previously sent down everything up to *a* + # given token, so we don't need to update the entry. + # - NEVER: We have never previously sent down the room, and we haven't + # sent anything down this time either so we leave it as NEVER. + + for room_id in room_ids: + current_status = self._statuses.get(room_id, HaveSentRoom.never()) + if current_status.status != HaveSentRoomFlag.LIVE: + continue + + self._statuses[room_id] = HaveSentRoom.previously(from_token) + + +@attr.s(auto_attribs=True) +class PerConnectionState: + """The per-connection state. A snapshot of what we've sent down the + connection before. + + Currently, we track whether we've sent down various aspects of a given room + before. + + We use the `rooms` field to store the position in the events stream for each + room that we've previously sent to the client before. On the next request + that includes the room, we can then send only what's changed since that + recorded position. + + Same goes for the `receipts` field so we only need to send the new receipts + since the last time you made a sync request. + + Attributes: + rooms: The status of each room for the events stream. + receipts: The status of each room for the receipts stream. + """ + + rooms: RoomStatusMap[RoomStreamToken] = attr.Factory(RoomStatusMap) + receipts: RoomStatusMap[MultiWriterStreamToken] = attr.Factory(RoomStatusMap) + + def get_mutable(self) -> "MutablePerConnectionState": + """Get a mutable copy of this state.""" + return MutablePerConnectionState( + rooms=self.rooms.get_mutable(), + receipts=self.receipts.get_mutable(), + ) + + +@attr.s(auto_attribs=True) +class MutablePerConnectionState(PerConnectionState): + """A mutable version of `PerConnectionState`""" + + rooms: MutableRoomStatusMap[RoomStreamToken] + receipts: MutableRoomStatusMap[MultiWriterStreamToken] + + def has_updates(self) -> bool: + return bool(self.rooms.get_updates()) or bool(self.receipts.get_updates()) @attr.s(auto_attribs=True) @@ -3058,9 +3310,9 @@ class SlidingSyncConnectionStore: to mapping of room ID to `HaveSentRoom`. """ - # `(user_id, conn_id)` -> `token` -> `room_id` -> `HaveSentRoom` - _connections: Dict[Tuple[str, str], Dict[int, Dict[str, HaveSentRoom]]] = ( - attr.Factory(dict) + # `(user_id, conn_id)` -> `connection_position` -> `PerConnectionState` + _connections: Dict[Tuple[str, str], Dict[int, PerConnectionState]] = attr.Factory( + dict ) async def is_valid_token( @@ -3073,48 +3325,52 @@ class SlidingSyncConnectionStore: conn_key = self._get_connection_key(sync_config) return connection_token in self._connections.get(conn_key, {}) - async def have_sent_room( - self, sync_config: SlidingSyncConfig, connection_token: int, room_id: str - ) -> HaveSentRoom: - """For the given user_id/conn_id/token, return whether we have - previously sent the room down - """ - - conn_key = self._get_connection_key(sync_config) - sync_statuses = self._connections.setdefault(conn_key, {}) - room_status = sync_statuses.get(connection_token, {}).get( - room_id, HAVE_SENT_ROOM_NEVER - ) - - return room_status - - @trace - async def record_rooms( + async def get_per_connection_state( self, sync_config: SlidingSyncConfig, from_token: Optional[SlidingSyncStreamToken], - *, - sent_room_ids: StrCollection, - unsent_room_ids: StrCollection, - ) -> int: - """Record which rooms we have/haven't sent down in a new response + ) -> PerConnectionState: + """Fetch the per-connection state for the token. - Attributes: - sync_config - from_token: The since token from the request, if any - sent_room_ids: The set of room IDs that we have sent down as - part of this request (only needs to be ones we didn't - previously sent down). - unsent_room_ids: The set of room IDs that have had updates - since the `from_token`, but which were not included in - this request + Raises: + SlidingSyncUnknownPosition if the connection_token is unknown + """ + if from_token is None: + return PerConnectionState() + + connection_position = from_token.connection_position + if connection_position == 0: + # Initial sync (request without a `from_token`) starts at `0` so + # there is no existing per-connection state + return PerConnectionState() + + conn_key = self._get_connection_key(sync_config) + sync_statuses = self._connections.get(conn_key, {}) + connection_state = sync_statuses.get(connection_position) + + if connection_state is None: + raise SlidingSyncUnknownPosition() + + return connection_state + + @trace + async def record_new_state( + self, + sync_config: SlidingSyncConfig, + from_token: Optional[SlidingSyncStreamToken], + new_connection_state: MutablePerConnectionState, + ) -> int: + """Record updated per-connection state, returning the connection + position associated with the new state. + + If there are no changes to the state this may return the same token as + the existing per-connection state. """ prev_connection_token = 0 if from_token is not None: prev_connection_token = from_token.connection_position - # If there are no changes then this is a noop. - if not sent_room_ids and not unsent_room_ids: + if not new_connection_state.has_updates(): return prev_connection_token conn_key = self._get_connection_key(sync_config) @@ -3125,42 +3381,12 @@ class SlidingSyncConnectionStore: new_store_token = prev_connection_token + 1 sync_statuses.pop(new_store_token, None) - # Copy over and update the room mappings. - new_room_statuses = dict(sync_statuses.get(prev_connection_token, {})) - - # Whether we have updated the `new_room_statuses`, if we don't by the - # end we can treat this as a noop. - have_updated = False - for room_id in sent_room_ids: - new_room_statuses[room_id] = HAVE_SENT_ROOM_LIVE - have_updated = True - - # Whether we add/update the entries for unsent rooms depends on the - # existing entry: - # - LIVE: We have previously sent down everything up to - # `last_room_token, so we update the entry to be `PREVIOUSLY` with - # `last_room_token`. - # - PREVIOUSLY: We have previously sent down everything up to *a* - # given token, so we don't need to update the entry. - # - NEVER: We have never previously sent down the room, and we haven't - # sent anything down this time either so we leave it as NEVER. - - # Work out the new state for unsent rooms that were `LIVE`. - if from_token: - new_unsent_state = HaveSentRoom.previously(from_token.stream_token.room_key) - else: - new_unsent_state = HAVE_SENT_ROOM_NEVER - - for room_id in unsent_room_ids: - prev_state = new_room_statuses.get(room_id) - if prev_state is not None and prev_state.status == HaveSentRoomFlag.LIVE: - new_room_statuses[room_id] = new_unsent_state - have_updated = True - - if not have_updated: - return prev_connection_token - - sync_statuses[new_store_token] = new_room_statuses + # We copy the `MutablePerConnectionState` so that the inner `ChainMap`s + # don't grow forever. + sync_statuses[new_store_token] = PerConnectionState( + rooms=new_connection_state.rooms.copy(), + receipts=new_connection_state.receipts.copy(), + ) return new_store_token diff --git a/synapse/http/client.py b/synapse/http/client.py index daa5cc899b..cb4f72d771 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -1057,11 +1057,11 @@ class _MultipartParserProtocol(protocol.Protocol): if not self.parser: def on_header_field(data: bytes, start: int, end: int) -> None: - if data[start:end] == b"Location": + if data[start:end].lower() == b"location": self.has_redirect = True - if data[start:end] == b"Content-Disposition": + if data[start:end].lower() == b"content-disposition": self.in_disposition = True - if data[start:end] == b"Content-Type": + if data[start:end].lower() == b"content-type": self.in_content_type = True def on_header_value(data: bytes, start: int, end: int) -> None: diff --git a/synapse/media/_base.py b/synapse/media/_base.py index 89dea39163..9341d4859e 100644 --- a/synapse/media/_base.py +++ b/synapse/media/_base.py @@ -22,7 +22,6 @@ import logging import os -import threading import urllib from abc import ABC, abstractmethod from types import TracebackType @@ -56,6 +55,7 @@ from synapse.logging.context import ( run_in_background, ) from synapse.util import Clock +from synapse.util.async_helpers import DeferredEvent from synapse.util.stringutils import is_ascii if TYPE_CHECKING: @@ -620,10 +620,13 @@ class ThreadedFileSender: A producer that sends the contents of a file to a consumer, reading from the file on a thread. - This works by spawning a loop in a threadpool that repeatedly reads from the - file and sends it to the consumer. The main thread communicates with the - loop via two `threading.Event`, which controls when to start/pause reading - and when to terminate. + This works by having a loop in a threadpool repeatedly reading from the + file, until the consumer pauses the producer. There is then a loop in the + main thread that waits until the consumer resumes the producer and then + starts reading in the threadpool again. + + This is done to ensure that we're never waiting in the threadpool, as + otherwise its easy to starve it of threads. """ # How much data to read in one go. @@ -643,12 +646,11 @@ class ThreadedFileSender: # Signals if the thread should keep reading/sending data. Set means # continue, clear means pause. - self.wakeup_event = threading.Event() + self.wakeup_event = DeferredEvent(self.reactor) # Signals if the thread should terminate, e.g. because the consumer has - # gone away. Both this and `wakeup_event` should be set to terminate the - # loop (otherwise the thread will block on `wakeup_event`). - self.stop_event = threading.Event() + # gone away. + self.stop_writing = False def beginFileTransfer( self, file: BinaryIO, consumer: interfaces.IConsumer @@ -663,12 +665,7 @@ class ThreadedFileSender: # We set the wakeup signal as we should start producing immediately. self.wakeup_event.set() - run_in_background( - defer_to_threadpool, - self.reactor, - self.thread_pool, - self._on_thread_read_loop, - ) + run_in_background(self.start_read_loop) return make_deferred_yieldable(self.deferred) @@ -684,44 +681,57 @@ class ThreadedFileSender: """interfaces.IPushProducer""" # Unregister the consumer so we don't try and interact with it again. + if self.consumer: + self.consumer.unregisterProducer() + self.consumer = None - # Terminate the thread loop. + # Terminate the loop. + self.stop_writing = True self.wakeup_event.set() - self.stop_event.set() if not self.deferred.called: self.deferred.errback(Exception("Consumer asked us to stop producing")) - def _on_thread_read_loop(self) -> None: - """This is the loop that happens on a thread.""" - + async def start_read_loop(self) -> None: + """This is the loop that drives reading/writing""" try: - while not self.stop_event.is_set(): - # We wait for the producer to signal that the consumer wants - # more data (or we should abort) + while not self.stop_writing: + # Start the loop in the threadpool to read data. + more_data = await defer_to_threadpool( + self.reactor, self.thread_pool, self._on_thread_read_loop + ) + if not more_data: + # Reached EOF, we can just return. + return + if not self.wakeup_event.is_set(): - ret = self.wakeup_event.wait(self.TIMEOUT_SECONDS) + ret = await self.wakeup_event.wait(self.TIMEOUT_SECONDS) if not ret: raise Exception("Timed out waiting to resume") - - # Check if we were woken up so that we abort the download - if self.stop_event.is_set(): - return - - # The file should always have been set before we get here. - assert self.file is not None - - chunk = self.file.read(self.CHUNK_SIZE) - if not chunk: - return - - self.reactor.callFromThread(self._write, chunk) - except Exception: - self.reactor.callFromThread(self._error, Failure()) + self._error(Failure()) finally: - self.reactor.callFromThread(self._finish) + self._finish() + + def _on_thread_read_loop(self) -> bool: + """This is the loop that happens on a thread. + + Returns: + Whether there is more data to send. + """ + + while not self.stop_writing and self.wakeup_event.is_set(): + # The file should always have been set before we get here. + assert self.file is not None + + chunk = self.file.read(self.CHUNK_SIZE) + if not chunk: + return False + + self.reactor.callFromThread(self._write, chunk) + + return True def _write(self, chunk: bytes) -> None: """Called from the thread to write a chunk of data""" @@ -729,7 +739,7 @@ class ThreadedFileSender: self.consumer.write(chunk) def _error(self, failure: Failure) -> None: - """Called from the thread when there was a fatal error""" + """Called when there was a fatal error""" if self.consumer: self.consumer.unregisterProducer() self.consumer = None @@ -738,7 +748,7 @@ class ThreadedFileSender: self.deferred.errback(failure) def _finish(self) -> None: - """Called from the thread when it finishes (either on success or + """Called when we have finished writing (either on success or failure).""" if self.file: self.file.close() diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py index cf4208eb71..c25d1a9ba3 100644 --- a/synapse/media/media_storage.py +++ b/synapse/media/media_storage.py @@ -544,7 +544,7 @@ class MultipartFileConsumer: Calculate the content length of the multipart response in bytes. """ - if not self.length: + if self.length is None: return None # calculate length of json field and content-type, disposition headers json_field = json.dumps(self.json_field) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 75df684416..874869dc2d 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -64,6 +64,7 @@ class VersionsRestServlet(RestServlet): async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: msc3881_enabled = self.config.experimental.msc3881_enabled + msc3575_enabled = self.config.experimental.msc3575_enabled if self.auth.has_access_token(request): requester = await self.auth.get_user_by_req( @@ -77,6 +78,9 @@ class VersionsRestServlet(RestServlet): msc3881_enabled = await self.store.is_feature_enabled( user_id, ExperimentalFeature.MSC3881 ) + msc3575_enabled = await self.store.is_feature_enabled( + user_id, ExperimentalFeature.MSC3575 + ) return ( 200, @@ -169,6 +173,8 @@ class VersionsRestServlet(RestServlet): ), # MSC4151: Report room API (Client-Server API) "org.matrix.msc4151": self.config.experimental.msc4151_enabled, + # Simplified sliding sync + "org.matrix.simplified_msc3575": msc3575_enabled, }, }, ) diff --git a/synapse/server.py b/synapse/server.py index 8b07bb39a0..d6c9cbdac0 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -124,6 +124,7 @@ from synapse.http.client import ( ) from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.media.media_repository import MediaRepository +from synapse.metrics import register_threadpool from synapse.metrics.common_usage_metrics import CommonUsageMetricsManager from synapse.module_api import ModuleApi from synapse.module_api.callbacks import ModuleApiCallbacks @@ -959,4 +960,7 @@ class HomeServer(metaclass=abc.ABCMeta): "during", "shutdown", media_threadpool.stop ) + # Register the threadpool with our metrics. + register_threadpool("media", media_threadpool) + return media_threadpool diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 3bde0ae0d4..e266cc2a20 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -51,10 +51,12 @@ from synapse.types import ( JsonMapping, MultiWriterStreamToken, PersistedPosition, + StrCollection, ) from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.util.iterutils import batch_iter if TYPE_CHECKING: from synapse.server import HomeServer @@ -550,6 +552,46 @@ class ReceiptsWorkerStore(SQLBaseStore): return results + async def get_rooms_with_receipts_between( + self, + room_ids: StrCollection, + from_key: MultiWriterStreamToken, + to_key: MultiWriterStreamToken, + ) -> StrCollection: + """Given a set of room_ids, find out which ones (may) have receipts + between the two tokens (> `from_token` and <= `to_token`).""" + + room_ids = self._receipts_stream_cache.get_entities_changed( + room_ids, from_key.stream + ) + if not room_ids: + return [] + + def f(txn: LoggingTransaction, room_ids: StrCollection) -> StrCollection: + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", room_ids + ) + + sql = f""" + SELECT DISTINCT room_id FROM receipts_linearized + WHERE {clause} AND ? < stream_id AND stream_id <= ? + """ + args.append(from_key.stream) + args.append(to_key.get_max_stream_pos()) + + txn.execute(sql, args) + + return [room_id for room_id, in txn] + + results: List[str] = [] + for batch in batch_iter(room_ids, 1000): + batch_result = await self.db_pool.runInteraction( + "get_rooms_with_receipts_between", f, batch + ) + results.extend(batch_result) + + return results + async def get_users_sent_receipts_between( self, last_id: int, current_id: int ) -> List[str]: diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py index 7d491d1728..eaa13da368 100644 --- a/synapse/storage/databases/main/state_deltas.py +++ b/synapse/storage/databases/main/state_deltas.py @@ -26,10 +26,11 @@ import attr from synapse.logging.opentracing import trace from synapse.storage._base import SQLBaseStore -from synapse.storage.database import LoggingTransaction +from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause from synapse.storage.databases.main.stream import _filter_results_by_stream -from synapse.types import RoomStreamToken +from synapse.types import RoomStreamToken, StrCollection from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.util.iterutils import batch_iter logger = logging.getLogger(__name__) @@ -200,3 +201,62 @@ class StateDeltasStore(SQLBaseStore): return await self.db_pool.runInteraction( "get_current_state_deltas_for_room", get_current_state_deltas_for_room_txn ) + + @trace + async def get_current_state_deltas_for_rooms( + self, + room_ids: StrCollection, + from_token: RoomStreamToken, + to_token: RoomStreamToken, + ) -> List[StateDelta]: + """Get the state deltas between two tokens for the set of rooms.""" + + room_ids = self._curr_state_delta_stream_cache.get_entities_changed( + room_ids, from_token.stream + ) + if not room_ids: + return [] + + def get_current_state_deltas_for_rooms_txn( + txn: LoggingTransaction, + room_ids: StrCollection, + ) -> List[StateDelta]: + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", room_ids + ) + + sql = f""" + SELECT instance_name, stream_id, room_id, type, state_key, event_id, prev_event_id + FROM current_state_delta_stream + WHERE {clause} AND ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + """ + args.append(from_token.stream) + args.append(to_token.get_max_stream_pos()) + + txn.execute(sql, args) + + return [ + StateDelta( + stream_id=row[1], + room_id=row[2], + event_type=row[3], + state_key=row[4], + event_id=row[5], + prev_event_id=row[6], + ) + for row in txn + if _filter_results_by_stream(from_token, to_token, row[0], row[1]) + ] + + results = [] + for batch in batch_iter(room_ids, 1000): + deltas = await self.db_pool.runInteraction( + "get_current_state_deltas_for_rooms", + get_current_state_deltas_for_rooms_txn, + batch, + ) + + results.extend(deltas) + + return results diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 70139beef2..8618bb0651 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -885,3 +885,46 @@ class AwakenableSleeper: # Cancel the sleep if we were woken up if call.active(): call.cancel() + + +class DeferredEvent: + """Like threading.Event but for async code""" + + def __init__(self, reactor: IReactorTime) -> None: + self._reactor = reactor + self._deferred: "defer.Deferred[None]" = defer.Deferred() + + def set(self) -> None: + if not self._deferred.called: + self._deferred.callback(None) + + def clear(self) -> None: + if self._deferred.called: + self._deferred = defer.Deferred() + + def is_set(self) -> bool: + return self._deferred.called + + async def wait(self, timeout_seconds: float) -> bool: + if self.is_set(): + return True + + # Create a deferred that gets called in N seconds + sleep_deferred: "defer.Deferred[None]" = defer.Deferred() + call = self._reactor.callLater(timeout_seconds, sleep_deferred.callback, None) + + try: + await make_deferred_yieldable( + defer.DeferredList( + [sleep_deferred, self._deferred], + fireOnOneCallback=True, + fireOnOneErrback=True, + consumeErrors=True, + ) + ) + finally: + # Cancel the sleep if we were woken up + if call.active(): + call.cancel() + + return self.is_set() diff --git a/tests/http/test_client.py b/tests/http/test_client.py index 721917f957..f2abec190b 100644 --- a/tests/http/test_client.py +++ b/tests/http/test_client.py @@ -49,8 +49,11 @@ from tests.unittest import TestCase class ReadMultipartResponseTests(TestCase): - data1 = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: text/plain\r\nContent-Disposition: inline; filename=test_upload\r\n\r\nfile_" - data2 = b"to_stream\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n" + multipart_response_data1 = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: text/plain\r\nContent-Disposition: inline; filename=test_upload\r\n\r\nfile_" + multipart_response_data2 = ( + b"to_stream\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n" + ) + multipart_response_data_cased = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\ncOntEnt-type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-tyPe: text/plain\r\nconTent-dispOsition: inline; filename=test_upload\r\n\r\nfile_" redirect_data = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nLocation: https://cdn.example.org/ab/c1/2345.txt\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n" @@ -103,8 +106,31 @@ class ReadMultipartResponseTests(TestCase): result, deferred, protocol = self._build_multipart_response(249, 250) # Start sending data. - protocol.dataReceived(self.data1) - protocol.dataReceived(self.data2) + protocol.dataReceived(self.multipart_response_data1) + protocol.dataReceived(self.multipart_response_data2) + # Close the connection. + protocol.connectionLost(Failure(ResponseDone())) + + multipart_response: MultipartResponse = deferred.result # type: ignore[assignment] + + self.assertEqual(multipart_response.json, b"{}") + self.assertEqual(result.getvalue(), b"file_to_stream") + self.assertEqual(multipart_response.length, len(b"file_to_stream")) + self.assertEqual(multipart_response.content_type, b"text/plain") + self.assertEqual( + multipart_response.disposition, b"inline; filename=test_upload" + ) + + def test_parse_file_lowercase_headers(self) -> None: + """ + Check that a multipart response containing a file is properly parsed + into the json/file parts, and the json and file are properly captured if the http headers are lowercased + """ + result, deferred, protocol = self._build_multipart_response(249, 250) + + # Start sending data. + protocol.dataReceived(self.multipart_response_data_cased) + protocol.dataReceived(self.multipart_response_data2) # Close the connection. protocol.connectionLost(Failure(ResponseDone())) @@ -143,7 +169,7 @@ class ReadMultipartResponseTests(TestCase): result, deferred, protocol = self._build_multipart_response(UNKNOWN_LENGTH, 180) # Start sending data. - protocol.dataReceived(self.data1) + protocol.dataReceived(self.multipart_response_data1) self.assertEqual(result.getvalue(), b"file_") self._assert_error(deferred, protocol) @@ -154,11 +180,11 @@ class ReadMultipartResponseTests(TestCase): result, deferred, protocol = self._build_multipart_response(UNKNOWN_LENGTH, 180) # Start sending data. - protocol.dataReceived(self.data1) + protocol.dataReceived(self.multipart_response_data1) self._assert_error(deferred, protocol) # More data might have come in. - protocol.dataReceived(self.data2) + protocol.dataReceived(self.multipart_response_data2) self.assertEqual(result.getvalue(), b"file_") self._assert_error(deferred, protocol) @@ -172,7 +198,7 @@ class ReadMultipartResponseTests(TestCase): self.assertFalse(deferred.called) # Start sending data. - protocol.dataReceived(self.data1) + protocol.dataReceived(self.multipart_response_data1) self._assert_error(deferred, protocol) self._cleanup_error(deferred) diff --git a/tests/rest/client/sliding_sync/test_extension_receipts.py b/tests/rest/client/sliding_sync/test_extension_receipts.py index 65fbac260e..39c51b367c 100644 --- a/tests/rest/client/sliding_sync/test_extension_receipts.py +++ b/tests/rest/client/sliding_sync/test_extension_receipts.py @@ -677,3 +677,108 @@ class SlidingSyncReceiptsExtensionTestCase(SlidingSyncBase): set(), exact=True, ) + + def test_receipts_incremental_sync_out_of_range(self) -> None: + """Tests that we don't return read receipts for rooms that fall out of + range, but then do send all read receipts once they're back in range. + """ + + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) + self.helper.join(room_id1, user1_id, tok=user1_tok) + room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok) + self.helper.join(room_id2, user1_id, tok=user1_tok) + + # Send a message and read receipt into room2 + event_response = self.helper.send(room_id2, body="new event", tok=user2_tok) + room2_event_id = event_response["event_id"] + + self.helper.send_read_receipt(room_id2, room2_event_id, tok=user1_tok) + + # Now send a message into room1 so that it is at the top of the list + self.helper.send(room_id1, body="new event", tok=user2_tok) + + # Make a SS request for only the top room. + sync_body = { + "lists": { + "main": { + "ranges": [[0, 0]], + "required_state": [], + "timeline_limit": 5, + } + }, + "extensions": { + "receipts": { + "enabled": True, + } + }, + } + response_body, from_token = self.do_sync(sync_body, tok=user1_tok) + + # The receipt is in room2, but only room1 is returned, so we don't + # expect to get the receipt. + self.assertIncludes( + response_body["extensions"]["receipts"].get("rooms").keys(), + set(), + exact=True, + ) + + # Move room2 into range. + self.helper.send(room_id2, body="new event", tok=user2_tok) + + response_body, from_token = self.do_sync( + sync_body, since=from_token, tok=user1_tok + ) + + # We expect to see the read receipt of room2, as that has the most + # recent update. + self.assertIncludes( + response_body["extensions"]["receipts"].get("rooms").keys(), + {room_id2}, + exact=True, + ) + receipt = response_body["extensions"]["receipts"]["rooms"][room_id2] + self.assertIncludes( + receipt["content"][room2_event_id][ReceiptTypes.READ].keys(), + {user1_id}, + exact=True, + ) + + # Send a message into room1 to bump it to the top, but also send a + # receipt in room2 + self.helper.send(room_id1, body="new event", tok=user2_tok) + self.helper.send_read_receipt(room_id2, room2_event_id, tok=user2_tok) + + # We don't expect to see the new read receipt. + response_body, from_token = self.do_sync( + sync_body, since=from_token, tok=user1_tok + ) + self.assertIncludes( + response_body["extensions"]["receipts"].get("rooms").keys(), + set(), + exact=True, + ) + + # But if we send a new message into room2, we expect to get the missing receipts + self.helper.send(room_id2, body="new event", tok=user2_tok) + + response_body, from_token = self.do_sync( + sync_body, since=from_token, tok=user1_tok + ) + self.assertIncludes( + response_body["extensions"]["receipts"].get("rooms").keys(), + {room_id2}, + exact=True, + ) + + # We should only see the new receipt + receipt = response_body["extensions"]["receipts"]["rooms"][room_id2] + self.assertIncludes( + receipt["content"][room2_event_id][ReceiptTypes.READ].keys(), + {user2_id}, + exact=True, + ) diff --git a/tests/rest/client/sliding_sync/test_extensions.py b/tests/rest/client/sliding_sync/test_extensions.py index 68f6661334..ae823d5415 100644 --- a/tests/rest/client/sliding_sync/test_extensions.py +++ b/tests/rest/client/sliding_sync/test_extensions.py @@ -120,19 +120,26 @@ class SlidingSyncExtensionsTestCase(SlidingSyncBase): "foo-list": { "ranges": [[0, 1]], "required_state": [], - "timeline_limit": 0, + # We set this to `1` because we're testing `receipts` which + # interact with the `timeline`. With receipts, when a room + # hasn't been sent down the connection before or it appears + # as `initial: true`, we only include receipts for events in + # the timeline to avoid bloating and blowing up the sync + # response as the number of users in the room increases. + # (this behavior is part of the spec) + "timeline_limit": 1, }, # We expect this list range to include room5, room4, room3 "bar-list": { "ranges": [[0, 2]], "required_state": [], - "timeline_limit": 0, + "timeline_limit": 1, }, }, "room_subscriptions": { room_id1: { "required_state": [], - "timeline_limit": 0, + "timeline_limit": 1, } }, } diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index e43140720d..9614cdd66a 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -45,7 +45,7 @@ from typing_extensions import Literal from twisted.test.proto_helpers import MemoryReactorClock from twisted.web.server import Site -from synapse.api.constants import Membership +from synapse.api.constants import Membership, ReceiptTypes from synapse.api.errors import Codes from synapse.server import HomeServer from synapse.types import JsonDict @@ -944,3 +944,15 @@ class RestHelper: assert len(p.links) == 1, "not exactly one link in confirmation page" oauth_uri = p.links[0] return oauth_uri + + def send_read_receipt(self, room_id: str, event_id: str, *, tok: str) -> None: + """Send a read receipt into the room at the given event""" + channel = make_request( + self.reactor, + self.site, + method="POST", + path=f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{event_id}", + content={}, + access_token=tok, + ) + assert channel.code == HTTPStatus.OK, channel.text_body