diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 43a1aad309..7975d10ec3 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -101,8 +101,20 @@ except Exception: # a hook which can be set during testing to assert that we aren't abusing logcontexts. +_IN_LOGCONTEXT_ERROR = False + + def logcontext_error(msg: str) -> None: - logger.warning(msg) + # Guard against re-entrancy: logging can trigger context switches, + # which call start()/stop(), which call logcontext_error() again. + global _IN_LOGCONTEXT_ERROR + if _IN_LOGCONTEXT_ERROR: + return + _IN_LOGCONTEXT_ERROR = True + try: + logger.warning(msg) + finally: + _IN_LOGCONTEXT_ERROR = False # get an id for the current thread. diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index bee8efd4ba..d8bbcd7caa 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -305,6 +305,12 @@ class _EventPeristenceQueue(Generic[_PersistResult]): self._event_persist_queues[room_id] = remaining_queue self._currently_persisting_rooms.discard(room_id) + # If new items were added while we were running, kick off + # another processing round. Without this, items enqueued + # while we hold _currently_persisting_rooms would be orphaned. + if room_id in self._event_persist_queues: + self._handle_queue(room_id) + # set handle_queue_loop off in the background self.hs.run_as_background_process("persist_events", handle_queue_loop) diff --git a/synapse/storage/native_database.py b/synapse/storage/native_database.py index a7362e3b21..aff7a9cff8 100644 --- a/synapse/storage/native_database.py +++ b/synapse/storage/native_database.py @@ -84,13 +84,23 @@ class NativeConnectionPool: self._thread_local = threading.local() # For in-memory SQLite or when an initial connection is provided, - # use a shared connection (not thread-local) + # use a shared connection (not thread-local). + # When using a shared connection, limit to 1 worker to avoid + # concurrent access deadlocks on the same SQLite connection. self._shared_conn: Connection | None = initial_connection if db_path == ":memory:" or db_path == "": self._use_shared_conn = True else: self._use_shared_conn = initial_connection is not None + if self._use_shared_conn and max_workers > 1: + # Recreate executor with single worker + self._executor.shutdown(wait=False) + self._executor = ThreadPoolExecutor( + max_workers=1, + thread_name_prefix=f"synapse-db-{db_config.name}", + ) + self._closed = False def _get_connection(self) -> Connection: diff --git a/tests/__init__.py b/tests/__init__.py index 3d20f1f0bb..fcd2134c89 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -18,13 +18,3 @@ # [This file includes modifications made by New Vector Limited] # # - -# Set up an asyncio event loop for tests. -import asyncio as _asyncio -import nest_asyncio as _nest_asyncio - -_test_asyncio_loop = _asyncio.new_event_loop() -_asyncio.set_event_loop(_test_asyncio_loop) - -# Allow nested event loop calls (run_until_complete inside run_until_complete) -_nest_asyncio.apply(_test_asyncio_loop) diff --git a/tests/http/__init__.py b/tests/http/__init__.py index 017e5e31f3..8d5b5c080e 100644 --- a/tests/http/__init__.py +++ b/tests/http/__init__.py @@ -21,18 +21,18 @@ import os.path import subprocess from incremental import Version -try: - from zope.interface import implementer -except ImportError: - pass - -try: - import twisted -except ImportError: - pass from OpenSSL import SSL from OpenSSL.SSL import Connection + +# The TLS test utilities (TestServerTLSConnectionFactory, wrap_server_factory_for_tls, +# get_test_https_policy) require Twisted's TLS infrastructure. They are only used +# by federation/HTTP-level protocol tests, not by REST API tests. +# Guard everything so that importing tests.http doesn't fail without Twisted. +_HAS_TWISTED_TLS = False try: + import twisted + from incremental import Version + from zope.interface import implementer from twisted.internet.address import IPv4Address from twisted.internet.interfaces import ( IOpenSSLServerConnectionCreator, @@ -43,22 +43,20 @@ try: from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401 from twisted.web.iweb import IPolicyForHTTPS # noqa: F401 + _HAS_TWISTED_TLS = True except ImportError: pass -def get_test_https_policy() -> BrowserLikePolicyForHTTPS: - """Get a test IPolicyForHTTPS which trusts the test CA cert - - Returns: - IPolicyForHTTPS - """ - ca_file = get_test_ca_cert_file() - with open(ca_file) as stream: - content = stream.read() - cert = Certificate.loadPEM(content) - trust_root = trustRootFromCertificates([cert]) - return BrowserLikePolicyForHTTPS(trustRoot=trust_root) +if _HAS_TWISTED_TLS: + def get_test_https_policy() -> "BrowserLikePolicyForHTTPS": + """Get a test IPolicyForHTTPS which trusts the test CA cert.""" + ca_file = get_test_ca_cert_file() + with open(ca_file) as stream: + content = stream.read() + cert = Certificate.loadPEM(content) + trust_root = trustRootFromCertificates([cert]) + return BrowserLikePolicyForHTTPS(trustRoot=trust_root) def get_test_ca_cert_file() -> str: @@ -154,52 +152,41 @@ def create_test_cert_file(sanlist: list[bytes]) -> str: return cert_filename -@implementer(IOpenSSLServerConnectionCreator) -class TestServerTLSConnectionFactory: - """An SSL connection creator which returns connections which present a certificate - signed by our test CA.""" +if _HAS_TWISTED_TLS: + @implementer(IOpenSSLServerConnectionCreator) + class TestServerTLSConnectionFactory: + """An SSL connection creator which returns connections which present a certificate + signed by our test CA.""" - def __init__(self, sanlist: list[bytes]): - """ - Args: - sanlist: a list of subjectAltName values for the cert - """ - self._cert_file = create_test_cert_file(sanlist) + def __init__(self, sanlist: list[bytes]): + self._cert_file = create_test_cert_file(sanlist) - def serverConnectionForTLS(self, tlsProtocol: TLSMemoryBIOProtocol) -> Connection: - ctx = SSL.Context(SSL.SSLv23_METHOD) - ctx.use_certificate_file(self._cert_file) - ctx.use_privatekey_file(get_test_key_file()) - return Connection(ctx, None) + def serverConnectionForTLS(self, tlsProtocol: "TLSMemoryBIOProtocol") -> Connection: + ctx = SSL.Context(SSL.SSLv23_METHOD) + ctx.use_certificate_file(self._cert_file) + ctx.use_privatekey_file(get_test_key_file()) + return Connection(ctx, None) + def wrap_server_factory_for_tls( + factory: "IProtocolFactory", clock: "IReactorTime", sanlist: list[bytes] + ) -> "TLSMemoryBIOFactory": + """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory.""" + connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist) + if twisted.version <= Version("Twisted", 23, 8, 0): + return TLSMemoryBIOFactory( + connection_creator, isClient=False, wrappedFactory=factory + ) + else: + return TLSMemoryBIOFactory( + connection_creator, isClient=False, wrappedFactory=factory, clock=clock + ) -def wrap_server_factory_for_tls( - factory: IProtocolFactory, clock: IReactorTime, sanlist: list[bytes] -) -> TLSMemoryBIOFactory: - """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory - - The resultant factory will create a TLS server which presents a certificate - signed by our test CA, valid for the domains in `sanlist` - - Args: - factory: protocol factory to wrap - sanlist: list of domains the cert should be valid for - - Returns: - interfaces.IProtocolFactory - """ - connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist) - # Twisted > 23.8.0 has a different API that accepts a clock. - if twisted.version <= Version("Twisted", 23, 8, 0): - return TLSMemoryBIOFactory( - connection_creator, isClient=False, wrappedFactory=factory - ) - else: - return TLSMemoryBIOFactory( - connection_creator, isClient=False, wrappedFactory=factory, clock=clock - ) - - -# A dummy address, useful for tests that use FakeTransport and don't care about where -# packets are going to/coming from. -dummy_address = IPv4Address("TCP", "127.0.0.1", 80) + # A dummy address, useful for tests that use FakeTransport + dummy_address = IPv4Address("TCP", "127.0.0.1", 80) +else: + # Dummy address that doesn't require Twisted + class _DummyAddress: + type = "TCP" + host = "127.0.0.1" + port = 80 + dummy_address = _DummyAddress() # type: ignore[assignment] diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index 2c8474fbe2..f6e2238992 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -18,6 +18,7 @@ # [This file includes modifications made by New Vector Limited] # # +import asyncio import time import urllib.parse from typing import ( @@ -183,8 +184,8 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): register.register_servlets, ] - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.hs = self.setup_test_homeserver() + async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.hs = await self.setup_test_homeserver() self.hs.config.registration.enable_registration = True self.hs.config.registration.registrations_require_3pid = [] self.hs.config.registration.auto_join_rooms = [] @@ -340,11 +341,11 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 403, msg=channel.result) @override_config({"session_lifetime": "24h"}) - def test_soft_logout(self) -> None: - self.register_user("kermit", "monkey") + async def test_soft_logout(self) -> None: + await self.register_user("kermit", "monkey") # we shouldn't be able to make requests without an access token - channel = self.make_request(b"GET", TEST_URL) + channel = await self.make_request(b"GET", TEST_URL) self.assertEqual(channel.code, 401, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN") @@ -354,21 +355,22 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "monkey", } - channel = self.make_request(b"POST", LOGIN_URL, params) + channel = await self.make_request(b"POST", LOGIN_URL, params) self.assertEqual(channel.code, 200, channel.result) access_token = channel.json_body["access_token"] device_id = channel.json_body["device_id"] # we should now be able to make requests with the access token - channel = self.make_request(b"GET", TEST_URL, access_token=access_token) + channel = await self.make_request(b"GET", TEST_URL, access_token=access_token) self.assertEqual(channel.code, 200, channel.result) - # time passes - self.reactor.advance(24 * 3600) + # time passes — advance slightly past the session lifetime + # so the token is expired (valid_until_ms < clock.time_msec()) + self.reactor.advance(24 * 3600 + 1) # ... and we should be soft-logouted - channel = self.make_request(b"GET", TEST_URL, access_token=access_token) + channel = await self.make_request(b"GET", TEST_URL, access_token=access_token) self.assertEqual(channel.code, 401, channel.result) self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["soft_logout"], True) @@ -378,28 +380,28 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # # we now log in as a different device - access_token_2 = self.login("kermit", "monkey") + access_token_2 = await self.login("kermit", "monkey") # more requests with the expired token should still return a soft-logout self.reactor.advance(3600) - channel = self.make_request(b"GET", TEST_URL, access_token=access_token) + channel = await self.make_request(b"GET", TEST_URL, access_token=access_token) self.assertEqual(channel.code, 401, channel.result) self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["soft_logout"], True) # ... but if we delete that device, it will be a proper logout - self._delete_device(access_token_2, "kermit", "monkey", device_id) + await self._delete_device(access_token_2, "kermit", "monkey", device_id) - channel = self.make_request(b"GET", TEST_URL, access_token=access_token) + channel = await self.make_request(b"GET", TEST_URL, access_token=access_token) self.assertEqual(channel.code, 401, channel.result) self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["soft_logout"], False) - def _delete_device( + async def _delete_device( self, access_token: str, user_id: str, password: str, device_id: str ) -> None: """Perform the UI-Auth to delete a device""" - channel = self.make_request( + channel = await self.make_request( b"DELETE", "devices/" + device_id, access_token=access_token ) self.assertEqual(channel.code, 401, channel.result) @@ -419,7 +421,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "session": channel.json_body["session"], } - channel = self.make_request( + channel = await self.make_request( b"DELETE", "devices/" + device_id, access_token=access_token, diff --git a/tests/server.py b/tests/server.py index e44564c3db..ce5d5cadf6 100644 --- a/tests/server.py +++ b/tests/server.py @@ -338,36 +338,29 @@ class FakeChannel: def transport(self) -> "FakeChannel": return self - def await_result(self, timeout_ms: int = 1000) -> None: + def await_result(self, timeout_ms: int = 5000) -> None: """ - Wait until the request is finished. + Wait until the request is finished by driving the asyncio event loop. """ import asyncio import time as _time - - deadline = _time.monotonic() + timeout_ms / 1000.0 - self._reactor.run() + from tests import pump_loop loop = asyncio.get_event_loop() + deadline = _time.monotonic() + timeout_ms / 1000.0 while not self.is_finished(): if _time.monotonic() > deadline: raise TimedOutException("Timed out waiting for request to finish.") - # Advance fake time by a small amount per iteration. - # This fires pending sleeps (e.g., ratelimit pauses) while - # keeping time advancement predictable. The old Twisted - # MemoryReactorClock.advance(0.1) did the same. + # Advance fake time OUTSIDE the event loop drive self._reactor.advance(0.1) - # Drive asyncio event loop for DB operations, task completions, etc. - if not loop.is_closed(): - async def _drain() -> None: - """Run multiple event loop ticks to drain pending work.""" - for _ in range(20): - await asyncio.sleep(0.001) - if self._clock is not None: - self._clock.advance(0.0) - loop.run_until_complete(_drain()) + + # Drive the event loop briefly + try: + pump_loop(loop, timeout=0.05, tick=0.001) + except TimeoutError: + pass # outer loop will retry def extract_cookies(self, cookies: MutableMapping[str, str]) -> None: """Process the contents of any Set-Cookie headers in the response @@ -414,7 +407,7 @@ class FakeSite: return self._resource -def make_request( +async def make_request( reactor: MemoryReactorClock, site: Site | FakeSite, method: bytes | str, @@ -544,8 +537,11 @@ def make_request( if root_resource is not None: req.dispatch(root_resource) - if await_result: - channel.await_result() + if await_result and req.render_deferred is not None: + # Await the handler task directly — the event loop drives + # all concurrent tasks (DB operations, background processes) + # naturally without needing nest_asyncio. + await req.render_deferred return channel diff --git a/tests/unittest.py b/tests/unittest.py index 9862540600..0816d2f83e 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -190,130 +190,38 @@ def make_homeserver_config_obj(config: dict[str, Any]) -> HomeServerConfig: return deepcopy_config(config_obj) -class TestCase(_stdlib_unittest.TestCase): - """A subclass of stdlib's TestCase which looks for 'loglevel' - attributes on both itself and its individual test methods, to override the - root logger's logging level while that test (case|method) runs.""" +class TestCase(_stdlib_unittest.IsolatedAsyncioTestCase): + """A subclass of IsolatedAsyncioTestCase that provides Synapse test utilities. - def __init__(self, methodName: str = "runTest"): - super().__init__(methodName) + Using IsolatedAsyncioTestCase means: + - Each test gets its own asyncio event loop, managed by the framework. + - Test methods can be ``async def`` and use ``await`` directly. + - setUp/tearDown can be async (asyncSetUp/asyncTearDown). + - No nest_asyncio needed — the event loop drives all tasks naturally. + """ - method = getattr(self, methodName, None) + def setUp(self) -> None: + # if we're not starting in the sentinel logcontext, then to be honest + # all future bets are off. + if current_context(): + self.fail( + "Test starting with non-sentinel logging context %s" + % (current_context(),) + ) - level = getattr(method, "loglevel", getattr(self, "loglevel", None)) + # Disable GC for duration of test (re-enabled in tearDown). + gc.disable() - @around(self) - def setUp(orig: Callable[[], R]) -> R: - # Set up a fresh asyncio event loop for each test - import asyncio as _asyncio - import nest_asyncio as _nest_asyncio - self._asyncio_loop = _asyncio.new_event_loop() - _asyncio.set_event_loop(self._asyncio_loop) - _nest_asyncio.apply(self._asyncio_loop) + self.addCleanup(setup_awaitable_errors()) - # if we're not starting in the sentinel logcontext, then to be honest - # all future bets are off. - if current_context(): - self.fail( - "Test starting with non-sentinel logging context %s" - % (current_context(),) - ) - - # Disable GC for duration of test. See below for why. - gc.disable() - - old_level = logging.getLogger().level - if level is not None and old_level != level: - - @around(self) - def tearDown(orig: Callable[[], R]) -> R: - ret = orig() - logging.getLogger().setLevel(old_level) - return ret - - logging.getLogger().setLevel(level) - - # Trial messes with the warnings configuration, thus this has to be - # done in the context of an individual TestCase. - self.addCleanup(setup_awaitable_errors()) - - return orig() - - # We want to force a GC to workaround problems with deferreds leaking - # logcontexts when they are GCed (see the logcontext docs). - # - # The easiest way to do this would be to do a full GC after each test - # run, but that is very expensive. Instead, we disable GC (above) for - # the duration of the test and only run a gen-0 GC, which is a lot - # quicker. This doesn't clean up everything, since the TestCase - # instance still holds references to objects created during the test, - # such as HomeServers, so we do a full GC every so often. - - @around(self) - def tearDown(orig: Callable[[], R]) -> R: - ret = orig() - - # Cancel any remaining asyncio tasks from this test - import asyncio as _asyncio - loop = _asyncio.get_event_loop() - if not loop.is_closed(): - pending = [t for t in _asyncio.all_tasks(loop) if not t.done()] - for t in pending: - t.cancel() - if pending: - loop.run_until_complete(_asyncio.sleep(0)) - - gc.collect(0) - # Run a full GC every 50 gen-0 GCs. - gen0_stats = gc.get_stats()[0] - gen0_collections = gen0_stats["collections"] - if gen0_collections % 50 == 0: - gc.collect() - gc.enable() - set_current_context(SENTINEL_CONTEXT) - - return ret - - def _callTestMethod(self, method: Callable[[], Any]) -> None: - """Override to handle async test methods. - - Twisted's trial auto-detected async test methods and wrapped them - with ensureDeferred. We replicate that behavior here. - """ - import inspect - - result = method() - if inspect.isawaitable(result): - from twisted.internet import defer, reactor - - d = defer.ensureDeferred(result) - - if not d.called: - finished: list[Any] = [] - d.addBoth(finished.append) - - if hasattr(self, "reactor"): - for _ in range(1000): - if finished: - break - self.reactor.advance(0.1) - else: - # Drive the global Twisted reactor - for _ in range(10000): - if finished: - break - reactor.runUntilCurrent() # type: ignore[attr-defined] - try: - reactor.doIteration(0.001) # type: ignore[attr-defined] - except NotImplementedError: - import time - time.sleep(0.001) - - if not finished: - self.fail("Async test method did not complete") - - if isinstance(finished[0], Failure): - finished[0].raiseException() + def tearDown(self) -> None: + gc.collect(0) + gen0_stats = gc.get_stats()[0] + gen0_collections = gen0_stats["collections"] + if gen0_collections % 50 == 0: + gc.collect() + gc.enable() + set_current_context(SENTINEL_CONTEXT) def mktemp(self) -> str: """Return a unique temporary path for test use. @@ -583,21 +491,15 @@ class HomeserverTestCase(TestCase): method = getattr(self, methodName, None) self._extra_config = getattr(method, "_extra_config", None) if method else None - def setUp(self) -> None: + async def asyncSetUp(self) -> None: """ Set up the TestCase by calling the homeserver constructor, optionally hijacking the authentication system to return a fixed user, and then calling the prepare function. """ - # Set up an asyncio event loop so that asyncio primitives (Future, Event, - # create_task, etc.) work even when driven by Twisted's MemoryReactorClock. - import asyncio - self._asyncio_loop = asyncio.new_event_loop() - asyncio.set_event_loop(self._asyncio_loop) - # We need to share the reactor between the homeserver and all of our test utils. self.reactor, self.clock = get_clock() - self.hs = self.make_homeserver(self.reactor, self.clock) + self.hs = await self.make_homeserver(self.reactor, self.clock) self.hs.get_datastores().main.tests_allow_no_chain_cover_index = False @@ -629,7 +531,7 @@ class HomeserverTestCase(TestCase): self.helper = RestHelper( self.hs, - checked_cast(MemoryReactorClock, self.hs.get_reactor()), + self.hs.get_reactor(), self.site, getattr(self, "user_id", None), ) @@ -640,13 +542,11 @@ class HomeserverTestCase(TestCase): token = "some_fake_token" # We need a valid token ID to satisfy foreign key constraints. - token_id = self.get_success( - self.hs.get_datastores().main.add_access_token_to_user( - self.helper.auth_user_id, - token, - None, - None, - ) + token_id = await self.hs.get_datastores().main.add_access_token_to_user( + self.helper.auth_user_id, + token, + None, + None, ) # This has to be a function and not just a Mock, because @@ -699,7 +599,7 @@ class HomeserverTestCase(TestCase): store.db_pool.updates.do_next_background_update(False), by=0.1 ) - def make_homeserver( + async def make_homeserver( self, reactor: ThreadedMemoryReactorClock, clock: Clock ) -> HomeServer: """ @@ -714,7 +614,7 @@ class HomeserverTestCase(TestCase): Function to be overridden in subclasses. """ - hs = self.setup_test_homeserver(reactor=reactor, clock=clock) + hs = await self.setup_test_homeserver(reactor=reactor, clock=clock) return hs def create_test_resource(self) -> Resource: @@ -773,7 +673,7 @@ class HomeserverTestCase(TestCase): Function to optionally be overridden in subclasses. """ - def make_request( + async def make_request( self, method: bytes | str, path: bytes | str, @@ -819,7 +719,7 @@ class HomeserverTestCase(TestCase): Returns: The FakeChannel object which stores the result of the request. """ - return make_request( + return await make_request( self.reactor, self.site, method, @@ -837,7 +737,7 @@ class HomeserverTestCase(TestCase): clock=self.clock, ) - def setup_test_homeserver( + async def setup_test_homeserver( self, server_name: str | None = None, config: JsonDict | None = None, @@ -879,10 +779,6 @@ class HomeserverTestCase(TestCase): # construct a homeserver with a matching name. server_name = config_obj.server.server_name - async def run_bg_updates() -> None: - with LoggingContext(name="run_bg_updates", server_name=server_name): - self.get_success(stor.db_pool.updates.run_background_updates(False)) - hs = setup_test_homeserver( cleanup_func=self.addCleanup, server_name=server_name, @@ -895,73 +791,49 @@ class HomeserverTestCase(TestCase): # Run the database background updates, when running against "master". if hs.__class__.__name__ == "TestHomeServer": - self.get_success(run_bg_updates()) + with LoggingContext(name="run_bg_updates", server_name=server_name): + await stor.db_pool.updates.run_background_updates(False) return hs - def pump(self, by: float = 0.0) -> None: - """Advance fake time and drive the asyncio event loop. + async def pump(self, by: float = 0.0) -> None: + """Advance fake time and yield to the event loop. ``reactor.advance()`` delegates to ``clock.advance()``, so calling either one advances the same fake-time source. """ - import asyncio - - loop = asyncio.get_event_loop() - - # Advance fake time (fires pending sleeps) AND drain callFromThread + # Advance fake time (fires pending sleeps and callFromThread) self.reactor.advance(by) + # Yield to the event loop so callbacks can run + await asyncio.sleep(0) - # Process asyncio callbacks (executor results, task completions, etc.) - if not loop.is_closed(): - loop.run_until_complete(asyncio.sleep(0)) - - def get_success(self, d: Awaitable[TV], by: float = 0.0) -> TV: - import asyncio - - loop = asyncio.get_event_loop() - - # Pump the fake reactor first if time advancement is needed + async def get_success(self, d: Awaitable[TV], by: float = 0.0) -> TV: + """Await an awaitable, optionally advancing fake time first.""" if by > 0: - self.reactor.pump([by] * 100) + self.reactor.advance(by) + await asyncio.sleep(0) + return await d # type: ignore[misc] - # Run the awaitable to completion on the asyncio loop. - # nest_asyncio allows this even if the loop is already running. - return loop.run_until_complete(d) # type: ignore[arg-type] - - def get_failure( + async def get_failure( self, d: Awaitable[Any], exc: type[_ExcType], by: float = 0.0 ) -> Any: - """ - Run an awaitable and get a Failure from it. - """ - import asyncio + """Await an awaitable and expect it to raise.""" + if by > 0: + self.reactor.advance(by) + await asyncio.sleep(0) + try: + await d + self.fail(f"Expected {exc}, but awaitable succeeded") + except BaseException as e: + if isinstance(e, exc): + return e + raise - future = asyncio.ensure_future(d) # type: ignore[arg-type] + async def get_success_or_raise(self, d: Awaitable[TV], by: float = 0.0) -> TV: + """Await an awaitable and return result or raise exception.""" + return await self.get_success(d, by=by) - error_holder: list[BaseException] = [] - - def _on_done(f: asyncio.Future) -> None: # type: ignore[type-arg] - try: - f.result() - except BaseException as e: - error_holder.append(e) - - future.add_done_callback(_on_done) - self.pump(by) - - if error_holder and isinstance(error_holder[0], exc): - return Failure(error_holder[0]) - elif error_holder: - self.fail(f"Expected {exc}, got {type(error_holder[0])}: {error_holder[0]}") - else: - self.fail("Expected failure, but awaitable succeeded") - - def get_success_or_raise(self, d: Awaitable[TV], by: float = 0.0) -> TV: - """Drive awaitable to completion and return result or raise exception.""" - return self.get_success(d, by=by) - - def register_user( + async def register_user( self, username: str, password: str, @@ -970,20 +842,11 @@ class HomeserverTestCase(TestCase): ) -> str: """ Register a user. Requires the Admin API be registered. - - Args: - username: The user part of the new user. - password: The password of the new user. - admin: Whether the user should be created as an admin or not. - displayname: The displayname of the new user. - - Returns: - The MXID of the new user. """ self.hs.config.registration.registration_shared_secret = "shared" # Create the user - channel = self.make_request("GET", "/_synapse/admin/v1/register") + channel = await self.make_request("GET", "/_synapse/admin/v1/register") self.assertEqual(channel.code, 200, msg=channel.result) nonce = channel.json_body["nonce"] @@ -1006,13 +869,13 @@ class HomeserverTestCase(TestCase): "mac": want_mac_digest, "inhibit_login": True, } - channel = self.make_request("POST", "/_synapse/admin/v1/register", body) + channel = await self.make_request("POST", "/_synapse/admin/v1/register", body) self.assertEqual(channel.code, 200, channel.json_body) user_id = channel.json_body["user_id"] return user_id - def register_appservice_user( + async def register_appservice_user( self, username: str, appservice_token: str, @@ -1031,7 +894,7 @@ class HomeserverTestCase(TestCase): Returns: The MXID of the new user, the device ID of the new user's first device. """ - channel = self.make_request( + channel = await self.make_request( "POST", "/_matrix/client/r0/register", { @@ -1044,7 +907,7 @@ class HomeserverTestCase(TestCase): self.assertEqual(channel.code, 200, channel.json_body) return channel.json_body["user_id"], channel.json_body.get("device_id") - def login( + async def login( self, username: str, password: str, @@ -1073,7 +936,7 @@ class HomeserverTestCase(TestCase): if additional_request_fields: body.update(additional_request_fields) - channel = self.make_request( + channel = await self.make_request( "POST", "/_matrix/client/r0/login", body,