From 771692addda05253a423535a704f30afed0fb62f Mon Sep 17 00:00:00 2001 From: Kegan Dougal <7190048+kegsay@users.noreply.github.com> Date: Wed, 24 Sep 2025 11:45:06 +0100 Subject: [PATCH] Rejig when we persist sticky events Persist inside persist_events to guarantee it is done. After that txn, recheck soft failure. --- synapse/storage/databases/main/events.py | 3 ++ .../storage/databases/main/sticky_events.py | 37 ++++++++----------- 2 files changed, 18 insertions(+), 22 deletions(-) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 824a77e456..ab9fbc4a1c 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1186,6 +1186,9 @@ class PersistEventsStore: sliding_sync_table_changes, ) + if self.msc4354_sticky_events: + self.store.insert_sticky_events_txn(txn, events_and_contexts) + # We only update the sliding sync tables for non-backfilled events. self._update_sliding_sync_tables_with_new_persisted_events_txn( txn, room_id, events_and_contexts diff --git a/synapse/storage/databases/main/sticky_events.py b/synapse/storage/databases/main/sticky_events.py index 80474e229b..e9a45bcf4e 100644 --- a/synapse/storage/databases/main/sticky_events.py +++ b/synapse/storage/databases/main/sticky_events.py @@ -57,6 +57,7 @@ logger = logging.getLogger(__name__) # Consumers call 'get_sticky_events_in_rooms' which has `WHERE expires_at > ?` # to filter out expired sticky events that have yet to be deleted. DELETE_EXPIRED_STICKY_EVENTS_MS = 60 * 1000 * 60 # 1 hour +MAX_STICKY_DURATION_MS = 3600000 # 1 hour class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): @@ -231,16 +232,18 @@ class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStor state_delta_for_room: The changes to the current state, used to detect if we need to re-evaluate soft-failed sticky events. """ - if len(events_and_contexts) == 0: - return - assert self._can_write_to_sticky_events - # fetch soft failed sticky events to recheck now, before we insert new sticky events, else - # we could incorrectly re-evaluate new sticky events in events_and_contexts + # fetch soft failed sticky events to recheck event_ids_to_check = await self._get_soft_failed_sticky_events_to_recheck( room_id, state_delta_for_room ) + # filter out soft-failed events in events_and_contexts as we just inserted them, so the + # soft failure status won't have changed for them. + persisting_event_ids = {ev.event_id for ev, _ in events_and_contexts} + event_ids_to_check = [ + item for item in event_ids_to_check if item not in persisting_event_ids + ] if event_ids_to_check: logger.info( "_get_soft_failed_sticky_events_to_recheck => %s", event_ids_to_check @@ -248,18 +251,7 @@ class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStor # recheck them and update any that now pass soft-fail checks. await self._recheck_soft_failed_events(room_id, event_ids_to_check) - # insert brand new sticky events. - await self._insert_sticky_events(events_and_contexts) - - async def _insert_sticky_events( - self, - events_and_contexts: List[EventPersistencePair], - ) -> None: - await self.db_pool.runInteraction( - "_insert_sticky_events", self._insert_sticky_events_txn, events_and_contexts - ) - - def _insert_sticky_events_txn( + def insert_sticky_events_txn( self, txn: LoggingTransaction, events_and_contexts: List[EventPersistencePair], @@ -279,17 +271,18 @@ class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStor if type(sticky_obj) is not dict: continue sticky_duration_ms = sticky_obj.get("duration_ms", None) - # MSC: Valid values are the integer range 0-3600000 (1 hour). + # MSC: Valid values are the integer range 0-MAX_STICKY_DURATION_MS if ( type(sticky_duration_ms) is int and sticky_duration_ms >= 0 - and sticky_duration_ms <= 3600000 + and sticky_duration_ms <= MAX_STICKY_DURATION_MS ): # MSC: The start time is min(now, origin_server_ts). # This ensures that malicious origin timestamps cannot specify start times in the future. - # Calculate the end time as start_time + min(sticky.duration_ms, 3600000). + # Calculate the end time as start_time + min(sticky.duration_ms, MAX_STICKY_DURATION_MS). expires_at = min(ev.origin_server_ts, now_ms) + min( - ev.get_dict()[StickyEvent.FIELD_NAME]["duration_ms"], 3600000 + ev.get_dict()[StickyEvent.FIELD_NAME]["duration_ms"], + MAX_STICKY_DURATION_MS, ) # filter out already expired sticky events if expires_at > now_ms: @@ -449,7 +442,7 @@ class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStor async def _recheck_soft_failed_events( self, room_id: str, - soft_failed_event_ids: List[str], + soft_failed_event_ids: Collection[str], ) -> None: """ Recheck authorised but soft-failed events. The provided event IDs must have already passed