diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index a81db3cfbf..7b18b2b88c 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -137,6 +137,7 @@ BOOLEAN_COLUMNS = { "has_known_state", "is_encrypted", ], + "sticky_events": ["soft_failed"], "thread_subscriptions": ["subscribed", "automatic"], "users": ["shadow_banned", "approved", "locked", "suspended"], "un_partial_stated_event_stream": ["rejection_status_changed"], diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 742b2af081..f9b86f99d3 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -23,7 +23,6 @@ import logging import sys from typing import Dict, List -from synapse.storage.databases.main.sticky_events import StickyEventsWorkerStore from twisted.web.resource import Resource import synapse @@ -102,6 +101,7 @@ from synapse.storage.databases.main.signatures import SignatureWorkerStore from synapse.storage.databases.main.sliding_sync import SlidingSyncStore from synapse.storage.databases.main.state import StateGroupWorkerStore from synapse.storage.databases.main.stats import StatsStore +from synapse.storage.databases.main.sticky_events import StickyEventsWorkerStore from synapse.storage.databases.main.stream import StreamWorkerStore from synapse.storage.databases.main.tags import TagsWorkerStore from synapse.storage.databases.main.task_scheduler import TaskSchedulerWorkerStore diff --git a/synapse/notifier.py b/synapse/notifier.py index e684df4866..136e766d68 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -533,6 +533,7 @@ class Notifier: StreamKeyType.TYPING, StreamKeyType.UN_PARTIAL_STATED_ROOMS, StreamKeyType.THREAD_SUBSCRIPTIONS, + StreamKeyType.STICKY_EVENTS, ], new_token: int, users: Optional[Collection[Union[str, UserID]]] = None, diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 7a86b2e65e..21b0fda58b 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -44,7 +44,10 @@ from synapse.replication.tcp.streams import ( UnPartialStatedEventStream, UnPartialStatedRoomStream, ) -from synapse.replication.tcp.streams._base import ThreadSubscriptionsStream +from synapse.replication.tcp.streams._base import ( + StickyEventsStream, + ThreadSubscriptionsStream, +) from synapse.replication.tcp.streams.events import ( EventsStream, EventsStreamEventRow, @@ -262,6 +265,13 @@ class ReplicationDataHandler: token, users=[row.user_id for row in rows], ) + elif stream_name == StickyEventsStream.NAME: + print(f"STICKY_EVENTS on_rdata {token} => {rows}") + self.notifier.on_new_event( + StreamKeyType.STICKY_EVENTS, + token, + rooms=[row.room_id for row in rows], + ) await self._presence_handler.process_replication_rows( stream_name, instance_name, token, rows diff --git a/synapse/storage/databases/main/sticky_events.py b/synapse/storage/databases/main/sticky_events.py index 3f08a8bde9..d6cdefa9f4 100644 --- a/synapse/storage/databases/main/sticky_events.py +++ b/synapse/storage/databases/main/sticky_events.py @@ -15,10 +15,13 @@ import time from typing import ( TYPE_CHECKING, Any, + Dict, Iterable, List, Optional, + Set, Tuple, + cast, ) from synapse.api.constants import EventTypes, StickyEvent @@ -29,6 +32,7 @@ from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, LoggingTransaction, + make_in_list_sql_clause, ) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.events import DeltaState @@ -94,6 +98,54 @@ class StickyEventsWorkerStore(CacheInvalidationWorkerStore): def get_sticky_events_stream_id_generator(self) -> MultiWriterIdGenerator: return self._sticky_events_id_gen + async def get_sticky_events_in_rooms( + self, + room_ids: List[str], + from_id: int, + ) -> Tuple[int, Dict[str, Set[str]]]: + """ + Fetch all the sticky events in the given rooms, from the given sticky stream ID. + + Args: + room_ids: The room IDs to return sticky events in. + from_id: The sticky stream ID that sticky events should be returned from. + Returns: + A tuple of (to_id, map[room_id, event_ids]) + """ + sticky_events_rows = await self.db_pool.runInteraction( + "get_sticky_events_in_rooms", + self._get_sticky_events_in_rooms_txn, + room_ids, + from_id, + ) + to_id = from_id + room_to_events: Dict[str, Set[str]] = {} + for stream_id, room_id, event_id in sticky_events_rows: + to_id = max(to_id, stream_id) + events = room_to_events.get(room_id, set()) + events.add(event_id) + room_to_events[room_id] = events + return (to_id, room_to_events) + + def _get_sticky_events_in_rooms_txn( + self, + txn: LoggingTransaction, + room_ids: List[str], + from_id: int, + ) -> List[Tuple[int, str, str]]: + if len(room_ids) == 0: + return [] + clause, room_id_values = make_in_list_sql_clause( + txn.database_engine, "room_id", room_ids + ) + txn.execute( + f""" + SELECT stream_id, room_id, event_id FROM sticky_events WHERE stream_id > ? AND {clause} + """, + (from_id, room_id_values), + ) + return cast(List[Tuple[int, str, str]], txn.fetchall()) + async def get_updated_sticky_events( self, from_id: int, to_id: int, limit: int ) -> List[Tuple[int, str, str]]: @@ -107,7 +159,24 @@ class StickyEventsWorkerStore(CacheInvalidationWorkerStore): Returns: list of (stream_id, room_id, event_id) tuples """ - return [] # TODO + return await self.db_pool.runInteraction( + "get_updated_sticky_events", + self._get_updated_sticky_events_txn, + from_id, + to_id, + limit, + ) + + def _get_updated_sticky_events_txn( + self, txn: LoggingTransaction, from_id: int, to_id: int, limit: int + ) -> List[Tuple[int, str, str]]: + txn.execute( + """ + SELECT stream_id, room_id, event_id FROM sticky_events WHERE stream_id > ? AND stream_id <= ? LIMIT ? + """, + (from_id, to_id, limit), + ) + return cast(List[Tuple[int, str, str]], txn.fetchall()) def handle_sticky_events_txn( self, @@ -137,6 +206,8 @@ class StickyEventsWorkerStore(CacheInvalidationWorkerStore): if len(events_and_contexts) == 0: return + assert self._can_write_to_sticky_events + # TODO: finish the impl # fetch soft failed sticky events to recheck now, before we insert new sticky events, else # we could incorrectly re-evaluate new sticky events diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 1e4bebe46d..52f1451724 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -84,6 +84,7 @@ class EventSources: self._instance_name ) thread_subscriptions_key = self.store.get_max_thread_subscriptions_stream_id() + sticky_events_key = self.store.get_max_sticky_events_stream_id() token = StreamToken( room_key=self.sources.room.get_current_key(), @@ -98,6 +99,7 @@ class EventSources: groups_key=0, un_partial_stated_rooms_key=un_partial_stated_rooms_key, thread_subscriptions_key=thread_subscriptions_key, + sticky_events_key=sticky_events_key, ) return token @@ -125,6 +127,7 @@ class EventSources: StreamKeyType.DEVICE_LIST: self.store.get_device_stream_id_generator(), StreamKeyType.UN_PARTIAL_STATED_ROOMS: self.store.get_un_partial_stated_rooms_id_generator(), StreamKeyType.THREAD_SUBSCRIPTIONS: self.store.get_thread_subscriptions_stream_id_generator(), + StreamKeyType.STICKY_EVENTS: self.store.get_sticky_events_stream_id_generator(), } for _, key in StreamKeyType.__members__.items(): diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 2d5b07ab8f..2ea8986511 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -997,6 +997,7 @@ class StreamKeyType(Enum): DEVICE_LIST = "device_list_key" UN_PARTIAL_STATED_ROOMS = "un_partial_stated_rooms_key" THREAD_SUBSCRIPTIONS = "thread_subscriptions_key" + STICKY_EVENTS = "sticky_events_key" @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -1018,6 +1019,7 @@ class StreamToken: 9. `groups_key`: `1` (note that this key is now unused) 10. `un_partial_stated_rooms_key`: `379` 11. `thread_subscriptions_key`: 4242 + 12. `sticky_events_key`: 4141 You can see how many of these keys correspond to the various fields in a "/sync" response: @@ -1077,6 +1079,7 @@ class StreamToken: groups_key: int un_partial_stated_rooms_key: int thread_subscriptions_key: int + sticky_events_key: int _SEPARATOR = "_" START: ClassVar["StreamToken"] @@ -1105,6 +1108,7 @@ class StreamToken: groups_key, un_partial_stated_rooms_key, thread_subscriptions_key, + sticky_events_key, ) = keys return cls( @@ -1121,6 +1125,7 @@ class StreamToken: groups_key=int(groups_key), un_partial_stated_rooms_key=int(un_partial_stated_rooms_key), thread_subscriptions_key=int(thread_subscriptions_key), + sticky_events_key=int(sticky_events_key), ) except CancelledError: raise @@ -1144,6 +1149,7 @@ class StreamToken: str(self.groups_key), str(self.un_partial_stated_rooms_key), str(self.thread_subscriptions_key), + str(self.sticky_events_key), ] ) @@ -1209,6 +1215,7 @@ class StreamToken: StreamKeyType.TYPING, StreamKeyType.UN_PARTIAL_STATED_ROOMS, StreamKeyType.THREAD_SUBSCRIPTIONS, + StreamKeyType.STICKY_EVENTS, ], ) -> int: ... @@ -1265,7 +1272,7 @@ class StreamToken: f"account_data: {self.account_data_key}, push_rules: {self.push_rules_key}, " f"to_device: {self.to_device_key}, device_list: {self.device_list_key}, " f"groups: {self.groups_key}, un_partial_stated_rooms: {self.un_partial_stated_rooms_key}," - f"thread_subscriptions: {self.thread_subscriptions_key})" + f"thread_subscriptions: {self.thread_subscriptions_key}, sticky_events: {self.sticky_events_key})" ) @@ -1281,6 +1288,7 @@ StreamToken.START = StreamToken( groups_key=0, un_partial_stated_rooms_key=0, thread_subscriptions_key=0, + sticky_events_key=0, )