Switch to IsolatedAsyncioTestCase to fix deadlocks when doing db

This commit is contained in:
Matthew Hodgson
2026-03-23 09:19:45 +00:00
parent 0d4574ef8d
commit ac2fb5cacd
8 changed files with 193 additions and 327 deletions

View File

@@ -101,8 +101,20 @@ except Exception:
# a hook which can be set during testing to assert that we aren't abusing logcontexts.
_IN_LOGCONTEXT_ERROR = False
def logcontext_error(msg: str) -> None:
logger.warning(msg)
# Guard against re-entrancy: logging can trigger context switches,
# which call start()/stop(), which call logcontext_error() again.
global _IN_LOGCONTEXT_ERROR
if _IN_LOGCONTEXT_ERROR:
return
_IN_LOGCONTEXT_ERROR = True
try:
logger.warning(msg)
finally:
_IN_LOGCONTEXT_ERROR = False
# get an id for the current thread.

View File

@@ -305,6 +305,12 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
self._event_persist_queues[room_id] = remaining_queue
self._currently_persisting_rooms.discard(room_id)
# If new items were added while we were running, kick off
# another processing round. Without this, items enqueued
# while we hold _currently_persisting_rooms would be orphaned.
if room_id in self._event_persist_queues:
self._handle_queue(room_id)
# set handle_queue_loop off in the background
self.hs.run_as_background_process("persist_events", handle_queue_loop)

View File

@@ -84,13 +84,23 @@ class NativeConnectionPool:
self._thread_local = threading.local()
# For in-memory SQLite or when an initial connection is provided,
# use a shared connection (not thread-local)
# use a shared connection (not thread-local).
# When using a shared connection, limit to 1 worker to avoid
# concurrent access deadlocks on the same SQLite connection.
self._shared_conn: Connection | None = initial_connection
if db_path == ":memory:" or db_path == "":
self._use_shared_conn = True
else:
self._use_shared_conn = initial_connection is not None
if self._use_shared_conn and max_workers > 1:
# Recreate executor with single worker
self._executor.shutdown(wait=False)
self._executor = ThreadPoolExecutor(
max_workers=1,
thread_name_prefix=f"synapse-db-{db_config.name}",
)
self._closed = False
def _get_connection(self) -> Connection:

View File

@@ -18,13 +18,3 @@
# [This file includes modifications made by New Vector Limited]
#
#
# Set up an asyncio event loop for tests.
import asyncio as _asyncio
import nest_asyncio as _nest_asyncio
_test_asyncio_loop = _asyncio.new_event_loop()
_asyncio.set_event_loop(_test_asyncio_loop)
# Allow nested event loop calls (run_until_complete inside run_until_complete)
_nest_asyncio.apply(_test_asyncio_loop)

View File

@@ -21,18 +21,18 @@ import os.path
import subprocess
from incremental import Version
try:
from zope.interface import implementer
except ImportError:
pass
try:
import twisted
except ImportError:
pass
from OpenSSL import SSL
from OpenSSL.SSL import Connection
# The TLS test utilities (TestServerTLSConnectionFactory, wrap_server_factory_for_tls,
# get_test_https_policy) require Twisted's TLS infrastructure. They are only used
# by federation/HTTP-level protocol tests, not by REST API tests.
# Guard everything so that importing tests.http doesn't fail without Twisted.
_HAS_TWISTED_TLS = False
try:
import twisted
from incremental import Version
from zope.interface import implementer
from twisted.internet.address import IPv4Address
from twisted.internet.interfaces import (
IOpenSSLServerConnectionCreator,
@@ -43,22 +43,20 @@ try:
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401
from twisted.web.iweb import IPolicyForHTTPS # noqa: F401
_HAS_TWISTED_TLS = True
except ImportError:
pass
def get_test_https_policy() -> BrowserLikePolicyForHTTPS:
"""Get a test IPolicyForHTTPS which trusts the test CA cert
Returns:
IPolicyForHTTPS
"""
ca_file = get_test_ca_cert_file()
with open(ca_file) as stream:
content = stream.read()
cert = Certificate.loadPEM(content)
trust_root = trustRootFromCertificates([cert])
return BrowserLikePolicyForHTTPS(trustRoot=trust_root)
if _HAS_TWISTED_TLS:
def get_test_https_policy() -> "BrowserLikePolicyForHTTPS":
"""Get a test IPolicyForHTTPS which trusts the test CA cert."""
ca_file = get_test_ca_cert_file()
with open(ca_file) as stream:
content = stream.read()
cert = Certificate.loadPEM(content)
trust_root = trustRootFromCertificates([cert])
return BrowserLikePolicyForHTTPS(trustRoot=trust_root)
def get_test_ca_cert_file() -> str:
@@ -154,52 +152,41 @@ def create_test_cert_file(sanlist: list[bytes]) -> str:
return cert_filename
@implementer(IOpenSSLServerConnectionCreator)
class TestServerTLSConnectionFactory:
"""An SSL connection creator which returns connections which present a certificate
signed by our test CA."""
if _HAS_TWISTED_TLS:
@implementer(IOpenSSLServerConnectionCreator)
class TestServerTLSConnectionFactory:
"""An SSL connection creator which returns connections which present a certificate
signed by our test CA."""
def __init__(self, sanlist: list[bytes]):
"""
Args:
sanlist: a list of subjectAltName values for the cert
"""
self._cert_file = create_test_cert_file(sanlist)
def __init__(self, sanlist: list[bytes]):
self._cert_file = create_test_cert_file(sanlist)
def serverConnectionForTLS(self, tlsProtocol: TLSMemoryBIOProtocol) -> Connection:
ctx = SSL.Context(SSL.SSLv23_METHOD)
ctx.use_certificate_file(self._cert_file)
ctx.use_privatekey_file(get_test_key_file())
return Connection(ctx, None)
def serverConnectionForTLS(self, tlsProtocol: "TLSMemoryBIOProtocol") -> Connection:
ctx = SSL.Context(SSL.SSLv23_METHOD)
ctx.use_certificate_file(self._cert_file)
ctx.use_privatekey_file(get_test_key_file())
return Connection(ctx, None)
def wrap_server_factory_for_tls(
factory: "IProtocolFactory", clock: "IReactorTime", sanlist: list[bytes]
) -> "TLSMemoryBIOFactory":
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory."""
connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
if twisted.version <= Version("Twisted", 23, 8, 0):
return TLSMemoryBIOFactory(
connection_creator, isClient=False, wrappedFactory=factory
)
else:
return TLSMemoryBIOFactory(
connection_creator, isClient=False, wrappedFactory=factory, clock=clock
)
def wrap_server_factory_for_tls(
factory: IProtocolFactory, clock: IReactorTime, sanlist: list[bytes]
) -> TLSMemoryBIOFactory:
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
The resultant factory will create a TLS server which presents a certificate
signed by our test CA, valid for the domains in `sanlist`
Args:
factory: protocol factory to wrap
sanlist: list of domains the cert should be valid for
Returns:
interfaces.IProtocolFactory
"""
connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
# Twisted > 23.8.0 has a different API that accepts a clock.
if twisted.version <= Version("Twisted", 23, 8, 0):
return TLSMemoryBIOFactory(
connection_creator, isClient=False, wrappedFactory=factory
)
else:
return TLSMemoryBIOFactory(
connection_creator, isClient=False, wrappedFactory=factory, clock=clock
)
# A dummy address, useful for tests that use FakeTransport and don't care about where
# packets are going to/coming from.
dummy_address = IPv4Address("TCP", "127.0.0.1", 80)
# A dummy address, useful for tests that use FakeTransport
dummy_address = IPv4Address("TCP", "127.0.0.1", 80)
else:
# Dummy address that doesn't require Twisted
class _DummyAddress:
type = "TCP"
host = "127.0.0.1"
port = 80
dummy_address = _DummyAddress() # type: ignore[assignment]

View File

@@ -18,6 +18,7 @@
# [This file includes modifications made by New Vector Limited]
#
#
import asyncio
import time
import urllib.parse
from typing import (
@@ -183,8 +184,8 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
register.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.hs = self.setup_test_homeserver()
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.hs = await self.setup_test_homeserver()
self.hs.config.registration.enable_registration = True
self.hs.config.registration.registrations_require_3pid = []
self.hs.config.registration.auto_join_rooms = []
@@ -340,11 +341,11 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 403, msg=channel.result)
@override_config({"session_lifetime": "24h"})
def test_soft_logout(self) -> None:
self.register_user("kermit", "monkey")
async def test_soft_logout(self) -> None:
await self.register_user("kermit", "monkey")
# we shouldn't be able to make requests without an access token
channel = self.make_request(b"GET", TEST_URL)
channel = await self.make_request(b"GET", TEST_URL)
self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN")
@@ -354,21 +355,22 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
channel = self.make_request(b"POST", LOGIN_URL, params)
channel = await self.make_request(b"POST", LOGIN_URL, params)
self.assertEqual(channel.code, 200, channel.result)
access_token = channel.json_body["access_token"]
device_id = channel.json_body["device_id"]
# we should now be able to make requests with the access token
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
channel = await self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, 200, channel.result)
# time passes
self.reactor.advance(24 * 3600)
# time passes — advance slightly past the session lifetime
# so the token is expired (valid_until_ms < clock.time_msec())
self.reactor.advance(24 * 3600 + 1)
# ... and we should be soft-logouted
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
channel = await self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True)
@@ -378,28 +380,28 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
#
# we now log in as a different device
access_token_2 = self.login("kermit", "monkey")
access_token_2 = await self.login("kermit", "monkey")
# more requests with the expired token should still return a soft-logout
self.reactor.advance(3600)
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
channel = await self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True)
# ... but if we delete that device, it will be a proper logout
self._delete_device(access_token_2, "kermit", "monkey", device_id)
await self._delete_device(access_token_2, "kermit", "monkey", device_id)
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
channel = await self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], False)
def _delete_device(
async def _delete_device(
self, access_token: str, user_id: str, password: str, device_id: str
) -> None:
"""Perform the UI-Auth to delete a device"""
channel = self.make_request(
channel = await self.make_request(
b"DELETE", "devices/" + device_id, access_token=access_token
)
self.assertEqual(channel.code, 401, channel.result)
@@ -419,7 +421,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"session": channel.json_body["session"],
}
channel = self.make_request(
channel = await self.make_request(
b"DELETE",
"devices/" + device_id,
access_token=access_token,

View File

@@ -338,36 +338,29 @@ class FakeChannel:
def transport(self) -> "FakeChannel":
return self
def await_result(self, timeout_ms: int = 1000) -> None:
def await_result(self, timeout_ms: int = 5000) -> None:
"""
Wait until the request is finished.
Wait until the request is finished by driving the asyncio event loop.
"""
import asyncio
import time as _time
deadline = _time.monotonic() + timeout_ms / 1000.0
self._reactor.run()
from tests import pump_loop
loop = asyncio.get_event_loop()
deadline = _time.monotonic() + timeout_ms / 1000.0
while not self.is_finished():
if _time.monotonic() > deadline:
raise TimedOutException("Timed out waiting for request to finish.")
# Advance fake time by a small amount per iteration.
# This fires pending sleeps (e.g., ratelimit pauses) while
# keeping time advancement predictable. The old Twisted
# MemoryReactorClock.advance(0.1) did the same.
# Advance fake time OUTSIDE the event loop drive
self._reactor.advance(0.1)
# Drive asyncio event loop for DB operations, task completions, etc.
if not loop.is_closed():
async def _drain() -> None:
"""Run multiple event loop ticks to drain pending work."""
for _ in range(20):
await asyncio.sleep(0.001)
if self._clock is not None:
self._clock.advance(0.0)
loop.run_until_complete(_drain())
# Drive the event loop briefly
try:
pump_loop(loop, timeout=0.05, tick=0.001)
except TimeoutError:
pass # outer loop will retry
def extract_cookies(self, cookies: MutableMapping[str, str]) -> None:
"""Process the contents of any Set-Cookie headers in the response
@@ -414,7 +407,7 @@ class FakeSite:
return self._resource
def make_request(
async def make_request(
reactor: MemoryReactorClock,
site: Site | FakeSite,
method: bytes | str,
@@ -544,8 +537,11 @@ def make_request(
if root_resource is not None:
req.dispatch(root_resource)
if await_result:
channel.await_result()
if await_result and req.render_deferred is not None:
# Await the handler task directly — the event loop drives
# all concurrent tasks (DB operations, background processes)
# naturally without needing nest_asyncio.
await req.render_deferred
return channel

View File

@@ -190,130 +190,38 @@ def make_homeserver_config_obj(config: dict[str, Any]) -> HomeServerConfig:
return deepcopy_config(config_obj)
class TestCase(_stdlib_unittest.TestCase):
"""A subclass of stdlib's TestCase which looks for 'loglevel'
attributes on both itself and its individual test methods, to override the
root logger's logging level while that test (case|method) runs."""
class TestCase(_stdlib_unittest.IsolatedAsyncioTestCase):
"""A subclass of IsolatedAsyncioTestCase that provides Synapse test utilities.
def __init__(self, methodName: str = "runTest"):
super().__init__(methodName)
Using IsolatedAsyncioTestCase means:
- Each test gets its own asyncio event loop, managed by the framework.
- Test methods can be ``async def`` and use ``await`` directly.
- setUp/tearDown can be async (asyncSetUp/asyncTearDown).
- No nest_asyncio needed — the event loop drives all tasks naturally.
"""
method = getattr(self, methodName, None)
def setUp(self) -> None:
# if we're not starting in the sentinel logcontext, then to be honest
# all future bets are off.
if current_context():
self.fail(
"Test starting with non-sentinel logging context %s"
% (current_context(),)
)
level = getattr(method, "loglevel", getattr(self, "loglevel", None))
# Disable GC for duration of test (re-enabled in tearDown).
gc.disable()
@around(self)
def setUp(orig: Callable[[], R]) -> R:
# Set up a fresh asyncio event loop for each test
import asyncio as _asyncio
import nest_asyncio as _nest_asyncio
self._asyncio_loop = _asyncio.new_event_loop()
_asyncio.set_event_loop(self._asyncio_loop)
_nest_asyncio.apply(self._asyncio_loop)
self.addCleanup(setup_awaitable_errors())
# if we're not starting in the sentinel logcontext, then to be honest
# all future bets are off.
if current_context():
self.fail(
"Test starting with non-sentinel logging context %s"
% (current_context(),)
)
# Disable GC for duration of test. See below for why.
gc.disable()
old_level = logging.getLogger().level
if level is not None and old_level != level:
@around(self)
def tearDown(orig: Callable[[], R]) -> R:
ret = orig()
logging.getLogger().setLevel(old_level)
return ret
logging.getLogger().setLevel(level)
# Trial messes with the warnings configuration, thus this has to be
# done in the context of an individual TestCase.
self.addCleanup(setup_awaitable_errors())
return orig()
# We want to force a GC to workaround problems with deferreds leaking
# logcontexts when they are GCed (see the logcontext docs).
#
# The easiest way to do this would be to do a full GC after each test
# run, but that is very expensive. Instead, we disable GC (above) for
# the duration of the test and only run a gen-0 GC, which is a lot
# quicker. This doesn't clean up everything, since the TestCase
# instance still holds references to objects created during the test,
# such as HomeServers, so we do a full GC every so often.
@around(self)
def tearDown(orig: Callable[[], R]) -> R:
ret = orig()
# Cancel any remaining asyncio tasks from this test
import asyncio as _asyncio
loop = _asyncio.get_event_loop()
if not loop.is_closed():
pending = [t for t in _asyncio.all_tasks(loop) if not t.done()]
for t in pending:
t.cancel()
if pending:
loop.run_until_complete(_asyncio.sleep(0))
gc.collect(0)
# Run a full GC every 50 gen-0 GCs.
gen0_stats = gc.get_stats()[0]
gen0_collections = gen0_stats["collections"]
if gen0_collections % 50 == 0:
gc.collect()
gc.enable()
set_current_context(SENTINEL_CONTEXT)
return ret
def _callTestMethod(self, method: Callable[[], Any]) -> None:
"""Override to handle async test methods.
Twisted's trial auto-detected async test methods and wrapped them
with ensureDeferred. We replicate that behavior here.
"""
import inspect
result = method()
if inspect.isawaitable(result):
from twisted.internet import defer, reactor
d = defer.ensureDeferred(result)
if not d.called:
finished: list[Any] = []
d.addBoth(finished.append)
if hasattr(self, "reactor"):
for _ in range(1000):
if finished:
break
self.reactor.advance(0.1)
else:
# Drive the global Twisted reactor
for _ in range(10000):
if finished:
break
reactor.runUntilCurrent() # type: ignore[attr-defined]
try:
reactor.doIteration(0.001) # type: ignore[attr-defined]
except NotImplementedError:
import time
time.sleep(0.001)
if not finished:
self.fail("Async test method did not complete")
if isinstance(finished[0], Failure):
finished[0].raiseException()
def tearDown(self) -> None:
gc.collect(0)
gen0_stats = gc.get_stats()[0]
gen0_collections = gen0_stats["collections"]
if gen0_collections % 50 == 0:
gc.collect()
gc.enable()
set_current_context(SENTINEL_CONTEXT)
def mktemp(self) -> str:
"""Return a unique temporary path for test use.
@@ -583,21 +491,15 @@ class HomeserverTestCase(TestCase):
method = getattr(self, methodName, None)
self._extra_config = getattr(method, "_extra_config", None) if method else None
def setUp(self) -> None:
async def asyncSetUp(self) -> None:
"""
Set up the TestCase by calling the homeserver constructor, optionally
hijacking the authentication system to return a fixed user, and then
calling the prepare function.
"""
# Set up an asyncio event loop so that asyncio primitives (Future, Event,
# create_task, etc.) work even when driven by Twisted's MemoryReactorClock.
import asyncio
self._asyncio_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._asyncio_loop)
# We need to share the reactor between the homeserver and all of our test utils.
self.reactor, self.clock = get_clock()
self.hs = self.make_homeserver(self.reactor, self.clock)
self.hs = await self.make_homeserver(self.reactor, self.clock)
self.hs.get_datastores().main.tests_allow_no_chain_cover_index = False
@@ -629,7 +531,7 @@ class HomeserverTestCase(TestCase):
self.helper = RestHelper(
self.hs,
checked_cast(MemoryReactorClock, self.hs.get_reactor()),
self.hs.get_reactor(),
self.site,
getattr(self, "user_id", None),
)
@@ -640,13 +542,11 @@ class HomeserverTestCase(TestCase):
token = "some_fake_token"
# We need a valid token ID to satisfy foreign key constraints.
token_id = self.get_success(
self.hs.get_datastores().main.add_access_token_to_user(
self.helper.auth_user_id,
token,
None,
None,
)
token_id = await self.hs.get_datastores().main.add_access_token_to_user(
self.helper.auth_user_id,
token,
None,
None,
)
# This has to be a function and not just a Mock, because
@@ -699,7 +599,7 @@ class HomeserverTestCase(TestCase):
store.db_pool.updates.do_next_background_update(False), by=0.1
)
def make_homeserver(
async def make_homeserver(
self, reactor: ThreadedMemoryReactorClock, clock: Clock
) -> HomeServer:
"""
@@ -714,7 +614,7 @@ class HomeserverTestCase(TestCase):
Function to be overridden in subclasses.
"""
hs = self.setup_test_homeserver(reactor=reactor, clock=clock)
hs = await self.setup_test_homeserver(reactor=reactor, clock=clock)
return hs
def create_test_resource(self) -> Resource:
@@ -773,7 +673,7 @@ class HomeserverTestCase(TestCase):
Function to optionally be overridden in subclasses.
"""
def make_request(
async def make_request(
self,
method: bytes | str,
path: bytes | str,
@@ -819,7 +719,7 @@ class HomeserverTestCase(TestCase):
Returns:
The FakeChannel object which stores the result of the request.
"""
return make_request(
return await make_request(
self.reactor,
self.site,
method,
@@ -837,7 +737,7 @@ class HomeserverTestCase(TestCase):
clock=self.clock,
)
def setup_test_homeserver(
async def setup_test_homeserver(
self,
server_name: str | None = None,
config: JsonDict | None = None,
@@ -879,10 +779,6 @@ class HomeserverTestCase(TestCase):
# construct a homeserver with a matching name.
server_name = config_obj.server.server_name
async def run_bg_updates() -> None:
with LoggingContext(name="run_bg_updates", server_name=server_name):
self.get_success(stor.db_pool.updates.run_background_updates(False))
hs = setup_test_homeserver(
cleanup_func=self.addCleanup,
server_name=server_name,
@@ -895,73 +791,49 @@ class HomeserverTestCase(TestCase):
# Run the database background updates, when running against "master".
if hs.__class__.__name__ == "TestHomeServer":
self.get_success(run_bg_updates())
with LoggingContext(name="run_bg_updates", server_name=server_name):
await stor.db_pool.updates.run_background_updates(False)
return hs
def pump(self, by: float = 0.0) -> None:
"""Advance fake time and drive the asyncio event loop.
async def pump(self, by: float = 0.0) -> None:
"""Advance fake time and yield to the event loop.
``reactor.advance()`` delegates to ``clock.advance()``, so calling
either one advances the same fake-time source.
"""
import asyncio
loop = asyncio.get_event_loop()
# Advance fake time (fires pending sleeps) AND drain callFromThread
# Advance fake time (fires pending sleeps and callFromThread)
self.reactor.advance(by)
# Yield to the event loop so callbacks can run
await asyncio.sleep(0)
# Process asyncio callbacks (executor results, task completions, etc.)
if not loop.is_closed():
loop.run_until_complete(asyncio.sleep(0))
def get_success(self, d: Awaitable[TV], by: float = 0.0) -> TV:
import asyncio
loop = asyncio.get_event_loop()
# Pump the fake reactor first if time advancement is needed
async def get_success(self, d: Awaitable[TV], by: float = 0.0) -> TV:
"""Await an awaitable, optionally advancing fake time first."""
if by > 0:
self.reactor.pump([by] * 100)
self.reactor.advance(by)
await asyncio.sleep(0)
return await d # type: ignore[misc]
# Run the awaitable to completion on the asyncio loop.
# nest_asyncio allows this even if the loop is already running.
return loop.run_until_complete(d) # type: ignore[arg-type]
def get_failure(
async def get_failure(
self, d: Awaitable[Any], exc: type[_ExcType], by: float = 0.0
) -> Any:
"""
Run an awaitable and get a Failure from it.
"""
import asyncio
"""Await an awaitable and expect it to raise."""
if by > 0:
self.reactor.advance(by)
await asyncio.sleep(0)
try:
await d
self.fail(f"Expected {exc}, but awaitable succeeded")
except BaseException as e:
if isinstance(e, exc):
return e
raise
future = asyncio.ensure_future(d) # type: ignore[arg-type]
async def get_success_or_raise(self, d: Awaitable[TV], by: float = 0.0) -> TV:
"""Await an awaitable and return result or raise exception."""
return await self.get_success(d, by=by)
error_holder: list[BaseException] = []
def _on_done(f: asyncio.Future) -> None: # type: ignore[type-arg]
try:
f.result()
except BaseException as e:
error_holder.append(e)
future.add_done_callback(_on_done)
self.pump(by)
if error_holder and isinstance(error_holder[0], exc):
return Failure(error_holder[0])
elif error_holder:
self.fail(f"Expected {exc}, got {type(error_holder[0])}: {error_holder[0]}")
else:
self.fail("Expected failure, but awaitable succeeded")
def get_success_or_raise(self, d: Awaitable[TV], by: float = 0.0) -> TV:
"""Drive awaitable to completion and return result or raise exception."""
return self.get_success(d, by=by)
def register_user(
async def register_user(
self,
username: str,
password: str,
@@ -970,20 +842,11 @@ class HomeserverTestCase(TestCase):
) -> str:
"""
Register a user. Requires the Admin API be registered.
Args:
username: The user part of the new user.
password: The password of the new user.
admin: Whether the user should be created as an admin or not.
displayname: The displayname of the new user.
Returns:
The MXID of the new user.
"""
self.hs.config.registration.registration_shared_secret = "shared"
# Create the user
channel = self.make_request("GET", "/_synapse/admin/v1/register")
channel = await self.make_request("GET", "/_synapse/admin/v1/register")
self.assertEqual(channel.code, 200, msg=channel.result)
nonce = channel.json_body["nonce"]
@@ -1006,13 +869,13 @@ class HomeserverTestCase(TestCase):
"mac": want_mac_digest,
"inhibit_login": True,
}
channel = self.make_request("POST", "/_synapse/admin/v1/register", body)
channel = await self.make_request("POST", "/_synapse/admin/v1/register", body)
self.assertEqual(channel.code, 200, channel.json_body)
user_id = channel.json_body["user_id"]
return user_id
def register_appservice_user(
async def register_appservice_user(
self,
username: str,
appservice_token: str,
@@ -1031,7 +894,7 @@ class HomeserverTestCase(TestCase):
Returns:
The MXID of the new user, the device ID of the new user's first device.
"""
channel = self.make_request(
channel = await self.make_request(
"POST",
"/_matrix/client/r0/register",
{
@@ -1044,7 +907,7 @@ class HomeserverTestCase(TestCase):
self.assertEqual(channel.code, 200, channel.json_body)
return channel.json_body["user_id"], channel.json_body.get("device_id")
def login(
async def login(
self,
username: str,
password: str,
@@ -1073,7 +936,7 @@ class HomeserverTestCase(TestCase):
if additional_request_fields:
body.update(additional_request_fields)
channel = self.make_request(
channel = await self.make_request(
"POST",
"/_matrix/client/r0/login",
body,