Add tests

This commit is contained in:
Eric Eastwood
2026-04-02 22:22:42 -05:00
parent ff5402f414
commit 154b41d0fe
2 changed files with 166 additions and 4 deletions
+3 -4
View File
@@ -51,7 +51,6 @@ from synapse.storage.util.id_generators import (
)
from synapse.streams.config import PaginationConfig
from synapse.types import (
AbstractMultiWriterStreamToken,
ISynapseReactor,
JsonDict,
MultiWriterStreamToken,
@@ -896,7 +895,7 @@ class Notifier:
async def wait_for_multi_writer_stream_token(
self,
token: AbstractMultiWriterStreamToken,
token: MultiWriterStreamToken,
id_gen: MultiWriterIdGenerator,
) -> bool:
"""
@@ -912,7 +911,7 @@ class Notifier:
True when this worker has caught up
False when we timed out waiting
"""
current_token = AbstractMultiWriterStreamToken.from_generator(id_gen)
current_token = MultiWriterStreamToken.from_generator(id_gen)
# Return early if we are already caught up
if token.is_before_or_eq(current_token):
return True
@@ -936,7 +935,7 @@ class Notifier:
start = self.clock.time_msec()
logged = False
while True:
current_token = AbstractMultiWriterStreamToken.from_generator(id_gen)
current_token = MultiWriterStreamToken.from_generator(id_gen)
if token.is_before_or_eq(current_token):
return True
+163
View File
@@ -0,0 +1,163 @@
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2026 Element Creations Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
import logging
from twisted.internet import defer
from twisted.internet.testing import MemoryReactor
from synapse.server import HomeServer
from synapse.types import MultiWriterStreamToken
from synapse.util.clock import Clock
from synapse.util.duration import Duration
import tests.unittest
logger = logging.getLogger(__name__)
class NotifierTestCase(tests.unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = self.hs.get_datastores().main
self.notifier = self.hs.get_notifier()
def test_wait_for_multi_writer_stream_token_with_caught_up_token(self) -> None:
"""
Test `wait_for_stream_token` when we receive a token that we are caught up to.
"""
# Create a token
presence_id_gen = self.store.get_presence_stream_id_gen()
token = MultiWriterStreamToken.from_generator(presence_id_gen)
# Function under test
wait_d = defer.ensureDeferred(
self.notifier.wait_for_multi_writer_stream_token(token, presence_id_gen)
)
# Done waiting and caught-up (True)
wait_result = self.get_success(wait_d)
self.assertEqual(wait_result, True)
def test_wait_for_multi_writer_stream_token_with_future_sync_token(self) -> None:
"""
Test `wait_for_stream_token` when we receive a token that is ahead of our
current token, we'll wait until the stream position advances.
This can happen if replication streams start lagging, and the client's
previous sync request was serviced by a worker ahead of ours.
"""
# We simulate a lagging stream by getting a stream ID from the ID gen
# and then waiting to mark it as "persisted".
presence_id_gen = self.store.get_presence_stream_id_gen()
ctx_mgr = presence_id_gen.get_next()
stream_id = self.get_success(ctx_mgr.__aenter__())
# Create the new token based on the stream ID above.
current_token = MultiWriterStreamToken.from_generator(presence_id_gen)
token = current_token.copy_and_advance(MultiWriterStreamToken(stream=stream_id))
# Function under test
wait_d = defer.ensureDeferred(
self.notifier.wait_for_multi_writer_stream_token(token, presence_id_gen)
)
# This should block waiting for the stream to update
#
# Advance time a little bit to make the
# `wait_for_multi_writer_stream_token(...)` sleep loop iterate.
self.reactor.advance(Duration(seconds=2).as_secs())
# It should still not be done yet
self.assertFalse(wait_d.called)
# Marking the stream ID as persisted should unblock the request.
self.get_success(ctx_mgr.__aexit__(None, None, None))
# Advance time to make another iteration of
# `wait_for_multi_writer_stream_token(...)` sleep loop so it sees that we're
# finally caught up now.
self.reactor.advance(Duration(seconds=1).as_secs())
# Done waiting and caught-up (True)
wait_result = self.get_success(wait_d)
self.assertEqual(wait_result, True)
def test_wait_for_multi_writer_stream_token_with_future_sync_token_timeout(
self,
) -> None:
"""
Test `wait_for_stream_token` when we receive a token that is ahead of our
current token, we'll wait until the stream position advances *until* we hit the
timeout.
This can happen if replication streams start lagging, and the client's
previous sync request was serviced by a worker ahead of ours.
"""
# We simulate a lagging stream by getting a stream ID from the ID gen
# and then waiting to mark it as "persisted".
presence_id_gen = self.store.get_presence_stream_id_gen()
ctx_mgr = presence_id_gen.get_next()
stream_id = self.get_success(ctx_mgr.__aenter__())
# Create the new token based on the stream ID above.
current_token = MultiWriterStreamToken.from_generator(presence_id_gen)
token = current_token.copy_and_advance(MultiWriterStreamToken(stream=stream_id))
# Function under test
wait_d = defer.ensureDeferred(
self.notifier.wait_for_multi_writer_stream_token(token, presence_id_gen)
)
# Advance time a little bit to make the
# `wait_for_multi_writer_stream_token(...)` sleep loop record 0 as the `start` time.
self.reactor.advance(Duration(seconds=0).as_secs())
# This should block waiting for the stream to update
#
# Advance time a little bit to make the
# `wait_for_multi_writer_stream_token(...)` sleep loop iterate.
self.reactor.advance(Duration(seconds=5).as_secs())
# It should still not be done yet (not enough time to hit the timeout)
self.assertFalse(wait_d.called)
# Advance time past the 10 second timeout (5 + 6 = 11 seconds) to make the
# `wait_for_multi_writer_stream_token(...)` sleep loop give up.
self.reactor.advance(Duration(seconds=6).as_secs())
# Make sure we gave up waiting and not caught-up (False)
wait_result = self.get_success(wait_d)
self.assertEqual(wait_result, False)
def test_wait_for_multi_writer_stream_token_with_invalid_future_sync_token(
self,
) -> None:
"""Like the previous test, except we give a token that has a stream
position ahead of what is in the DB, i.e. its invalid and we shouldn't
wait for the stream to advance (as it may never do so).
This can happen due to older versions of Synapse giving out stream
positions without persisting them in the DB, and so on restart the
stream would get reset back to an older position.
"""
presence_id_gen = self.store.get_presence_stream_id_gen()
# Create a token and advance one of the streams.
current_token = MultiWriterStreamToken.from_generator(presence_id_gen)
token = current_token.copy_and_advance(
MultiWriterStreamToken(stream=current_token.get_max_stream_pos() + 1)
)
# Function under test
wait_d = defer.ensureDeferred(
self.notifier.wait_for_multi_writer_stream_token(token, presence_id_gen)
)
# Expect to fail. We expect callers to sanitize/validate the tokens they give to
# `wait_for_multi_writer_stream_token` to ensure they aren't in the future.
self.get_failure(wait_d, exc=AssertionError)