diff --git a/changelog.d/19764.misc b/changelog.d/19764.misc new file mode 100644 index 0000000000..8704e3eed6 --- /dev/null +++ b/changelog.d/19764.misc @@ -0,0 +1 @@ +Replace unique `quarantined_media` waiting patterns with standard `wait_for_stream_token(...)`. diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 1633cca884..35454c1522 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -43,7 +43,13 @@ from synapse.rest.admin._base import ( from synapse.storage.databases.main.media_repository import ( MediaSortOrder, ) -from synapse.types import JsonDict, UserID +from synapse.types import ( + JsonDict, + MultiWriterStreamToken, + StreamKeyType, + StreamToken, + UserID, +) if TYPE_CHECKING: from synapse.server import HomeServer @@ -243,6 +249,7 @@ class ListQuarantineChanges(RestServlet): self.auth = hs.get_auth() self.server_name = hs.hostname self.replication = hs.get_replication_data_handler() + self.notifier = hs.get_notifier() async def on_GET(self, request: SynapseRequest) -> tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) @@ -256,8 +263,7 @@ class ListQuarantineChanges(RestServlet): # The caller is trying to get future data, which we don't allow because # we know it's an invalid state that should never happen. We could # wait until we reach the token but we might as well not waste our - # resources on that which is why `wait_for_quarantined_media_stream_id(...)` - # has assertions around this. + # resources on that. raise SynapseError( HTTPStatus.BAD_REQUEST, "The `from` token is considered invalid because it includes stream positions " @@ -268,9 +274,16 @@ class ListQuarantineChanges(RestServlet): errcode=Codes.INVALID_PARAM, ) + # Create a `StreamToken` that's compatible with `wait_for_stream_token`. + # + # FIXME: Ideally, this endpoint would use a `StreamToken` to begin with + from_token = StreamToken.START.copy_and_replace( + StreamKeyType.QUARANTINED_MEDIA, MultiWriterStreamToken(stream=from_id) + ) + # We need to wait to ensure that our current worker is actually caught up with # the stream position, otherwise we might not return what we think we're returning. - if not await self.store.wait_for_quarantined_media_stream_id(from_id): + if not await self.notifier.wait_for_stream_token(from_token): raise SynapseError( HTTPStatus.INTERNAL_SERVER_ERROR, "Timed out while waiting for the worker serving this request to catch up to the given " @@ -280,7 +293,7 @@ class ListQuarantineChanges(RestServlet): errcode=Codes.UNKNOWN, ) - to_id = await self.store.get_current_quarantined_media_stream_id() + to_id = self.store.get_current_quarantined_media_stream_id() changes = await self.store.get_quarantined_media_changes( from_id=from_id, to_id=to_id, diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index a0c42082f0..95aa2cb7dc 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -62,12 +62,12 @@ from synapse.storage.types import Cursor from synapse.storage.util.id_generators import IdGenerator, MultiWriterIdGenerator from synapse.types import ( JsonDict, + MultiWriterStreamToken, RetentionPolicy, StrCollection, ThirdPartyInstanceID, ) from synapse.util.caches.descriptors import cached, cachedList -from synapse.util.duration import Duration from synapse.util.json import json_encoder from synapse.util.stringutils import MXC_REGEX @@ -1302,7 +1302,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): return local_media_ids - async def get_current_quarantined_media_stream_id(self) -> int: + def get_quarantined_media_stream_token(self) -> MultiWriterStreamToken: + return MultiWriterStreamToken.from_generator( + self._quarantined_media_changes_id_gen + ) + + def get_quarantined_media_stream_id_generator(self) -> MultiWriterIdGenerator: + return self._quarantined_media_changes_id_gen + + def get_current_quarantined_media_stream_id(self) -> int: """Gets the position of the quarantined media changes stream. Returns: @@ -1318,74 +1326,6 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): """ return await self._quarantined_media_changes_id_gen.get_max_allocated_token() - async def wait_for_quarantined_media_stream_id(self, target_id: int) -> bool: - """Waits until the quarantined media changes stream reaches the given stream ID. - - See https://github.com/element-hq/synapse/pull/19644 for more details. - - TODO: Replace function and call sites with https://github.com/element-hq/synapse/pull/19644 - - Args: - target_id: The stream ID to wait for. - - Returns: - True when caught up to the target stream ID. - False when timing out while waiting. - """ - # We ideally would use something like `wait_for_stream_position` in the meantime, - # but that short circuits if the instance name matches the current instance name. - # Doing so means that if *another* writer is actually leading the to_id, then we'll - # assume that we're caught up when we aren't. - # - # NOTE: Because this is implemented to wait for stream positions by integer ID, - # we're technically waiting for *all* workers to catch up rather than just waiting - # for *our* worker to catch up. This is okay for now because the quarantined media - # stream should be pretty fast to update, and if it's not then the only thing we're - # affecting is an admin API that probably has a tool automatically retrying requests - # anyway. https://github.com/element-hq/synapse/pull/19644 does the waiting properly - # so this should be replaced by that (or similar). - - # Get the minimum shared position/ID across all workers - current_id = self._quarantined_media_changes_id_gen.get_current_token() - if current_id >= target_id: - return True # nothing to wait for: we're already caught up. - - # "This should never happen". Tokens we hand out via the API should exist. If they - # don't, then we're in a bad state and need to explode. - max_persisted_position = ( - await self._quarantined_media_changes_id_gen.get_max_allocated_token() - ) - assert max_persisted_position >= target_id, ( - f"Unable to wait for invalid future token (token={target_id} has positions " - f"ahead of our max persisted position={max_persisted_position})" - ) - - # Start waiting until we've caught up to the `stream_token` - start = self.clock.time_msec() - logged = False - while True: - # Like above, get the minimum shared ID across all workers - current_id = self._quarantined_media_changes_id_gen.get_current_token() - if current_id >= target_id: - return True - - now = self.clock.time_msec() - - # Timed out - if now - start > 10_000: - return False - - if not logged: - logger.info( - "Waiting for current token to reach %s; currently at %s", - target_id, - current_id, - ) - logged = True - - # TODO: be better - await self.clock.sleep(Duration(milliseconds=500)) - async def get_quarantined_media_changes( self, *, from_id: int, to_id: int, limit: int ) -> list[QuarantinedMediaUpdate]: diff --git a/synapse/streams/events.py b/synapse/streams/events.py index f5677a2082..36490fcb35 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -85,6 +85,7 @@ class EventSources: ) thread_subscriptions_key = self.store.get_max_thread_subscriptions_stream_id() sticky_events_key = self.store.get_max_sticky_events_stream_id() + quarantined_media_key = self.store.get_quarantined_media_stream_token() token = StreamToken( room_key=self.sources.room.get_current_key(), @@ -100,6 +101,7 @@ class EventSources: un_partial_stated_rooms_key=un_partial_stated_rooms_key, thread_subscriptions_key=thread_subscriptions_key, sticky_events_key=sticky_events_key, + quarantined_media_key=quarantined_media_key, ) return token @@ -128,6 +130,7 @@ class EventSources: StreamKeyType.UN_PARTIAL_STATED_ROOMS: self.store.get_un_partial_stated_rooms_id_generator(), StreamKeyType.THREAD_SUBSCRIPTIONS: self.store.get_thread_subscriptions_stream_id_generator(), StreamKeyType.STICKY_EVENTS: self.store.get_sticky_events_stream_id_generator(), + StreamKeyType.QUARANTINED_MEDIA: self.store.get_quarantined_media_stream_id_generator(), } for _, key in StreamKeyType.__members__.items(): diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 8b005ef84d..8537a63bde 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -1060,6 +1060,7 @@ class StreamKeyType(Enum): UN_PARTIAL_STATED_ROOMS = "un_partial_stated_rooms_key" THREAD_SUBSCRIPTIONS = "thread_subscriptions_key" STICKY_EVENTS = "sticky_events_key" + QUARANTINED_MEDIA = "quarantined_media_key" @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -1067,7 +1068,7 @@ class StreamToken: """A collection of keys joined together by underscores in the following order and which represent the position in their respective streams. - ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1_379_4242` + ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1_379_4242_4141_4343` 1. `room_key`: `s2633508` which is a `RoomStreamToken` - `RoomStreamToken`'s can also look like `t426-2633508` or `m56~2.58~3.59` - See the docstring for `RoomStreamToken` for more details. @@ -1082,12 +1083,13 @@ class StreamToken: 10. `un_partial_stated_rooms_key`: `379` 11. `thread_subscriptions_key`: 4242 12. `sticky_events_key`: 4141 + 13. `quarantined_media_key`: 4343 You can see how many of these keys correspond to the various fields in a "/sync" response: ```json { - "next_batch": "s12_4_0_1_1_1_1_4_1_1", + "next_batch": "s12_4_0_1_1_1_1_4_1_1_1_1_1", "presence": { "events": [] }, @@ -1099,7 +1101,7 @@ class StreamToken: "!QrZlfIDQLNLdZHqTnt:hs1": { "timeline": { "events": [], - "prev_batch": "s10_4_0_1_1_1_1_4_1_1", + "prev_batch": "s10_4_0_1_1_1_1_4_1_1_1_1_1", "limited": false }, "state": { @@ -1142,6 +1144,9 @@ class StreamToken: un_partial_stated_rooms_key: int thread_subscriptions_key: int sticky_events_key: int + quarantined_media_key: MultiWriterStreamToken = attr.ib( + validator=attr.validators.instance_of(MultiWriterStreamToken) + ) _SEPARATOR = "_" START: ClassVar["StreamToken"] @@ -1171,6 +1176,7 @@ class StreamToken: un_partial_stated_rooms_key, thread_subscriptions_key, sticky_events_key, + quarantined_media_key, ) = keys return cls( @@ -1188,6 +1194,9 @@ class StreamToken: un_partial_stated_rooms_key=int(un_partial_stated_rooms_key), thread_subscriptions_key=int(thread_subscriptions_key), sticky_events_key=int(sticky_events_key), + quarantined_media_key=await MultiWriterStreamToken.parse( + store, quarantined_media_key + ), ) except CancelledError: raise @@ -1212,6 +1221,7 @@ class StreamToken: str(self.un_partial_stated_rooms_key), str(self.thread_subscriptions_key), str(self.sticky_events_key), + await self.quarantined_media_key.to_string(store), ] ) @@ -1241,6 +1251,12 @@ class StreamToken: self.device_list_key.copy_and_advance(new_value), ) return new_token + elif key == StreamKeyType.QUARANTINED_MEDIA: + new_token = self.copy_and_replace( + StreamKeyType.QUARANTINED_MEDIA, + self.quarantined_media_key.copy_and_advance(new_value), + ) + return new_token new_token = self.copy_and_replace(key, new_value) new_id = new_token.get_field(key) @@ -1263,6 +1279,7 @@ class StreamToken: key: Literal[ StreamKeyType.RECEIPT, StreamKeyType.DEVICE_LIST, + StreamKeyType.QUARANTINED_MEDIA, ], ) -> MultiWriterStreamToken: ... @@ -1334,7 +1351,8 @@ class StreamToken: f"account_data: {self.account_data_key}, push_rules: {self.push_rules_key}, " f"to_device: {self.to_device_key}, device_list: {self.device_list_key}, " f"groups: {self.groups_key}, un_partial_stated_rooms: {self.un_partial_stated_rooms_key}," - f"thread_subscriptions: {self.thread_subscriptions_key}, sticky_events: {self.sticky_events_key})" + f"thread_subscriptions: {self.thread_subscriptions_key}, sticky_events: {self.sticky_events_key}" + f"quarantined_media: {self.quarantined_media_key})" ) @@ -1351,6 +1369,7 @@ StreamToken.START = StreamToken( un_partial_stated_rooms_key=0, thread_subscriptions_key=0, sticky_events_key=0, + quarantined_media_key=MultiWriterStreamToken(stream=0), ) diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 507cf10c5d..c4e4170c6f 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -2549,7 +2549,7 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase): def test_topo_token_is_accepted(self) -> None: """Test Topo Token is accepted.""" - token = "t1-0_0_0_0_0_0_0_0_0_0_0_0" + token = "t1-0_0_0_0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token), @@ -2563,7 +2563,7 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase): def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None: """Test that stream token is accepted for forward pagination.""" - token = "s0_0_0_0_0_0_0_0_0_0_0_0" + token = "s0_0_0_0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token), diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 28872fa06c..10325c536a 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -2248,7 +2248,7 @@ class RoomMessageListTestCase(RoomBase): self.room_id = self.helper.create_room_as(self.user_id) def test_topo_token_is_accepted(self) -> None: - token = "t1-0_0_0_0_0_0_0_0_0_0_0_0" + token = "t1-0_0_0_0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) @@ -2259,7 +2259,7 @@ class RoomMessageListTestCase(RoomBase): self.assertTrue("end" in channel.json_body) def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None: - token = "s0_0_0_0_0_0_0_0_0_0_0_0" + token = "s0_0_0_0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) )