diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 55ff80a2e3..af28082301 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -50,8 +50,13 @@ from typing import ( import attr from typing_extensions import ParamSpec -from twisted.internet import defer, threads -from twisted.python.threadpool import ThreadPool +try: + from twisted.internet import defer, threads + from twisted.python.threadpool import ThreadPool + + HAS_TWISTED = True +except ImportError: + HAS_TWISTED = False from synapse.logging.loggers import ExplicitlyConfiguredLogger from synapse.util.stringutils import random_string_insecure_fast @@ -738,34 +743,25 @@ class PreserveLoggingContext: _thread_local = threading.local() _thread_local.current_context = SENTINEL_CONTEXT -# ContextVar kept in sync with _thread_local. This is used by asyncio-native code -# paths (make_future_yieldable, run_coroutine_in_background_native, etc.) and will -# become the sole storage mechanism once all Deferred usage is removed (Phase 7). -# -# IMPORTANT: We cannot use ContextVar as the primary storage while Twisted Deferreds -# are in use, because asyncio's call_later/call_soon run callbacks in context COPIES. -# The _set_context_cb Deferred callback pattern relies on writes being globally visible -# on the thread, which threading.local provides but ContextVar with asyncio does not. _current_context_var: contextvars.ContextVar[ "LoggingContextOrSentinel" ] = contextvars.ContextVar("synapse_logging_context", default=SENTINEL_CONTEXT) def current_context() -> LoggingContextOrSentinel: - """Get the current logging context from thread local storage""" + """Get the current logging context.""" return getattr(_thread_local, "current_context", SENTINEL_CONTEXT) def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSentinel: - """Set the current logging context in thread local storage + """Set the current logging context. + Args: context: The context to activate. Returns: The context that was previously active """ - # everything blows up if we allow current_context to be set to None, so sanity-check - # that now. if context is None: raise TypeError("'context' argument may not be None") @@ -775,7 +771,6 @@ def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSe rusage = get_thread_resource_usage() current.stop(rusage) _thread_local.current_context = context - # Keep ContextVar in sync for asyncio-native code paths _current_context_var.set(context) context.start(rusage) diff --git a/tests/__init__.py b/tests/__init__.py index 4c8633b445..13406f8f7a 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -19,11 +19,4 @@ # # -from twisted.trial import util - -from synapse.util.patch_inline_callbacks import do_patch - -# attempt to do the patch before we load any synapse code -do_patch() - -util.DEFAULT_TIMEOUT_DURATION = 20 +# Test initialization — no longer patches Twisted inlineCallbacks diff --git a/tests/unittest.py b/tests/unittest.py index 6022c750d0..77fb994fe9 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -48,11 +48,12 @@ import signedjson.key import unpaddedbase64 from typing_extensions import Concatenate, ParamSpec +import unittest as _stdlib_unittest + from twisted.internet.defer import Deferred, ensureDeferred from twisted.internet.testing import MemoryReactor, MemoryReactorClock from twisted.python.failure import Failure from twisted.python.threadpool import ThreadPool -from twisted.trial import unittest from twisted.web.resource import Resource from twisted.web.server import Request @@ -179,15 +180,15 @@ def make_homeserver_config_obj(config: dict[str, Any]) -> HomeServerConfig: return deepcopy_config(config_obj) -class TestCase(unittest.TestCase): - """A subclass of twisted.trial's TestCase which looks for 'loglevel' +class TestCase(_stdlib_unittest.TestCase): + """A subclass of stdlib's TestCase which looks for 'loglevel' attributes on both itself and its individual test methods, to override the root logger's logging level while that test (case|method) runs.""" - def __init__(self, methodName: str): + def __init__(self, methodName: str = "runTest"): super().__init__(methodName) - method = getattr(self, methodName) + method = getattr(self, methodName, None) level = getattr(method, "loglevel", getattr(self, "loglevel", None)) @@ -245,6 +246,159 @@ class TestCase(unittest.TestCase): return ret + def _callTestMethod(self, method: Callable[[], Any]) -> None: + """Override to handle async test methods. + + Twisted's trial auto-detected async test methods and wrapped them + with ensureDeferred. We replicate that behavior here by running + the coroutine in the Twisted reactor. + """ + import inspect + + result = method() + if inspect.isawaitable(result): + # Use Twisted's reactor to run the async test + from twisted.internet import defer, reactor + + d = defer.ensureDeferred(result) + + # Use blockingCallFromThread-style approach: run reactor until done + if not d.called: + # Drive the reactor until the deferred fires + finished = [] + d.addBoth(finished.append) + + # If we have a test reactor (HomeserverTestCase), pump it + if hasattr(self, "reactor"): + for _ in range(1000): + if finished: + break + self.reactor.advance(0.1) + else: + # For non-HomeserverTestCase async tests, use the + # Twisted global reactor with iterate() + from twisted.internet import reactor as global_reactor + + # Install the reactor if not already running + for _ in range(10000): + if finished: + break + global_reactor.runUntilCurrent() # type: ignore[attr-defined] + global_reactor.doIteration(0.001) # type: ignore[attr-defined] + + if not finished: + self.fail("Async test method did not complete") + + if isinstance(finished[0], Failure): + finished[0].raiseException() + + def mktemp(self) -> str: + """Return a unique temporary path for test use. + + Replacement for twisted.trial.unittest.TestCase.mktemp. + Returns a path that does NOT yet exist (matching trial's behavior). + """ + import tempfile + + d = tempfile.mkdtemp() + import os + import shutil + + shutil.rmtree(d) + return d + + def assertRaises(self, expected_exception, *args, **kwargs): # type: ignore[override] + """Override to match Twisted trial behavior. + + When called with (exception_class, callable, *args), trial returned + the exception instance. Stdlib returns None. This override returns + the exception for backward compatibility. + """ + if args: + # Callable form: assertRaises(Exc, func, arg1, arg2) + ctx = super().assertRaises(expected_exception) + with ctx: + args[0](*args[1:], **kwargs) + return ctx.exception + else: + # Context manager form: with self.assertRaises(Exc): + return super().assertRaises(expected_exception) + + def assertFailure( + self, d: "Deferred[Any]", *expected_types: type[BaseException] + ) -> "Deferred[Any]": + """Assert that a Deferred fails with the given exception type(s). + + Replacement for twisted.trial.unittest.TestCase.assertFailure. + Returns the Deferred with an errback that checks the failure type. + """ + + def _check(f: Failure) -> Failure: + if not f.check(*expected_types): + self.fail( + f"Expected {expected_types}, got {f.type}: {f.value}" + ) + return f + + return d.addErrback(_check) + + def assertApproximates( + self, first: float, second: float, tolerance: float + ) -> None: + """Assert that first and second are within tolerance of each other. + + Replacement for twisted.trial.unittest.TestCase.assertApproximates. + """ + if abs(first - second) > tolerance: + self.fail( + f"{first!r} not within {tolerance!r} of {second!r} " + f"(difference: {abs(first - second)!r})" + ) + + def assertNoResult(self, d: "Deferred[Any]") -> None: + """Assert that a Deferred has not yet fired. + + Replacement for twisted.trial.unittest.TestCase.assertNoResult. + """ + results: list[Any] = [] + d.addBoth(results.append) + if results: + self.fail(f"Expected Deferred to have no result, but it has: {results[0]!r}") + + def successResultOf(self, d: "Deferred[TV]") -> TV: + """Extract the result of a Deferred that has already fired successfully. + + Replacement for twisted.trial.unittest.TestCase.successResultOf. + """ + results: list[Any] = [] + d.addBoth(results.append) + if not results: + self.fail(f"Deferred {d!r} has not fired yet") + result = results[0] + if isinstance(result, Failure): + result.raiseException() + return result # type: ignore[return-value] + + def failureResultOf( + self, d: "Deferred[Any]", *expected_types: type[BaseException] + ) -> Failure: + """Extract the Failure from a Deferred that has already fired with error. + + Replacement for twisted.trial.unittest.TestCase.failureResultOf. + """ + results: list[Any] = [] + d.addBoth(results.append) + if not results: + self.fail(f"Deferred {d!r} has not fired yet") + result = results[0] + if not isinstance(result, Failure): + self.fail(f"Deferred {d!r} succeeded with {result!r}, expected failure") + if expected_types and not result.check(*expected_types): + self.fail( + f"Expected {expected_types}, got {result.type}: {result.value}" + ) + return result + def assertObjectHasAttributes(self, attrs: dict[str, object], obj: object) -> None: """Asserts that the given object has each of the attributes given, and that the value of each matches according to assertEqual.""" @@ -379,12 +533,12 @@ class HomeserverTestCase(TestCase): needs_threadpool: ClassVar[bool] = False servlets: ClassVar[list[RegisterServletsFunc]] = [] - def __init__(self, methodName: str): + def __init__(self, methodName: str = "runTest"): super().__init__(methodName) # see if we have any additional config for this test - method = getattr(self, methodName) - self._extra_config = getattr(method, "_extra_config", None) + method = getattr(self, methodName, None) + self._extra_config = getattr(method, "_extra_config", None) if method else None def setUp(self) -> None: """