From 63cfbe7e2dd6f3aa4bb9031a8eee1f394db71d63 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Fri, 13 Mar 2026 16:51:21 -0400 Subject: [PATCH] Add profile updates stream Add a new stream type for profile updates. This allows sync processes to determine which profile updates a given user has or hasn't seen yet. --- synapse/config/workers.py | 25 +++- synapse/handlers/profile.py | 45 ++++++ synapse/notifier.py | 1 + synapse/storage/databases/main/profile.py | 174 +++++++++++++++++++++- synapse/streams/events.py | 3 + synapse/types/__init__.py | 13 +- 6 files changed, 254 insertions(+), 7 deletions(-) diff --git a/synapse/config/workers.py b/synapse/config/workers.py index 996be88cb2..f4d5d687a0 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -142,6 +142,9 @@ class WriterLocations: push_rules: The instances that write to the push stream. Currently can only be a single instance. device_lists: The instances that write to the device list stream. + thread_subscriptions: The instances that write to the thread subscriptions + stream. + profile_updates: The instances that write to the profile updates stream. """ events: list[str] = attr.ib( @@ -177,7 +180,11 @@ class WriterLocations: converter=_instance_to_list_converter, ) thread_subscriptions: list[str] = attr.ib( - default=["master"], + default=[MAIN_PROCESS_INSTANCE_NAME], + converter=_instance_to_list_converter, + ) + profile_updates: list[str] = attr.ib( + default=[MAIN_PROCESS_INSTANCE_NAME], converter=_instance_to_list_converter, ) @@ -355,8 +362,7 @@ class WorkerConfig(Config): writers = config.get("stream_writers") or {} self.writers = WriterLocations(**writers) - # Check that the configured writers for events and typing also appears in - # `instance_map`. + # Check that the configured writers also appear in `instance_map`. for stream in ( "events", "typing", @@ -365,6 +371,9 @@ class WorkerConfig(Config): "receipts", "presence", "push_rules", + "device_lists", + "thread_subscriptions", + "profile_updates", ): instances = _instance_to_list_converter(getattr(self.writers, stream)) for instance in instances: @@ -415,6 +424,16 @@ class WorkerConfig(Config): "Must specify at least one instance to handle `device_lists` messages." ) + if len(self.writers.thread_subscriptions) == 0: + raise ConfigError( + "Must specify at least one instance to handle `thread_subscriptions` messages." + ) + + if len(self.writers.profile_updates) == 0: + raise ConfigError( + "Must specify at least one instance to handle `profile_updates` messages." + ) + self.events_shard_config = RoutableShardedWorkerHandlingConfig( self.writers.events ) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index d123bcdd36..b3bffb0cc2 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -42,6 +42,7 @@ from synapse.types import ( JsonValue, Requester, ScheduledTask, + StreamKeyType, TaskStatus, UserID, create_requester, @@ -75,6 +76,8 @@ class ProfileHandler: self.clock = hs.get_clock() # nb must be called this for @cached self.store = hs.get_datastores().main self.hs = hs + self._notifier = hs.get_notifier() + self._msc4429_enabled = hs.config.experimental.msc4429_enabled self.federation = hs.get_federation_client() hs.get_federation_registry().register_query_handler( @@ -99,6 +102,24 @@ class ProfileHandler: ) self._worker_locks = hs.get_worker_locks_handler() + async def _notify_profile_update(self, user_id: UserID, stream_id: int) -> None: + room_ids = await self.store.get_rooms_for_user(user_id.to_string()) + if not room_ids: + return + + self._notifier.on_new_event( + StreamKeyType.PROFILE_UPDATES, stream_id, rooms=room_ids + ) + + async def _record_profile_updates( + self, user_id: UserID, updates: list[tuple[str, JsonValue | None]] + ) -> None: + if not self._msc4429_enabled or not updates: + return + + stream_id = await self.store.add_profile_updates(user_id, updates) + await self._notify_profile_update(user_id, stream_id) + async def get_profile(self, user_id: str, ignore_backoff: bool = True) -> JsonDict: """ Get a user's profile as a JSON dictionary. @@ -253,6 +274,9 @@ class ProfileHandler: ) await self.store.set_profile_displayname(target_user, displayname_to_set) + await self._record_profile_updates( + target_user, [(ProfileFields.DISPLAYNAME, displayname_to_set)] + ) profile = await self.store.get_profileinfo(target_user) @@ -362,6 +386,9 @@ class ProfileHandler: ) await self.store.set_profile_avatar_url(target_user, avatar_url_to_set) + await self._record_profile_updates( + target_user, [(ProfileFields.AVATAR_URL, avatar_url_to_set)] + ) profile = await self.store.get_profileinfo(target_user) @@ -406,6 +433,8 @@ class ProfileHandler: # have it. raise AuthError(400, "Cannot remove another user's profile") + profile_updates: list[tuple[str, JsonValue | None]] = [] + current_profile: ProfileInfo | None = None if not by_admin: current_profile = await self.store.get_profileinfo(target_user) if not self.hs.config.registration.enable_set_displayname: @@ -428,7 +457,21 @@ class ProfileHandler: Codes.FORBIDDEN, ) + if self._msc4429_enabled: + if current_profile is None: + current_profile = await self.store.get_profileinfo(target_user) + + if current_profile.display_name is not None: + profile_updates.append((ProfileFields.DISPLAYNAME, None)) + if current_profile.avatar_url is not None: + profile_updates.append((ProfileFields.AVATAR_URL, None)) + + custom_fields = await self.store.get_profile_fields(target_user) + for field_name in custom_fields.keys(): + profile_updates.append((field_name, None)) + await self.store.delete_profile(target_user) + await self._record_profile_updates(target_user, profile_updates) await self._third_party_rules.on_profile_update( target_user.to_string(), @@ -582,6 +625,7 @@ class ProfileHandler: raise AuthError(403, "Cannot set another user's profile") await self.store.set_profile_field(target_user, field_name, new_value) + await self._record_profile_updates(target_user, [(field_name, new_value)]) # Custom fields do not propagate into the user directory *or* rooms. profile = await self.store.get_profileinfo(target_user) @@ -617,6 +661,7 @@ class ProfileHandler: raise AuthError(400, "Cannot set another user's profile") await self.store.delete_profile_field(target_user, field_name) + await self._record_profile_updates(target_user, [(field_name, None)]) # Custom fields do not propagate into the user directory *or* rooms. profile = await self.store.get_profileinfo(target_user) diff --git a/synapse/notifier.py b/synapse/notifier.py index 93d438def7..a1fa432dfb 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -527,6 +527,7 @@ class Notifier: StreamKeyType.UN_PARTIAL_STATED_ROOMS, StreamKeyType.THREAD_SUBSCRIPTIONS, StreamKeyType.STICKY_EVENTS, + StreamKeyType.PROFILE_UPDATES, ], new_token: int, users: Collection[str | UserID] | None = None, diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 9b787e19a3..a79d6eab63 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -19,13 +19,15 @@ # # import json -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Collection, Iterable, Sequence, cast +import attr from canonicaljson import encode_canonical_json from synapse.api.constants import ProfileFields from synapse.api.errors import Codes, StoreError -from synapse.storage._base import SQLBaseStore +from synapse.replication.tcp.streams._base import ProfileUpdatesStream +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, @@ -33,6 +35,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.roommember import ProfileInfo from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import JsonDict, JsonValue, UserID if TYPE_CHECKING: @@ -43,6 +46,15 @@ if TYPE_CHECKING: MAX_PROFILE_SIZE = 65536 +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ProfileUpdate: + """An update to a user's profile.""" + + stream_id: int + user_id: str + field_name: str + + class ProfileWorkerStore(SQLBaseStore): def __init__( self, @@ -52,6 +64,7 @@ class ProfileWorkerStore(SQLBaseStore): ): super().__init__(database, db_conn, hs) self.server_name: str = hs.hostname + self._instance_name: str = hs.get_instance_name() self.database_engine = database.engine self.db_pool.updates.register_background_index_update( "profiles_full_user_id_key_idx", @@ -65,6 +78,23 @@ class ProfileWorkerStore(SQLBaseStore): "populate_full_user_id_profiles", self.populate_full_user_id_profiles ) + self._can_write_to_profile_updates = ( + self._instance_name in hs.config.worker.writers.profile_updates + ) + self._profile_updates_id_gen: MultiWriterIdGenerator = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + notifier=hs.get_replication_notifier(), + stream_name="profile_updates", + server_name=self.server_name, + instance_name=self._instance_name, + tables=[ + ("profile_updates", "instance_name", "stream_id"), + ], + sequence_name="profile_updates_sequence", + writers=hs.config.worker.writers.profile_updates, + ) + async def populate_full_user_id_profiles( self, progress: JsonDict, batch_size: int ) -> int: @@ -291,6 +321,146 @@ class ProfileWorkerStore(SQLBaseStore): result = json.loads(result) return result or {} + def get_max_profile_updates_stream_id(self) -> int: + """Get the current maximum stream_id for profile updates.""" + return self._profile_updates_id_gen.get_current_token() + + def get_profile_updates_stream_id_generator(self) -> MultiWriterIdGenerator: + return self._profile_updates_id_gen + + async def get_updated_profile_updates( + self, from_id: int, to_id: int, limit: int + ) -> list[tuple[int, str, str]]: + """Get profile updates that have changed, for the profile_updates stream.""" + if from_id == to_id: + return [] + + def _get_updated_profile_updates_txn( + txn: LoggingTransaction, + ) -> list[tuple[int, str, str]]: + sql = ( + "SELECT stream_id, user_id, field_name" + " FROM profile_updates" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (from_id, to_id, limit)) + return cast(list[tuple[int, str, str]], txn.fetchall()) + + return await self.db_pool.runInteraction( + "get_updated_profile_updates", _get_updated_profile_updates_txn + ) + + async def get_profile_updates_for_fields( + self, + *, + from_id: int, + to_id: int, + field_names: Iterable[str], + ) -> list[ProfileUpdate]: + """Get profile update markers for the given fields in a stream range.""" + if from_id == to_id: + return [] + + field_names = list(field_names) + if not field_names: + return [] + + def _get_profile_updates_for_fields_txn( + txn: LoggingTransaction, + ) -> list[ProfileUpdate]: + clause, args = make_in_list_sql_clause( + txn.database_engine, "field_name", field_names + ) + sql = ( + "SELECT stream_id, user_id, field_name" + " FROM profile_updates" + f" WHERE ? < stream_id AND stream_id <= ? AND {clause}" + " ORDER BY stream_id ASC" + ) + txn.execute(sql, (from_id, to_id, *args)) + rows = cast(list[tuple[int, str, str]], txn.fetchall()) + + updates: list[ProfileUpdate] = [] + for stream_id, user_id, field_name in rows: + updates.append( + ProfileUpdate( + stream_id=stream_id, + user_id=user_id, + field_name=field_name, + ) + ) + + return updates + + return await self.db_pool.runInteraction( + "get_profile_updates_for_fields", _get_profile_updates_for_fields_txn + ) + + async def get_profile_data_for_users( + self, user_ids: Collection[str] + ) -> dict[str, tuple[str | None, str | None, JsonDict]]: + """Fetch displayname/avatar_url/custom fields for a list of users.""" + if not user_ids: + return {} + + rows = cast( + list[tuple[str, str | None, str | None, object | None]], + await self.db_pool.simple_select_many_batch( + table="profiles", + column="full_user_id", + iterable=user_ids, + retcols=("full_user_id", "displayname", "avatar_url", "fields"), + desc="get_profile_data_for_users", + ), + ) + + results: dict[str, tuple[str | None, str | None, JsonDict]] = {} + for full_user_id, displayname, avatar_url, fields in rows: + if fields is None: + fields_dict: JsonDict = {} + elif isinstance(fields, (str, bytes, bytearray, memoryview)): + fields_dict = cast(JsonDict, db_to_json(fields)) + else: + fields_dict = cast(JsonDict, fields) + + results[full_user_id] = (displayname, avatar_url, fields_dict) + + return results + + async def add_profile_updates( + self, user_id: UserID, updates: Sequence[tuple[str, JsonValue | None]] + ) -> int: + """Persist profile update markers and return the last stream ID.""" + assert self._can_write_to_profile_updates + + if not updates: + return self._profile_updates_id_gen.get_current_token() + + user_id_str = user_id.to_string() + + def _add_profile_updates_txn(txn: LoggingTransaction) -> int: + stream_ids = self._profile_updates_id_gen.get_next_mult_txn( + txn, len(updates) + ) + for stream_id, (field_name, _value) in zip(stream_ids, updates): + self.db_pool.simple_insert_txn( + txn, + table="profile_updates", + values={ + "stream_id": stream_id, + "instance_name": self._instance_name, + "user_id": user_id_str, + "field_name": field_name, + }, + ) + + return stream_ids[-1] + + return await self.db_pool.runInteraction( + "add_profile_updates", _add_profile_updates_txn + ) + async def create_profile(self, user_id: UserID) -> None: """ Create a blank profile for a user. diff --git a/synapse/streams/events.py b/synapse/streams/events.py index d2720fb959..90b74c75c7 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -85,6 +85,7 @@ class EventSources: ) thread_subscriptions_key = self.store.get_max_thread_subscriptions_stream_id() sticky_events_key = self.store.get_max_sticky_events_stream_id() + profile_updates_key = self.store.get_max_profile_updates_stream_id() token = StreamToken( room_key=self.sources.room.get_current_key(), @@ -100,6 +101,7 @@ class EventSources: un_partial_stated_rooms_key=un_partial_stated_rooms_key, thread_subscriptions_key=thread_subscriptions_key, sticky_events_key=sticky_events_key, + profile_updates_key=profile_updates_key, ) return token @@ -128,6 +130,7 @@ class EventSources: StreamKeyType.UN_PARTIAL_STATED_ROOMS: self.store.get_un_partial_stated_rooms_id_generator(), StreamKeyType.THREAD_SUBSCRIPTIONS: self.store.get_thread_subscriptions_stream_id_generator(), StreamKeyType.STICKY_EVENTS: self.store.get_sticky_events_stream_id_generator(), + StreamKeyType.PROFILE_UPDATES: self.store.get_profile_updates_stream_id_generator(), } for _, key in StreamKeyType.__members__.items(): diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index fb1f1192b7..915188fe1e 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -1007,6 +1007,7 @@ class StreamKeyType(Enum): UN_PARTIAL_STATED_ROOMS = "un_partial_stated_rooms_key" THREAD_SUBSCRIPTIONS = "thread_subscriptions_key" STICKY_EVENTS = "sticky_events_key" + PROFILE_UPDATES = "profile_updates_key" @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -1014,7 +1015,7 @@ class StreamToken: """A collection of keys joined together by underscores in the following order and which represent the position in their respective streams. - ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1_379_4242` + ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1_379_4242_4141_101` 1. `room_key`: `s2633508` which is a `RoomStreamToken` - `RoomStreamToken`'s can also look like `t426-2633508` or `m56~2.58~3.59` - See the docstring for `RoomStreamToken` for more details. @@ -1029,6 +1030,7 @@ class StreamToken: 10. `un_partial_stated_rooms_key`: `379` 11. `thread_subscriptions_key`: 4242 12. `sticky_events_key`: 4141 + 13. `profile_updates_key`: 101 You can see how many of these keys correspond to the various fields in a "/sync" response: @@ -1089,6 +1091,7 @@ class StreamToken: un_partial_stated_rooms_key: int thread_subscriptions_key: int sticky_events_key: int + profile_updates_key: int _SEPARATOR = "_" START: ClassVar["StreamToken"] @@ -1118,6 +1121,7 @@ class StreamToken: un_partial_stated_rooms_key, thread_subscriptions_key, sticky_events_key, + profile_updates_key, ) = keys return cls( @@ -1135,6 +1139,7 @@ class StreamToken: un_partial_stated_rooms_key=int(un_partial_stated_rooms_key), thread_subscriptions_key=int(thread_subscriptions_key), sticky_events_key=int(sticky_events_key), + profile_updates_key=int(profile_updates_key), ) except CancelledError: raise @@ -1159,6 +1164,7 @@ class StreamToken: str(self.un_partial_stated_rooms_key), str(self.thread_subscriptions_key), str(self.sticky_events_key), + str(self.profile_updates_key), ] ) @@ -1225,6 +1231,7 @@ class StreamToken: StreamKeyType.UN_PARTIAL_STATED_ROOMS, StreamKeyType.THREAD_SUBSCRIPTIONS, StreamKeyType.STICKY_EVENTS, + StreamKeyType.PROFILE_UPDATES, ], ) -> int: ... @@ -1281,7 +1288,8 @@ class StreamToken: f"account_data: {self.account_data_key}, push_rules: {self.push_rules_key}, " f"to_device: {self.to_device_key}, device_list: {self.device_list_key}, " f"groups: {self.groups_key}, un_partial_stated_rooms: {self.un_partial_stated_rooms_key}," - f"thread_subscriptions: {self.thread_subscriptions_key}, sticky_events: {self.sticky_events_key})" + f"thread_subscriptions: {self.thread_subscriptions_key}, sticky_events: {self.sticky_events_key}, " + f"profile_updates: {self.profile_updates_key})" ) @@ -1298,6 +1306,7 @@ StreamToken.START = StreamToken( un_partial_stated_rooms_key=0, thread_subscriptions_key=0, sticky_events_key=0, + profile_updates_key=0, )