Replace wait_for_quarantined_media_stream_id(...) with standard wait_for_stream_token(...) (#19764)

In order to be able to use `wait_for_stream_token(...)`, we have to add
the `quarantined_media` stream to the `StreamToken`. Even though we
don't care about `/sync`'ing `quarantined_media`, this aligns with the
future where all endpoints should probably use `StreamToken`, see
https://github.com/element-hq/synapse/issues/19647

Follow-up to https://github.com/element-hq/synapse/pull/19558 and
https://github.com/element-hq/synapse/pull/19644
This commit is contained in:
Eric Eastwood
2026-05-15 13:51:03 -05:00
committed by GitHub
parent 19f636244c
commit 8eb220a5e2
7 changed files with 59 additions and 83 deletions
+1
View File
@@ -0,0 +1 @@
Replace unique `quarantined_media` waiting patterns with standard `wait_for_stream_token(...)`.
+18 -5
View File
@@ -43,7 +43,13 @@ from synapse.rest.admin._base import (
from synapse.storage.databases.main.media_repository import (
MediaSortOrder,
)
from synapse.types import JsonDict, UserID
from synapse.types import (
JsonDict,
MultiWriterStreamToken,
StreamKeyType,
StreamToken,
UserID,
)
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -243,6 +249,7 @@ class ListQuarantineChanges(RestServlet):
self.auth = hs.get_auth()
self.server_name = hs.hostname
self.replication = hs.get_replication_data_handler()
self.notifier = hs.get_notifier()
async def on_GET(self, request: SynapseRequest) -> tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
@@ -256,8 +263,7 @@ class ListQuarantineChanges(RestServlet):
# The caller is trying to get future data, which we don't allow because
# we know it's an invalid state that should never happen. We could
# wait until we reach the token but we might as well not waste our
# resources on that which is why `wait_for_quarantined_media_stream_id(...)`
# has assertions around this.
# resources on that.
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"The `from` token is considered invalid because it includes stream positions "
@@ -268,9 +274,16 @@ class ListQuarantineChanges(RestServlet):
errcode=Codes.INVALID_PARAM,
)
# Create a `StreamToken` that's compatible with `wait_for_stream_token`.
#
# FIXME: Ideally, this endpoint would use a `StreamToken` to begin with
from_token = StreamToken.START.copy_and_replace(
StreamKeyType.QUARANTINED_MEDIA, MultiWriterStreamToken(stream=from_id)
)
# We need to wait to ensure that our current worker is actually caught up with
# the stream position, otherwise we might not return what we think we're returning.
if not await self.store.wait_for_quarantined_media_stream_id(from_id):
if not await self.notifier.wait_for_stream_token(from_token):
raise SynapseError(
HTTPStatus.INTERNAL_SERVER_ERROR,
"Timed out while waiting for the worker serving this request to catch up to the given "
@@ -280,7 +293,7 @@ class ListQuarantineChanges(RestServlet):
errcode=Codes.UNKNOWN,
)
to_id = await self.store.get_current_quarantined_media_stream_id()
to_id = self.store.get_current_quarantined_media_stream_id()
changes = await self.store.get_quarantined_media_changes(
from_id=from_id,
to_id=to_id,
+10 -70
View File
@@ -62,12 +62,12 @@ from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import IdGenerator, MultiWriterIdGenerator
from synapse.types import (
JsonDict,
MultiWriterStreamToken,
RetentionPolicy,
StrCollection,
ThirdPartyInstanceID,
)
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.duration import Duration
from synapse.util.json import json_encoder
from synapse.util.stringutils import MXC_REGEX
@@ -1302,7 +1302,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
return local_media_ids
async def get_current_quarantined_media_stream_id(self) -> int:
def get_quarantined_media_stream_token(self) -> MultiWriterStreamToken:
return MultiWriterStreamToken.from_generator(
self._quarantined_media_changes_id_gen
)
def get_quarantined_media_stream_id_generator(self) -> MultiWriterIdGenerator:
return self._quarantined_media_changes_id_gen
def get_current_quarantined_media_stream_id(self) -> int:
"""Gets the position of the quarantined media changes stream.
Returns:
@@ -1318,74 +1326,6 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
"""
return await self._quarantined_media_changes_id_gen.get_max_allocated_token()
async def wait_for_quarantined_media_stream_id(self, target_id: int) -> bool:
"""Waits until the quarantined media changes stream reaches the given stream ID.
See https://github.com/element-hq/synapse/pull/19644 for more details.
TODO: Replace function and call sites with https://github.com/element-hq/synapse/pull/19644
Args:
target_id: The stream ID to wait for.
Returns:
True when caught up to the target stream ID.
False when timing out while waiting.
"""
# We ideally would use something like `wait_for_stream_position` in the meantime,
# but that short circuits if the instance name matches the current instance name.
# Doing so means that if *another* writer is actually leading the to_id, then we'll
# assume that we're caught up when we aren't.
#
# NOTE: Because this is implemented to wait for stream positions by integer ID,
# we're technically waiting for *all* workers to catch up rather than just waiting
# for *our* worker to catch up. This is okay for now because the quarantined media
# stream should be pretty fast to update, and if it's not then the only thing we're
# affecting is an admin API that probably has a tool automatically retrying requests
# anyway. https://github.com/element-hq/synapse/pull/19644 does the waiting properly
# so this should be replaced by that (or similar).
# Get the minimum shared position/ID across all workers
current_id = self._quarantined_media_changes_id_gen.get_current_token()
if current_id >= target_id:
return True # nothing to wait for: we're already caught up.
# "This should never happen". Tokens we hand out via the API should exist. If they
# don't, then we're in a bad state and need to explode.
max_persisted_position = (
await self._quarantined_media_changes_id_gen.get_max_allocated_token()
)
assert max_persisted_position >= target_id, (
f"Unable to wait for invalid future token (token={target_id} has positions "
f"ahead of our max persisted position={max_persisted_position})"
)
# Start waiting until we've caught up to the `stream_token`
start = self.clock.time_msec()
logged = False
while True:
# Like above, get the minimum shared ID across all workers
current_id = self._quarantined_media_changes_id_gen.get_current_token()
if current_id >= target_id:
return True
now = self.clock.time_msec()
# Timed out
if now - start > 10_000:
return False
if not logged:
logger.info(
"Waiting for current token to reach %s; currently at %s",
target_id,
current_id,
)
logged = True
# TODO: be better
await self.clock.sleep(Duration(milliseconds=500))
async def get_quarantined_media_changes(
self, *, from_id: int, to_id: int, limit: int
) -> list[QuarantinedMediaUpdate]:
+3
View File
@@ -85,6 +85,7 @@ class EventSources:
)
thread_subscriptions_key = self.store.get_max_thread_subscriptions_stream_id()
sticky_events_key = self.store.get_max_sticky_events_stream_id()
quarantined_media_key = self.store.get_quarantined_media_stream_token()
token = StreamToken(
room_key=self.sources.room.get_current_key(),
@@ -100,6 +101,7 @@ class EventSources:
un_partial_stated_rooms_key=un_partial_stated_rooms_key,
thread_subscriptions_key=thread_subscriptions_key,
sticky_events_key=sticky_events_key,
quarantined_media_key=quarantined_media_key,
)
return token
@@ -128,6 +130,7 @@ class EventSources:
StreamKeyType.UN_PARTIAL_STATED_ROOMS: self.store.get_un_partial_stated_rooms_id_generator(),
StreamKeyType.THREAD_SUBSCRIPTIONS: self.store.get_thread_subscriptions_stream_id_generator(),
StreamKeyType.STICKY_EVENTS: self.store.get_sticky_events_stream_id_generator(),
StreamKeyType.QUARANTINED_MEDIA: self.store.get_quarantined_media_stream_id_generator(),
}
for _, key in StreamKeyType.__members__.items():
+23 -4
View File
@@ -1060,6 +1060,7 @@ class StreamKeyType(Enum):
UN_PARTIAL_STATED_ROOMS = "un_partial_stated_rooms_key"
THREAD_SUBSCRIPTIONS = "thread_subscriptions_key"
STICKY_EVENTS = "sticky_events_key"
QUARANTINED_MEDIA = "quarantined_media_key"
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -1067,7 +1068,7 @@ class StreamToken:
"""A collection of keys joined together by underscores in the following
order and which represent the position in their respective streams.
ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1_379_4242`
ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1_379_4242_4141_4343`
1. `room_key`: `s2633508` which is a `RoomStreamToken`
- `RoomStreamToken`'s can also look like `t426-2633508` or `m56~2.58~3.59`
- See the docstring for `RoomStreamToken` for more details.
@@ -1082,12 +1083,13 @@ class StreamToken:
10. `un_partial_stated_rooms_key`: `379`
11. `thread_subscriptions_key`: 4242
12. `sticky_events_key`: 4141
13. `quarantined_media_key`: 4343
You can see how many of these keys correspond to the various
fields in a "/sync" response:
```json
{
"next_batch": "s12_4_0_1_1_1_1_4_1_1",
"next_batch": "s12_4_0_1_1_1_1_4_1_1_1_1_1",
"presence": {
"events": []
},
@@ -1099,7 +1101,7 @@ class StreamToken:
"!QrZlfIDQLNLdZHqTnt:hs1": {
"timeline": {
"events": [],
"prev_batch": "s10_4_0_1_1_1_1_4_1_1",
"prev_batch": "s10_4_0_1_1_1_1_4_1_1_1_1_1",
"limited": false
},
"state": {
@@ -1142,6 +1144,9 @@ class StreamToken:
un_partial_stated_rooms_key: int
thread_subscriptions_key: int
sticky_events_key: int
quarantined_media_key: MultiWriterStreamToken = attr.ib(
validator=attr.validators.instance_of(MultiWriterStreamToken)
)
_SEPARATOR = "_"
START: ClassVar["StreamToken"]
@@ -1171,6 +1176,7 @@ class StreamToken:
un_partial_stated_rooms_key,
thread_subscriptions_key,
sticky_events_key,
quarantined_media_key,
) = keys
return cls(
@@ -1188,6 +1194,9 @@ class StreamToken:
un_partial_stated_rooms_key=int(un_partial_stated_rooms_key),
thread_subscriptions_key=int(thread_subscriptions_key),
sticky_events_key=int(sticky_events_key),
quarantined_media_key=await MultiWriterStreamToken.parse(
store, quarantined_media_key
),
)
except CancelledError:
raise
@@ -1212,6 +1221,7 @@ class StreamToken:
str(self.un_partial_stated_rooms_key),
str(self.thread_subscriptions_key),
str(self.sticky_events_key),
await self.quarantined_media_key.to_string(store),
]
)
@@ -1241,6 +1251,12 @@ class StreamToken:
self.device_list_key.copy_and_advance(new_value),
)
return new_token
elif key == StreamKeyType.QUARANTINED_MEDIA:
new_token = self.copy_and_replace(
StreamKeyType.QUARANTINED_MEDIA,
self.quarantined_media_key.copy_and_advance(new_value),
)
return new_token
new_token = self.copy_and_replace(key, new_value)
new_id = new_token.get_field(key)
@@ -1263,6 +1279,7 @@ class StreamToken:
key: Literal[
StreamKeyType.RECEIPT,
StreamKeyType.DEVICE_LIST,
StreamKeyType.QUARANTINED_MEDIA,
],
) -> MultiWriterStreamToken: ...
@@ -1334,7 +1351,8 @@ class StreamToken:
f"account_data: {self.account_data_key}, push_rules: {self.push_rules_key}, "
f"to_device: {self.to_device_key}, device_list: {self.device_list_key}, "
f"groups: {self.groups_key}, un_partial_stated_rooms: {self.un_partial_stated_rooms_key},"
f"thread_subscriptions: {self.thread_subscriptions_key}, sticky_events: {self.sticky_events_key})"
f"thread_subscriptions: {self.thread_subscriptions_key}, sticky_events: {self.sticky_events_key}"
f"quarantined_media: {self.quarantined_media_key})"
)
@@ -1351,6 +1369,7 @@ StreamToken.START = StreamToken(
un_partial_stated_rooms_key=0,
thread_subscriptions_key=0,
sticky_events_key=0,
quarantined_media_key=MultiWriterStreamToken(stream=0),
)
+2 -2
View File
@@ -2549,7 +2549,7 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase):
def test_topo_token_is_accepted(self) -> None:
"""Test Topo Token is accepted."""
token = "t1-0_0_0_0_0_0_0_0_0_0_0_0"
token = "t1-0_0_0_0_0_0_0_0_0_0_0_0_0"
channel = self.make_request(
"GET",
"/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token),
@@ -2563,7 +2563,7 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase):
def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None:
"""Test that stream token is accepted for forward pagination."""
token = "s0_0_0_0_0_0_0_0_0_0_0_0"
token = "s0_0_0_0_0_0_0_0_0_0_0_0_0"
channel = self.make_request(
"GET",
"/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token),
+2 -2
View File
@@ -2248,7 +2248,7 @@ class RoomMessageListTestCase(RoomBase):
self.room_id = self.helper.create_room_as(self.user_id)
def test_topo_token_is_accepted(self) -> None:
token = "t1-0_0_0_0_0_0_0_0_0_0_0_0"
token = "t1-0_0_0_0_0_0_0_0_0_0_0_0_0"
channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
)
@@ -2259,7 +2259,7 @@ class RoomMessageListTestCase(RoomBase):
self.assertTrue("end" in channel.json_body)
def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None:
token = "s0_0_0_0_0_0_0_0_0_0_0_0"
token = "s0_0_0_0_0_0_0_0_0_0_0_0_0"
channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
)