Rejig when we persist sticky events

Persist inside persist_events to guarantee it is done. After that txn,
recheck soft failure.
This commit is contained in:
Kegan Dougal
2025-09-24 11:45:06 +01:00
parent 666e94b75a
commit 771692addd
2 changed files with 18 additions and 22 deletions
+3
View File
@@ -1186,6 +1186,9 @@ class PersistEventsStore:
sliding_sync_table_changes,
)
if self.msc4354_sticky_events:
self.store.insert_sticky_events_txn(txn, events_and_contexts)
# We only update the sliding sync tables for non-backfilled events.
self._update_sliding_sync_tables_with_new_persisted_events_txn(
txn, room_id, events_and_contexts
+15 -22
View File
@@ -57,6 +57,7 @@ logger = logging.getLogger(__name__)
# Consumers call 'get_sticky_events_in_rooms' which has `WHERE expires_at > ?`
# to filter out expired sticky events that have yet to be deleted.
DELETE_EXPIRED_STICKY_EVENTS_MS = 60 * 1000 * 60 # 1 hour
MAX_STICKY_DURATION_MS = 3600000 # 1 hour
class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
@@ -231,16 +232,18 @@ class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStor
state_delta_for_room: The changes to the current state, used to detect if we need to
re-evaluate soft-failed sticky events.
"""
if len(events_and_contexts) == 0:
return
assert self._can_write_to_sticky_events
# fetch soft failed sticky events to recheck now, before we insert new sticky events, else
# we could incorrectly re-evaluate new sticky events in events_and_contexts
# fetch soft failed sticky events to recheck
event_ids_to_check = await self._get_soft_failed_sticky_events_to_recheck(
room_id, state_delta_for_room
)
# filter out soft-failed events in events_and_contexts as we just inserted them, so the
# soft failure status won't have changed for them.
persisting_event_ids = {ev.event_id for ev, _ in events_and_contexts}
event_ids_to_check = [
item for item in event_ids_to_check if item not in persisting_event_ids
]
if event_ids_to_check:
logger.info(
"_get_soft_failed_sticky_events_to_recheck => %s", event_ids_to_check
@@ -248,18 +251,7 @@ class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStor
# recheck them and update any that now pass soft-fail checks.
await self._recheck_soft_failed_events(room_id, event_ids_to_check)
# insert brand new sticky events.
await self._insert_sticky_events(events_and_contexts)
async def _insert_sticky_events(
self,
events_and_contexts: List[EventPersistencePair],
) -> None:
await self.db_pool.runInteraction(
"_insert_sticky_events", self._insert_sticky_events_txn, events_and_contexts
)
def _insert_sticky_events_txn(
def insert_sticky_events_txn(
self,
txn: LoggingTransaction,
events_and_contexts: List[EventPersistencePair],
@@ -279,17 +271,18 @@ class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStor
if type(sticky_obj) is not dict:
continue
sticky_duration_ms = sticky_obj.get("duration_ms", None)
# MSC: Valid values are the integer range 0-3600000 (1 hour).
# MSC: Valid values are the integer range 0-MAX_STICKY_DURATION_MS
if (
type(sticky_duration_ms) is int
and sticky_duration_ms >= 0
and sticky_duration_ms <= 3600000
and sticky_duration_ms <= MAX_STICKY_DURATION_MS
):
# MSC: The start time is min(now, origin_server_ts).
# This ensures that malicious origin timestamps cannot specify start times in the future.
# Calculate the end time as start_time + min(sticky.duration_ms, 3600000).
# Calculate the end time as start_time + min(sticky.duration_ms, MAX_STICKY_DURATION_MS).
expires_at = min(ev.origin_server_ts, now_ms) + min(
ev.get_dict()[StickyEvent.FIELD_NAME]["duration_ms"], 3600000
ev.get_dict()[StickyEvent.FIELD_NAME]["duration_ms"],
MAX_STICKY_DURATION_MS,
)
# filter out already expired sticky events
if expires_at > now_ms:
@@ -449,7 +442,7 @@ class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStor
async def _recheck_soft_failed_events(
self,
room_id: str,
soft_failed_event_ids: List[str],
soft_failed_event_ids: Collection[str],
) -> None:
"""
Recheck authorised but soft-failed events. The provided event IDs must have already passed