Implement sliding sync extension for sticky events

This commit is contained in:
Olivier 'reivilibre
2025-12-22 14:21:18 +00:00
parent 89009dfdac
commit 0954bdae7d
2 changed files with 172 additions and 5 deletions

View File

@@ -11,7 +11,6 @@
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
import itertools
import logging
from collections import ChainMap
@@ -26,11 +25,13 @@ from typing import (
from typing_extensions import TypeAlias, assert_never
from synapse.api.constants import AccountDataTypes, EduTypes
from synapse.api.constants import AccountDataTypes, EduTypes, StickyEvent
from synapse.events import EventBase
from synapse.handlers.receipts import ReceiptEventSource
from synapse.logging.opentracing import trace
from synapse.storage.databases.main.receipts import ReceiptInRoom
from synapse.types import (
Absent,
DeviceListUpdates,
JsonMapping,
MultiWriterStreamToken,
@@ -47,10 +48,12 @@ from synapse.types.handlers.sliding_sync import (
SlidingSyncConfig,
SlidingSyncResult,
)
from synapse.types.rest.client import SlidingSyncStickyEventsToken
from synapse.util.async_helpers import (
concurrently_execute,
gather_optional_coroutines,
)
from synapse.visibility import filter_and_transform_events_for_client
_ThreadSubscription: TypeAlias = (
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadSubscription
@@ -73,7 +76,10 @@ class SlidingSyncExtensionHandler:
self.event_sources = hs.get_event_sources()
self.device_handler = hs.get_device_handler()
self.push_rules_handler = hs.get_push_rules_handler()
self.clock = hs.get_clock()
self._storage_controllers = hs.get_storage_controllers()
self._enable_thread_subscriptions = hs.config.experimental.msc4306_enabled
self._enable_sticky_events = hs.config.experimental.msc4354_enabled
@trace
async def get_extensions_response(
@@ -174,6 +180,19 @@ class SlidingSyncExtensionHandler:
from_token=from_token,
)
sticky_events_coro = None
if (
sync_config.extensions.sticky_events is not Absent
and self._enable_sticky_events
):
sticky_events_coro = self.get_sticky_events_extension_response(
sync_config=sync_config,
sticky_events_request=sync_config.extensions.sticky_events,
actual_room_ids=actual_room_ids,
to_token=to_token,
from_token=from_token,
)
(
to_device_response,
e2ee_response,
@@ -181,6 +200,7 @@ class SlidingSyncExtensionHandler:
receipts_response,
typing_response,
thread_subs_response,
sticky_events_response,
) = await gather_optional_coroutines(
to_device_coro,
e2ee_coro,
@@ -188,6 +208,7 @@ class SlidingSyncExtensionHandler:
receipts_coro,
typing_coro,
thread_subs_coro,
sticky_events_coro,
)
return SlidingSyncResult.Extensions(
@@ -197,6 +218,7 @@ class SlidingSyncExtensionHandler:
receipts=receipts_response,
typing=typing_response,
thread_subscriptions=thread_subs_response,
sticky_events=sticky_events_response,
)
def find_relevant_room_ids_for_extension(
@@ -967,3 +989,65 @@ class SlidingSyncExtensionHandler:
unsubscribed=unsubscribed_threads,
prev_batch=prev_batch,
)
async def get_sticky_events_extension_response(
self,
sync_config: SlidingSyncConfig,
sticky_events_request: SlidingSyncConfig.Extensions.StickyEventsExtension,
actual_room_ids: set[str],
to_token: StreamToken,
from_token: SlidingSyncStreamToken | None,
) -> SlidingSyncResult.Extensions.StickyEventsExtension | None:
if not sticky_events_request.enabled:
return None
now = self.clock.time_msec()
since_token = sticky_events_request.since or SlidingSyncStickyEventsToken(
sticky_events_stream_id=0
)
(
sticky_events_to_id,
room_to_event_ids,
) = await self.store.get_sticky_events_in_rooms(
actual_room_ids,
from_id=since_token.sticky_events_stream_id,
to_id=to_token.sticky_events_key,
now=now,
limit=min(sticky_events_request.limit, StickyEvent.MAX_EVENTS_IN_SYNC),
)
# No need to preserve sticky event order here because we will
# reassemble it in the right order after.
all_sticky_event_ids = {
ev_id for evs in room_to_event_ids.values() for ev_id in evs
}
unfiltered_events = await self.store.get_events_as_list(all_sticky_event_ids)
filtered_events = await filter_and_transform_events_for_client(
self._storage_controllers,
sync_config.user.to_string(),
unfiltered_events,
# As per MSC4354:
# > History visibility checks MUST NOT be applied to sticky events.
# > Any joined user is authorised to see sticky events for the duration they remain sticky.
always_include_ids=frozenset(all_sticky_event_ids),
)
filtered_event_map = {ev.event_id: ev for ev in filtered_events}
room_id_to_sticky_events: dict[str, list[EventBase]] = {}
for room_id, sticky_event_ids in room_to_event_ids.items():
filtered_events_for_room = [
filtered_event_map[event_id]
# This reintroduces the correct order
# (by the sticky events stream)
for event_id in sticky_event_ids
if event_id in filtered_event_map
]
if len(filtered_events_for_room) == 0:
continue
room_id_to_sticky_events[room_id] = filtered_events_for_room
return SlidingSyncResult.Extensions.StickyEventsExtension(
room_id_to_sticky_events=room_id_to_sticky_events,
next_batch=SlidingSyncStickyEventsToken(
sticky_events_stream_id=sticky_events_to_id
),
)

View File

@@ -21,7 +21,7 @@
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Mapping
from typing import TYPE_CHECKING, Any, Literal, Mapping
import attr
@@ -656,6 +656,7 @@ class SlidingSyncRestServlet(RestServlet):
- receipts (MSC3960)
- account data (MSC3959)
- thread subscriptions (MSC4308)
- sticky events (MSC4354)
Request query parameters:
timeout: How long to wait for new events in milliseconds.
@@ -879,7 +880,7 @@ class SlidingSyncRestServlet(RestServlet):
requester, sliding_sync_result.rooms
)
response["extensions"] = await self.encode_extensions(
requester, sliding_sync_result.extensions
requester, sliding_sync_result.extensions, sliding_sync_result.rooms
)
return response
@@ -1029,8 +1030,18 @@ class SlidingSyncRestServlet(RestServlet):
@trace_with_opname("sliding_sync.encode_extensions")
async def encode_extensions(
self, requester: Requester, extensions: SlidingSyncResult.Extensions
self,
requester: Requester,
extensions: SlidingSyncResult.Extensions,
ref_rooms_results: Mapping[str, SlidingSyncResult.RoomResult],
) -> JsonDict:
"""
Args:
ref_rooms_results:
Map of room ID -> RoomResult that was serialised as the `room` section
of the Sliding Sync response.
Will not be mutated, only used for reading.
"""
serialized_extensions: JsonDict = {}
if extensions.to_device is not None:
@@ -1099,8 +1110,80 @@ class SlidingSyncRestServlet(RestServlet):
_serialise_thread_subscriptions(extensions.thread_subscriptions)
)
if extensions.sticky_events:
serialized_extensions[
"org.matrix.msc4354.sticky_events"
] = await self._serialise_sticky_events(
requester, extensions.sticky_events, ref_rooms_results
)
return serialized_extensions
async def _serialise_sticky_events(
self,
requester: Requester,
sticky_events: SlidingSyncResult.Extensions.StickyEventsExtension,
ref_rooms_results: Mapping[str, SlidingSyncResult.RoomResult],
) -> JsonDict:
"""
Serialise the sticky events extension response.
This includes deduplicating by filtering out sticky events
from this extension that already appeared in the timeline
section.
Args:
ref_rooms_results:
Map of room ID -> RoomResult that was serialised as the `room` section
of the Sliding Sync response.
Will not be mutated, only used for reading.
"""
time_now = self.clock.time_msec()
# Same as SSS timelines.
#
serialize_options = SerializeEventConfig(
event_format=format_event_for_client_v2_without_room_id,
requester=requester,
)
rooms_out: dict[str, dict[Literal["events"], list[JsonDict]]] = {}
for (
room_id,
possibly_duplicated_sticky_events,
) in sticky_events.room_id_to_sticky_events.items():
# As per MSC4354:
# Remove sticky events that are already in the timeline, else we will needlessly duplicate
# events.
# There is no purpose in including sticky events in the sticky section if they're already in
# the timeline, as either way the client becomes aware of them.
# This is particularly important given the risk of sticky events spam since
# anyone can send sticky events, so halving the bandwidth on average for each sticky
# event is helpful.
room_result = ref_rooms_results.get(room_id)
if room_result is None:
# Nothing to deduplicate
sticky_events_to_write = possibly_duplicated_sticky_events
else:
sent_event_ids_in_room_section = {
ev.event_id for ev in room_result.timeline_events
}
sticky_events_to_write = [
ev
for ev in possibly_duplicated_sticky_events
if ev.event_id not in sent_event_ids_in_room_section
]
rooms_out[room_id] = {
"events": await self.event_serializer.serialize_events(
sticky_events_to_write, time_now, config=serialize_options
)
}
return {
"rooms": rooms_out,
"next_batch": sticky_events.next_batch.serialise(),
}
def _serialise_thread_subscriptions(
thread_subscriptions: SlidingSyncResult.Extensions.ThreadSubscriptionsExtension,