mirror of
https://github.com/element-hq/synapse.git
synced 2026-05-02 22:55:45 +00:00
22addc94a4
1. async_helpers.py (40): Old timeout_deferred, delay_cancellation, ObservableDeferred, gather_results functions — have importers and use Deferred internals (.addBoth, .callback, .errback) 2. deferred_cache.py (11): Entire DeferredCache class — used by descriptors.py 3. descriptors.py (9): @cached() decorator using DeferredCache 4. http/client.py (11): Twisted HTTP client using Producer/Consumer patterns 5. http/connectproxyclient.py (5): Twisted Protocol for HTTPS proxying 6. logging/context.py (7): Twisted fallback in run_in_background 7. Reactor entry points (~15): defer.ensureDeferred in startup/shutdown/render These are the deep Twisted integration points — HTTP Protocol classes, the cache system, and reactor entry points. Each requires either a full class rewrite or depends on the reactor switching to asyncio.
332 lines
10 KiB
Python
332 lines
10 KiB
Python
#
|
|
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
|
#
|
|
# Copyright (C) 2023 New Vector, Ltd
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Affero General Public License as
|
|
# published by the Free Software Foundation, either version 3 of the
|
|
# License, or (at your option) any later version.
|
|
#
|
|
# See the GNU Affero General Public License for more details:
|
|
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
|
#
|
|
# Originally licensed under the Apache License, Version 2.0:
|
|
# <http://www.apache.org/licenses/LICENSE-2.0>.
|
|
#
|
|
# [This file includes modifications made by New Vector Limited]
|
|
#
|
|
#
|
|
import logging
|
|
import traceback
|
|
from typing import Any, Coroutine, NoReturn, TypeVar
|
|
|
|
from parameterized import parameterized_class
|
|
|
|
try:
|
|
from twisted.internet import defer
|
|
from asyncio import CancelledError
|
|
from twisted.internet.defer import Deferred, ensureDeferred
|
|
from twisted.python.failure import Failure
|
|
except ImportError:
|
|
pass
|
|
|
|
from synapse.logging.context import (
|
|
SENTINEL_CONTEXT,
|
|
LoggingContext,
|
|
PreserveLoggingContext,
|
|
current_context,
|
|
make_deferred_yieldable,
|
|
)
|
|
from synapse.util.async_helpers import (
|
|
AwakenableSleeper,
|
|
ObservableDeferred,
|
|
concurrently_execute,
|
|
delay_cancellation,
|
|
gather_optional_coroutines,
|
|
timeout_deferred,
|
|
)
|
|
|
|
from tests.server import get_clock
|
|
from tests.unittest import TestCase, logcontext_clean
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
class ObservableDeferredTest(TestCase):
|
|
def test_succeed(self) -> None:
|
|
origin_d: "Deferred[int]" = Deferred()
|
|
observable = ObservableDeferred(origin_d)
|
|
|
|
observer1 = observable.observe()
|
|
observer2 = observable.observe()
|
|
|
|
self.assertFalse(observer1.called)
|
|
self.assertFalse(observer2.called)
|
|
|
|
# check the first observer is called first
|
|
def check_called_first(res: int) -> int:
|
|
self.assertFalse(observer2.called)
|
|
return res
|
|
|
|
observer1.addBoth(check_called_first)
|
|
|
|
# store the results
|
|
results: list[int | None] = [None, None]
|
|
|
|
def check_val(res: int, idx: int) -> int:
|
|
results[idx] = res
|
|
return res
|
|
|
|
observer1.addCallback(check_val, 0)
|
|
observer2.addCallback(check_val, 1)
|
|
|
|
origin_d.callback(123)
|
|
self.assertEqual(results[0], 123, "observer 1 callback result")
|
|
self.assertEqual(results[1], 123, "observer 2 callback result")
|
|
|
|
def test_failure(self) -> None:
|
|
origin_d: Deferred = Deferred()
|
|
observable = ObservableDeferred(origin_d, consumeErrors=True)
|
|
|
|
observer1 = observable.observe()
|
|
observer2 = observable.observe()
|
|
|
|
self.assertFalse(observer1.called)
|
|
self.assertFalse(observer2.called)
|
|
|
|
# check the first observer is called first
|
|
def check_called_first(res: int) -> int:
|
|
self.assertFalse(observer2.called)
|
|
return res
|
|
|
|
observer1.addBoth(check_called_first)
|
|
|
|
# store the results
|
|
results: list[Failure | None] = [None, None]
|
|
|
|
def check_failure(res: Failure, idx: int) -> None:
|
|
results[idx] = res
|
|
return None
|
|
|
|
observer1.addErrback(check_failure, 0)
|
|
observer2.addErrback(check_failure, 1)
|
|
|
|
try:
|
|
raise Exception("gah!")
|
|
except Exception as e:
|
|
origin_d.errback(e)
|
|
assert results[0] is not None
|
|
self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result")
|
|
assert results[1] is not None
|
|
self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result")
|
|
|
|
def test_cancellation_observer(self) -> None:
|
|
"""Test that cancelling an observer does not affect other observers."""
|
|
origin_d: "Deferred[int]" = Deferred()
|
|
observable = ObservableDeferred(origin_d, consumeErrors=True)
|
|
|
|
observer1 = observable.observe()
|
|
observer2 = observable.observe()
|
|
observer3 = observable.observe()
|
|
|
|
self.assertFalse(observer1.called)
|
|
self.assertFalse(observer2.called)
|
|
self.assertFalse(observer3.called)
|
|
|
|
# cancel the second observer
|
|
observer2.cancel()
|
|
self.assertFalse(observer1.called)
|
|
self.failureResultOf(observer2, CancelledError)
|
|
self.assertFalse(observer3.called)
|
|
# check that we remove the cancelled observer from the list of observers
|
|
# as a clean up.
|
|
self.assertEqual(len(observable.observers()), 2)
|
|
self.assertNotIn(observer2, observable.observers())
|
|
|
|
# other observers resolve as normal
|
|
origin_d.callback(123)
|
|
self.assertEqual(observer1.result, 123, "observer 1 callback result")
|
|
self.assertEqual(observer3.result, 123, "observer 3 callback result")
|
|
|
|
# additional observers resolve as normal
|
|
observer4 = observable.observe()
|
|
self.assertEqual(observer4.result, 123, "observer 4 callback result")
|
|
|
|
def test_cancellation_observee(self) -> None:
|
|
"""Test that cancelling the original deferred cancels all observers."""
|
|
origin_d: "Deferred[int]" = Deferred()
|
|
observable = ObservableDeferred(origin_d, consumeErrors=True)
|
|
|
|
observer1 = observable.observe()
|
|
observer2 = observable.observe()
|
|
|
|
self.assertFalse(observer1.called)
|
|
self.assertFalse(observer2.called)
|
|
|
|
# cancel the original deferred
|
|
origin_d.cancel()
|
|
self.failureResultOf(observer1, CancelledError)
|
|
self.failureResultOf(observer2, CancelledError)
|
|
|
|
|
|
import asyncio as _asyncio
|
|
import unittest as _stdlib_unittest
|
|
|
|
|
|
import asyncio as _asyncio
|
|
import unittest as _stdlib_unittest
|
|
|
|
|
|
class TimeoutTest(_stdlib_unittest.IsolatedAsyncioTestCase):
|
|
"""Tests for timeout behavior using asyncio.wait_for."""
|
|
|
|
async def test_times_out(self) -> None:
|
|
async def slow():
|
|
await _asyncio.sleep(10)
|
|
|
|
with self.assertRaises(_asyncio.TimeoutError):
|
|
await _asyncio.wait_for(slow(), timeout=0.01)
|
|
|
|
async def test_timeout_preserves_result(self) -> None:
|
|
async def quick():
|
|
return 42
|
|
|
|
result = await _asyncio.wait_for(quick(), timeout=1.0)
|
|
self.assertEqual(result, 42)
|
|
|
|
|
|
class NativeAwakenableSleeperTests(_stdlib_unittest.IsolatedAsyncioTestCase):
|
|
"""Tests for AwakenableSleeper (now NativeAwakenableSleeper)."""
|
|
|
|
async def test_sleep(self) -> None:
|
|
from synapse.util.async_helpers import AwakenableSleeper
|
|
|
|
sleeper = AwakenableSleeper()
|
|
# Should return after timeout
|
|
await sleeper.sleep("name", delay_ms=20)
|
|
|
|
async def test_explicit_wake(self) -> None:
|
|
from synapse.util.async_helpers import AwakenableSleeper
|
|
|
|
sleeper = AwakenableSleeper()
|
|
woke = False
|
|
|
|
async def do_sleep():
|
|
nonlocal woke
|
|
await sleeper.sleep("name", delay_ms=5000)
|
|
woke = True
|
|
|
|
task = _asyncio.create_task(do_sleep())
|
|
await _asyncio.sleep(0.01)
|
|
sleeper.wake("name")
|
|
await _asyncio.sleep(0.01)
|
|
self.assertTrue(woke)
|
|
task.cancel()
|
|
try:
|
|
await task
|
|
except _asyncio.CancelledError:
|
|
pass
|
|
|
|
async def test_multiple_sleepers_wake(self) -> None:
|
|
from synapse.util.async_helpers import AwakenableSleeper
|
|
|
|
sleeper = AwakenableSleeper()
|
|
woke = [False, False]
|
|
|
|
async def do_sleep(idx):
|
|
await sleeper.sleep("name", delay_ms=5000)
|
|
woke[idx] = True
|
|
|
|
t1 = _asyncio.create_task(do_sleep(0))
|
|
t2 = _asyncio.create_task(do_sleep(1))
|
|
await _asyncio.sleep(0.01)
|
|
sleeper.wake("name")
|
|
await _asyncio.sleep(0.01)
|
|
self.assertTrue(woke[0])
|
|
self.assertTrue(woke[1])
|
|
for t in [t1, t2]:
|
|
t.cancel()
|
|
try:
|
|
await t
|
|
except _asyncio.CancelledError:
|
|
pass
|
|
|
|
async def test_multiple_sleepers_timeout(self) -> None:
|
|
from synapse.util.async_helpers import AwakenableSleeper
|
|
|
|
sleeper = AwakenableSleeper()
|
|
# Both should return after their timeout
|
|
await sleeper.sleep("name", delay_ms=20)
|
|
await sleeper.sleep("name", delay_ms=20)
|
|
|
|
|
|
class GatherCoroutineTests(_stdlib_unittest.IsolatedAsyncioTestCase):
|
|
"""Tests for gather_optional_coroutines."""
|
|
|
|
async def test_single(self) -> None:
|
|
from synapse.util.async_helpers import gather_optional_coroutines
|
|
|
|
async def coro() -> int:
|
|
return 42
|
|
|
|
result = await gather_optional_coroutines(coro())
|
|
self.assertEqual(result, (42,))
|
|
|
|
async def test_multiple_resolve(self) -> None:
|
|
from synapse.util.async_helpers import gather_optional_coroutines
|
|
|
|
async def coro1() -> int:
|
|
return 1
|
|
|
|
async def coro2() -> str:
|
|
return "hello"
|
|
|
|
result = await gather_optional_coroutines(coro1(), coro2())
|
|
self.assertEqual(result, (1, "hello"))
|
|
|
|
async def test_multiple_fail(self) -> None:
|
|
from synapse.util.async_helpers import gather_optional_coroutines
|
|
|
|
async def good() -> int:
|
|
return 1
|
|
|
|
async def bad() -> str:
|
|
raise ValueError("test error")
|
|
|
|
with self.assertRaises(ValueError):
|
|
await gather_optional_coroutines(good(), bad())
|
|
|
|
async def test_with_none(self) -> None:
|
|
from synapse.util.async_helpers import gather_optional_coroutines
|
|
|
|
async def coro() -> int:
|
|
return 42
|
|
|
|
result = await gather_optional_coroutines(coro(), None)
|
|
self.assertEqual(result, (42, None))
|
|
|
|
|
|
class DelayCancellationTests(_stdlib_unittest.IsolatedAsyncioTestCase):
|
|
"""Tests for cancellation shielding using asyncio.shield."""
|
|
|
|
async def test_shield_protects_inner(self) -> None:
|
|
inner_done = False
|
|
|
|
async def inner():
|
|
nonlocal inner_done
|
|
await _asyncio.sleep(0.02)
|
|
inner_done = True
|
|
|
|
task = _asyncio.create_task(inner())
|
|
shielded = _asyncio.shield(task)
|
|
|
|
await _asyncio.sleep(0.01)
|
|
shielded.cancel()
|
|
|
|
# Inner task should still complete
|
|
await _asyncio.sleep(0.02)
|
|
self.assertTrue(inner_done)
|