mirror of
https://github.com/element-hq/synapse.git
synced 2026-05-24 15:15:22 +00:00
Replace wait_for_quarantined_media_stream_id(...) with standard wait_for_stream_token(...) (#19764)
In order to be able to use `wait_for_stream_token(...)`, we have to add the `quarantined_media` stream to the `StreamToken`. Even though we don't care about `/sync`'ing `quarantined_media`, this aligns with the future where all endpoints should probably use `StreamToken`, see https://github.com/element-hq/synapse/issues/19647 Follow-up to https://github.com/element-hq/synapse/pull/19558 and https://github.com/element-hq/synapse/pull/19644
This commit is contained in:
@@ -0,0 +1 @@
|
||||
Replace unique `quarantined_media` waiting patterns with standard `wait_for_stream_token(...)`.
|
||||
@@ -43,7 +43,13 @@ from synapse.rest.admin._base import (
|
||||
from synapse.storage.databases.main.media_repository import (
|
||||
MediaSortOrder,
|
||||
)
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
MultiWriterStreamToken,
|
||||
StreamKeyType,
|
||||
StreamToken,
|
||||
UserID,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
@@ -243,6 +249,7 @@ class ListQuarantineChanges(RestServlet):
|
||||
self.auth = hs.get_auth()
|
||||
self.server_name = hs.hostname
|
||||
self.replication = hs.get_replication_data_handler()
|
||||
self.notifier = hs.get_notifier()
|
||||
|
||||
async def on_GET(self, request: SynapseRequest) -> tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
@@ -256,8 +263,7 @@ class ListQuarantineChanges(RestServlet):
|
||||
# The caller is trying to get future data, which we don't allow because
|
||||
# we know it's an invalid state that should never happen. We could
|
||||
# wait until we reach the token but we might as well not waste our
|
||||
# resources on that which is why `wait_for_quarantined_media_stream_id(...)`
|
||||
# has assertions around this.
|
||||
# resources on that.
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"The `from` token is considered invalid because it includes stream positions "
|
||||
@@ -268,9 +274,16 @@ class ListQuarantineChanges(RestServlet):
|
||||
errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
# Create a `StreamToken` that's compatible with `wait_for_stream_token`.
|
||||
#
|
||||
# FIXME: Ideally, this endpoint would use a `StreamToken` to begin with
|
||||
from_token = StreamToken.START.copy_and_replace(
|
||||
StreamKeyType.QUARANTINED_MEDIA, MultiWriterStreamToken(stream=from_id)
|
||||
)
|
||||
|
||||
# We need to wait to ensure that our current worker is actually caught up with
|
||||
# the stream position, otherwise we might not return what we think we're returning.
|
||||
if not await self.store.wait_for_quarantined_media_stream_id(from_id):
|
||||
if not await self.notifier.wait_for_stream_token(from_token):
|
||||
raise SynapseError(
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
"Timed out while waiting for the worker serving this request to catch up to the given "
|
||||
@@ -280,7 +293,7 @@ class ListQuarantineChanges(RestServlet):
|
||||
errcode=Codes.UNKNOWN,
|
||||
)
|
||||
|
||||
to_id = await self.store.get_current_quarantined_media_stream_id()
|
||||
to_id = self.store.get_current_quarantined_media_stream_id()
|
||||
changes = await self.store.get_quarantined_media_changes(
|
||||
from_id=from_id,
|
||||
to_id=to_id,
|
||||
|
||||
@@ -62,12 +62,12 @@ from synapse.storage.types import Cursor
|
||||
from synapse.storage.util.id_generators import IdGenerator, MultiWriterIdGenerator
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
MultiWriterStreamToken,
|
||||
RetentionPolicy,
|
||||
StrCollection,
|
||||
ThirdPartyInstanceID,
|
||||
)
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.duration import Duration
|
||||
from synapse.util.json import json_encoder
|
||||
from synapse.util.stringutils import MXC_REGEX
|
||||
|
||||
@@ -1302,7 +1302,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
||||
|
||||
return local_media_ids
|
||||
|
||||
async def get_current_quarantined_media_stream_id(self) -> int:
|
||||
def get_quarantined_media_stream_token(self) -> MultiWriterStreamToken:
|
||||
return MultiWriterStreamToken.from_generator(
|
||||
self._quarantined_media_changes_id_gen
|
||||
)
|
||||
|
||||
def get_quarantined_media_stream_id_generator(self) -> MultiWriterIdGenerator:
|
||||
return self._quarantined_media_changes_id_gen
|
||||
|
||||
def get_current_quarantined_media_stream_id(self) -> int:
|
||||
"""Gets the position of the quarantined media changes stream.
|
||||
|
||||
Returns:
|
||||
@@ -1318,74 +1326,6 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
||||
"""
|
||||
return await self._quarantined_media_changes_id_gen.get_max_allocated_token()
|
||||
|
||||
async def wait_for_quarantined_media_stream_id(self, target_id: int) -> bool:
|
||||
"""Waits until the quarantined media changes stream reaches the given stream ID.
|
||||
|
||||
See https://github.com/element-hq/synapse/pull/19644 for more details.
|
||||
|
||||
TODO: Replace function and call sites with https://github.com/element-hq/synapse/pull/19644
|
||||
|
||||
Args:
|
||||
target_id: The stream ID to wait for.
|
||||
|
||||
Returns:
|
||||
True when caught up to the target stream ID.
|
||||
False when timing out while waiting.
|
||||
"""
|
||||
# We ideally would use something like `wait_for_stream_position` in the meantime,
|
||||
# but that short circuits if the instance name matches the current instance name.
|
||||
# Doing so means that if *another* writer is actually leading the to_id, then we'll
|
||||
# assume that we're caught up when we aren't.
|
||||
#
|
||||
# NOTE: Because this is implemented to wait for stream positions by integer ID,
|
||||
# we're technically waiting for *all* workers to catch up rather than just waiting
|
||||
# for *our* worker to catch up. This is okay for now because the quarantined media
|
||||
# stream should be pretty fast to update, and if it's not then the only thing we're
|
||||
# affecting is an admin API that probably has a tool automatically retrying requests
|
||||
# anyway. https://github.com/element-hq/synapse/pull/19644 does the waiting properly
|
||||
# so this should be replaced by that (or similar).
|
||||
|
||||
# Get the minimum shared position/ID across all workers
|
||||
current_id = self._quarantined_media_changes_id_gen.get_current_token()
|
||||
if current_id >= target_id:
|
||||
return True # nothing to wait for: we're already caught up.
|
||||
|
||||
# "This should never happen". Tokens we hand out via the API should exist. If they
|
||||
# don't, then we're in a bad state and need to explode.
|
||||
max_persisted_position = (
|
||||
await self._quarantined_media_changes_id_gen.get_max_allocated_token()
|
||||
)
|
||||
assert max_persisted_position >= target_id, (
|
||||
f"Unable to wait for invalid future token (token={target_id} has positions "
|
||||
f"ahead of our max persisted position={max_persisted_position})"
|
||||
)
|
||||
|
||||
# Start waiting until we've caught up to the `stream_token`
|
||||
start = self.clock.time_msec()
|
||||
logged = False
|
||||
while True:
|
||||
# Like above, get the minimum shared ID across all workers
|
||||
current_id = self._quarantined_media_changes_id_gen.get_current_token()
|
||||
if current_id >= target_id:
|
||||
return True
|
||||
|
||||
now = self.clock.time_msec()
|
||||
|
||||
# Timed out
|
||||
if now - start > 10_000:
|
||||
return False
|
||||
|
||||
if not logged:
|
||||
logger.info(
|
||||
"Waiting for current token to reach %s; currently at %s",
|
||||
target_id,
|
||||
current_id,
|
||||
)
|
||||
logged = True
|
||||
|
||||
# TODO: be better
|
||||
await self.clock.sleep(Duration(milliseconds=500))
|
||||
|
||||
async def get_quarantined_media_changes(
|
||||
self, *, from_id: int, to_id: int, limit: int
|
||||
) -> list[QuarantinedMediaUpdate]:
|
||||
|
||||
@@ -85,6 +85,7 @@ class EventSources:
|
||||
)
|
||||
thread_subscriptions_key = self.store.get_max_thread_subscriptions_stream_id()
|
||||
sticky_events_key = self.store.get_max_sticky_events_stream_id()
|
||||
quarantined_media_key = self.store.get_quarantined_media_stream_token()
|
||||
|
||||
token = StreamToken(
|
||||
room_key=self.sources.room.get_current_key(),
|
||||
@@ -100,6 +101,7 @@ class EventSources:
|
||||
un_partial_stated_rooms_key=un_partial_stated_rooms_key,
|
||||
thread_subscriptions_key=thread_subscriptions_key,
|
||||
sticky_events_key=sticky_events_key,
|
||||
quarantined_media_key=quarantined_media_key,
|
||||
)
|
||||
return token
|
||||
|
||||
@@ -128,6 +130,7 @@ class EventSources:
|
||||
StreamKeyType.UN_PARTIAL_STATED_ROOMS: self.store.get_un_partial_stated_rooms_id_generator(),
|
||||
StreamKeyType.THREAD_SUBSCRIPTIONS: self.store.get_thread_subscriptions_stream_id_generator(),
|
||||
StreamKeyType.STICKY_EVENTS: self.store.get_sticky_events_stream_id_generator(),
|
||||
StreamKeyType.QUARANTINED_MEDIA: self.store.get_quarantined_media_stream_id_generator(),
|
||||
}
|
||||
|
||||
for _, key in StreamKeyType.__members__.items():
|
||||
|
||||
@@ -1060,6 +1060,7 @@ class StreamKeyType(Enum):
|
||||
UN_PARTIAL_STATED_ROOMS = "un_partial_stated_rooms_key"
|
||||
THREAD_SUBSCRIPTIONS = "thread_subscriptions_key"
|
||||
STICKY_EVENTS = "sticky_events_key"
|
||||
QUARANTINED_MEDIA = "quarantined_media_key"
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
@@ -1067,7 +1068,7 @@ class StreamToken:
|
||||
"""A collection of keys joined together by underscores in the following
|
||||
order and which represent the position in their respective streams.
|
||||
|
||||
ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1_379_4242`
|
||||
ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1_379_4242_4141_4343`
|
||||
1. `room_key`: `s2633508` which is a `RoomStreamToken`
|
||||
- `RoomStreamToken`'s can also look like `t426-2633508` or `m56~2.58~3.59`
|
||||
- See the docstring for `RoomStreamToken` for more details.
|
||||
@@ -1082,12 +1083,13 @@ class StreamToken:
|
||||
10. `un_partial_stated_rooms_key`: `379`
|
||||
11. `thread_subscriptions_key`: 4242
|
||||
12. `sticky_events_key`: 4141
|
||||
13. `quarantined_media_key`: 4343
|
||||
|
||||
You can see how many of these keys correspond to the various
|
||||
fields in a "/sync" response:
|
||||
```json
|
||||
{
|
||||
"next_batch": "s12_4_0_1_1_1_1_4_1_1",
|
||||
"next_batch": "s12_4_0_1_1_1_1_4_1_1_1_1_1",
|
||||
"presence": {
|
||||
"events": []
|
||||
},
|
||||
@@ -1099,7 +1101,7 @@ class StreamToken:
|
||||
"!QrZlfIDQLNLdZHqTnt:hs1": {
|
||||
"timeline": {
|
||||
"events": [],
|
||||
"prev_batch": "s10_4_0_1_1_1_1_4_1_1",
|
||||
"prev_batch": "s10_4_0_1_1_1_1_4_1_1_1_1_1",
|
||||
"limited": false
|
||||
},
|
||||
"state": {
|
||||
@@ -1142,6 +1144,9 @@ class StreamToken:
|
||||
un_partial_stated_rooms_key: int
|
||||
thread_subscriptions_key: int
|
||||
sticky_events_key: int
|
||||
quarantined_media_key: MultiWriterStreamToken = attr.ib(
|
||||
validator=attr.validators.instance_of(MultiWriterStreamToken)
|
||||
)
|
||||
|
||||
_SEPARATOR = "_"
|
||||
START: ClassVar["StreamToken"]
|
||||
@@ -1171,6 +1176,7 @@ class StreamToken:
|
||||
un_partial_stated_rooms_key,
|
||||
thread_subscriptions_key,
|
||||
sticky_events_key,
|
||||
quarantined_media_key,
|
||||
) = keys
|
||||
|
||||
return cls(
|
||||
@@ -1188,6 +1194,9 @@ class StreamToken:
|
||||
un_partial_stated_rooms_key=int(un_partial_stated_rooms_key),
|
||||
thread_subscriptions_key=int(thread_subscriptions_key),
|
||||
sticky_events_key=int(sticky_events_key),
|
||||
quarantined_media_key=await MultiWriterStreamToken.parse(
|
||||
store, quarantined_media_key
|
||||
),
|
||||
)
|
||||
except CancelledError:
|
||||
raise
|
||||
@@ -1212,6 +1221,7 @@ class StreamToken:
|
||||
str(self.un_partial_stated_rooms_key),
|
||||
str(self.thread_subscriptions_key),
|
||||
str(self.sticky_events_key),
|
||||
await self.quarantined_media_key.to_string(store),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1241,6 +1251,12 @@ class StreamToken:
|
||||
self.device_list_key.copy_and_advance(new_value),
|
||||
)
|
||||
return new_token
|
||||
elif key == StreamKeyType.QUARANTINED_MEDIA:
|
||||
new_token = self.copy_and_replace(
|
||||
StreamKeyType.QUARANTINED_MEDIA,
|
||||
self.quarantined_media_key.copy_and_advance(new_value),
|
||||
)
|
||||
return new_token
|
||||
|
||||
new_token = self.copy_and_replace(key, new_value)
|
||||
new_id = new_token.get_field(key)
|
||||
@@ -1263,6 +1279,7 @@ class StreamToken:
|
||||
key: Literal[
|
||||
StreamKeyType.RECEIPT,
|
||||
StreamKeyType.DEVICE_LIST,
|
||||
StreamKeyType.QUARANTINED_MEDIA,
|
||||
],
|
||||
) -> MultiWriterStreamToken: ...
|
||||
|
||||
@@ -1334,7 +1351,8 @@ class StreamToken:
|
||||
f"account_data: {self.account_data_key}, push_rules: {self.push_rules_key}, "
|
||||
f"to_device: {self.to_device_key}, device_list: {self.device_list_key}, "
|
||||
f"groups: {self.groups_key}, un_partial_stated_rooms: {self.un_partial_stated_rooms_key},"
|
||||
f"thread_subscriptions: {self.thread_subscriptions_key}, sticky_events: {self.sticky_events_key})"
|
||||
f"thread_subscriptions: {self.thread_subscriptions_key}, sticky_events: {self.sticky_events_key}"
|
||||
f"quarantined_media: {self.quarantined_media_key})"
|
||||
)
|
||||
|
||||
|
||||
@@ -1351,6 +1369,7 @@ StreamToken.START = StreamToken(
|
||||
un_partial_stated_rooms_key=0,
|
||||
thread_subscriptions_key=0,
|
||||
sticky_events_key=0,
|
||||
quarantined_media_key=MultiWriterStreamToken(stream=0),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -2549,7 +2549,7 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
def test_topo_token_is_accepted(self) -> None:
|
||||
"""Test Topo Token is accepted."""
|
||||
token = "t1-0_0_0_0_0_0_0_0_0_0_0_0"
|
||||
token = "t1-0_0_0_0_0_0_0_0_0_0_0_0_0"
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token),
|
||||
@@ -2563,7 +2563,7 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None:
|
||||
"""Test that stream token is accepted for forward pagination."""
|
||||
token = "s0_0_0_0_0_0_0_0_0_0_0_0"
|
||||
token = "s0_0_0_0_0_0_0_0_0_0_0_0_0"
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token),
|
||||
|
||||
@@ -2248,7 +2248,7 @@ class RoomMessageListTestCase(RoomBase):
|
||||
self.room_id = self.helper.create_room_as(self.user_id)
|
||||
|
||||
def test_topo_token_is_accepted(self) -> None:
|
||||
token = "t1-0_0_0_0_0_0_0_0_0_0_0_0"
|
||||
token = "t1-0_0_0_0_0_0_0_0_0_0_0_0_0"
|
||||
channel = self.make_request(
|
||||
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
|
||||
)
|
||||
@@ -2259,7 +2259,7 @@ class RoomMessageListTestCase(RoomBase):
|
||||
self.assertTrue("end" in channel.json_body)
|
||||
|
||||
def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None:
|
||||
token = "s0_0_0_0_0_0_0_0_0_0_0_0"
|
||||
token = "s0_0_0_0_0_0_0_0_0_0_0_0_0"
|
||||
channel = self.make_request(
|
||||
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user