Add profile updates stream

Add a new stream type for profile updates. This allows sync processes to
determine which profile updates a given user has or hasn't seen yet.
This commit is contained in:
Andrew Morgan
2026-03-13 16:51:21 -04:00
parent d0b616ddce
commit 63cfbe7e2d
6 changed files with 254 additions and 7 deletions

View File

@@ -142,6 +142,9 @@ class WriterLocations:
push_rules: The instances that write to the push stream. Currently
can only be a single instance.
device_lists: The instances that write to the device list stream.
thread_subscriptions: The instances that write to the thread subscriptions
stream.
profile_updates: The instances that write to the profile updates stream.
"""
events: list[str] = attr.ib(
@@ -177,7 +180,11 @@ class WriterLocations:
converter=_instance_to_list_converter,
)
thread_subscriptions: list[str] = attr.ib(
default=["master"],
default=[MAIN_PROCESS_INSTANCE_NAME],
converter=_instance_to_list_converter,
)
profile_updates: list[str] = attr.ib(
default=[MAIN_PROCESS_INSTANCE_NAME],
converter=_instance_to_list_converter,
)
@@ -355,8 +362,7 @@ class WorkerConfig(Config):
writers = config.get("stream_writers") or {}
self.writers = WriterLocations(**writers)
# Check that the configured writers for events and typing also appears in
# `instance_map`.
# Check that the configured writers also appear in `instance_map`.
for stream in (
"events",
"typing",
@@ -365,6 +371,9 @@ class WorkerConfig(Config):
"receipts",
"presence",
"push_rules",
"device_lists",
"thread_subscriptions",
"profile_updates",
):
instances = _instance_to_list_converter(getattr(self.writers, stream))
for instance in instances:
@@ -415,6 +424,16 @@ class WorkerConfig(Config):
"Must specify at least one instance to handle `device_lists` messages."
)
if len(self.writers.thread_subscriptions) == 0:
raise ConfigError(
"Must specify at least one instance to handle `thread_subscriptions` messages."
)
if len(self.writers.profile_updates) == 0:
raise ConfigError(
"Must specify at least one instance to handle `profile_updates` messages."
)
self.events_shard_config = RoutableShardedWorkerHandlingConfig(
self.writers.events
)

View File

@@ -42,6 +42,7 @@ from synapse.types import (
JsonValue,
Requester,
ScheduledTask,
StreamKeyType,
TaskStatus,
UserID,
create_requester,
@@ -75,6 +76,8 @@ class ProfileHandler:
self.clock = hs.get_clock() # nb must be called this for @cached
self.store = hs.get_datastores().main
self.hs = hs
self._notifier = hs.get_notifier()
self._msc4429_enabled = hs.config.experimental.msc4429_enabled
self.federation = hs.get_federation_client()
hs.get_federation_registry().register_query_handler(
@@ -99,6 +102,24 @@ class ProfileHandler:
)
self._worker_locks = hs.get_worker_locks_handler()
async def _notify_profile_update(self, user_id: UserID, stream_id: int) -> None:
room_ids = await self.store.get_rooms_for_user(user_id.to_string())
if not room_ids:
return
self._notifier.on_new_event(
StreamKeyType.PROFILE_UPDATES, stream_id, rooms=room_ids
)
async def _record_profile_updates(
self, user_id: UserID, updates: list[tuple[str, JsonValue | None]]
) -> None:
if not self._msc4429_enabled or not updates:
return
stream_id = await self.store.add_profile_updates(user_id, updates)
await self._notify_profile_update(user_id, stream_id)
async def get_profile(self, user_id: str, ignore_backoff: bool = True) -> JsonDict:
"""
Get a user's profile as a JSON dictionary.
@@ -253,6 +274,9 @@ class ProfileHandler:
)
await self.store.set_profile_displayname(target_user, displayname_to_set)
await self._record_profile_updates(
target_user, [(ProfileFields.DISPLAYNAME, displayname_to_set)]
)
profile = await self.store.get_profileinfo(target_user)
@@ -362,6 +386,9 @@ class ProfileHandler:
)
await self.store.set_profile_avatar_url(target_user, avatar_url_to_set)
await self._record_profile_updates(
target_user, [(ProfileFields.AVATAR_URL, avatar_url_to_set)]
)
profile = await self.store.get_profileinfo(target_user)
@@ -406,6 +433,8 @@ class ProfileHandler:
# have it.
raise AuthError(400, "Cannot remove another user's profile")
profile_updates: list[tuple[str, JsonValue | None]] = []
current_profile: ProfileInfo | None = None
if not by_admin:
current_profile = await self.store.get_profileinfo(target_user)
if not self.hs.config.registration.enable_set_displayname:
@@ -428,7 +457,21 @@ class ProfileHandler:
Codes.FORBIDDEN,
)
if self._msc4429_enabled:
if current_profile is None:
current_profile = await self.store.get_profileinfo(target_user)
if current_profile.display_name is not None:
profile_updates.append((ProfileFields.DISPLAYNAME, None))
if current_profile.avatar_url is not None:
profile_updates.append((ProfileFields.AVATAR_URL, None))
custom_fields = await self.store.get_profile_fields(target_user)
for field_name in custom_fields.keys():
profile_updates.append((field_name, None))
await self.store.delete_profile(target_user)
await self._record_profile_updates(target_user, profile_updates)
await self._third_party_rules.on_profile_update(
target_user.to_string(),
@@ -582,6 +625,7 @@ class ProfileHandler:
raise AuthError(403, "Cannot set another user's profile")
await self.store.set_profile_field(target_user, field_name, new_value)
await self._record_profile_updates(target_user, [(field_name, new_value)])
# Custom fields do not propagate into the user directory *or* rooms.
profile = await self.store.get_profileinfo(target_user)
@@ -617,6 +661,7 @@ class ProfileHandler:
raise AuthError(400, "Cannot set another user's profile")
await self.store.delete_profile_field(target_user, field_name)
await self._record_profile_updates(target_user, [(field_name, None)])
# Custom fields do not propagate into the user directory *or* rooms.
profile = await self.store.get_profileinfo(target_user)

View File

@@ -527,6 +527,7 @@ class Notifier:
StreamKeyType.UN_PARTIAL_STATED_ROOMS,
StreamKeyType.THREAD_SUBSCRIPTIONS,
StreamKeyType.STICKY_EVENTS,
StreamKeyType.PROFILE_UPDATES,
],
new_token: int,
users: Collection[str | UserID] | None = None,

View File

@@ -19,13 +19,15 @@
#
#
import json
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, Collection, Iterable, Sequence, cast
import attr
from canonicaljson import encode_canonical_json
from synapse.api.constants import ProfileFields
from synapse.api.errors import Codes, StoreError
from synapse.storage._base import SQLBaseStore
from synapse.replication.tcp.streams._base import ProfileUpdatesStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
@@ -33,6 +35,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.roommember import ProfileInfo
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import JsonDict, JsonValue, UserID
if TYPE_CHECKING:
@@ -43,6 +46,15 @@ if TYPE_CHECKING:
MAX_PROFILE_SIZE = 65536
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ProfileUpdate:
"""An update to a user's profile."""
stream_id: int
user_id: str
field_name: str
class ProfileWorkerStore(SQLBaseStore):
def __init__(
self,
@@ -52,6 +64,7 @@ class ProfileWorkerStore(SQLBaseStore):
):
super().__init__(database, db_conn, hs)
self.server_name: str = hs.hostname
self._instance_name: str = hs.get_instance_name()
self.database_engine = database.engine
self.db_pool.updates.register_background_index_update(
"profiles_full_user_id_key_idx",
@@ -65,6 +78,23 @@ class ProfileWorkerStore(SQLBaseStore):
"populate_full_user_id_profiles", self.populate_full_user_id_profiles
)
self._can_write_to_profile_updates = (
self._instance_name in hs.config.worker.writers.profile_updates
)
self._profile_updates_id_gen: MultiWriterIdGenerator = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="profile_updates",
server_name=self.server_name,
instance_name=self._instance_name,
tables=[
("profile_updates", "instance_name", "stream_id"),
],
sequence_name="profile_updates_sequence",
writers=hs.config.worker.writers.profile_updates,
)
async def populate_full_user_id_profiles(
self, progress: JsonDict, batch_size: int
) -> int:
@@ -291,6 +321,146 @@ class ProfileWorkerStore(SQLBaseStore):
result = json.loads(result)
return result or {}
def get_max_profile_updates_stream_id(self) -> int:
"""Get the current maximum stream_id for profile updates."""
return self._profile_updates_id_gen.get_current_token()
def get_profile_updates_stream_id_generator(self) -> MultiWriterIdGenerator:
return self._profile_updates_id_gen
async def get_updated_profile_updates(
self, from_id: int, to_id: int, limit: int
) -> list[tuple[int, str, str]]:
"""Get profile updates that have changed, for the profile_updates stream."""
if from_id == to_id:
return []
def _get_updated_profile_updates_txn(
txn: LoggingTransaction,
) -> list[tuple[int, str, str]]:
sql = (
"SELECT stream_id, user_id, field_name"
" FROM profile_updates"
" WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (from_id, to_id, limit))
return cast(list[tuple[int, str, str]], txn.fetchall())
return await self.db_pool.runInteraction(
"get_updated_profile_updates", _get_updated_profile_updates_txn
)
async def get_profile_updates_for_fields(
self,
*,
from_id: int,
to_id: int,
field_names: Iterable[str],
) -> list[ProfileUpdate]:
"""Get profile update markers for the given fields in a stream range."""
if from_id == to_id:
return []
field_names = list(field_names)
if not field_names:
return []
def _get_profile_updates_for_fields_txn(
txn: LoggingTransaction,
) -> list[ProfileUpdate]:
clause, args = make_in_list_sql_clause(
txn.database_engine, "field_name", field_names
)
sql = (
"SELECT stream_id, user_id, field_name"
" FROM profile_updates"
f" WHERE ? < stream_id AND stream_id <= ? AND {clause}"
" ORDER BY stream_id ASC"
)
txn.execute(sql, (from_id, to_id, *args))
rows = cast(list[tuple[int, str, str]], txn.fetchall())
updates: list[ProfileUpdate] = []
for stream_id, user_id, field_name in rows:
updates.append(
ProfileUpdate(
stream_id=stream_id,
user_id=user_id,
field_name=field_name,
)
)
return updates
return await self.db_pool.runInteraction(
"get_profile_updates_for_fields", _get_profile_updates_for_fields_txn
)
async def get_profile_data_for_users(
self, user_ids: Collection[str]
) -> dict[str, tuple[str | None, str | None, JsonDict]]:
"""Fetch displayname/avatar_url/custom fields for a list of users."""
if not user_ids:
return {}
rows = cast(
list[tuple[str, str | None, str | None, object | None]],
await self.db_pool.simple_select_many_batch(
table="profiles",
column="full_user_id",
iterable=user_ids,
retcols=("full_user_id", "displayname", "avatar_url", "fields"),
desc="get_profile_data_for_users",
),
)
results: dict[str, tuple[str | None, str | None, JsonDict]] = {}
for full_user_id, displayname, avatar_url, fields in rows:
if fields is None:
fields_dict: JsonDict = {}
elif isinstance(fields, (str, bytes, bytearray, memoryview)):
fields_dict = cast(JsonDict, db_to_json(fields))
else:
fields_dict = cast(JsonDict, fields)
results[full_user_id] = (displayname, avatar_url, fields_dict)
return results
async def add_profile_updates(
self, user_id: UserID, updates: Sequence[tuple[str, JsonValue | None]]
) -> int:
"""Persist profile update markers and return the last stream ID."""
assert self._can_write_to_profile_updates
if not updates:
return self._profile_updates_id_gen.get_current_token()
user_id_str = user_id.to_string()
def _add_profile_updates_txn(txn: LoggingTransaction) -> int:
stream_ids = self._profile_updates_id_gen.get_next_mult_txn(
txn, len(updates)
)
for stream_id, (field_name, _value) in zip(stream_ids, updates):
self.db_pool.simple_insert_txn(
txn,
table="profile_updates",
values={
"stream_id": stream_id,
"instance_name": self._instance_name,
"user_id": user_id_str,
"field_name": field_name,
},
)
return stream_ids[-1]
return await self.db_pool.runInteraction(
"add_profile_updates", _add_profile_updates_txn
)
async def create_profile(self, user_id: UserID) -> None:
"""
Create a blank profile for a user.

View File

@@ -85,6 +85,7 @@ class EventSources:
)
thread_subscriptions_key = self.store.get_max_thread_subscriptions_stream_id()
sticky_events_key = self.store.get_max_sticky_events_stream_id()
profile_updates_key = self.store.get_max_profile_updates_stream_id()
token = StreamToken(
room_key=self.sources.room.get_current_key(),
@@ -100,6 +101,7 @@ class EventSources:
un_partial_stated_rooms_key=un_partial_stated_rooms_key,
thread_subscriptions_key=thread_subscriptions_key,
sticky_events_key=sticky_events_key,
profile_updates_key=profile_updates_key,
)
return token
@@ -128,6 +130,7 @@ class EventSources:
StreamKeyType.UN_PARTIAL_STATED_ROOMS: self.store.get_un_partial_stated_rooms_id_generator(),
StreamKeyType.THREAD_SUBSCRIPTIONS: self.store.get_thread_subscriptions_stream_id_generator(),
StreamKeyType.STICKY_EVENTS: self.store.get_sticky_events_stream_id_generator(),
StreamKeyType.PROFILE_UPDATES: self.store.get_profile_updates_stream_id_generator(),
}
for _, key in StreamKeyType.__members__.items():

View File

@@ -1007,6 +1007,7 @@ class StreamKeyType(Enum):
UN_PARTIAL_STATED_ROOMS = "un_partial_stated_rooms_key"
THREAD_SUBSCRIPTIONS = "thread_subscriptions_key"
STICKY_EVENTS = "sticky_events_key"
PROFILE_UPDATES = "profile_updates_key"
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -1014,7 +1015,7 @@ class StreamToken:
"""A collection of keys joined together by underscores in the following
order and which represent the position in their respective streams.
ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1_379_4242`
ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1_379_4242_4141_101`
1. `room_key`: `s2633508` which is a `RoomStreamToken`
- `RoomStreamToken`'s can also look like `t426-2633508` or `m56~2.58~3.59`
- See the docstring for `RoomStreamToken` for more details.
@@ -1029,6 +1030,7 @@ class StreamToken:
10. `un_partial_stated_rooms_key`: `379`
11. `thread_subscriptions_key`: 4242
12. `sticky_events_key`: 4141
13. `profile_updates_key`: 101
You can see how many of these keys correspond to the various
fields in a "/sync" response:
@@ -1089,6 +1091,7 @@ class StreamToken:
un_partial_stated_rooms_key: int
thread_subscriptions_key: int
sticky_events_key: int
profile_updates_key: int
_SEPARATOR = "_"
START: ClassVar["StreamToken"]
@@ -1118,6 +1121,7 @@ class StreamToken:
un_partial_stated_rooms_key,
thread_subscriptions_key,
sticky_events_key,
profile_updates_key,
) = keys
return cls(
@@ -1135,6 +1139,7 @@ class StreamToken:
un_partial_stated_rooms_key=int(un_partial_stated_rooms_key),
thread_subscriptions_key=int(thread_subscriptions_key),
sticky_events_key=int(sticky_events_key),
profile_updates_key=int(profile_updates_key),
)
except CancelledError:
raise
@@ -1159,6 +1164,7 @@ class StreamToken:
str(self.un_partial_stated_rooms_key),
str(self.thread_subscriptions_key),
str(self.sticky_events_key),
str(self.profile_updates_key),
]
)
@@ -1225,6 +1231,7 @@ class StreamToken:
StreamKeyType.UN_PARTIAL_STATED_ROOMS,
StreamKeyType.THREAD_SUBSCRIPTIONS,
StreamKeyType.STICKY_EVENTS,
StreamKeyType.PROFILE_UPDATES,
],
) -> int: ...
@@ -1281,7 +1288,8 @@ class StreamToken:
f"account_data: {self.account_data_key}, push_rules: {self.push_rules_key}, "
f"to_device: {self.to_device_key}, device_list: {self.device_list_key}, "
f"groups: {self.groups_key}, un_partial_stated_rooms: {self.un_partial_stated_rooms_key},"
f"thread_subscriptions: {self.thread_subscriptions_key}, sticky_events: {self.sticky_events_key})"
f"thread_subscriptions: {self.thread_subscriptions_key}, sticky_events: {self.sticky_events_key}, "
f"profile_updates: {self.profile_updates_key})"
)
@@ -1298,6 +1306,7 @@ StreamToken.START = StreamToken(
un_partial_stated_rooms_key=0,
thread_subscriptions_key=0,
sticky_events_key=0,
profile_updates_key=0,
)