diff --git a/synapse/notifier.py b/synapse/notifier.py index d266684a74..6fe66797be 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -51,7 +51,6 @@ from synapse.storage.util.id_generators import ( ) from synapse.streams.config import PaginationConfig from synapse.types import ( - AbstractMultiWriterStreamToken, ISynapseReactor, JsonDict, MultiWriterStreamToken, @@ -896,7 +895,7 @@ class Notifier: async def wait_for_multi_writer_stream_token( self, - token: AbstractMultiWriterStreamToken, + token: MultiWriterStreamToken, id_gen: MultiWriterIdGenerator, ) -> bool: """ @@ -912,7 +911,7 @@ class Notifier: True when this worker has caught up False when we timed out waiting """ - current_token = AbstractMultiWriterStreamToken.from_generator(id_gen) + current_token = MultiWriterStreamToken.from_generator(id_gen) # Return early if we are already caught up if token.is_before_or_eq(current_token): return True @@ -936,7 +935,7 @@ class Notifier: start = self.clock.time_msec() logged = False while True: - current_token = AbstractMultiWriterStreamToken.from_generator(id_gen) + current_token = MultiWriterStreamToken.from_generator(id_gen) if token.is_before_or_eq(current_token): return True diff --git a/tests/test_notifier.py b/tests/test_notifier.py new file mode 100644 index 0000000000..42ff76eab9 --- /dev/null +++ b/tests/test_notifier.py @@ -0,0 +1,163 @@ +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2026 Element Creations Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . + +import logging + +from twisted.internet import defer +from twisted.internet.testing import MemoryReactor + +from synapse.server import HomeServer +from synapse.types import MultiWriterStreamToken +from synapse.util.clock import Clock +from synapse.util.duration import Duration + +import tests.unittest + +logger = logging.getLogger(__name__) + + +class NotifierTestCase(tests.unittest.HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = self.hs.get_datastores().main + self.notifier = self.hs.get_notifier() + + def test_wait_for_multi_writer_stream_token_with_caught_up_token(self) -> None: + """ + Test `wait_for_stream_token` when we receive a token that we are caught up to. + """ + # Create a token + presence_id_gen = self.store.get_presence_stream_id_gen() + token = MultiWriterStreamToken.from_generator(presence_id_gen) + + # Function under test + wait_d = defer.ensureDeferred( + self.notifier.wait_for_multi_writer_stream_token(token, presence_id_gen) + ) + + # Done waiting and caught-up (True) + wait_result = self.get_success(wait_d) + self.assertEqual(wait_result, True) + + def test_wait_for_multi_writer_stream_token_with_future_sync_token(self) -> None: + """ + Test `wait_for_stream_token` when we receive a token that is ahead of our + current token, we'll wait until the stream position advances. + + This can happen if replication streams start lagging, and the client's + previous sync request was serviced by a worker ahead of ours. + """ + # We simulate a lagging stream by getting a stream ID from the ID gen + # and then waiting to mark it as "persisted". + presence_id_gen = self.store.get_presence_stream_id_gen() + ctx_mgr = presence_id_gen.get_next() + stream_id = self.get_success(ctx_mgr.__aenter__()) + + # Create the new token based on the stream ID above. + current_token = MultiWriterStreamToken.from_generator(presence_id_gen) + token = current_token.copy_and_advance(MultiWriterStreamToken(stream=stream_id)) + + # Function under test + wait_d = defer.ensureDeferred( + self.notifier.wait_for_multi_writer_stream_token(token, presence_id_gen) + ) + + # This should block waiting for the stream to update + # + # Advance time a little bit to make the + # `wait_for_multi_writer_stream_token(...)` sleep loop iterate. + self.reactor.advance(Duration(seconds=2).as_secs()) + # It should still not be done yet + self.assertFalse(wait_d.called) + + # Marking the stream ID as persisted should unblock the request. + self.get_success(ctx_mgr.__aexit__(None, None, None)) + + # Advance time to make another iteration of + # `wait_for_multi_writer_stream_token(...)` sleep loop so it sees that we're + # finally caught up now. + self.reactor.advance(Duration(seconds=1).as_secs()) + + # Done waiting and caught-up (True) + wait_result = self.get_success(wait_d) + self.assertEqual(wait_result, True) + + def test_wait_for_multi_writer_stream_token_with_future_sync_token_timeout( + self, + ) -> None: + """ + Test `wait_for_stream_token` when we receive a token that is ahead of our + current token, we'll wait until the stream position advances *until* we hit the + timeout. + + This can happen if replication streams start lagging, and the client's + previous sync request was serviced by a worker ahead of ours. + """ + # We simulate a lagging stream by getting a stream ID from the ID gen + # and then waiting to mark it as "persisted". + presence_id_gen = self.store.get_presence_stream_id_gen() + ctx_mgr = presence_id_gen.get_next() + stream_id = self.get_success(ctx_mgr.__aenter__()) + + # Create the new token based on the stream ID above. + current_token = MultiWriterStreamToken.from_generator(presence_id_gen) + token = current_token.copy_and_advance(MultiWriterStreamToken(stream=stream_id)) + + # Function under test + wait_d = defer.ensureDeferred( + self.notifier.wait_for_multi_writer_stream_token(token, presence_id_gen) + ) + # Advance time a little bit to make the + # `wait_for_multi_writer_stream_token(...)` sleep loop record 0 as the `start` time. + self.reactor.advance(Duration(seconds=0).as_secs()) + + # This should block waiting for the stream to update + # + # Advance time a little bit to make the + # `wait_for_multi_writer_stream_token(...)` sleep loop iterate. + self.reactor.advance(Duration(seconds=5).as_secs()) + # It should still not be done yet (not enough time to hit the timeout) + self.assertFalse(wait_d.called) + # Advance time past the 10 second timeout (5 + 6 = 11 seconds) to make the + # `wait_for_multi_writer_stream_token(...)` sleep loop give up. + self.reactor.advance(Duration(seconds=6).as_secs()) + + # Make sure we gave up waiting and not caught-up (False) + wait_result = self.get_success(wait_d) + self.assertEqual(wait_result, False) + + def test_wait_for_multi_writer_stream_token_with_invalid_future_sync_token( + self, + ) -> None: + """Like the previous test, except we give a token that has a stream + position ahead of what is in the DB, i.e. its invalid and we shouldn't + wait for the stream to advance (as it may never do so). + + This can happen due to older versions of Synapse giving out stream + positions without persisting them in the DB, and so on restart the + stream would get reset back to an older position. + """ + presence_id_gen = self.store.get_presence_stream_id_gen() + + # Create a token and advance one of the streams. + current_token = MultiWriterStreamToken.from_generator(presence_id_gen) + token = current_token.copy_and_advance( + MultiWriterStreamToken(stream=current_token.get_max_stream_pos() + 1) + ) + + # Function under test + wait_d = defer.ensureDeferred( + self.notifier.wait_for_multi_writer_stream_token(token, presence_id_gen) + ) + + # Expect to fail. We expect callers to sanitize/validate the tokens they give to + # `wait_for_multi_writer_stream_token` to ensure they aren't in the future. + self.get_failure(wait_d, exc=AssertionError)