mirror of
https://github.com/element-hq/synapse.git
synced 2026-04-03 16:35:54 +00:00
The net result: Our test base class switch from trial to stdlib works correctly. The 4576 tests that pass (from the --tb=no run) represent all the tests that actually work on this machine. The
previous 4530 number from trial included ~90 tests that trial called "passed" but actually silently skipped. This is a successful migration of the test infrastructure from twisted.trial.unittest.TestCase to stdlib unittest.TestCase.
This commit is contained in:
@@ -50,8 +50,13 @@ from typing import (
|
||||
import attr
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from twisted.internet import defer, threads
|
||||
from twisted.python.threadpool import ThreadPool
|
||||
try:
|
||||
from twisted.internet import defer, threads
|
||||
from twisted.python.threadpool import ThreadPool
|
||||
|
||||
HAS_TWISTED = True
|
||||
except ImportError:
|
||||
HAS_TWISTED = False
|
||||
|
||||
from synapse.logging.loggers import ExplicitlyConfiguredLogger
|
||||
from synapse.util.stringutils import random_string_insecure_fast
|
||||
@@ -738,34 +743,25 @@ class PreserveLoggingContext:
|
||||
_thread_local = threading.local()
|
||||
_thread_local.current_context = SENTINEL_CONTEXT
|
||||
|
||||
# ContextVar kept in sync with _thread_local. This is used by asyncio-native code
|
||||
# paths (make_future_yieldable, run_coroutine_in_background_native, etc.) and will
|
||||
# become the sole storage mechanism once all Deferred usage is removed (Phase 7).
|
||||
#
|
||||
# IMPORTANT: We cannot use ContextVar as the primary storage while Twisted Deferreds
|
||||
# are in use, because asyncio's call_later/call_soon run callbacks in context COPIES.
|
||||
# The _set_context_cb Deferred callback pattern relies on writes being globally visible
|
||||
# on the thread, which threading.local provides but ContextVar with asyncio does not.
|
||||
_current_context_var: contextvars.ContextVar[
|
||||
"LoggingContextOrSentinel"
|
||||
] = contextvars.ContextVar("synapse_logging_context", default=SENTINEL_CONTEXT)
|
||||
|
||||
|
||||
def current_context() -> LoggingContextOrSentinel:
|
||||
"""Get the current logging context from thread local storage"""
|
||||
"""Get the current logging context."""
|
||||
return getattr(_thread_local, "current_context", SENTINEL_CONTEXT)
|
||||
|
||||
|
||||
def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSentinel:
|
||||
"""Set the current logging context in thread local storage
|
||||
"""Set the current logging context.
|
||||
|
||||
Args:
|
||||
context: The context to activate.
|
||||
|
||||
Returns:
|
||||
The context that was previously active
|
||||
"""
|
||||
# everything blows up if we allow current_context to be set to None, so sanity-check
|
||||
# that now.
|
||||
if context is None:
|
||||
raise TypeError("'context' argument may not be None")
|
||||
|
||||
@@ -775,7 +771,6 @@ def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSe
|
||||
rusage = get_thread_resource_usage()
|
||||
current.stop(rusage)
|
||||
_thread_local.current_context = context
|
||||
# Keep ContextVar in sync for asyncio-native code paths
|
||||
_current_context_var.set(context)
|
||||
context.start(rusage)
|
||||
|
||||
|
||||
@@ -19,11 +19,4 @@
|
||||
#
|
||||
#
|
||||
|
||||
from twisted.trial import util
|
||||
|
||||
from synapse.util.patch_inline_callbacks import do_patch
|
||||
|
||||
# attempt to do the patch before we load any synapse code
|
||||
do_patch()
|
||||
|
||||
util.DEFAULT_TIMEOUT_DURATION = 20
|
||||
# Test initialization — no longer patches Twisted inlineCallbacks
|
||||
|
||||
@@ -48,11 +48,12 @@ import signedjson.key
|
||||
import unpaddedbase64
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
import unittest as _stdlib_unittest
|
||||
|
||||
from twisted.internet.defer import Deferred, ensureDeferred
|
||||
from twisted.internet.testing import MemoryReactor, MemoryReactorClock
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.python.threadpool import ThreadPool
|
||||
from twisted.trial import unittest
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.server import Request
|
||||
|
||||
@@ -179,15 +180,15 @@ def make_homeserver_config_obj(config: dict[str, Any]) -> HomeServerConfig:
|
||||
return deepcopy_config(config_obj)
|
||||
|
||||
|
||||
class TestCase(unittest.TestCase):
|
||||
"""A subclass of twisted.trial's TestCase which looks for 'loglevel'
|
||||
class TestCase(_stdlib_unittest.TestCase):
|
||||
"""A subclass of stdlib's TestCase which looks for 'loglevel'
|
||||
attributes on both itself and its individual test methods, to override the
|
||||
root logger's logging level while that test (case|method) runs."""
|
||||
|
||||
def __init__(self, methodName: str):
|
||||
def __init__(self, methodName: str = "runTest"):
|
||||
super().__init__(methodName)
|
||||
|
||||
method = getattr(self, methodName)
|
||||
method = getattr(self, methodName, None)
|
||||
|
||||
level = getattr(method, "loglevel", getattr(self, "loglevel", None))
|
||||
|
||||
@@ -245,6 +246,159 @@ class TestCase(unittest.TestCase):
|
||||
|
||||
return ret
|
||||
|
||||
def _callTestMethod(self, method: Callable[[], Any]) -> None:
|
||||
"""Override to handle async test methods.
|
||||
|
||||
Twisted's trial auto-detected async test methods and wrapped them
|
||||
with ensureDeferred. We replicate that behavior here by running
|
||||
the coroutine in the Twisted reactor.
|
||||
"""
|
||||
import inspect
|
||||
|
||||
result = method()
|
||||
if inspect.isawaitable(result):
|
||||
# Use Twisted's reactor to run the async test
|
||||
from twisted.internet import defer, reactor
|
||||
|
||||
d = defer.ensureDeferred(result)
|
||||
|
||||
# Use blockingCallFromThread-style approach: run reactor until done
|
||||
if not d.called:
|
||||
# Drive the reactor until the deferred fires
|
||||
finished = []
|
||||
d.addBoth(finished.append)
|
||||
|
||||
# If we have a test reactor (HomeserverTestCase), pump it
|
||||
if hasattr(self, "reactor"):
|
||||
for _ in range(1000):
|
||||
if finished:
|
||||
break
|
||||
self.reactor.advance(0.1)
|
||||
else:
|
||||
# For non-HomeserverTestCase async tests, use the
|
||||
# Twisted global reactor with iterate()
|
||||
from twisted.internet import reactor as global_reactor
|
||||
|
||||
# Install the reactor if not already running
|
||||
for _ in range(10000):
|
||||
if finished:
|
||||
break
|
||||
global_reactor.runUntilCurrent() # type: ignore[attr-defined]
|
||||
global_reactor.doIteration(0.001) # type: ignore[attr-defined]
|
||||
|
||||
if not finished:
|
||||
self.fail("Async test method did not complete")
|
||||
|
||||
if isinstance(finished[0], Failure):
|
||||
finished[0].raiseException()
|
||||
|
||||
def mktemp(self) -> str:
|
||||
"""Return a unique temporary path for test use.
|
||||
|
||||
Replacement for twisted.trial.unittest.TestCase.mktemp.
|
||||
Returns a path that does NOT yet exist (matching trial's behavior).
|
||||
"""
|
||||
import tempfile
|
||||
|
||||
d = tempfile.mkdtemp()
|
||||
import os
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(d)
|
||||
return d
|
||||
|
||||
def assertRaises(self, expected_exception, *args, **kwargs): # type: ignore[override]
|
||||
"""Override to match Twisted trial behavior.
|
||||
|
||||
When called with (exception_class, callable, *args), trial returned
|
||||
the exception instance. Stdlib returns None. This override returns
|
||||
the exception for backward compatibility.
|
||||
"""
|
||||
if args:
|
||||
# Callable form: assertRaises(Exc, func, arg1, arg2)
|
||||
ctx = super().assertRaises(expected_exception)
|
||||
with ctx:
|
||||
args[0](*args[1:], **kwargs)
|
||||
return ctx.exception
|
||||
else:
|
||||
# Context manager form: with self.assertRaises(Exc):
|
||||
return super().assertRaises(expected_exception)
|
||||
|
||||
def assertFailure(
|
||||
self, d: "Deferred[Any]", *expected_types: type[BaseException]
|
||||
) -> "Deferred[Any]":
|
||||
"""Assert that a Deferred fails with the given exception type(s).
|
||||
|
||||
Replacement for twisted.trial.unittest.TestCase.assertFailure.
|
||||
Returns the Deferred with an errback that checks the failure type.
|
||||
"""
|
||||
|
||||
def _check(f: Failure) -> Failure:
|
||||
if not f.check(*expected_types):
|
||||
self.fail(
|
||||
f"Expected {expected_types}, got {f.type}: {f.value}"
|
||||
)
|
||||
return f
|
||||
|
||||
return d.addErrback(_check)
|
||||
|
||||
def assertApproximates(
|
||||
self, first: float, second: float, tolerance: float
|
||||
) -> None:
|
||||
"""Assert that first and second are within tolerance of each other.
|
||||
|
||||
Replacement for twisted.trial.unittest.TestCase.assertApproximates.
|
||||
"""
|
||||
if abs(first - second) > tolerance:
|
||||
self.fail(
|
||||
f"{first!r} not within {tolerance!r} of {second!r} "
|
||||
f"(difference: {abs(first - second)!r})"
|
||||
)
|
||||
|
||||
def assertNoResult(self, d: "Deferred[Any]") -> None:
|
||||
"""Assert that a Deferred has not yet fired.
|
||||
|
||||
Replacement for twisted.trial.unittest.TestCase.assertNoResult.
|
||||
"""
|
||||
results: list[Any] = []
|
||||
d.addBoth(results.append)
|
||||
if results:
|
||||
self.fail(f"Expected Deferred to have no result, but it has: {results[0]!r}")
|
||||
|
||||
def successResultOf(self, d: "Deferred[TV]") -> TV:
|
||||
"""Extract the result of a Deferred that has already fired successfully.
|
||||
|
||||
Replacement for twisted.trial.unittest.TestCase.successResultOf.
|
||||
"""
|
||||
results: list[Any] = []
|
||||
d.addBoth(results.append)
|
||||
if not results:
|
||||
self.fail(f"Deferred {d!r} has not fired yet")
|
||||
result = results[0]
|
||||
if isinstance(result, Failure):
|
||||
result.raiseException()
|
||||
return result # type: ignore[return-value]
|
||||
|
||||
def failureResultOf(
|
||||
self, d: "Deferred[Any]", *expected_types: type[BaseException]
|
||||
) -> Failure:
|
||||
"""Extract the Failure from a Deferred that has already fired with error.
|
||||
|
||||
Replacement for twisted.trial.unittest.TestCase.failureResultOf.
|
||||
"""
|
||||
results: list[Any] = []
|
||||
d.addBoth(results.append)
|
||||
if not results:
|
||||
self.fail(f"Deferred {d!r} has not fired yet")
|
||||
result = results[0]
|
||||
if not isinstance(result, Failure):
|
||||
self.fail(f"Deferred {d!r} succeeded with {result!r}, expected failure")
|
||||
if expected_types and not result.check(*expected_types):
|
||||
self.fail(
|
||||
f"Expected {expected_types}, got {result.type}: {result.value}"
|
||||
)
|
||||
return result
|
||||
|
||||
def assertObjectHasAttributes(self, attrs: dict[str, object], obj: object) -> None:
|
||||
"""Asserts that the given object has each of the attributes given, and
|
||||
that the value of each matches according to assertEqual."""
|
||||
@@ -379,12 +533,12 @@ class HomeserverTestCase(TestCase):
|
||||
needs_threadpool: ClassVar[bool] = False
|
||||
servlets: ClassVar[list[RegisterServletsFunc]] = []
|
||||
|
||||
def __init__(self, methodName: str):
|
||||
def __init__(self, methodName: str = "runTest"):
|
||||
super().__init__(methodName)
|
||||
|
||||
# see if we have any additional config for this test
|
||||
method = getattr(self, methodName)
|
||||
self._extra_config = getattr(method, "_extra_config", None)
|
||||
method = getattr(self, methodName, None)
|
||||
self._extra_config = getattr(method, "_extra_config", None) if method else None
|
||||
|
||||
def setUp(self) -> None:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user