Remove wait_for_multi_writer_stream_token

See https://github.com/element-hq/synapse/pull/19644#discussion_r3029861226
This commit is contained in:
Eric Eastwood
2026-05-07 14:35:48 -05:00
parent be0599425c
commit 83d6bdbd77
4 changed files with 50 additions and 109 deletions
+1 -1
View File
@@ -824,7 +824,7 @@ mod tests {
// Serialize the keys in the reverse order.
for (c, _) in ascii_order.iter().rev() {
map_serializer.serialize_entry(c.into(), &1).unwrap();
map_serializer.serialize_entry(c, &1).unwrap();
}
SerializeMap::end(map_serializer).unwrap();
-67
View File
@@ -47,9 +47,6 @@ from synapse.logging import issue9533_logger
from synapse.logging.context import PreserveLoggingContext
from synapse.logging.opentracing import log_kv, start_active_span
from synapse.metrics import SERVER_NAME_LABEL, LaterGauge
from synapse.storage.util.id_generators import (
MultiWriterIdGenerator,
)
from synapse.streams.config import PaginationConfig
from synapse.types import (
ISynapseReactor,
@@ -895,70 +892,6 @@ class Notifier:
# TODO: be better
await self.clock.sleep(Duration(milliseconds=500))
async def wait_for_multi_writer_stream_token(
self,
token: MultiWriterStreamToken,
id_gen: MultiWriterIdGenerator,
) -> bool:
"""
Wait for this worker to catch up with the given stream token.
This is important to ensure that the worker has a proper view of the world
before trying to serve a request. For example, one worker can return a response
with some `next_batch` token, but then the next request goes to another worker
which is behind; if the worker assembles a response up to the token, it could be
missing data in the gap between where it's behind and the requested token.
Returns:
True when this worker has caught up
False when we timed out waiting
"""
current_token = MultiWriterStreamToken.from_generator(id_gen)
# Return early if we are already caught up
if token.is_before_or_eq(current_token):
return True
# Assert as we consider this a Synapse programming error. We shouldn't be
# handing out invalid future tokens and tokens should be validated before it
# reaches this point.
#
# We consider a token invalid, if the token has positions ahead of our persisted
# positions in the database
#
# Previously, we would bound the tokens within this function but that leads to
# bad patterns upstream where people can continue to use the unbounded token.
max_persisted_position = await id_gen.get_max_allocated_token()
assert max_persisted_position >= token.get_max_stream_pos(), (
f"Refusing to wait for invalid future token (token={token} "
"that has positions ahead of our max persisted position {max_persisted_position}) "
"(Synapse programming error)"
)
# Start waiting until we've caught up to the `stream_token`
start = self.clock.time_msec()
logged = False
while True:
current_token = MultiWriterStreamToken.from_generator(id_gen)
if token.is_before_or_eq(current_token):
return True
now = self.clock.time_msec()
# Timed out
if now - start > 10_000:
return False
if not logged:
logger.info(
"Waiting for current token to reach %s; currently at %s",
token,
current_token,
)
logged = True
# TODO: be better
await self.clock.sleep(Duration(milliseconds=500))
async def _get_room_ids(
self, user: UserID, explicit_room_id: str | None
) -> tuple[StrCollection, bool]:
+7
View File
@@ -1312,6 +1312,13 @@ class StreamToken:
self_value = self.get_field(key)
other_value = other_token.get_field(key)
logger.info(
"asdf key=%s self_value=%s, other_value=%s",
key,
self_value,
other_value,
)
if isinstance(self_value, RoomStreamToken):
assert isinstance(other_value, RoomStreamToken)
if not self_value.is_before_or_eq(other_value):
+42 -41
View File
@@ -16,7 +16,7 @@ from twisted.internet import defer
from twisted.internet.testing import MemoryReactor
from synapse.server import HomeServer
from synapse.types import MultiWriterStreamToken
from synapse.types import MultiWriterStreamToken, StreamKeyType, StreamToken
from synapse.util.clock import Clock
from synapse.util.duration import Duration
@@ -30,24 +30,23 @@ class NotifierTestCase(tests.unittest.HomeserverTestCase):
self.store = self.hs.get_datastores().main
self.notifier = self.hs.get_notifier()
def test_wait_for_multi_writer_stream_token_with_caught_up_token(self) -> None:
def test_wait_for_stream_token_with_caught_up_token(self) -> None:
"""
Test `wait_for_stream_token` when we receive a token that we are caught up to.
"""
# Create a token
presence_id_gen = self.store.get_presence_stream_id_gen()
token = MultiWriterStreamToken.from_generator(presence_id_gen)
receipt_id_gen = self.store.get_receipts_stream_id_gen()
receipt_token = MultiWriterStreamToken.from_generator(receipt_id_gen)
token = StreamToken.START.copy_and_replace(StreamKeyType.RECEIPT, receipt_token)
# Function under test
wait_d = defer.ensureDeferred(
self.notifier.wait_for_multi_writer_stream_token(token, presence_id_gen)
)
wait_d = defer.ensureDeferred(self.notifier.wait_for_stream_token(token))
# Done waiting and caught-up (True)
wait_result = self.get_success(wait_d)
self.assertEqual(wait_result, True)
def test_wait_for_multi_writer_stream_token_with_future_sync_token(self) -> None:
def test_wait_for_stream_token_with_future_sync_token(self) -> None:
"""
Test `wait_for_stream_token` when we receive a token that is ahead of our
current token, we'll wait until the stream position advances.
@@ -57,23 +56,24 @@ class NotifierTestCase(tests.unittest.HomeserverTestCase):
"""
# We simulate a lagging stream by getting a stream ID from the ID gen
# and then waiting to mark it as "persisted".
presence_id_gen = self.store.get_presence_stream_id_gen()
ctx_mgr = presence_id_gen.get_next()
stream_id = self.get_success(ctx_mgr.__aenter__())
receipt_id_gen = self.store.get_receipts_stream_id_gen()
ctx_mgr = receipt_id_gen.get_next()
receipt_stream_id = self.get_success(ctx_mgr.__aenter__())
# Create the new token based on the stream ID above.
current_token = MultiWriterStreamToken.from_generator(presence_id_gen)
token = current_token.copy_and_advance(MultiWriterStreamToken(stream=stream_id))
current_receipt_token = MultiWriterStreamToken.from_generator(receipt_id_gen)
receipt_token = current_receipt_token.copy_and_advance(
MultiWriterStreamToken(stream=receipt_stream_id)
)
token = StreamToken.START.copy_and_advance(StreamKeyType.RECEIPT, receipt_token)
# Function under test
wait_d = defer.ensureDeferred(
self.notifier.wait_for_multi_writer_stream_token(token, presence_id_gen)
)
wait_d = defer.ensureDeferred(self.notifier.wait_for_stream_token(token))
# This should block waiting for the stream to update
#
# Advance time a little bit to make the
# `wait_for_multi_writer_stream_token(...)` sleep loop iterate.
# `wait_for_stream_token(...)` sleep loop iterate.
self.reactor.advance(Duration(seconds=2).as_secs())
# It should still not be done yet
self.assertFalse(wait_d.called)
@@ -82,7 +82,7 @@ class NotifierTestCase(tests.unittest.HomeserverTestCase):
self.get_success(ctx_mgr.__aexit__(None, None, None))
# Advance time to make another iteration of
# `wait_for_multi_writer_stream_token(...)` sleep loop so it sees that we're
# `wait_for_stream_token(...)` sleep loop so it sees that we're
# finally caught up now.
self.reactor.advance(Duration(seconds=1).as_secs())
@@ -90,7 +90,7 @@ class NotifierTestCase(tests.unittest.HomeserverTestCase):
wait_result = self.get_success(wait_d)
self.assertEqual(wait_result, True)
def test_wait_for_multi_writer_stream_token_with_future_sync_token_timeout(
def test_wait_for_stream_token_with_future_sync_token_timeout(
self,
) -> None:
"""
@@ -103,42 +103,43 @@ class NotifierTestCase(tests.unittest.HomeserverTestCase):
"""
# We simulate a lagging stream by getting a stream ID from the ID gen
# and then waiting to mark it as "persisted".
presence_id_gen = self.store.get_presence_stream_id_gen()
ctx_mgr = presence_id_gen.get_next()
stream_id = self.get_success(ctx_mgr.__aenter__())
receipt_id_gen = self.store.get_receipts_stream_id_gen()
ctx_mgr = receipt_id_gen.get_next()
receipt_stream_id = self.get_success(ctx_mgr.__aenter__())
# Create the new token based on the stream ID above.
current_token = MultiWriterStreamToken.from_generator(presence_id_gen)
token = current_token.copy_and_advance(MultiWriterStreamToken(stream=stream_id))
current_receipt_token = MultiWriterStreamToken.from_generator(receipt_id_gen)
receipt_token = current_receipt_token.copy_and_advance(
MultiWriterStreamToken(stream=receipt_stream_id)
)
token = StreamToken.START.copy_and_advance(StreamKeyType.RECEIPT, receipt_token)
# Function under test
wait_d = defer.ensureDeferred(
self.notifier.wait_for_multi_writer_stream_token(token, presence_id_gen)
)
wait_d = defer.ensureDeferred(self.notifier.wait_for_stream_token(token))
# Advance time a little bit to make the
# `wait_for_multi_writer_stream_token(...)` sleep loop record 0 as the `start` time.
# `wait_for_stream_token(...)` sleep loop record 0 as the `start` time.
self.reactor.advance(Duration(seconds=0).as_secs())
# This should block waiting for the stream to update
#
# Advance time a little bit to make the
# `wait_for_multi_writer_stream_token(...)` sleep loop iterate.
# `wait_for_stream_token(...)` sleep loop iterate.
self.reactor.advance(Duration(seconds=5).as_secs())
# It should still not be done yet (not enough time to hit the timeout)
self.assertFalse(wait_d.called)
# Advance time past the 10 second timeout (5 + 6 = 11 seconds) to make the
# `wait_for_multi_writer_stream_token(...)` sleep loop give up.
# `wait_for_stream_token(...)` sleep loop give up.
self.reactor.advance(Duration(seconds=6).as_secs())
# Make sure we gave up waiting and not caught-up (False)
wait_result = self.get_success(wait_d)
self.assertEqual(wait_result, False)
def test_wait_for_multi_writer_stream_token_with_invalid_future_sync_token(
def test_wait_for_stream_token_with_invalid_future_sync_token(
self,
) -> None:
"""
Like `test_wait_for_multi_writer_stream_token_with_future_sync_token`, except we
Like `test_wait_for_stream_token_with_future_sync_token`, except we
give a token that has a stream position ahead of what is in the DB, i.e. its
invalid and we shouldn't wait for the stream to advance (as it may never do so).
@@ -146,19 +147,19 @@ class NotifierTestCase(tests.unittest.HomeserverTestCase):
positions without persisting them in the DB, and so on restart the
stream would get reset back to an older position.
"""
presence_id_gen = self.store.get_presence_stream_id_gen()
# Create a token and advance one of the streams.
current_token = MultiWriterStreamToken.from_generator(presence_id_gen)
token = current_token.copy_and_advance(
MultiWriterStreamToken(stream=current_token.get_max_stream_pos() + 1)
receipt_id_gen = self.store.get_receipts_stream_id_gen()
current_receipt_token = MultiWriterStreamToken.from_generator(receipt_id_gen)
receipt_token = current_receipt_token.copy_and_advance(
MultiWriterStreamToken(
stream=current_receipt_token.get_max_stream_pos() + 1
)
)
token = StreamToken.START.copy_and_advance(StreamKeyType.RECEIPT, receipt_token)
# Function under test
wait_d = defer.ensureDeferred(
self.notifier.wait_for_multi_writer_stream_token(token, presence_id_gen)
)
wait_d = defer.ensureDeferred(self.notifier.wait_for_stream_token(token))
# Expect to fail. We expect callers to sanitize/validate the tokens they give to
# `wait_for_multi_writer_stream_token` to ensure they aren't in the future.
# `wait_for_stream_token` to ensure they aren't in the future.
self.get_failure(wait_d, exc=AssertionError)