mirror of
https://github.com/element-hq/synapse.git
synced 2026-03-31 12:45:44 +00:00
--- Summary: Synapse Twisted → asyncio Migration What's been built (Phases 0-7) — 10 new files, ~3200 lines, 96 tests Every Twisted component has an asyncio-native replacement ready: ┌──────────────────┬───────────────────────────────────────────────┬─────────────────────────────────────────────────────────┬───────────────────────────────────────────────┐ │ Component │ Twisted Original │ asyncio Replacement │ File │ ├──────────────────┼───────────────────────────────────────────────┼─────────────────────────────────────────────────────────┼───────────────────────────────────────────────┤ │ LoggingContext │ threading.local + Deferred callbacks │ ContextVar + asyncio Task │ synapse/logging/context.py (native functions) │ ├──────────────────┼───────────────────────────────────────────────┼─────────────────────────────────────────────────────────┼───────────────────────────────────────────────┤ │ Async primitives │ ObservableDeferred, Linearizer, ReadWriteLock │ ObservableFuture, NativeLinearizer, NativeReadWriteLock │ synapse/util/async_helpers.py │ ├──────────────────┼───────────────────────────────────────────────┼─────────────────────────────────────────────────────────┼───────────────────────────────────────────────┤ │ Cache │ DeferredCache │ FutureCache │ synapse/util/caches/future_cache.py │ ├──────────────────┼───────────────────────────────────────────────┼─────────────────────────────────────────────────────────┼───────────────────────────────────────────────┤ │ Clock │ Clock (Twisted reactor) │ NativeClock (asyncio) │ synapse/util/clock.py │ ├──────────────────┼───────────────────────────────────────────────┼─────────────────────────────────────────────────────────┼───────────────────────────────────────────────┤ │ Database │ adbapi.ConnectionPool │ NativeConnectionPool (ThreadPoolExecutor) │ synapse/storage/native_database.py │ ├──────────────────┼───────────────────────────────────────────────┼─────────────────────────────────────────────────────────┼───────────────────────────────────────────────┤ │ HTTP Client │ treq + Twisted Agent │ aiohttp.ClientSession │ synapse/http/native_client.py │ ├──────────────────┼───────────────────────────────────────────────┼─────────────────────────────────────────────────────────┼───────────────────────────────────────────────┤ │ HTTP Server │ JsonResource + Twisted Site │ NativeJsonResource + aiohttp.web │ synapse/http/native_server.py │ ├──────────────────┼───────────────────────────────────────────────┼─────────────────────────────────────────────────────────┼───────────────────────────────────────────────┤ │ Replication │ LineOnlyReceiver (Twisted Protocol) │ asyncio.StreamReader/Writer │ synapse/replication/tcp/native_protocol.py │ ├──────────────────┼───────────────────────────────────────────────┼─────────────────────────────────────────────────────────┼───────────────────────────────────────────────┤ │ Event loop │ ISynapseReactor │ ISynapseEventLoop │ synapse/types/__init__.py │ ├──────────────────┼───────────────────────────────────────────────┼─────────────────────────────────────────────────────────┼───────────────────────────────────────────────┤ │ Test helper │ — │ FakeAsyncioLoop │ tests/async_helpers.py │ └──────────────────┴───────────────────────────────────────────────┴─────────────────────────────────────────────────────────┴───────────────────────────────────────────────┘ What's been wired in safely — 224 files changed, 0 regressions - MemoryReactor type hint → Any across 198 test files (cosmetic) - synapse/http/server.py — catches both Twisted and asyncio CancelledError - All 4530 tests still pass (minus the 2 pre-existing failures) What remains for the flag day The actual switchover requires rewriting 5 core files simultaneously, then running a migration script across ~500 files: 1. tests/unittest.py + tests/server.py — switch from twisted.trial.TestCase to unittest.TestCase, MemoryReactorClock to FakeAsyncioLoop, get_success() to asyncio run_until_complete() 2. synapse/logging/context.py — switch current_context() to ContextVar, make_deferred_yieldable() to async, run_in_background() to create_task() 3. synapse/util/async_helpers.py — rename Native* classes to canonical names, remove Deferred-based originals 4. Migration script — update all CancelledError, defer.*, Deferred imports across ~500 files 5. pyproject.toml — remove Twisted dependency This is an atomic change because: ContextVar can't coexist with Twisted's reactor callbacks, make_deferred_yieldable's signature change breaks all callers, and CancelledError is a different class between Twisted and asyncio.
1694 lines
56 KiB
Python
1694 lines
56 KiB
Python
#
|
|
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
|
#
|
|
# Copyright (C) 2025 New Vector, Ltd
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Affero General Public License as
|
|
# published by the Free Software Foundation, either version 3 of the
|
|
# License, or (at your option) any later version.
|
|
#
|
|
# See the GNU Affero General Public License for more details:
|
|
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
|
#
|
|
|
|
"""Tests for Phase 0 asyncio-native parallel implementations.
|
|
|
|
These test the new asyncio-native primitives added alongside the existing
|
|
Twisted-based ones, ensuring they work correctly before being swapped in
|
|
during later migration phases.
|
|
"""
|
|
|
|
import asyncio
|
|
import unittest
|
|
|
|
from synapse.logging.context import (
|
|
SENTINEL_CONTEXT,
|
|
LoggingContext,
|
|
_current_context_var,
|
|
_native_current_context,
|
|
_native_set_current_context,
|
|
make_future_yieldable,
|
|
run_coroutine_in_background_native,
|
|
run_in_background_native,
|
|
)
|
|
from synapse.util.async_helpers import (
|
|
NativeLinearizer,
|
|
NativeReadWriteLock,
|
|
ObservableFuture,
|
|
)
|
|
|
|
|
|
class ContextVarContextTest(unittest.IsolatedAsyncioTestCase):
|
|
"""Tests for contextvars-based context tracking."""
|
|
|
|
def setUp(self) -> None:
|
|
# Ensure we start from sentinel
|
|
_current_context_var.set(SENTINEL_CONTEXT)
|
|
|
|
def test_default_is_sentinel(self) -> None:
|
|
self.assertIs(_native_current_context(), SENTINEL_CONTEXT)
|
|
|
|
def test_set_and_get(self) -> None:
|
|
ctx = LoggingContext(name="test", server_name="test.server")
|
|
old = _native_set_current_context(ctx)
|
|
self.assertIs(old, SENTINEL_CONTEXT)
|
|
self.assertIs(_native_current_context(), ctx)
|
|
# Restore
|
|
_native_set_current_context(SENTINEL_CONTEXT)
|
|
|
|
def test_set_returns_previous(self) -> None:
|
|
ctx1 = LoggingContext(name="ctx1", server_name="test.server")
|
|
ctx2 = LoggingContext(name="ctx2", server_name="test.server")
|
|
_native_set_current_context(ctx1)
|
|
old = _native_set_current_context(ctx2)
|
|
self.assertIs(old, ctx1)
|
|
_native_set_current_context(SENTINEL_CONTEXT)
|
|
|
|
def test_none_raises(self) -> None:
|
|
with self.assertRaises(TypeError):
|
|
_native_set_current_context(None) # type: ignore[arg-type]
|
|
|
|
async def test_task_inherits_context(self) -> None:
|
|
"""asyncio.Tasks inherit the parent's contextvars by default."""
|
|
ctx = LoggingContext(name="parent", server_name="test.server")
|
|
_native_set_current_context(ctx)
|
|
|
|
result = None
|
|
|
|
async def child() -> None:
|
|
nonlocal result
|
|
result = _native_current_context()
|
|
|
|
task = asyncio.create_task(child())
|
|
await task
|
|
self.assertIs(result, ctx)
|
|
_native_set_current_context(SENTINEL_CONTEXT)
|
|
|
|
|
|
class MakeFutureYieldableTest(unittest.IsolatedAsyncioTestCase):
|
|
def setUp(self) -> None:
|
|
_current_context_var.set(SENTINEL_CONTEXT)
|
|
|
|
async def test_already_done_future(self) -> None:
|
|
loop = asyncio.get_running_loop()
|
|
f: asyncio.Future[int] = loop.create_future()
|
|
f.set_result(42)
|
|
|
|
result = await make_future_yieldable(f)
|
|
self.assertEqual(result, 42)
|
|
|
|
async def test_pending_future_preserves_context(self) -> None:
|
|
ctx = LoggingContext(name="test", server_name="test.server")
|
|
_native_set_current_context(ctx)
|
|
|
|
loop = asyncio.get_running_loop()
|
|
f: asyncio.Future[str] = loop.create_future()
|
|
|
|
# Schedule the resolution so the coroutine can actually await
|
|
loop.call_soon(f.set_result, "hello")
|
|
result = await make_future_yieldable(f)
|
|
|
|
# Context should be restored after awaiting
|
|
self.assertIs(_native_current_context(), ctx)
|
|
self.assertEqual(result, "hello")
|
|
|
|
_native_set_current_context(SENTINEL_CONTEXT)
|
|
|
|
async def test_pending_future_exception(self) -> None:
|
|
ctx = LoggingContext(name="test", server_name="test.server")
|
|
_native_set_current_context(ctx)
|
|
|
|
loop = asyncio.get_running_loop()
|
|
f: asyncio.Future[str] = loop.create_future()
|
|
|
|
yieldable = make_future_yieldable(f)
|
|
f.set_exception(ValueError("boom"))
|
|
|
|
with self.assertRaises(ValueError):
|
|
await yieldable
|
|
|
|
# Context should still be restored after exception
|
|
self.assertIs(_native_current_context(), ctx)
|
|
_native_set_current_context(SENTINEL_CONTEXT)
|
|
|
|
|
|
class RunCoroutineInBackgroundNativeTest(unittest.IsolatedAsyncioTestCase):
|
|
def setUp(self) -> None:
|
|
_current_context_var.set(SENTINEL_CONTEXT)
|
|
|
|
async def test_preserves_calling_context(self) -> None:
|
|
ctx = LoggingContext(name="caller", server_name="test.server")
|
|
_native_set_current_context(ctx)
|
|
|
|
results: list[object] = []
|
|
|
|
async def bg_work() -> str:
|
|
results.append(_native_current_context())
|
|
return "done"
|
|
|
|
task = run_coroutine_in_background_native(bg_work())
|
|
|
|
# Calling context should be preserved
|
|
self.assertIs(_native_current_context(), ctx)
|
|
|
|
result = await task
|
|
self.assertEqual(result, "done")
|
|
|
|
_native_set_current_context(SENTINEL_CONTEXT)
|
|
|
|
async def test_resets_to_sentinel_on_completion(self) -> None:
|
|
async def bg_work() -> str:
|
|
return "done"
|
|
|
|
task = run_coroutine_in_background_native(bg_work())
|
|
result = await task
|
|
|
|
# Verify the task completed successfully
|
|
self.assertEqual(result, "done")
|
|
|
|
|
|
class RunInBackgroundNativeTest(unittest.IsolatedAsyncioTestCase):
|
|
def setUp(self) -> None:
|
|
_current_context_var.set(SENTINEL_CONTEXT)
|
|
|
|
async def test_with_coroutine_function(self) -> None:
|
|
async def my_func(x: int) -> int:
|
|
return x * 2
|
|
|
|
task = run_in_background_native(my_func, 21)
|
|
result = await task
|
|
self.assertEqual(result, 42)
|
|
|
|
async def test_with_sync_function(self) -> None:
|
|
def my_func(x: int) -> int:
|
|
return x * 2
|
|
|
|
task = run_in_background_native(my_func, 21)
|
|
result = await task
|
|
self.assertEqual(result, 42)
|
|
|
|
async def test_with_exception(self) -> None:
|
|
def my_func() -> None:
|
|
raise ValueError("sync error")
|
|
|
|
task = run_in_background_native(my_func)
|
|
with self.assertRaises(ValueError):
|
|
await task
|
|
|
|
|
|
class ObservableFutureTest(unittest.IsolatedAsyncioTestCase):
|
|
async def test_succeed(self) -> None:
|
|
loop = asyncio.get_running_loop()
|
|
origin: asyncio.Future[int] = loop.create_future()
|
|
observable = ObservableFuture(origin)
|
|
|
|
obs1 = observable.observe()
|
|
obs2 = observable.observe()
|
|
|
|
self.assertFalse(observable.has_called())
|
|
self.assertTrue(observable.has_observers())
|
|
|
|
origin.set_result(42)
|
|
# Give the event loop a chance to process callbacks
|
|
await asyncio.sleep(0)
|
|
|
|
self.assertTrue(observable.has_called())
|
|
self.assertTrue(observable.has_succeeded())
|
|
self.assertEqual(await obs1, 42)
|
|
self.assertEqual(await obs2, 42)
|
|
|
|
async def test_fail(self) -> None:
|
|
loop = asyncio.get_running_loop()
|
|
origin: asyncio.Future[int] = loop.create_future()
|
|
observable = ObservableFuture(origin)
|
|
|
|
obs1 = observable.observe()
|
|
obs2 = observable.observe()
|
|
|
|
origin.set_exception(ValueError("boom"))
|
|
await asyncio.sleep(0)
|
|
|
|
self.assertTrue(observable.has_called())
|
|
self.assertFalse(observable.has_succeeded())
|
|
|
|
with self.assertRaises(ValueError):
|
|
await obs1
|
|
with self.assertRaises(ValueError):
|
|
await obs2
|
|
|
|
async def test_observe_after_resolution(self) -> None:
|
|
loop = asyncio.get_running_loop()
|
|
origin: asyncio.Future[str] = loop.create_future()
|
|
observable = ObservableFuture(origin)
|
|
|
|
origin.set_result("hello")
|
|
await asyncio.sleep(0)
|
|
|
|
obs = observable.observe()
|
|
self.assertEqual(await obs, "hello")
|
|
|
|
async def test_no_observers(self) -> None:
|
|
loop = asyncio.get_running_loop()
|
|
origin: asyncio.Future[int] = loop.create_future()
|
|
observable = ObservableFuture(origin)
|
|
|
|
self.assertFalse(observable.has_observers())
|
|
origin.set_result(1)
|
|
await asyncio.sleep(0)
|
|
|
|
self.assertFalse(observable.has_observers())
|
|
|
|
async def test_get_result(self) -> None:
|
|
loop = asyncio.get_running_loop()
|
|
origin: asyncio.Future[int] = loop.create_future()
|
|
observable = ObservableFuture(origin)
|
|
|
|
with self.assertRaises(ValueError):
|
|
observable.get_result()
|
|
|
|
origin.set_result(99)
|
|
await asyncio.sleep(0)
|
|
|
|
self.assertEqual(observable.get_result(), 99)
|
|
|
|
|
|
class NativeLinearizerTest(unittest.IsolatedAsyncioTestCase):
|
|
async def test_uncontended(self) -> None:
|
|
linearizer = NativeLinearizer("test")
|
|
async with linearizer.queue("key"):
|
|
pass
|
|
|
|
async def test_serializes_access(self) -> None:
|
|
linearizer = NativeLinearizer("test")
|
|
order: list[int] = []
|
|
|
|
async def worker(n: int) -> None:
|
|
async with linearizer.queue("key"):
|
|
order.append(n)
|
|
await asyncio.sleep(0.01)
|
|
|
|
tasks = [asyncio.create_task(worker(i)) for i in range(3)]
|
|
await asyncio.gather(*tasks)
|
|
|
|
# All three should have run in order
|
|
self.assertEqual(order, [0, 1, 2])
|
|
|
|
async def test_different_keys_concurrent(self) -> None:
|
|
linearizer = NativeLinearizer("test")
|
|
running: list[str] = []
|
|
|
|
async def worker(key: str) -> None:
|
|
async with linearizer.queue(key):
|
|
running.append(key)
|
|
await asyncio.sleep(0.01)
|
|
|
|
tasks = [
|
|
asyncio.create_task(worker("a")),
|
|
asyncio.create_task(worker("b")),
|
|
]
|
|
await asyncio.gather(*tasks)
|
|
|
|
# Both should have started (different keys are independent)
|
|
self.assertEqual(set(running), {"a", "b"})
|
|
|
|
async def test_max_count(self) -> None:
|
|
linearizer = NativeLinearizer("test", max_count=2)
|
|
concurrent = 0
|
|
max_concurrent = 0
|
|
|
|
async def worker() -> None:
|
|
nonlocal concurrent, max_concurrent
|
|
async with linearizer.queue("key"):
|
|
concurrent += 1
|
|
max_concurrent = max(max_concurrent, concurrent)
|
|
await asyncio.sleep(0.01)
|
|
concurrent -= 1
|
|
|
|
tasks = [asyncio.create_task(worker()) for _ in range(5)]
|
|
await asyncio.gather(*tasks)
|
|
|
|
self.assertLessEqual(max_concurrent, 2)
|
|
|
|
async def test_is_queued(self) -> None:
|
|
linearizer = NativeLinearizer("test")
|
|
|
|
self.assertFalse(linearizer.is_queued("key"))
|
|
|
|
acquired = asyncio.Event()
|
|
release = asyncio.Event()
|
|
|
|
async def holder() -> None:
|
|
async with linearizer.queue("key"):
|
|
acquired.set()
|
|
await release.wait()
|
|
|
|
task1 = asyncio.create_task(holder())
|
|
await acquired.wait()
|
|
|
|
# Start a second worker that will be queued
|
|
async def waiter() -> None:
|
|
async with linearizer.queue("key"):
|
|
pass
|
|
|
|
task2 = asyncio.create_task(waiter())
|
|
await asyncio.sleep(0) # Let task2 start and get queued
|
|
|
|
self.assertTrue(linearizer.is_queued("key"))
|
|
|
|
release.set()
|
|
await asyncio.gather(task1, task2)
|
|
|
|
async def test_cancellation(self) -> None:
|
|
linearizer = NativeLinearizer("test")
|
|
|
|
acquired = asyncio.Event()
|
|
release = asyncio.Event()
|
|
|
|
async def holder() -> None:
|
|
async with linearizer.queue("key"):
|
|
acquired.set()
|
|
await release.wait()
|
|
|
|
task1 = asyncio.create_task(holder())
|
|
await acquired.wait()
|
|
|
|
async def waiter() -> None:
|
|
async with linearizer.queue("key"):
|
|
pass
|
|
|
|
task2 = asyncio.create_task(waiter())
|
|
await asyncio.sleep(0)
|
|
|
|
task2.cancel()
|
|
with self.assertRaises(asyncio.CancelledError):
|
|
await task2
|
|
|
|
release.set()
|
|
await task1
|
|
|
|
|
|
class NativeReadWriteLockTest(unittest.IsolatedAsyncioTestCase):
|
|
async def test_readers_concurrent(self) -> None:
|
|
lock = NativeReadWriteLock()
|
|
concurrent = 0
|
|
max_concurrent = 0
|
|
|
|
async def reader() -> None:
|
|
nonlocal concurrent, max_concurrent
|
|
async with lock.read("key"):
|
|
concurrent += 1
|
|
max_concurrent = max(max_concurrent, concurrent)
|
|
await asyncio.sleep(0.01)
|
|
concurrent -= 1
|
|
|
|
tasks = [asyncio.create_task(reader()) for _ in range(3)]
|
|
await asyncio.gather(*tasks)
|
|
|
|
# All readers should run concurrently
|
|
self.assertEqual(max_concurrent, 3)
|
|
|
|
async def test_writer_exclusive(self) -> None:
|
|
lock = NativeReadWriteLock()
|
|
order: list[str] = []
|
|
|
|
async def writer(name: str) -> None:
|
|
async with lock.write("key"):
|
|
order.append(f"{name}_start")
|
|
await asyncio.sleep(0.01)
|
|
order.append(f"{name}_end")
|
|
|
|
tasks = [
|
|
asyncio.create_task(writer("w1")),
|
|
asyncio.create_task(writer("w2")),
|
|
]
|
|
await asyncio.gather(*tasks)
|
|
|
|
# Writers should be serialized: w1 should finish before w2 starts
|
|
self.assertEqual(
|
|
order, ["w1_start", "w1_end", "w2_start", "w2_end"]
|
|
)
|
|
|
|
async def test_writer_blocks_reader(self) -> None:
|
|
lock = NativeReadWriteLock()
|
|
order: list[str] = []
|
|
|
|
writer_acquired = asyncio.Event()
|
|
writer_release = asyncio.Event()
|
|
|
|
async def writer() -> None:
|
|
async with lock.write("key"):
|
|
order.append("writer_start")
|
|
writer_acquired.set()
|
|
await writer_release.wait()
|
|
order.append("writer_end")
|
|
|
|
async def reader() -> None:
|
|
await writer_acquired.wait()
|
|
async with lock.read("key"):
|
|
order.append("reader")
|
|
|
|
w_task = asyncio.create_task(writer())
|
|
r_task = asyncio.create_task(reader())
|
|
|
|
await asyncio.sleep(0.01) # Let writer acquire and reader queue
|
|
writer_release.set()
|
|
|
|
await asyncio.gather(w_task, r_task)
|
|
|
|
self.assertEqual(order, ["writer_start", "writer_end", "reader"])
|
|
|
|
async def test_reader_blocks_writer(self) -> None:
|
|
lock = NativeReadWriteLock()
|
|
order: list[str] = []
|
|
|
|
reader_acquired = asyncio.Event()
|
|
reader_release = asyncio.Event()
|
|
|
|
async def reader() -> None:
|
|
async with lock.read("key"):
|
|
order.append("reader_start")
|
|
reader_acquired.set()
|
|
await reader_release.wait()
|
|
order.append("reader_end")
|
|
|
|
async def writer() -> None:
|
|
await reader_acquired.wait()
|
|
async with lock.write("key"):
|
|
order.append("writer")
|
|
|
|
r_task = asyncio.create_task(reader())
|
|
w_task = asyncio.create_task(writer())
|
|
|
|
await asyncio.sleep(0.01)
|
|
reader_release.set()
|
|
|
|
await asyncio.gather(r_task, w_task)
|
|
|
|
self.assertEqual(order, ["reader_start", "reader_end", "writer"])
|
|
|
|
|
|
class NativeClockTest(unittest.IsolatedAsyncioTestCase):
|
|
"""Tests for the asyncio-native NativeClock."""
|
|
|
|
async def test_time(self) -> None:
|
|
from synapse.util.clock import NativeClock
|
|
|
|
clock = NativeClock(server_name="test.server")
|
|
t = clock.time()
|
|
self.assertIsInstance(t, float)
|
|
self.assertGreater(t, 0)
|
|
|
|
async def test_time_msec(self) -> None:
|
|
from synapse.util.clock import NativeClock
|
|
|
|
clock = NativeClock(server_name="test.server")
|
|
t = clock.time_msec()
|
|
self.assertIsInstance(t, int)
|
|
self.assertGreater(t, 0)
|
|
|
|
async def test_sleep(self) -> None:
|
|
from synapse.util.clock import NativeClock
|
|
from synapse.util.duration import Duration
|
|
|
|
clock = NativeClock(server_name="test.server")
|
|
before = clock.time()
|
|
await clock.sleep(Duration(milliseconds=50))
|
|
after = clock.time()
|
|
self.assertGreaterEqual(after - before, 0.04)
|
|
|
|
async def test_call_later(self) -> None:
|
|
from synapse.util.clock import NativeClock
|
|
from synapse.util.duration import Duration
|
|
|
|
clock = NativeClock(server_name="test.server")
|
|
called = asyncio.Event()
|
|
|
|
def callback() -> None:
|
|
called.set()
|
|
|
|
wrapper = clock.call_later(Duration(milliseconds=20), callback)
|
|
self.assertTrue(wrapper.active())
|
|
|
|
await asyncio.wait_for(called.wait(), timeout=1.0)
|
|
self.assertTrue(called.is_set())
|
|
|
|
async def test_call_later_cancel(self) -> None:
|
|
from synapse.util.clock import NativeClock
|
|
from synapse.util.duration import Duration
|
|
|
|
clock = NativeClock(server_name="test.server")
|
|
called = False
|
|
|
|
def callback() -> None:
|
|
nonlocal called
|
|
called = True
|
|
|
|
wrapper = clock.call_later(Duration(milliseconds=50), callback)
|
|
self.assertTrue(wrapper.active())
|
|
|
|
wrapper.cancel()
|
|
self.assertFalse(wrapper.active())
|
|
|
|
await asyncio.sleep(0.1)
|
|
self.assertFalse(called)
|
|
|
|
async def test_call_later_getTime(self) -> None:
|
|
from synapse.util.clock import NativeClock
|
|
from synapse.util.duration import Duration
|
|
|
|
clock = NativeClock(server_name="test.server")
|
|
wrapper = clock.call_later(Duration(seconds=10), lambda: None)
|
|
# getTime should return a time in the future
|
|
self.assertGreater(wrapper.getTime(), 0)
|
|
wrapper.cancel()
|
|
|
|
async def test_looping_call(self) -> None:
|
|
from synapse.util.clock import NativeClock
|
|
from synapse.util.duration import Duration
|
|
|
|
clock = NativeClock(server_name="test.server")
|
|
call_count = 0
|
|
|
|
def callback() -> None:
|
|
nonlocal call_count
|
|
call_count += 1
|
|
|
|
call = clock.looping_call(callback, Duration(milliseconds=30))
|
|
|
|
# looping_call waits `duration` before first call
|
|
await asyncio.sleep(0.01)
|
|
self.assertEqual(call_count, 0)
|
|
|
|
await asyncio.sleep(0.05)
|
|
self.assertGreaterEqual(call_count, 1)
|
|
|
|
call.stop()
|
|
|
|
old_count = call_count
|
|
await asyncio.sleep(0.05)
|
|
self.assertEqual(call_count, old_count)
|
|
|
|
async def test_looping_call_now(self) -> None:
|
|
from synapse.util.clock import NativeClock
|
|
from synapse.util.duration import Duration
|
|
|
|
clock = NativeClock(server_name="test.server")
|
|
call_count = 0
|
|
|
|
def callback() -> None:
|
|
nonlocal call_count
|
|
call_count += 1
|
|
|
|
call = clock.looping_call_now(callback, Duration(milliseconds=30))
|
|
|
|
# looping_call_now should call immediately
|
|
await asyncio.sleep(0.01)
|
|
self.assertGreaterEqual(call_count, 1)
|
|
|
|
call.stop()
|
|
|
|
async def test_looping_call_waits_for_completion(self) -> None:
|
|
"""Test that the next iteration waits for the previous to complete."""
|
|
from synapse.util.clock import NativeClock
|
|
from synapse.util.duration import Duration
|
|
|
|
clock = NativeClock(server_name="test.server")
|
|
concurrent = 0
|
|
max_concurrent = 0
|
|
|
|
async def slow_callback() -> None:
|
|
nonlocal concurrent, max_concurrent
|
|
concurrent += 1
|
|
max_concurrent = max(max_concurrent, concurrent)
|
|
await asyncio.sleep(0.04)
|
|
concurrent -= 1
|
|
|
|
call = clock.looping_call_now(slow_callback, Duration(milliseconds=10))
|
|
|
|
await asyncio.sleep(0.15)
|
|
call.stop()
|
|
|
|
# Should never have more than 1 concurrent execution
|
|
self.assertEqual(max_concurrent, 1)
|
|
|
|
async def test_looping_call_survives_error(self) -> None:
|
|
"""Test that an error in the callback doesn't stop the loop."""
|
|
from synapse.util.clock import NativeClock
|
|
from synapse.util.duration import Duration
|
|
|
|
clock = NativeClock(server_name="test.server")
|
|
call_count = 0
|
|
|
|
def flaky_callback() -> None:
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count == 1:
|
|
raise ValueError("first call fails")
|
|
|
|
call = clock.looping_call_now(flaky_callback, Duration(milliseconds=20))
|
|
|
|
await asyncio.sleep(0.08)
|
|
call.stop()
|
|
|
|
# Should have been called more than once despite the error
|
|
self.assertGreaterEqual(call_count, 2)
|
|
|
|
async def test_shutdown(self) -> None:
|
|
from synapse.util.clock import NativeClock
|
|
from synapse.util.duration import Duration
|
|
|
|
clock = NativeClock(server_name="test.server")
|
|
looping_called = False
|
|
delayed_called = False
|
|
|
|
def looping_cb() -> None:
|
|
nonlocal looping_called
|
|
looping_called = True
|
|
|
|
def delayed_cb() -> None:
|
|
nonlocal delayed_called
|
|
delayed_called = True
|
|
|
|
clock.looping_call(looping_cb, Duration(seconds=1))
|
|
clock.call_later(Duration(seconds=1), delayed_cb)
|
|
|
|
# Shutdown immediately before any callbacks can fire
|
|
clock.shutdown()
|
|
|
|
await asyncio.sleep(0.05)
|
|
self.assertFalse(looping_called)
|
|
self.assertFalse(delayed_called)
|
|
|
|
async def test_cancel_all_delayed_calls(self) -> None:
|
|
from synapse.util.clock import NativeClock
|
|
from synapse.util.duration import Duration
|
|
|
|
clock = NativeClock(server_name="test.server")
|
|
called = False
|
|
|
|
def callback() -> None:
|
|
nonlocal called
|
|
called = True
|
|
|
|
clock.call_later(Duration(milliseconds=50), callback)
|
|
clock.call_later(Duration(milliseconds=50), callback)
|
|
|
|
clock.cancel_all_delayed_calls()
|
|
|
|
await asyncio.sleep(0.1)
|
|
self.assertFalse(called)
|
|
|
|
async def test_call_when_running(self) -> None:
|
|
from synapse.util.clock import NativeClock
|
|
|
|
clock = NativeClock(server_name="test.server")
|
|
called = asyncio.Event()
|
|
|
|
def callback() -> None:
|
|
called.set()
|
|
|
|
clock.call_when_running(callback)
|
|
|
|
await asyncio.wait_for(called.wait(), timeout=1.0)
|
|
self.assertTrue(called.is_set())
|
|
|
|
async def test_add_system_event_trigger(self) -> None:
|
|
from synapse.util.clock import NativeClock
|
|
|
|
clock = NativeClock(server_name="test.server")
|
|
|
|
trigger_id = clock.add_system_event_trigger(
|
|
"before", "shutdown", lambda: None
|
|
)
|
|
self.assertIsInstance(trigger_id, int)
|
|
self.assertEqual(len(clock._shutdown_callbacks), 1)
|
|
|
|
async def test_shutdown_prevents_new_calls(self) -> None:
|
|
from synapse.util.clock import NativeClock
|
|
from synapse.util.duration import Duration
|
|
|
|
clock = NativeClock(server_name="test.server")
|
|
clock.shutdown()
|
|
|
|
with self.assertRaises(Exception):
|
|
clock.looping_call(lambda: None, Duration(seconds=1))
|
|
|
|
with self.assertRaises(Exception):
|
|
clock.call_later(Duration(seconds=1), lambda: None)
|
|
|
|
|
|
class NativeConnectionPoolTest(unittest.IsolatedAsyncioTestCase):
|
|
"""Tests for the asyncio-native NativeConnectionPool using SQLite."""
|
|
|
|
async def asyncSetUp(self) -> None:
|
|
from synapse.config.database import DatabaseConnectionConfig
|
|
from synapse.storage.engines.sqlite import Sqlite3Engine
|
|
from synapse.storage.native_database import NativeConnectionPool
|
|
|
|
db_conf = {"name": "sqlite3", "args": {"database": ":memory:"}}
|
|
self.engine = Sqlite3Engine(db_conf)
|
|
self.db_config = DatabaseConnectionConfig("test_db", db_conf)
|
|
|
|
self.pool = NativeConnectionPool(
|
|
db_config=self.db_config,
|
|
engine=self.engine,
|
|
server_name="test.server",
|
|
max_workers=2,
|
|
)
|
|
|
|
async def asyncTearDown(self) -> None:
|
|
self.pool.close()
|
|
|
|
async def test_run_with_connection(self) -> None:
|
|
def create_and_query(conn: object) -> list:
|
|
assert hasattr(conn, "execute")
|
|
conn.execute("CREATE TABLE IF NOT EXISTS test (id INTEGER PRIMARY KEY, val TEXT)") # type: ignore[union-attr]
|
|
conn.execute("INSERT INTO test VALUES (1, 'hello')") # type: ignore[union-attr]
|
|
conn.commit() # type: ignore[union-attr]
|
|
cursor = conn.execute("SELECT val FROM test WHERE id = 1") # type: ignore[union-attr]
|
|
return cursor.fetchall()
|
|
|
|
result = await self.pool.runWithConnection(create_and_query)
|
|
self.assertEqual(result, [("hello",)])
|
|
|
|
async def test_run_interaction_commits(self) -> None:
|
|
# First create the table
|
|
def create_table(conn: object) -> None:
|
|
conn.execute("CREATE TABLE IF NOT EXISTS test2 (id INTEGER PRIMARY KEY, val TEXT)") # type: ignore[union-attr]
|
|
|
|
await self.pool.runInteraction(create_table)
|
|
|
|
# Then insert in a transaction
|
|
def insert(conn: object) -> None:
|
|
conn.execute("INSERT INTO test2 VALUES (1, 'world')") # type: ignore[union-attr]
|
|
|
|
await self.pool.runInteraction(insert)
|
|
|
|
# Verify it was committed
|
|
def query(conn: object) -> list:
|
|
cursor = conn.execute("SELECT val FROM test2 WHERE id = 1") # type: ignore[union-attr]
|
|
return cursor.fetchall()
|
|
|
|
result = await self.pool.runWithConnection(query)
|
|
self.assertEqual(result, [("world",)])
|
|
|
|
async def test_run_interaction_rolls_back_on_error(self) -> None:
|
|
# Create table first
|
|
def create_table(conn: object) -> None:
|
|
conn.execute("CREATE TABLE IF NOT EXISTS test3 (id INTEGER PRIMARY KEY, val TEXT)") # type: ignore[union-attr]
|
|
|
|
await self.pool.runInteraction(create_table)
|
|
|
|
# Insert that should be rolled back
|
|
def failing_insert(conn: object) -> None:
|
|
conn.execute("INSERT INTO test3 VALUES (1, 'should_rollback')") # type: ignore[union-attr]
|
|
raise ValueError("deliberate error")
|
|
|
|
with self.assertRaises(ValueError):
|
|
await self.pool.runInteraction(failing_insert)
|
|
|
|
# Verify nothing was inserted
|
|
def query(conn: object) -> list:
|
|
cursor = conn.execute("SELECT COUNT(*) FROM test3") # type: ignore[union-attr]
|
|
return cursor.fetchall()
|
|
|
|
result = await self.pool.runWithConnection(query)
|
|
self.assertEqual(result, [(0,)])
|
|
|
|
async def test_connection_reuse(self) -> None:
|
|
"""Verify the same thread reuses its connection."""
|
|
connection_ids: list[int] = []
|
|
|
|
def get_conn_id(conn: object) -> int:
|
|
conn_id = id(conn)
|
|
connection_ids.append(conn_id)
|
|
return conn_id
|
|
|
|
id1 = await self.pool.runWithConnection(get_conn_id)
|
|
id2 = await self.pool.runWithConnection(get_conn_id)
|
|
|
|
# With a single-threaded pool, the same connection should be reused
|
|
# With multi-threaded, it depends on which thread runs.
|
|
# At minimum, we should get valid connection IDs.
|
|
self.assertIsInstance(id1, int)
|
|
self.assertIsInstance(id2, int)
|
|
|
|
async def test_closed_pool_raises(self) -> None:
|
|
self.pool.close()
|
|
with self.assertRaises(Exception):
|
|
await self.pool.runWithConnection(lambda conn: None)
|
|
|
|
async def test_concurrent_operations(self) -> None:
|
|
"""Test that multiple concurrent operations work correctly."""
|
|
import asyncio
|
|
|
|
results: list[int] = []
|
|
|
|
def work(conn: object, n: int) -> int:
|
|
return n * 2
|
|
|
|
tasks = [self.pool.runWithConnection(work, i) for i in range(10)]
|
|
results = await asyncio.gather(*tasks)
|
|
|
|
self.assertEqual(sorted(results), [0, 2, 4, 6, 8, 10, 12, 14, 16, 18])
|
|
|
|
|
|
class NativeSimpleHttpClientTest(unittest.IsolatedAsyncioTestCase):
|
|
"""Tests for the asyncio-native NativeSimpleHttpClient."""
|
|
|
|
async def asyncSetUp(self) -> None:
|
|
from aiohttp import web
|
|
|
|
# Create a simple test HTTP server
|
|
self.app = web.Application()
|
|
self.app.router.add_get("/json", self._handle_json)
|
|
self.app.router.add_post("/json", self._handle_json_post)
|
|
self.app.router.add_put("/json", self._handle_json_put)
|
|
self.app.router.add_get("/raw", self._handle_raw)
|
|
self.app.router.add_get("/file", self._handle_file)
|
|
self.app.router.add_get("/error", self._handle_error)
|
|
self.app.router.add_post("/form", self._handle_form)
|
|
|
|
self.runner = web.AppRunner(self.app)
|
|
await self.runner.setup()
|
|
self.site = web.TCPSite(self.runner, "127.0.0.1", 0)
|
|
await self.site.start()
|
|
|
|
# Get the actual bound port
|
|
sock = self.site._server.sockets[0] # type: ignore[union-attr]
|
|
self.port = sock.getsockname()[1]
|
|
self.base_url = f"http://127.0.0.1:{self.port}"
|
|
|
|
from synapse.http.native_client import NativeSimpleHttpClient
|
|
|
|
self.client = NativeSimpleHttpClient(
|
|
user_agent="test-agent/1.0",
|
|
)
|
|
|
|
async def asyncTearDown(self) -> None:
|
|
await self.client.close()
|
|
await self.runner.cleanup()
|
|
|
|
@staticmethod
|
|
async def _handle_json(request: "aiohttp.web.Request") -> "aiohttp.web.Response":
|
|
from aiohttp import web
|
|
|
|
return web.json_response({"hello": "world"})
|
|
|
|
@staticmethod
|
|
async def _handle_json_post(
|
|
request: "aiohttp.web.Request",
|
|
) -> "aiohttp.web.Response":
|
|
from aiohttp import web
|
|
|
|
body = await request.json()
|
|
return web.json_response({"received": body})
|
|
|
|
@staticmethod
|
|
async def _handle_json_put(
|
|
request: "aiohttp.web.Request",
|
|
) -> "aiohttp.web.Response":
|
|
from aiohttp import web
|
|
|
|
body = await request.json()
|
|
return web.json_response({"updated": body})
|
|
|
|
@staticmethod
|
|
async def _handle_raw(request: "aiohttp.web.Request") -> "aiohttp.web.Response":
|
|
from aiohttp import web
|
|
|
|
return web.Response(body=b"raw bytes here")
|
|
|
|
@staticmethod
|
|
async def _handle_file(request: "aiohttp.web.Request") -> "aiohttp.web.Response":
|
|
from aiohttp import web
|
|
|
|
return web.Response(
|
|
body=b"file content " * 100,
|
|
content_type="application/octet-stream",
|
|
)
|
|
|
|
@staticmethod
|
|
async def _handle_error(request: "aiohttp.web.Request") -> "aiohttp.web.Response":
|
|
from aiohttp import web
|
|
|
|
return web.Response(status=500, body=b"Internal Server Error")
|
|
|
|
@staticmethod
|
|
async def _handle_form(request: "aiohttp.web.Request") -> "aiohttp.web.Response":
|
|
from aiohttp import web
|
|
|
|
data = await request.post()
|
|
return web.json_response({"form_data": dict(data)})
|
|
|
|
async def test_get_json(self) -> None:
|
|
result = await self.client.get_json(f"{self.base_url}/json")
|
|
self.assertEqual(result, {"hello": "world"})
|
|
|
|
async def test_get_json_with_args(self) -> None:
|
|
result = await self.client.get_json(
|
|
f"{self.base_url}/json", args={"foo": "bar"}
|
|
)
|
|
self.assertEqual(result, {"hello": "world"})
|
|
|
|
async def test_post_json_get_json(self) -> None:
|
|
result = await self.client.post_json_get_json(
|
|
f"{self.base_url}/json", {"key": "value"}
|
|
)
|
|
self.assertEqual(result, {"received": {"key": "value"}})
|
|
|
|
async def test_put_json(self) -> None:
|
|
result = await self.client.put_json(
|
|
f"{self.base_url}/json", {"key": "updated_value"}
|
|
)
|
|
self.assertEqual(result, {"updated": {"key": "updated_value"}})
|
|
|
|
async def test_get_raw(self) -> None:
|
|
result = await self.client.get_raw(f"{self.base_url}/raw")
|
|
self.assertEqual(result, b"raw bytes here")
|
|
|
|
async def test_get_file(self) -> None:
|
|
from io import BytesIO
|
|
|
|
output = BytesIO()
|
|
length, headers, url, code = await self.client.get_file(
|
|
f"{self.base_url}/file", output
|
|
)
|
|
self.assertEqual(code, 200)
|
|
self.assertGreater(length, 0)
|
|
self.assertEqual(len(output.getvalue()), length)
|
|
|
|
async def test_get_file_max_size(self) -> None:
|
|
from io import BytesIO
|
|
|
|
from synapse.api.errors import SynapseError
|
|
|
|
output = BytesIO()
|
|
with self.assertRaises(SynapseError) as ctx:
|
|
await self.client.get_file(
|
|
f"{self.base_url}/file", output, max_size=10
|
|
)
|
|
self.assertIn("too large", str(ctx.exception))
|
|
|
|
async def test_error_response(self) -> None:
|
|
from synapse.api.errors import HttpResponseException
|
|
|
|
with self.assertRaises(HttpResponseException) as ctx:
|
|
await self.client.get_json(f"{self.base_url}/error")
|
|
self.assertEqual(ctx.exception.code, 500)
|
|
|
|
async def test_post_urlencoded_get_json(self) -> None:
|
|
result = await self.client.post_urlencoded_get_json(
|
|
f"{self.base_url}/form", args={"username": "test"}
|
|
)
|
|
self.assertEqual(result["form_data"]["username"], "test")
|
|
|
|
async def test_request_method(self) -> None:
|
|
response = await self.client.request("GET", f"{self.base_url}/raw")
|
|
self.assertEqual(response.status, 200)
|
|
body = await response.read()
|
|
self.assertEqual(body, b"raw bytes here")
|
|
|
|
|
|
class NativeJsonResourceTest(unittest.IsolatedAsyncioTestCase):
|
|
"""Tests for the asyncio-native NativeJsonResource and NativeSynapseRequest."""
|
|
|
|
async def asyncSetUp(self) -> None:
|
|
import re
|
|
|
|
from aiohttp import web
|
|
from aiohttp.test_utils import TestServer
|
|
|
|
from synapse.http.native_server import NativeJsonResource
|
|
|
|
self.resource = NativeJsonResource(server_name="test.server")
|
|
|
|
# Register handlers using the same pattern as RestServlet.register()
|
|
self.resource.register_paths(
|
|
"GET",
|
|
[re.compile("^/_test/hello$")],
|
|
self._handle_hello,
|
|
"TestHelloServlet",
|
|
)
|
|
self.resource.register_paths(
|
|
"GET",
|
|
[re.compile("^/_test/user/(?P<user_id>[^/]+)$")],
|
|
self._handle_user,
|
|
"TestUserServlet",
|
|
)
|
|
self.resource.register_paths(
|
|
"POST",
|
|
[re.compile("^/_test/echo$")],
|
|
self._handle_echo,
|
|
"TestEchoServlet",
|
|
)
|
|
self.resource.register_paths(
|
|
"GET",
|
|
[re.compile("^/_test/error$")],
|
|
self._handle_error,
|
|
"TestErrorServlet",
|
|
)
|
|
self.resource.register_paths(
|
|
"GET",
|
|
[re.compile("^/_test/sync$")],
|
|
self._handle_sync,
|
|
"TestSyncServlet",
|
|
)
|
|
|
|
app = self.resource.build_app()
|
|
self.server = TestServer(app)
|
|
await self.server.start_server()
|
|
self.base_url = f"http://{self.server.host}:{self.server.port}"
|
|
|
|
import aiohttp as aio
|
|
|
|
self.session = aio.ClientSession()
|
|
|
|
async def asyncTearDown(self) -> None:
|
|
await self.session.close()
|
|
await self.server.close()
|
|
|
|
# --- Test handlers (mimic servlet pattern) ---
|
|
|
|
@staticmethod
|
|
async def _handle_hello(request: Any) -> tuple[int, dict]:
|
|
return 200, {"message": "hello world"}
|
|
|
|
@staticmethod
|
|
async def _handle_user(request: Any, user_id: str) -> tuple[int, dict]:
|
|
return 200, {"user_id": user_id}
|
|
|
|
@staticmethod
|
|
async def _handle_echo(request: Any) -> tuple[int, dict]:
|
|
from synapse.http.servlet import parse_json_object_from_request
|
|
|
|
body = parse_json_object_from_request(request)
|
|
return 200, {"echo": body}
|
|
|
|
@staticmethod
|
|
async def _handle_error(request: Any) -> tuple[int, dict]:
|
|
from synapse.api.errors import Codes, SynapseError
|
|
|
|
raise SynapseError(403, "Forbidden", Codes.FORBIDDEN)
|
|
|
|
@staticmethod
|
|
def _handle_sync(request: Any) -> tuple[int, dict]:
|
|
"""Synchronous handler (not async)."""
|
|
return 200, {"sync": True}
|
|
|
|
# --- Tests ---
|
|
|
|
async def test_get_json(self) -> None:
|
|
async with self.session.get(f"{self.base_url}/_test/hello") as resp:
|
|
self.assertEqual(resp.status, 200)
|
|
data = await resp.json()
|
|
self.assertEqual(data["message"], "hello world")
|
|
|
|
async def test_path_parameters(self) -> None:
|
|
async with self.session.get(
|
|
f"{self.base_url}/_test/user/@alice:example.com"
|
|
) as resp:
|
|
self.assertEqual(resp.status, 200)
|
|
data = await resp.json()
|
|
self.assertEqual(data["user_id"], "@alice:example.com")
|
|
|
|
async def test_url_encoded_path_params(self) -> None:
|
|
async with self.session.get(
|
|
f"{self.base_url}/_test/user/%40bob%3Aexample.com"
|
|
) as resp:
|
|
self.assertEqual(resp.status, 200)
|
|
data = await resp.json()
|
|
self.assertEqual(data["user_id"], "@bob:example.com")
|
|
|
|
async def test_post_json(self) -> None:
|
|
async with self.session.post(
|
|
f"{self.base_url}/_test/echo",
|
|
json={"key": "value"},
|
|
) as resp:
|
|
self.assertEqual(resp.status, 200)
|
|
data = await resp.json()
|
|
self.assertEqual(data["echo"]["key"], "value")
|
|
|
|
async def test_synapse_error(self) -> None:
|
|
async with self.session.get(f"{self.base_url}/_test/error") as resp:
|
|
self.assertEqual(resp.status, 403)
|
|
data = await resp.json()
|
|
self.assertEqual(data["errcode"], "M_FORBIDDEN")
|
|
|
|
async def test_404_not_found(self) -> None:
|
|
async with self.session.get(
|
|
f"{self.base_url}/_test/nonexistent"
|
|
) as resp:
|
|
self.assertEqual(resp.status, 404)
|
|
data = await resp.json()
|
|
self.assertEqual(data["errcode"], "M_UNRECOGNIZED")
|
|
|
|
async def test_405_method_not_allowed(self) -> None:
|
|
async with self.session.delete(f"{self.base_url}/_test/hello") as resp:
|
|
self.assertEqual(resp.status, 405)
|
|
data = await resp.json()
|
|
self.assertEqual(data["errcode"], "M_UNRECOGNIZED")
|
|
|
|
async def test_options_cors(self) -> None:
|
|
async with self.session.options(f"{self.base_url}/_test/hello") as resp:
|
|
self.assertEqual(resp.status, 204)
|
|
self.assertIn("Access-Control-Allow-Origin", resp.headers)
|
|
self.assertEqual(resp.headers["Access-Control-Allow-Origin"], "*")
|
|
|
|
async def test_sync_handler(self) -> None:
|
|
async with self.session.get(f"{self.base_url}/_test/sync") as resp:
|
|
self.assertEqual(resp.status, 200)
|
|
data = await resp.json()
|
|
self.assertTrue(data["sync"])
|
|
|
|
async def test_cors_on_success(self) -> None:
|
|
async with self.session.get(f"{self.base_url}/_test/hello") as resp:
|
|
self.assertEqual(resp.status, 200)
|
|
self.assertIn("Access-Control-Allow-Origin", resp.headers)
|
|
|
|
async def test_json_content_type(self) -> None:
|
|
async with self.session.get(f"{self.base_url}/_test/hello") as resp:
|
|
self.assertEqual(resp.status, 200)
|
|
self.assertIn("application/json", resp.headers["Content-Type"])
|
|
|
|
|
|
class NativeSynapseRequestTest(unittest.IsolatedAsyncioTestCase):
|
|
"""Tests for the NativeSynapseRequest compatibility shim."""
|
|
|
|
async def test_args_parsing(self) -> None:
|
|
from aiohttp.test_utils import make_mocked_request
|
|
|
|
from synapse.http.native_server import NativeSynapseRequest
|
|
|
|
req = make_mocked_request("GET", "/_test?foo=bar&foo=baz&num=42")
|
|
native_req = NativeSynapseRequest(req, b"")
|
|
|
|
self.assertIn(b"foo", native_req.args)
|
|
self.assertEqual(native_req.args[b"foo"], [b"bar", b"baz"])
|
|
self.assertEqual(native_req.args[b"num"], [b"42"])
|
|
|
|
async def test_method_and_path(self) -> None:
|
|
from aiohttp.test_utils import make_mocked_request
|
|
|
|
from synapse.http.native_server import NativeSynapseRequest
|
|
|
|
req = make_mocked_request("POST", "/_test/path")
|
|
native_req = NativeSynapseRequest(req, b'{"key":"val"}')
|
|
|
|
self.assertEqual(native_req.method, b"POST")
|
|
self.assertEqual(native_req.path, b"/_test/path")
|
|
|
|
async def test_content_body(self) -> None:
|
|
from aiohttp.test_utils import make_mocked_request
|
|
|
|
from synapse.http.native_server import NativeSynapseRequest
|
|
|
|
body = b'{"hello": "world"}'
|
|
req = make_mocked_request("POST", "/_test")
|
|
native_req = NativeSynapseRequest(req, body)
|
|
|
|
self.assertEqual(native_req.content.read(), body)
|
|
|
|
async def test_request_headers(self) -> None:
|
|
from aiohttp.test_utils import make_mocked_request
|
|
|
|
from synapse.http.native_server import NativeSynapseRequest
|
|
|
|
req = make_mocked_request(
|
|
"GET", "/_test", headers={"Authorization": "Bearer token123"}
|
|
)
|
|
native_req = NativeSynapseRequest(req, b"")
|
|
|
|
auth = native_req.requestHeaders.getRawHeaders("Authorization")
|
|
self.assertIsNotNone(auth)
|
|
self.assertEqual(auth, [b"Bearer token123"])
|
|
|
|
async def test_response_building(self) -> None:
|
|
from aiohttp.test_utils import make_mocked_request
|
|
|
|
from synapse.http.native_server import NativeSynapseRequest
|
|
|
|
req = make_mocked_request("GET", "/_test")
|
|
native_req = NativeSynapseRequest(req, b"")
|
|
|
|
native_req.setResponseCode(201)
|
|
native_req.setHeader(b"X-Custom", b"value")
|
|
native_req.write(b"hello ")
|
|
native_req.write(b"world")
|
|
native_req.finish()
|
|
|
|
response = native_req.build_response()
|
|
self.assertEqual(response.status, 201)
|
|
self.assertEqual(response.headers["X-Custom"], "value")
|
|
self.assertEqual(response.body, b"hello world")
|
|
|
|
|
|
class NativeReplicationProtocolTest(unittest.IsolatedAsyncioTestCase):
|
|
"""Tests for the asyncio-native replication protocol."""
|
|
|
|
async def _make_pipe(
|
|
self,
|
|
) -> tuple[asyncio.StreamReader, asyncio.StreamWriter, asyncio.StreamReader, asyncio.StreamWriter]:
|
|
"""Create a connected pair of (reader, writer) using a TCP loopback server."""
|
|
connections: list[tuple[asyncio.StreamReader, asyncio.StreamWriter]] = []
|
|
ready = asyncio.Event()
|
|
|
|
async def on_connect(r: asyncio.StreamReader, w: asyncio.StreamWriter) -> None:
|
|
connections.append((r, w))
|
|
ready.set()
|
|
|
|
server = await asyncio.start_server(on_connect, "127.0.0.1", 0)
|
|
addr = server.sockets[0].getsockname()
|
|
client_r, client_w = await asyncio.open_connection(addr[0], addr[1])
|
|
await ready.wait()
|
|
server_r, server_w = connections[0]
|
|
self._server_to_close = server
|
|
return client_r, client_w, server_r, server_w
|
|
|
|
async def asyncTearDown(self) -> None:
|
|
if hasattr(self, "_server_to_close"):
|
|
self._server_to_close.close()
|
|
await self._server_to_close.wait_closed()
|
|
|
|
async def test_send_and_receive_command(self) -> None:
|
|
from synapse.replication.tcp.commands import PingCommand
|
|
from synapse.replication.tcp.native_protocol import NativeReplicationProtocol
|
|
from synapse.replication.tcp.protocol import (
|
|
VALID_CLIENT_COMMANDS,
|
|
VALID_SERVER_COMMANDS,
|
|
)
|
|
|
|
client_r, client_w, server_r, server_w = await self._make_pipe()
|
|
|
|
received_commands: list[str] = []
|
|
|
|
class TestProtocol(NativeReplicationProtocol):
|
|
async def on_PING(self, cmd: object) -> None:
|
|
received_commands.append("PING")
|
|
|
|
# Server protocol receives from client
|
|
server_proto = TestProtocol(
|
|
server_name="test.server",
|
|
valid_inbound_commands=VALID_CLIENT_COMMANDS,
|
|
valid_outbound_commands=VALID_SERVER_COMMANDS,
|
|
)
|
|
await server_proto.start(server_r, server_w)
|
|
|
|
# Client sends a PING directly via the writer
|
|
client_w.write(b"PING 12345\n")
|
|
await client_w.drain()
|
|
|
|
# Give time for the read loop to process
|
|
await asyncio.sleep(0.05)
|
|
|
|
# Server should have received the PING (from start's initial ping + our manual one)
|
|
self.assertIn("PING", received_commands)
|
|
|
|
await server_proto.close()
|
|
client_w.close()
|
|
|
|
async def test_protocol_sends_initial_ping(self) -> None:
|
|
from synapse.replication.tcp.native_protocol import NativeReplicationProtocol
|
|
from synapse.replication.tcp.protocol import (
|
|
VALID_CLIENT_COMMANDS,
|
|
VALID_SERVER_COMMANDS,
|
|
)
|
|
|
|
client_r, client_w, server_r, server_w = await self._make_pipe()
|
|
|
|
proto = NativeReplicationProtocol(
|
|
server_name="test.server",
|
|
valid_inbound_commands=VALID_CLIENT_COMMANDS,
|
|
valid_outbound_commands=VALID_SERVER_COMMANDS,
|
|
)
|
|
await proto.start(server_r, server_w)
|
|
|
|
# Read the initial ping sent by the protocol
|
|
line = await asyncio.wait_for(client_r.readline(), timeout=2.0)
|
|
self.assertTrue(line.startswith(b"PING "))
|
|
|
|
await proto.close()
|
|
client_w.close()
|
|
|
|
async def test_close_connection(self) -> None:
|
|
from synapse.replication.tcp.native_protocol import (
|
|
ConnectionState,
|
|
NativeReplicationProtocol,
|
|
)
|
|
from synapse.replication.tcp.protocol import (
|
|
VALID_CLIENT_COMMANDS,
|
|
VALID_SERVER_COMMANDS,
|
|
)
|
|
|
|
client_r, client_w, server_r, server_w = await self._make_pipe()
|
|
|
|
closed = asyncio.Event()
|
|
|
|
class TestProtocol(NativeReplicationProtocol):
|
|
async def on_connection_lost(self) -> None:
|
|
closed.set()
|
|
|
|
proto = TestProtocol(
|
|
server_name="test.server",
|
|
valid_inbound_commands=VALID_CLIENT_COMMANDS,
|
|
valid_outbound_commands=VALID_SERVER_COMMANDS,
|
|
)
|
|
await proto.start(server_r, server_w)
|
|
|
|
await proto.close()
|
|
|
|
await asyncio.wait_for(closed.wait(), timeout=2.0)
|
|
self.assertEqual(proto._state, ConnectionState.CLOSED)
|
|
client_w.close()
|
|
|
|
async def test_eof_triggers_close(self) -> None:
|
|
from synapse.replication.tcp.native_protocol import (
|
|
ConnectionState,
|
|
NativeReplicationProtocol,
|
|
)
|
|
from synapse.replication.tcp.protocol import (
|
|
VALID_CLIENT_COMMANDS,
|
|
VALID_SERVER_COMMANDS,
|
|
)
|
|
|
|
client_r, client_w, server_r, server_w = await self._make_pipe()
|
|
|
|
closed = asyncio.Event()
|
|
|
|
class TestProtocol(NativeReplicationProtocol):
|
|
async def on_connection_lost(self) -> None:
|
|
closed.set()
|
|
|
|
proto = TestProtocol(
|
|
server_name="test.server",
|
|
valid_inbound_commands=VALID_CLIENT_COMMANDS,
|
|
valid_outbound_commands=VALID_SERVER_COMMANDS,
|
|
)
|
|
await proto.start(server_r, server_w)
|
|
|
|
# Close the client side — server should detect EOF
|
|
client_w.close()
|
|
|
|
await asyncio.wait_for(closed.wait(), timeout=2.0)
|
|
self.assertEqual(proto._state, ConnectionState.CLOSED)
|
|
|
|
async def test_server_and_client_helpers(self) -> None:
|
|
from synapse.replication.tcp.native_protocol import (
|
|
NativeReplicationProtocol,
|
|
start_native_replication_server,
|
|
)
|
|
from synapse.replication.tcp.protocol import (
|
|
VALID_CLIENT_COMMANDS,
|
|
VALID_SERVER_COMMANDS,
|
|
)
|
|
|
|
server_connected = asyncio.Event()
|
|
|
|
def server_factory() -> NativeReplicationProtocol:
|
|
proto = NativeReplicationProtocol(
|
|
server_name="test.server",
|
|
valid_inbound_commands=VALID_CLIENT_COMMANDS,
|
|
valid_outbound_commands=VALID_SERVER_COMMANDS,
|
|
)
|
|
server_connected.set()
|
|
return proto
|
|
|
|
server = await start_native_replication_server(
|
|
"127.0.0.1", 0, server_factory
|
|
)
|
|
addr = server.sockets[0].getsockname()
|
|
|
|
# Connect a client
|
|
reader, writer = await asyncio.open_connection(addr[0], addr[1])
|
|
|
|
await asyncio.wait_for(server_connected.wait(), timeout=2.0)
|
|
|
|
# Read the server's initial PING
|
|
line = await asyncio.wait_for(reader.readline(), timeout=2.0)
|
|
self.assertTrue(line.startswith(b"PING "))
|
|
|
|
writer.close()
|
|
server.close()
|
|
await server.wait_closed()
|
|
|
|
|
|
class NativeAsyncUtilitiesTest(unittest.IsolatedAsyncioTestCase):
|
|
"""Tests for Phase 7 native async utility functions."""
|
|
|
|
async def test_native_gather_results(self) -> None:
|
|
from synapse.util.async_helpers import native_gather_results
|
|
|
|
async def double(x: int) -> int:
|
|
return x * 2
|
|
|
|
results = await native_gather_results(double, [1, 2, 3])
|
|
self.assertEqual(results, [2, 4, 6])
|
|
|
|
async def test_native_concurrently_execute(self) -> None:
|
|
from synapse.util.async_helpers import native_concurrently_execute
|
|
|
|
results: list[int] = []
|
|
concurrent = 0
|
|
max_concurrent = 0
|
|
|
|
async def work(x: int) -> None:
|
|
nonlocal concurrent, max_concurrent
|
|
concurrent += 1
|
|
max_concurrent = max(max_concurrent, concurrent)
|
|
results.append(x)
|
|
await asyncio.sleep(0.01)
|
|
concurrent -= 1
|
|
|
|
await native_concurrently_execute(work, range(10), limit=3)
|
|
|
|
self.assertEqual(sorted(results), list(range(10)))
|
|
self.assertLessEqual(max_concurrent, 3)
|
|
|
|
async def test_native_stop_cancellation(self) -> None:
|
|
from synapse.util.async_helpers import native_stop_cancellation
|
|
|
|
loop = asyncio.get_running_loop()
|
|
inner: asyncio.Future[str] = loop.create_future()
|
|
shielded = native_stop_cancellation(inner)
|
|
|
|
# Cancel the shielded future
|
|
shielded.cancel()
|
|
|
|
# Inner should NOT be cancelled
|
|
self.assertFalse(inner.cancelled())
|
|
|
|
# Resolve inner
|
|
inner.set_result("ok")
|
|
self.assertEqual(inner.result(), "ok")
|
|
|
|
async def test_native_awakeable_sleeper(self) -> None:
|
|
from synapse.util.async_helpers import NativeAwakenableSleeper
|
|
|
|
sleeper = NativeAwakenableSleeper()
|
|
|
|
woke_early = False
|
|
|
|
async def sleeping() -> None:
|
|
nonlocal woke_early
|
|
await sleeper.sleep("test", delay_ms=5000)
|
|
woke_early = True
|
|
|
|
task = asyncio.create_task(sleeping())
|
|
await asyncio.sleep(0.01)
|
|
|
|
sleeper.wake("test")
|
|
await asyncio.sleep(0.01)
|
|
|
|
self.assertTrue(woke_early)
|
|
task.cancel()
|
|
try:
|
|
await task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
async def test_native_awakeable_sleeper_timeout(self) -> None:
|
|
from synapse.util.async_helpers import NativeAwakenableSleeper
|
|
|
|
sleeper = NativeAwakenableSleeper()
|
|
|
|
# Should return after timeout without wake
|
|
await sleeper.sleep("test", delay_ms=20)
|
|
|
|
async def test_native_event(self) -> None:
|
|
from synapse.util.async_helpers import NativeEvent
|
|
|
|
event = NativeEvent()
|
|
self.assertFalse(event.is_set())
|
|
|
|
event.set()
|
|
self.assertTrue(event.is_set())
|
|
|
|
result = await event.wait(timeout_seconds=1.0)
|
|
self.assertTrue(result)
|
|
|
|
event.clear()
|
|
self.assertFalse(event.is_set())
|
|
|
|
async def test_native_event_timeout(self) -> None:
|
|
from synapse.util.async_helpers import NativeEvent
|
|
|
|
event = NativeEvent()
|
|
result = await event.wait(timeout_seconds=0.02)
|
|
self.assertFalse(result)
|
|
|
|
|
|
class FutureCacheTest(unittest.IsolatedAsyncioTestCase):
|
|
"""Tests for the asyncio-native FutureCache."""
|
|
|
|
async def test_set_and_get(self) -> None:
|
|
from synapse.util.caches.future_cache import FutureCache
|
|
|
|
cache: FutureCache[str] = FutureCache("test", max_entries=100)
|
|
|
|
loop = asyncio.get_running_loop()
|
|
f: asyncio.Future[str] = loop.create_future()
|
|
cache.set("key1", f)
|
|
|
|
# Resolve the future
|
|
f.set_result("value1")
|
|
await asyncio.sleep(0) # Let callbacks run
|
|
|
|
# Now get should return the cached value
|
|
result = await cache.get("key1")
|
|
self.assertEqual(result, "value1")
|
|
|
|
async def test_get_pending(self) -> None:
|
|
from synapse.util.caches.future_cache import FutureCache
|
|
|
|
cache: FutureCache[str] = FutureCache("test")
|
|
|
|
loop = asyncio.get_running_loop()
|
|
f: asyncio.Future[str] = loop.create_future()
|
|
cache.set("key1", f)
|
|
|
|
# Get while still pending — returns observer
|
|
observer = cache.get("key1")
|
|
self.assertFalse(observer.done())
|
|
|
|
# Resolve
|
|
f.set_result("hello")
|
|
await asyncio.sleep(0)
|
|
|
|
result = await observer
|
|
self.assertEqual(result, "hello")
|
|
|
|
async def test_get_missing_raises_keyerror(self) -> None:
|
|
from synapse.util.caches.future_cache import FutureCache
|
|
|
|
cache: FutureCache[str] = FutureCache("test")
|
|
|
|
with self.assertRaises(KeyError):
|
|
cache.get("nonexistent")
|
|
|
|
async def test_failed_future_not_cached(self) -> None:
|
|
from synapse.util.caches.future_cache import FutureCache
|
|
|
|
cache: FutureCache[str] = FutureCache("test")
|
|
|
|
loop = asyncio.get_running_loop()
|
|
f: asyncio.Future[str] = loop.create_future()
|
|
cache.set("key1", f)
|
|
|
|
f.set_exception(ValueError("boom"))
|
|
await asyncio.sleep(0)
|
|
|
|
# Should not be cached
|
|
with self.assertRaises(KeyError):
|
|
cache.get("key1")
|
|
|
|
async def test_invalidate(self) -> None:
|
|
from synapse.util.caches.future_cache import FutureCache
|
|
|
|
cache: FutureCache[str] = FutureCache("test")
|
|
|
|
loop = asyncio.get_running_loop()
|
|
f: asyncio.Future[str] = loop.create_future()
|
|
cache.set("key1", f)
|
|
f.set_result("value1")
|
|
await asyncio.sleep(0)
|
|
|
|
self.assertIn("key1", cache)
|
|
|
|
cache.invalidate("key1")
|
|
self.assertNotIn("key1", cache)
|
|
|
|
async def test_invalidation_callback(self) -> None:
|
|
from synapse.util.caches.future_cache import FutureCache
|
|
|
|
cache: FutureCache[str] = FutureCache("test")
|
|
callback_called = False
|
|
|
|
def on_invalidate() -> None:
|
|
nonlocal callback_called
|
|
callback_called = True
|
|
|
|
loop = asyncio.get_running_loop()
|
|
f: asyncio.Future[str] = loop.create_future()
|
|
cache.set("key1", f, callback=on_invalidate)
|
|
f.set_result("value1")
|
|
await asyncio.sleep(0)
|
|
|
|
cache.invalidate("key1")
|
|
self.assertTrue(callback_called)
|
|
|
|
async def test_invalidate_all(self) -> None:
|
|
from synapse.util.caches.future_cache import FutureCache
|
|
|
|
cache: FutureCache[str] = FutureCache("test")
|
|
|
|
loop = asyncio.get_running_loop()
|
|
for i in range(5):
|
|
f: asyncio.Future[str] = loop.create_future()
|
|
cache.set(f"key{i}", f)
|
|
f.set_result(f"val{i}")
|
|
await asyncio.sleep(0)
|
|
|
|
self.assertEqual(len(cache), 5)
|
|
cache.invalidate_all()
|
|
self.assertEqual(len(cache), 0)
|
|
|
|
async def test_max_entries_eviction(self) -> None:
|
|
from synapse.util.caches.future_cache import FutureCache
|
|
|
|
cache: FutureCache[int] = FutureCache("test", max_entries=3)
|
|
|
|
loop = asyncio.get_running_loop()
|
|
for i in range(5):
|
|
f: asyncio.Future[int] = loop.create_future()
|
|
cache.set(f"key{i}", f)
|
|
f.set_result(i)
|
|
await asyncio.sleep(0)
|
|
|
|
# Should have evicted oldest entries
|
|
self.assertLessEqual(len(cache._completed), 3)
|
|
|
|
async def test_multiple_observers(self) -> None:
|
|
from synapse.util.caches.future_cache import FutureCache
|
|
|
|
cache: FutureCache[str] = FutureCache("test")
|
|
|
|
loop = asyncio.get_running_loop()
|
|
f: asyncio.Future[str] = loop.create_future()
|
|
cache.set("key1", f)
|
|
|
|
obs1 = cache.get("key1")
|
|
obs2 = cache.get("key1")
|
|
|
|
f.set_result("shared_value")
|
|
await asyncio.sleep(0)
|
|
|
|
r1 = await obs1
|
|
r2 = await obs2
|
|
self.assertEqual(r1, "shared_value")
|
|
self.assertEqual(r2, "shared_value")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|