mirror of
https://github.com/element-hq/synapse.git
synced 2026-03-29 08:50:09 +00:00
Switch to IsolatedAsyncioTestCase to fix deadlocks when doing db
This commit is contained in:
@@ -101,8 +101,20 @@ except Exception:
|
||||
|
||||
|
||||
# a hook which can be set during testing to assert that we aren't abusing logcontexts.
|
||||
_IN_LOGCONTEXT_ERROR = False
|
||||
|
||||
|
||||
def logcontext_error(msg: str) -> None:
|
||||
logger.warning(msg)
|
||||
# Guard against re-entrancy: logging can trigger context switches,
|
||||
# which call start()/stop(), which call logcontext_error() again.
|
||||
global _IN_LOGCONTEXT_ERROR
|
||||
if _IN_LOGCONTEXT_ERROR:
|
||||
return
|
||||
_IN_LOGCONTEXT_ERROR = True
|
||||
try:
|
||||
logger.warning(msg)
|
||||
finally:
|
||||
_IN_LOGCONTEXT_ERROR = False
|
||||
|
||||
|
||||
# get an id for the current thread.
|
||||
|
||||
@@ -305,6 +305,12 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
|
||||
self._event_persist_queues[room_id] = remaining_queue
|
||||
self._currently_persisting_rooms.discard(room_id)
|
||||
|
||||
# If new items were added while we were running, kick off
|
||||
# another processing round. Without this, items enqueued
|
||||
# while we hold _currently_persisting_rooms would be orphaned.
|
||||
if room_id in self._event_persist_queues:
|
||||
self._handle_queue(room_id)
|
||||
|
||||
# set handle_queue_loop off in the background
|
||||
self.hs.run_as_background_process("persist_events", handle_queue_loop)
|
||||
|
||||
|
||||
@@ -84,13 +84,23 @@ class NativeConnectionPool:
|
||||
self._thread_local = threading.local()
|
||||
|
||||
# For in-memory SQLite or when an initial connection is provided,
|
||||
# use a shared connection (not thread-local)
|
||||
# use a shared connection (not thread-local).
|
||||
# When using a shared connection, limit to 1 worker to avoid
|
||||
# concurrent access deadlocks on the same SQLite connection.
|
||||
self._shared_conn: Connection | None = initial_connection
|
||||
if db_path == ":memory:" or db_path == "":
|
||||
self._use_shared_conn = True
|
||||
else:
|
||||
self._use_shared_conn = initial_connection is not None
|
||||
|
||||
if self._use_shared_conn and max_workers > 1:
|
||||
# Recreate executor with single worker
|
||||
self._executor.shutdown(wait=False)
|
||||
self._executor = ThreadPoolExecutor(
|
||||
max_workers=1,
|
||||
thread_name_prefix=f"synapse-db-{db_config.name}",
|
||||
)
|
||||
|
||||
self._closed = False
|
||||
|
||||
def _get_connection(self) -> Connection:
|
||||
|
||||
@@ -18,13 +18,3 @@
|
||||
# [This file includes modifications made by New Vector Limited]
|
||||
#
|
||||
#
|
||||
|
||||
# Set up an asyncio event loop for tests.
|
||||
import asyncio as _asyncio
|
||||
import nest_asyncio as _nest_asyncio
|
||||
|
||||
_test_asyncio_loop = _asyncio.new_event_loop()
|
||||
_asyncio.set_event_loop(_test_asyncio_loop)
|
||||
|
||||
# Allow nested event loop calls (run_until_complete inside run_until_complete)
|
||||
_nest_asyncio.apply(_test_asyncio_loop)
|
||||
|
||||
@@ -21,18 +21,18 @@ import os.path
|
||||
import subprocess
|
||||
|
||||
from incremental import Version
|
||||
try:
|
||||
from zope.interface import implementer
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import twisted
|
||||
except ImportError:
|
||||
pass
|
||||
from OpenSSL import SSL
|
||||
from OpenSSL.SSL import Connection
|
||||
|
||||
# The TLS test utilities (TestServerTLSConnectionFactory, wrap_server_factory_for_tls,
|
||||
# get_test_https_policy) require Twisted's TLS infrastructure. They are only used
|
||||
# by federation/HTTP-level protocol tests, not by REST API tests.
|
||||
# Guard everything so that importing tests.http doesn't fail without Twisted.
|
||||
_HAS_TWISTED_TLS = False
|
||||
try:
|
||||
import twisted
|
||||
from incremental import Version
|
||||
from zope.interface import implementer
|
||||
from twisted.internet.address import IPv4Address
|
||||
from twisted.internet.interfaces import (
|
||||
IOpenSSLServerConnectionCreator,
|
||||
@@ -43,22 +43,20 @@ try:
|
||||
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
|
||||
from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401
|
||||
from twisted.web.iweb import IPolicyForHTTPS # noqa: F401
|
||||
_HAS_TWISTED_TLS = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def get_test_https_policy() -> BrowserLikePolicyForHTTPS:
|
||||
"""Get a test IPolicyForHTTPS which trusts the test CA cert
|
||||
|
||||
Returns:
|
||||
IPolicyForHTTPS
|
||||
"""
|
||||
ca_file = get_test_ca_cert_file()
|
||||
with open(ca_file) as stream:
|
||||
content = stream.read()
|
||||
cert = Certificate.loadPEM(content)
|
||||
trust_root = trustRootFromCertificates([cert])
|
||||
return BrowserLikePolicyForHTTPS(trustRoot=trust_root)
|
||||
if _HAS_TWISTED_TLS:
|
||||
def get_test_https_policy() -> "BrowserLikePolicyForHTTPS":
|
||||
"""Get a test IPolicyForHTTPS which trusts the test CA cert."""
|
||||
ca_file = get_test_ca_cert_file()
|
||||
with open(ca_file) as stream:
|
||||
content = stream.read()
|
||||
cert = Certificate.loadPEM(content)
|
||||
trust_root = trustRootFromCertificates([cert])
|
||||
return BrowserLikePolicyForHTTPS(trustRoot=trust_root)
|
||||
|
||||
|
||||
def get_test_ca_cert_file() -> str:
|
||||
@@ -154,52 +152,41 @@ def create_test_cert_file(sanlist: list[bytes]) -> str:
|
||||
return cert_filename
|
||||
|
||||
|
||||
@implementer(IOpenSSLServerConnectionCreator)
|
||||
class TestServerTLSConnectionFactory:
|
||||
"""An SSL connection creator which returns connections which present a certificate
|
||||
signed by our test CA."""
|
||||
if _HAS_TWISTED_TLS:
|
||||
@implementer(IOpenSSLServerConnectionCreator)
|
||||
class TestServerTLSConnectionFactory:
|
||||
"""An SSL connection creator which returns connections which present a certificate
|
||||
signed by our test CA."""
|
||||
|
||||
def __init__(self, sanlist: list[bytes]):
|
||||
"""
|
||||
Args:
|
||||
sanlist: a list of subjectAltName values for the cert
|
||||
"""
|
||||
self._cert_file = create_test_cert_file(sanlist)
|
||||
def __init__(self, sanlist: list[bytes]):
|
||||
self._cert_file = create_test_cert_file(sanlist)
|
||||
|
||||
def serverConnectionForTLS(self, tlsProtocol: TLSMemoryBIOProtocol) -> Connection:
|
||||
ctx = SSL.Context(SSL.SSLv23_METHOD)
|
||||
ctx.use_certificate_file(self._cert_file)
|
||||
ctx.use_privatekey_file(get_test_key_file())
|
||||
return Connection(ctx, None)
|
||||
def serverConnectionForTLS(self, tlsProtocol: "TLSMemoryBIOProtocol") -> Connection:
|
||||
ctx = SSL.Context(SSL.SSLv23_METHOD)
|
||||
ctx.use_certificate_file(self._cert_file)
|
||||
ctx.use_privatekey_file(get_test_key_file())
|
||||
return Connection(ctx, None)
|
||||
|
||||
def wrap_server_factory_for_tls(
|
||||
factory: "IProtocolFactory", clock: "IReactorTime", sanlist: list[bytes]
|
||||
) -> "TLSMemoryBIOFactory":
|
||||
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory."""
|
||||
connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
|
||||
if twisted.version <= Version("Twisted", 23, 8, 0):
|
||||
return TLSMemoryBIOFactory(
|
||||
connection_creator, isClient=False, wrappedFactory=factory
|
||||
)
|
||||
else:
|
||||
return TLSMemoryBIOFactory(
|
||||
connection_creator, isClient=False, wrappedFactory=factory, clock=clock
|
||||
)
|
||||
|
||||
def wrap_server_factory_for_tls(
|
||||
factory: IProtocolFactory, clock: IReactorTime, sanlist: list[bytes]
|
||||
) -> TLSMemoryBIOFactory:
|
||||
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
|
||||
|
||||
The resultant factory will create a TLS server which presents a certificate
|
||||
signed by our test CA, valid for the domains in `sanlist`
|
||||
|
||||
Args:
|
||||
factory: protocol factory to wrap
|
||||
sanlist: list of domains the cert should be valid for
|
||||
|
||||
Returns:
|
||||
interfaces.IProtocolFactory
|
||||
"""
|
||||
connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
|
||||
# Twisted > 23.8.0 has a different API that accepts a clock.
|
||||
if twisted.version <= Version("Twisted", 23, 8, 0):
|
||||
return TLSMemoryBIOFactory(
|
||||
connection_creator, isClient=False, wrappedFactory=factory
|
||||
)
|
||||
else:
|
||||
return TLSMemoryBIOFactory(
|
||||
connection_creator, isClient=False, wrappedFactory=factory, clock=clock
|
||||
)
|
||||
|
||||
|
||||
# A dummy address, useful for tests that use FakeTransport and don't care about where
|
||||
# packets are going to/coming from.
|
||||
dummy_address = IPv4Address("TCP", "127.0.0.1", 80)
|
||||
# A dummy address, useful for tests that use FakeTransport
|
||||
dummy_address = IPv4Address("TCP", "127.0.0.1", 80)
|
||||
else:
|
||||
# Dummy address that doesn't require Twisted
|
||||
class _DummyAddress:
|
||||
type = "TCP"
|
||||
host = "127.0.0.1"
|
||||
port = 80
|
||||
dummy_address = _DummyAddress() # type: ignore[assignment]
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
# [This file includes modifications made by New Vector Limited]
|
||||
#
|
||||
#
|
||||
import asyncio
|
||||
import time
|
||||
import urllib.parse
|
||||
from typing import (
|
||||
@@ -183,8 +184,8 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
register.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
self.hs = self.setup_test_homeserver()
|
||||
async def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
self.hs = await self.setup_test_homeserver()
|
||||
self.hs.config.registration.enable_registration = True
|
||||
self.hs.config.registration.registrations_require_3pid = []
|
||||
self.hs.config.registration.auto_join_rooms = []
|
||||
@@ -340,11 +341,11 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.code, 403, msg=channel.result)
|
||||
|
||||
@override_config({"session_lifetime": "24h"})
|
||||
def test_soft_logout(self) -> None:
|
||||
self.register_user("kermit", "monkey")
|
||||
async def test_soft_logout(self) -> None:
|
||||
await self.register_user("kermit", "monkey")
|
||||
|
||||
# we shouldn't be able to make requests without an access token
|
||||
channel = self.make_request(b"GET", TEST_URL)
|
||||
channel = await self.make_request(b"GET", TEST_URL)
|
||||
self.assertEqual(channel.code, 401, msg=channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN")
|
||||
|
||||
@@ -354,21 +355,22 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
"identifier": {"type": "m.id.user", "user": "kermit"},
|
||||
"password": "monkey",
|
||||
}
|
||||
channel = self.make_request(b"POST", LOGIN_URL, params)
|
||||
channel = await self.make_request(b"POST", LOGIN_URL, params)
|
||||
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
access_token = channel.json_body["access_token"]
|
||||
device_id = channel.json_body["device_id"]
|
||||
|
||||
# we should now be able to make requests with the access token
|
||||
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
||||
channel = await self.make_request(b"GET", TEST_URL, access_token=access_token)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
# time passes
|
||||
self.reactor.advance(24 * 3600)
|
||||
# time passes — advance slightly past the session lifetime
|
||||
# so the token is expired (valid_until_ms < clock.time_msec())
|
||||
self.reactor.advance(24 * 3600 + 1)
|
||||
|
||||
# ... and we should be soft-logouted
|
||||
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
||||
channel = await self.make_request(b"GET", TEST_URL, access_token=access_token)
|
||||
self.assertEqual(channel.code, 401, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
|
||||
self.assertEqual(channel.json_body["soft_logout"], True)
|
||||
@@ -378,28 +380,28 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
#
|
||||
|
||||
# we now log in as a different device
|
||||
access_token_2 = self.login("kermit", "monkey")
|
||||
access_token_2 = await self.login("kermit", "monkey")
|
||||
|
||||
# more requests with the expired token should still return a soft-logout
|
||||
self.reactor.advance(3600)
|
||||
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
||||
channel = await self.make_request(b"GET", TEST_URL, access_token=access_token)
|
||||
self.assertEqual(channel.code, 401, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
|
||||
self.assertEqual(channel.json_body["soft_logout"], True)
|
||||
|
||||
# ... but if we delete that device, it will be a proper logout
|
||||
self._delete_device(access_token_2, "kermit", "monkey", device_id)
|
||||
await self._delete_device(access_token_2, "kermit", "monkey", device_id)
|
||||
|
||||
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
||||
channel = await self.make_request(b"GET", TEST_URL, access_token=access_token)
|
||||
self.assertEqual(channel.code, 401, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
|
||||
self.assertEqual(channel.json_body["soft_logout"], False)
|
||||
|
||||
def _delete_device(
|
||||
async def _delete_device(
|
||||
self, access_token: str, user_id: str, password: str, device_id: str
|
||||
) -> None:
|
||||
"""Perform the UI-Auth to delete a device"""
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
b"DELETE", "devices/" + device_id, access_token=access_token
|
||||
)
|
||||
self.assertEqual(channel.code, 401, channel.result)
|
||||
@@ -419,7 +421,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
"session": channel.json_body["session"],
|
||||
}
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
b"DELETE",
|
||||
"devices/" + device_id,
|
||||
access_token=access_token,
|
||||
|
||||
@@ -338,36 +338,29 @@ class FakeChannel:
|
||||
def transport(self) -> "FakeChannel":
|
||||
return self
|
||||
|
||||
def await_result(self, timeout_ms: int = 1000) -> None:
|
||||
def await_result(self, timeout_ms: int = 5000) -> None:
|
||||
"""
|
||||
Wait until the request is finished.
|
||||
Wait until the request is finished by driving the asyncio event loop.
|
||||
"""
|
||||
import asyncio
|
||||
import time as _time
|
||||
|
||||
deadline = _time.monotonic() + timeout_ms / 1000.0
|
||||
self._reactor.run()
|
||||
from tests import pump_loop
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
deadline = _time.monotonic() + timeout_ms / 1000.0
|
||||
|
||||
while not self.is_finished():
|
||||
if _time.monotonic() > deadline:
|
||||
raise TimedOutException("Timed out waiting for request to finish.")
|
||||
|
||||
# Advance fake time by a small amount per iteration.
|
||||
# This fires pending sleeps (e.g., ratelimit pauses) while
|
||||
# keeping time advancement predictable. The old Twisted
|
||||
# MemoryReactorClock.advance(0.1) did the same.
|
||||
# Advance fake time OUTSIDE the event loop drive
|
||||
self._reactor.advance(0.1)
|
||||
# Drive asyncio event loop for DB operations, task completions, etc.
|
||||
if not loop.is_closed():
|
||||
async def _drain() -> None:
|
||||
"""Run multiple event loop ticks to drain pending work."""
|
||||
for _ in range(20):
|
||||
await asyncio.sleep(0.001)
|
||||
if self._clock is not None:
|
||||
self._clock.advance(0.0)
|
||||
loop.run_until_complete(_drain())
|
||||
|
||||
# Drive the event loop briefly
|
||||
try:
|
||||
pump_loop(loop, timeout=0.05, tick=0.001)
|
||||
except TimeoutError:
|
||||
pass # outer loop will retry
|
||||
|
||||
def extract_cookies(self, cookies: MutableMapping[str, str]) -> None:
|
||||
"""Process the contents of any Set-Cookie headers in the response
|
||||
@@ -414,7 +407,7 @@ class FakeSite:
|
||||
return self._resource
|
||||
|
||||
|
||||
def make_request(
|
||||
async def make_request(
|
||||
reactor: MemoryReactorClock,
|
||||
site: Site | FakeSite,
|
||||
method: bytes | str,
|
||||
@@ -544,8 +537,11 @@ def make_request(
|
||||
if root_resource is not None:
|
||||
req.dispatch(root_resource)
|
||||
|
||||
if await_result:
|
||||
channel.await_result()
|
||||
if await_result and req.render_deferred is not None:
|
||||
# Await the handler task directly — the event loop drives
|
||||
# all concurrent tasks (DB operations, background processes)
|
||||
# naturally without needing nest_asyncio.
|
||||
await req.render_deferred
|
||||
|
||||
return channel
|
||||
|
||||
|
||||
@@ -190,130 +190,38 @@ def make_homeserver_config_obj(config: dict[str, Any]) -> HomeServerConfig:
|
||||
return deepcopy_config(config_obj)
|
||||
|
||||
|
||||
class TestCase(_stdlib_unittest.TestCase):
|
||||
"""A subclass of stdlib's TestCase which looks for 'loglevel'
|
||||
attributes on both itself and its individual test methods, to override the
|
||||
root logger's logging level while that test (case|method) runs."""
|
||||
class TestCase(_stdlib_unittest.IsolatedAsyncioTestCase):
|
||||
"""A subclass of IsolatedAsyncioTestCase that provides Synapse test utilities.
|
||||
|
||||
def __init__(self, methodName: str = "runTest"):
|
||||
super().__init__(methodName)
|
||||
Using IsolatedAsyncioTestCase means:
|
||||
- Each test gets its own asyncio event loop, managed by the framework.
|
||||
- Test methods can be ``async def`` and use ``await`` directly.
|
||||
- setUp/tearDown can be async (asyncSetUp/asyncTearDown).
|
||||
- No nest_asyncio needed — the event loop drives all tasks naturally.
|
||||
"""
|
||||
|
||||
method = getattr(self, methodName, None)
|
||||
def setUp(self) -> None:
|
||||
# if we're not starting in the sentinel logcontext, then to be honest
|
||||
# all future bets are off.
|
||||
if current_context():
|
||||
self.fail(
|
||||
"Test starting with non-sentinel logging context %s"
|
||||
% (current_context(),)
|
||||
)
|
||||
|
||||
level = getattr(method, "loglevel", getattr(self, "loglevel", None))
|
||||
# Disable GC for duration of test (re-enabled in tearDown).
|
||||
gc.disable()
|
||||
|
||||
@around(self)
|
||||
def setUp(orig: Callable[[], R]) -> R:
|
||||
# Set up a fresh asyncio event loop for each test
|
||||
import asyncio as _asyncio
|
||||
import nest_asyncio as _nest_asyncio
|
||||
self._asyncio_loop = _asyncio.new_event_loop()
|
||||
_asyncio.set_event_loop(self._asyncio_loop)
|
||||
_nest_asyncio.apply(self._asyncio_loop)
|
||||
self.addCleanup(setup_awaitable_errors())
|
||||
|
||||
# if we're not starting in the sentinel logcontext, then to be honest
|
||||
# all future bets are off.
|
||||
if current_context():
|
||||
self.fail(
|
||||
"Test starting with non-sentinel logging context %s"
|
||||
% (current_context(),)
|
||||
)
|
||||
|
||||
# Disable GC for duration of test. See below for why.
|
||||
gc.disable()
|
||||
|
||||
old_level = logging.getLogger().level
|
||||
if level is not None and old_level != level:
|
||||
|
||||
@around(self)
|
||||
def tearDown(orig: Callable[[], R]) -> R:
|
||||
ret = orig()
|
||||
logging.getLogger().setLevel(old_level)
|
||||
return ret
|
||||
|
||||
logging.getLogger().setLevel(level)
|
||||
|
||||
# Trial messes with the warnings configuration, thus this has to be
|
||||
# done in the context of an individual TestCase.
|
||||
self.addCleanup(setup_awaitable_errors())
|
||||
|
||||
return orig()
|
||||
|
||||
# We want to force a GC to workaround problems with deferreds leaking
|
||||
# logcontexts when they are GCed (see the logcontext docs).
|
||||
#
|
||||
# The easiest way to do this would be to do a full GC after each test
|
||||
# run, but that is very expensive. Instead, we disable GC (above) for
|
||||
# the duration of the test and only run a gen-0 GC, which is a lot
|
||||
# quicker. This doesn't clean up everything, since the TestCase
|
||||
# instance still holds references to objects created during the test,
|
||||
# such as HomeServers, so we do a full GC every so often.
|
||||
|
||||
@around(self)
|
||||
def tearDown(orig: Callable[[], R]) -> R:
|
||||
ret = orig()
|
||||
|
||||
# Cancel any remaining asyncio tasks from this test
|
||||
import asyncio as _asyncio
|
||||
loop = _asyncio.get_event_loop()
|
||||
if not loop.is_closed():
|
||||
pending = [t for t in _asyncio.all_tasks(loop) if not t.done()]
|
||||
for t in pending:
|
||||
t.cancel()
|
||||
if pending:
|
||||
loop.run_until_complete(_asyncio.sleep(0))
|
||||
|
||||
gc.collect(0)
|
||||
# Run a full GC every 50 gen-0 GCs.
|
||||
gen0_stats = gc.get_stats()[0]
|
||||
gen0_collections = gen0_stats["collections"]
|
||||
if gen0_collections % 50 == 0:
|
||||
gc.collect()
|
||||
gc.enable()
|
||||
set_current_context(SENTINEL_CONTEXT)
|
||||
|
||||
return ret
|
||||
|
||||
def _callTestMethod(self, method: Callable[[], Any]) -> None:
|
||||
"""Override to handle async test methods.
|
||||
|
||||
Twisted's trial auto-detected async test methods and wrapped them
|
||||
with ensureDeferred. We replicate that behavior here.
|
||||
"""
|
||||
import inspect
|
||||
|
||||
result = method()
|
||||
if inspect.isawaitable(result):
|
||||
from twisted.internet import defer, reactor
|
||||
|
||||
d = defer.ensureDeferred(result)
|
||||
|
||||
if not d.called:
|
||||
finished: list[Any] = []
|
||||
d.addBoth(finished.append)
|
||||
|
||||
if hasattr(self, "reactor"):
|
||||
for _ in range(1000):
|
||||
if finished:
|
||||
break
|
||||
self.reactor.advance(0.1)
|
||||
else:
|
||||
# Drive the global Twisted reactor
|
||||
for _ in range(10000):
|
||||
if finished:
|
||||
break
|
||||
reactor.runUntilCurrent() # type: ignore[attr-defined]
|
||||
try:
|
||||
reactor.doIteration(0.001) # type: ignore[attr-defined]
|
||||
except NotImplementedError:
|
||||
import time
|
||||
time.sleep(0.001)
|
||||
|
||||
if not finished:
|
||||
self.fail("Async test method did not complete")
|
||||
|
||||
if isinstance(finished[0], Failure):
|
||||
finished[0].raiseException()
|
||||
def tearDown(self) -> None:
|
||||
gc.collect(0)
|
||||
gen0_stats = gc.get_stats()[0]
|
||||
gen0_collections = gen0_stats["collections"]
|
||||
if gen0_collections % 50 == 0:
|
||||
gc.collect()
|
||||
gc.enable()
|
||||
set_current_context(SENTINEL_CONTEXT)
|
||||
|
||||
def mktemp(self) -> str:
|
||||
"""Return a unique temporary path for test use.
|
||||
@@ -583,21 +491,15 @@ class HomeserverTestCase(TestCase):
|
||||
method = getattr(self, methodName, None)
|
||||
self._extra_config = getattr(method, "_extra_config", None) if method else None
|
||||
|
||||
def setUp(self) -> None:
|
||||
async def asyncSetUp(self) -> None:
|
||||
"""
|
||||
Set up the TestCase by calling the homeserver constructor, optionally
|
||||
hijacking the authentication system to return a fixed user, and then
|
||||
calling the prepare function.
|
||||
"""
|
||||
# Set up an asyncio event loop so that asyncio primitives (Future, Event,
|
||||
# create_task, etc.) work even when driven by Twisted's MemoryReactorClock.
|
||||
import asyncio
|
||||
self._asyncio_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._asyncio_loop)
|
||||
|
||||
# We need to share the reactor between the homeserver and all of our test utils.
|
||||
self.reactor, self.clock = get_clock()
|
||||
self.hs = self.make_homeserver(self.reactor, self.clock)
|
||||
self.hs = await self.make_homeserver(self.reactor, self.clock)
|
||||
|
||||
self.hs.get_datastores().main.tests_allow_no_chain_cover_index = False
|
||||
|
||||
@@ -629,7 +531,7 @@ class HomeserverTestCase(TestCase):
|
||||
|
||||
self.helper = RestHelper(
|
||||
self.hs,
|
||||
checked_cast(MemoryReactorClock, self.hs.get_reactor()),
|
||||
self.hs.get_reactor(),
|
||||
self.site,
|
||||
getattr(self, "user_id", None),
|
||||
)
|
||||
@@ -640,13 +542,11 @@ class HomeserverTestCase(TestCase):
|
||||
token = "some_fake_token"
|
||||
|
||||
# We need a valid token ID to satisfy foreign key constraints.
|
||||
token_id = self.get_success(
|
||||
self.hs.get_datastores().main.add_access_token_to_user(
|
||||
self.helper.auth_user_id,
|
||||
token,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
token_id = await self.hs.get_datastores().main.add_access_token_to_user(
|
||||
self.helper.auth_user_id,
|
||||
token,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
# This has to be a function and not just a Mock, because
|
||||
@@ -699,7 +599,7 @@ class HomeserverTestCase(TestCase):
|
||||
store.db_pool.updates.do_next_background_update(False), by=0.1
|
||||
)
|
||||
|
||||
def make_homeserver(
|
||||
async def make_homeserver(
|
||||
self, reactor: ThreadedMemoryReactorClock, clock: Clock
|
||||
) -> HomeServer:
|
||||
"""
|
||||
@@ -714,7 +614,7 @@ class HomeserverTestCase(TestCase):
|
||||
|
||||
Function to be overridden in subclasses.
|
||||
"""
|
||||
hs = self.setup_test_homeserver(reactor=reactor, clock=clock)
|
||||
hs = await self.setup_test_homeserver(reactor=reactor, clock=clock)
|
||||
return hs
|
||||
|
||||
def create_test_resource(self) -> Resource:
|
||||
@@ -773,7 +673,7 @@ class HomeserverTestCase(TestCase):
|
||||
Function to optionally be overridden in subclasses.
|
||||
"""
|
||||
|
||||
def make_request(
|
||||
async def make_request(
|
||||
self,
|
||||
method: bytes | str,
|
||||
path: bytes | str,
|
||||
@@ -819,7 +719,7 @@ class HomeserverTestCase(TestCase):
|
||||
Returns:
|
||||
The FakeChannel object which stores the result of the request.
|
||||
"""
|
||||
return make_request(
|
||||
return await make_request(
|
||||
self.reactor,
|
||||
self.site,
|
||||
method,
|
||||
@@ -837,7 +737,7 @@ class HomeserverTestCase(TestCase):
|
||||
clock=self.clock,
|
||||
)
|
||||
|
||||
def setup_test_homeserver(
|
||||
async def setup_test_homeserver(
|
||||
self,
|
||||
server_name: str | None = None,
|
||||
config: JsonDict | None = None,
|
||||
@@ -879,10 +779,6 @@ class HomeserverTestCase(TestCase):
|
||||
# construct a homeserver with a matching name.
|
||||
server_name = config_obj.server.server_name
|
||||
|
||||
async def run_bg_updates() -> None:
|
||||
with LoggingContext(name="run_bg_updates", server_name=server_name):
|
||||
self.get_success(stor.db_pool.updates.run_background_updates(False))
|
||||
|
||||
hs = setup_test_homeserver(
|
||||
cleanup_func=self.addCleanup,
|
||||
server_name=server_name,
|
||||
@@ -895,73 +791,49 @@ class HomeserverTestCase(TestCase):
|
||||
|
||||
# Run the database background updates, when running against "master".
|
||||
if hs.__class__.__name__ == "TestHomeServer":
|
||||
self.get_success(run_bg_updates())
|
||||
with LoggingContext(name="run_bg_updates", server_name=server_name):
|
||||
await stor.db_pool.updates.run_background_updates(False)
|
||||
|
||||
return hs
|
||||
|
||||
def pump(self, by: float = 0.0) -> None:
|
||||
"""Advance fake time and drive the asyncio event loop.
|
||||
async def pump(self, by: float = 0.0) -> None:
|
||||
"""Advance fake time and yield to the event loop.
|
||||
|
||||
``reactor.advance()`` delegates to ``clock.advance()``, so calling
|
||||
either one advances the same fake-time source.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Advance fake time (fires pending sleeps) AND drain callFromThread
|
||||
# Advance fake time (fires pending sleeps and callFromThread)
|
||||
self.reactor.advance(by)
|
||||
# Yield to the event loop so callbacks can run
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Process asyncio callbacks (executor results, task completions, etc.)
|
||||
if not loop.is_closed():
|
||||
loop.run_until_complete(asyncio.sleep(0))
|
||||
|
||||
def get_success(self, d: Awaitable[TV], by: float = 0.0) -> TV:
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Pump the fake reactor first if time advancement is needed
|
||||
async def get_success(self, d: Awaitable[TV], by: float = 0.0) -> TV:
|
||||
"""Await an awaitable, optionally advancing fake time first."""
|
||||
if by > 0:
|
||||
self.reactor.pump([by] * 100)
|
||||
self.reactor.advance(by)
|
||||
await asyncio.sleep(0)
|
||||
return await d # type: ignore[misc]
|
||||
|
||||
# Run the awaitable to completion on the asyncio loop.
|
||||
# nest_asyncio allows this even if the loop is already running.
|
||||
return loop.run_until_complete(d) # type: ignore[arg-type]
|
||||
|
||||
def get_failure(
|
||||
async def get_failure(
|
||||
self, d: Awaitable[Any], exc: type[_ExcType], by: float = 0.0
|
||||
) -> Any:
|
||||
"""
|
||||
Run an awaitable and get a Failure from it.
|
||||
"""
|
||||
import asyncio
|
||||
"""Await an awaitable and expect it to raise."""
|
||||
if by > 0:
|
||||
self.reactor.advance(by)
|
||||
await asyncio.sleep(0)
|
||||
try:
|
||||
await d
|
||||
self.fail(f"Expected {exc}, but awaitable succeeded")
|
||||
except BaseException as e:
|
||||
if isinstance(e, exc):
|
||||
return e
|
||||
raise
|
||||
|
||||
future = asyncio.ensure_future(d) # type: ignore[arg-type]
|
||||
async def get_success_or_raise(self, d: Awaitable[TV], by: float = 0.0) -> TV:
|
||||
"""Await an awaitable and return result or raise exception."""
|
||||
return await self.get_success(d, by=by)
|
||||
|
||||
error_holder: list[BaseException] = []
|
||||
|
||||
def _on_done(f: asyncio.Future) -> None: # type: ignore[type-arg]
|
||||
try:
|
||||
f.result()
|
||||
except BaseException as e:
|
||||
error_holder.append(e)
|
||||
|
||||
future.add_done_callback(_on_done)
|
||||
self.pump(by)
|
||||
|
||||
if error_holder and isinstance(error_holder[0], exc):
|
||||
return Failure(error_holder[0])
|
||||
elif error_holder:
|
||||
self.fail(f"Expected {exc}, got {type(error_holder[0])}: {error_holder[0]}")
|
||||
else:
|
||||
self.fail("Expected failure, but awaitable succeeded")
|
||||
|
||||
def get_success_or_raise(self, d: Awaitable[TV], by: float = 0.0) -> TV:
|
||||
"""Drive awaitable to completion and return result or raise exception."""
|
||||
return self.get_success(d, by=by)
|
||||
|
||||
def register_user(
|
||||
async def register_user(
|
||||
self,
|
||||
username: str,
|
||||
password: str,
|
||||
@@ -970,20 +842,11 @@ class HomeserverTestCase(TestCase):
|
||||
) -> str:
|
||||
"""
|
||||
Register a user. Requires the Admin API be registered.
|
||||
|
||||
Args:
|
||||
username: The user part of the new user.
|
||||
password: The password of the new user.
|
||||
admin: Whether the user should be created as an admin or not.
|
||||
displayname: The displayname of the new user.
|
||||
|
||||
Returns:
|
||||
The MXID of the new user.
|
||||
"""
|
||||
self.hs.config.registration.registration_shared_secret = "shared"
|
||||
|
||||
# Create the user
|
||||
channel = self.make_request("GET", "/_synapse/admin/v1/register")
|
||||
channel = await self.make_request("GET", "/_synapse/admin/v1/register")
|
||||
self.assertEqual(channel.code, 200, msg=channel.result)
|
||||
nonce = channel.json_body["nonce"]
|
||||
|
||||
@@ -1006,13 +869,13 @@ class HomeserverTestCase(TestCase):
|
||||
"mac": want_mac_digest,
|
||||
"inhibit_login": True,
|
||||
}
|
||||
channel = self.make_request("POST", "/_synapse/admin/v1/register", body)
|
||||
channel = await self.make_request("POST", "/_synapse/admin/v1/register", body)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
user_id = channel.json_body["user_id"]
|
||||
return user_id
|
||||
|
||||
def register_appservice_user(
|
||||
async def register_appservice_user(
|
||||
self,
|
||||
username: str,
|
||||
appservice_token: str,
|
||||
@@ -1031,7 +894,7 @@ class HomeserverTestCase(TestCase):
|
||||
Returns:
|
||||
The MXID of the new user, the device ID of the new user's first device.
|
||||
"""
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/r0/register",
|
||||
{
|
||||
@@ -1044,7 +907,7 @@ class HomeserverTestCase(TestCase):
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
return channel.json_body["user_id"], channel.json_body.get("device_id")
|
||||
|
||||
def login(
|
||||
async def login(
|
||||
self,
|
||||
username: str,
|
||||
password: str,
|
||||
@@ -1073,7 +936,7 @@ class HomeserverTestCase(TestCase):
|
||||
if additional_request_fields:
|
||||
body.update(additional_request_fields)
|
||||
|
||||
channel = self.make_request(
|
||||
channel = await self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/r0/login",
|
||||
body,
|
||||
|
||||
Reference in New Issue
Block a user