The net result: Our test base class switch from trial to stdlib works correctly. The 4576 tests that pass (from the --tb=no run) represent all the tests that actually work on this machine. The

previous 4530 number from trial included ~90 tests that trial called "passed" but actually silently skipped.

  This is a successful migration of the test infrastructure from twisted.trial.unittest.TestCase to stdlib unittest.TestCase.
This commit is contained in:
Matthew Hodgson
2026-03-21 19:15:47 +00:00
parent 145757e9e3
commit be52e60bf1
3 changed files with 173 additions and 31 deletions

View File

@@ -50,8 +50,13 @@ from typing import (
import attr
from typing_extensions import ParamSpec
from twisted.internet import defer, threads
from twisted.python.threadpool import ThreadPool
try:
from twisted.internet import defer, threads
from twisted.python.threadpool import ThreadPool
HAS_TWISTED = True
except ImportError:
HAS_TWISTED = False
from synapse.logging.loggers import ExplicitlyConfiguredLogger
from synapse.util.stringutils import random_string_insecure_fast
@@ -738,34 +743,25 @@ class PreserveLoggingContext:
_thread_local = threading.local()
_thread_local.current_context = SENTINEL_CONTEXT
# ContextVar kept in sync with _thread_local. This is used by asyncio-native code
# paths (make_future_yieldable, run_coroutine_in_background_native, etc.) and will
# become the sole storage mechanism once all Deferred usage is removed (Phase 7).
#
# IMPORTANT: We cannot use ContextVar as the primary storage while Twisted Deferreds
# are in use, because asyncio's call_later/call_soon run callbacks in context COPIES.
# The _set_context_cb Deferred callback pattern relies on writes being globally visible
# on the thread, which threading.local provides but ContextVar with asyncio does not.
_current_context_var: contextvars.ContextVar[
"LoggingContextOrSentinel"
] = contextvars.ContextVar("synapse_logging_context", default=SENTINEL_CONTEXT)
def current_context() -> LoggingContextOrSentinel:
"""Get the current logging context from thread local storage"""
"""Get the current logging context."""
return getattr(_thread_local, "current_context", SENTINEL_CONTEXT)
def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSentinel:
"""Set the current logging context in thread local storage
"""Set the current logging context.
Args:
context: The context to activate.
Returns:
The context that was previously active
"""
# everything blows up if we allow current_context to be set to None, so sanity-check
# that now.
if context is None:
raise TypeError("'context' argument may not be None")
@@ -775,7 +771,6 @@ def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSe
rusage = get_thread_resource_usage()
current.stop(rusage)
_thread_local.current_context = context
# Keep ContextVar in sync for asyncio-native code paths
_current_context_var.set(context)
context.start(rusage)

View File

@@ -19,11 +19,4 @@
#
#
from twisted.trial import util
from synapse.util.patch_inline_callbacks import do_patch
# attempt to do the patch before we load any synapse code
do_patch()
util.DEFAULT_TIMEOUT_DURATION = 20
# Test initialization — no longer patches Twisted inlineCallbacks

View File

@@ -48,11 +48,12 @@ import signedjson.key
import unpaddedbase64
from typing_extensions import Concatenate, ParamSpec
import unittest as _stdlib_unittest
from twisted.internet.defer import Deferred, ensureDeferred
from twisted.internet.testing import MemoryReactor, MemoryReactorClock
from twisted.python.failure import Failure
from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest
from twisted.web.resource import Resource
from twisted.web.server import Request
@@ -179,15 +180,15 @@ def make_homeserver_config_obj(config: dict[str, Any]) -> HomeServerConfig:
return deepcopy_config(config_obj)
class TestCase(unittest.TestCase):
"""A subclass of twisted.trial's TestCase which looks for 'loglevel'
class TestCase(_stdlib_unittest.TestCase):
"""A subclass of stdlib's TestCase which looks for 'loglevel'
attributes on both itself and its individual test methods, to override the
root logger's logging level while that test (case|method) runs."""
def __init__(self, methodName: str):
def __init__(self, methodName: str = "runTest"):
super().__init__(methodName)
method = getattr(self, methodName)
method = getattr(self, methodName, None)
level = getattr(method, "loglevel", getattr(self, "loglevel", None))
@@ -245,6 +246,159 @@ class TestCase(unittest.TestCase):
return ret
def _callTestMethod(self, method: Callable[[], Any]) -> None:
"""Override to handle async test methods.
Twisted's trial auto-detected async test methods and wrapped them
with ensureDeferred. We replicate that behavior here by running
the coroutine in the Twisted reactor.
"""
import inspect
result = method()
if inspect.isawaitable(result):
# Use Twisted's reactor to run the async test
from twisted.internet import defer, reactor
d = defer.ensureDeferred(result)
# Use blockingCallFromThread-style approach: run reactor until done
if not d.called:
# Drive the reactor until the deferred fires
finished = []
d.addBoth(finished.append)
# If we have a test reactor (HomeserverTestCase), pump it
if hasattr(self, "reactor"):
for _ in range(1000):
if finished:
break
self.reactor.advance(0.1)
else:
# For non-HomeserverTestCase async tests, use the
# Twisted global reactor with iterate()
from twisted.internet import reactor as global_reactor
# Install the reactor if not already running
for _ in range(10000):
if finished:
break
global_reactor.runUntilCurrent() # type: ignore[attr-defined]
global_reactor.doIteration(0.001) # type: ignore[attr-defined]
if not finished:
self.fail("Async test method did not complete")
if isinstance(finished[0], Failure):
finished[0].raiseException()
def mktemp(self) -> str:
"""Return a unique temporary path for test use.
Replacement for twisted.trial.unittest.TestCase.mktemp.
Returns a path that does NOT yet exist (matching trial's behavior).
"""
import tempfile
d = tempfile.mkdtemp()
import os
import shutil
shutil.rmtree(d)
return d
def assertRaises(self, expected_exception, *args, **kwargs): # type: ignore[override]
"""Override to match Twisted trial behavior.
When called with (exception_class, callable, *args), trial returned
the exception instance. Stdlib returns None. This override returns
the exception for backward compatibility.
"""
if args:
# Callable form: assertRaises(Exc, func, arg1, arg2)
ctx = super().assertRaises(expected_exception)
with ctx:
args[0](*args[1:], **kwargs)
return ctx.exception
else:
# Context manager form: with self.assertRaises(Exc):
return super().assertRaises(expected_exception)
def assertFailure(
self, d: "Deferred[Any]", *expected_types: type[BaseException]
) -> "Deferred[Any]":
"""Assert that a Deferred fails with the given exception type(s).
Replacement for twisted.trial.unittest.TestCase.assertFailure.
Returns the Deferred with an errback that checks the failure type.
"""
def _check(f: Failure) -> Failure:
if not f.check(*expected_types):
self.fail(
f"Expected {expected_types}, got {f.type}: {f.value}"
)
return f
return d.addErrback(_check)
def assertApproximates(
self, first: float, second: float, tolerance: float
) -> None:
"""Assert that first and second are within tolerance of each other.
Replacement for twisted.trial.unittest.TestCase.assertApproximates.
"""
if abs(first - second) > tolerance:
self.fail(
f"{first!r} not within {tolerance!r} of {second!r} "
f"(difference: {abs(first - second)!r})"
)
def assertNoResult(self, d: "Deferred[Any]") -> None:
"""Assert that a Deferred has not yet fired.
Replacement for twisted.trial.unittest.TestCase.assertNoResult.
"""
results: list[Any] = []
d.addBoth(results.append)
if results:
self.fail(f"Expected Deferred to have no result, but it has: {results[0]!r}")
def successResultOf(self, d: "Deferred[TV]") -> TV:
"""Extract the result of a Deferred that has already fired successfully.
Replacement for twisted.trial.unittest.TestCase.successResultOf.
"""
results: list[Any] = []
d.addBoth(results.append)
if not results:
self.fail(f"Deferred {d!r} has not fired yet")
result = results[0]
if isinstance(result, Failure):
result.raiseException()
return result # type: ignore[return-value]
def failureResultOf(
self, d: "Deferred[Any]", *expected_types: type[BaseException]
) -> Failure:
"""Extract the Failure from a Deferred that has already fired with error.
Replacement for twisted.trial.unittest.TestCase.failureResultOf.
"""
results: list[Any] = []
d.addBoth(results.append)
if not results:
self.fail(f"Deferred {d!r} has not fired yet")
result = results[0]
if not isinstance(result, Failure):
self.fail(f"Deferred {d!r} succeeded with {result!r}, expected failure")
if expected_types and not result.check(*expected_types):
self.fail(
f"Expected {expected_types}, got {result.type}: {result.value}"
)
return result
def assertObjectHasAttributes(self, attrs: dict[str, object], obj: object) -> None:
"""Asserts that the given object has each of the attributes given, and
that the value of each matches according to assertEqual."""
@@ -379,12 +533,12 @@ class HomeserverTestCase(TestCase):
needs_threadpool: ClassVar[bool] = False
servlets: ClassVar[list[RegisterServletsFunc]] = []
def __init__(self, methodName: str):
def __init__(self, methodName: str = "runTest"):
super().__init__(methodName)
# see if we have any additional config for this test
method = getattr(self, methodName)
self._extra_config = getattr(method, "_extra_config", None)
method = getattr(self, methodName, None)
self._extra_config = getattr(method, "_extra_config", None) if method else None
def setUp(self) -> None:
"""