From a1267a1f37e319cc389e0941db57f90ed5bf4dee Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Sat, 21 Mar 2026 14:35:29 +0000 Subject: [PATCH] =?UTF-8?q?=E2=8F=BA=20Phase=202:=20NativeClock=20?= =?UTF-8?q?=E2=80=94=20Complete?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 3 new classes added to synapse/util/clock.py, 15 new tests, all passing, mypy clean, no regressions. NativeLoopingCall — asyncio Task wrapper with stop(). Tracks in WeakSet for automatic cleanup. NativeDelayedCallWrapper — Wraps asyncio.TimerHandle with the same interface as DelayedCallWrapper (cancel(), active(), getTime(), delay(), reset()). Since TimerHandle is immutable, delay()/reset() cancel and reschedule. NativeClock — Same public API as Clock but uses: - time.time() instead of reactor.seconds() - asyncio.sleep() instead of Deferred + reactor.callLater - asyncio.create_task() with while True loop instead of LoopingCall - loop.call_later() instead of reactor.callLater() - loop.call_soon() instead of reactor.callWhenRunning() - Logcontext wrapping preserved (same PreserveLoggingContext + run_in_background pattern) - LoopingCall semantics preserved: waits for previous invocation to complete, survives errors --- synapse/util/clock.py | 333 ++++++++++++++++++++++++++++++++ tests/util/test_native_async.py | 250 ++++++++++++++++++++++++ 2 files changed, 583 insertions(+) diff --git a/synapse/util/clock.py b/synapse/util/clock.py index 7232a1331c..b0c62a9e1f 100644 --- a/synapse/util/clock.py +++ b/synapse/util/clock.py @@ -14,7 +14,10 @@ # +import asyncio +import inspect import logging +import time as time_mod from functools import wraps from typing import ( Any, @@ -608,3 +611,333 @@ class DelayedCallWrapper: def active(self) -> bool: """Propagate the call to the underlying delayed_call.""" return self.delayed_call.active() + + +# =========================================================================== +# Phase 2: asyncio-native Clock implementation +# +# NativeClock provides the same public interface as Clock but uses asyncio +# primitives instead of Twisted. It is unused until later phases switch +# hs.get_clock() to return a NativeClock. +# =========================================================================== + + +class NativeLoopingCall: + """asyncio-native equivalent of Twisted's LoopingCall. + + Runs a function repeatedly with a fixed interval between completions. + If the function returns an awaitable, waits for it to complete before + scheduling the next call (same semantics as LoopingCall). + """ + + def __init__(self, task: "asyncio.Task[None]") -> None: + self._task = task + + def stop(self) -> None: + """Stop the looping call.""" + self._task.cancel() + + +class NativeDelayedCallWrapper: + """asyncio-native equivalent of DelayedCallWrapper. + + Wraps an asyncio.TimerHandle. Since TimerHandle is immutable (no delay/reset), + delay() and reset() cancel and reschedule. + """ + + def __init__( + self, + handle: asyncio.TimerHandle, + call_id: int, + clock: "NativeClock", + loop: asyncio.AbstractEventLoop, + scheduled_time: float, + callback: Callable, + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> None: + self._handle = handle + self.call_id = call_id + self.clock = clock + self._loop = loop + self._scheduled_time = scheduled_time + self._callback = callback + self._args = args + self._kwargs = kwargs + self._cancelled = False + + def cancel(self) -> None: + """Cancel the scheduled call.""" + if not self._cancelled: + self._handle.cancel() + self._cancelled = True + self.clock._call_id_to_delayed_call.pop(self.call_id, None) + + def active(self) -> bool: + """Check if the call is still pending.""" + return not self._cancelled + + def getTime(self) -> float: + """Return the scheduled execution time.""" + return self._scheduled_time + + def delay(self, secondsLater: float) -> None: + """Delay execution by N additional seconds.""" + if self._cancelled: + return + remaining = self._scheduled_time - self._loop.time() + new_delay = remaining + secondsLater + self._handle.cancel() + self._scheduled_time = self._loop.time() + new_delay + self._handle = self._loop.call_later( + new_delay, self._callback, *self._args + ) + + def reset(self, secondsFromNow: float) -> None: + """Reset to fire secondsFromNow from now.""" + if self._cancelled: + return + self._handle.cancel() + self._scheduled_time = self._loop.time() + secondsFromNow + self._handle = self._loop.call_later( + secondsFromNow, self._callback, *self._args + ) + + +class NativeClock: + """asyncio-native equivalent of Clock. + + Provides the same public interface as Clock but uses asyncio primitives + (asyncio.sleep, loop.call_later, asyncio.create_task) instead of Twisted. + + Args: + server_name: The server name for logging context. + """ + + def __init__(self, server_name: str) -> None: + self._server_name = server_name + self._delayed_call_id: int = 0 + self._looping_calls: WeakSet[NativeLoopingCall] = WeakSet() + self._call_id_to_delayed_call: dict[int, NativeDelayedCallWrapper] = {} + self._is_shutdown = False + self._shutdown_callbacks: list[tuple[str, str, Callable, tuple, dict]] = [] + # Lazily initialized when first needed + self._loop: asyncio.AbstractEventLoop | None = None + + def _get_loop(self) -> asyncio.AbstractEventLoop: + if self._loop is None: + self._loop = asyncio.get_running_loop() + return self._loop + + def shutdown(self) -> None: + self._is_shutdown = True + self.cancel_all_looping_calls() + self.cancel_all_delayed_calls() + + def time(self) -> float: + """Returns the current system time in seconds since epoch.""" + return time_mod.time() + + def time_msec(self) -> int: + """Returns the current system time in milliseconds since epoch.""" + return int(self.time() * 1000) + + async def sleep(self, duration: Duration) -> None: + await asyncio.sleep(duration.as_secs()) + + def looping_call( + self, + f: Callable[P, object], + duration: Duration, + *args: P.args, + **kwargs: P.kwargs, + ) -> NativeLoopingCall: + """Call a function repeatedly, waiting `duration` before the first call.""" + return self._looping_call_common(f, duration, False, *args, **kwargs) + + def looping_call_now( + self, + f: Callable[P, object], + duration: Duration, + *args: P.args, + **kwargs: P.kwargs, + ) -> NativeLoopingCall: + """Call a function immediately, then repeatedly thereafter.""" + return self._looping_call_common(f, duration, True, *args, **kwargs) + + def _looping_call_common( + self, + f: Callable[P, object], + duration: Duration, + now: bool, + *args: P.args, + **kwargs: P.kwargs, + ) -> NativeLoopingCall: + if self._is_shutdown: + raise Exception("Cannot start looping call. Clock has been shutdown") + + instance_id = random_string_insecure_fast(5) + interval = duration.as_secs() + + async def _loop() -> None: + if not now: + await asyncio.sleep(interval) + + while True: + try: + clock_debug_logger.debug( + "looping_call(%s): Executing callback", instance_id + ) + with context.PreserveLoggingContext( + context.LoggingContext( + name="looping_call", server_name=self._server_name + ) + ): + result = f(*args, **kwargs) + if inspect.isawaitable(result): + await result + except asyncio.CancelledError: + return + except Exception: + logger.exception("Looping call %s died", instance_id) + + await asyncio.sleep(interval) + + task_obj = asyncio.create_task(_loop()) + call = NativeLoopingCall(task_obj) + self._looping_calls.add(call) + + clock_debug_logger.debug( + "looping_call(%s): Scheduled looping call every %sms", + instance_id, + duration.as_millis(), + stack_info=True, + ) + + return call + + def cancel_all_looping_calls(self, consumeErrors: bool = True) -> None: + for call in list(self._looping_calls): + try: + call.stop() + except Exception: + if not consumeErrors: + raise + self._looping_calls.clear() + + def call_later( + self, + delay: Duration, + callback: Callable, + *args: Any, + call_later_cancel_on_shutdown: bool = True, + **kwargs: Any, + ) -> NativeDelayedCallWrapper: + call_id = self._delayed_call_id + self._delayed_call_id += 1 + + if self._is_shutdown: + raise Exception("Cannot start delayed call. Clock has been shutdown") + + loop = self._get_loop() + + def wrapped_callback(*args: Any, **kwargs: Any) -> None: + clock_debug_logger.debug("call_later(%s): Executing callback", call_id) + try: + with context.PreserveLoggingContext( + context.LoggingContext( + name="call_later", server_name=self._server_name + ) + ): + context.run_in_background(callback, *args, **kwargs) + finally: + if call_later_cancel_on_shutdown: + self._call_id_to_delayed_call.pop(call_id, None) + + scheduled_time = loop.time() + delay.as_secs() + handle = loop.call_later(delay.as_secs(), wrapped_callback, *args, **kwargs) + + clock_debug_logger.debug( + "call_later(%s): Scheduled call for %ss later (tracked: %s)", + call_id, + delay.as_secs(), + call_later_cancel_on_shutdown, + stack_info=True, + ) + + wrapped_call = NativeDelayedCallWrapper( + handle, call_id, self, loop, scheduled_time, wrapped_callback, args, kwargs + ) + if call_later_cancel_on_shutdown: + self._call_id_to_delayed_call[call_id] = wrapped_call + + return wrapped_call + + def cancel_call_later( + self, wrapped_call: NativeDelayedCallWrapper, ignore_errs: bool = False + ) -> None: + try: + wrapped_call.cancel() + except Exception: + if not ignore_errs: + raise + + def cancel_all_delayed_calls(self, ignore_errs: bool = True) -> None: + for call_id, call in list(self._call_id_to_delayed_call.items()): + try: + call.cancel() + except Exception: + if not ignore_errs: + raise + self._call_id_to_delayed_call.clear() + + def call_when_running( + self, + callback: Callable[P, object], + *args: P.args, + **kwargs: P.kwargs, + ) -> None: + """Call a function on the next event loop iteration.""" + instance_id = random_string_insecure_fast(5) + + def wrapped_callback(*args: Any, **kwargs: Any) -> None: + clock_debug_logger.debug( + "call_when_running(%s): Executing callback", instance_id + ) + with context.PreserveLoggingContext( + context.LoggingContext( + name="call_when_running", server_name=self._server_name + ) + ): + context.run_in_background(callback, *args, **kwargs) + + loop = self._get_loop() + if kwargs: + loop.call_soon(lambda: wrapped_callback(*args, **kwargs)) + else: + loop.call_soon(wrapped_callback, *args) + + def add_system_event_trigger( + self, + phase: str, + event_type: str, + callback: Callable[P, object], + *args: P.args, + **kwargs: P.kwargs, + ) -> int: + """Store a callback to be invoked during shutdown. + + Returns an ID that could be used to remove the trigger (not currently + needed by callers, but matches the Twisted API). + """ + trigger_id = len(self._shutdown_callbacks) + self._shutdown_callbacks.append( + (phase, event_type, callback, args, kwargs) + ) + clock_debug_logger.debug( + "add_system_event_trigger: registered %s %s callback", + phase, + event_type, + stack_info=True, + ) + return trigger_id diff --git a/tests/util/test_native_async.py b/tests/util/test_native_async.py index fe52297114..9c1c2a3bb7 100644 --- a/tests/util/test_native_async.py +++ b/tests/util/test_native_async.py @@ -492,5 +492,255 @@ class NativeReadWriteLockTest(unittest.IsolatedAsyncioTestCase): self.assertEqual(order, ["reader_start", "reader_end", "writer"]) +class NativeClockTest(unittest.IsolatedAsyncioTestCase): + """Tests for the asyncio-native NativeClock.""" + + async def test_time(self) -> None: + from synapse.util.clock import NativeClock + + clock = NativeClock(server_name="test.server") + t = clock.time() + self.assertIsInstance(t, float) + self.assertGreater(t, 0) + + async def test_time_msec(self) -> None: + from synapse.util.clock import NativeClock + + clock = NativeClock(server_name="test.server") + t = clock.time_msec() + self.assertIsInstance(t, int) + self.assertGreater(t, 0) + + async def test_sleep(self) -> None: + from synapse.util.clock import NativeClock + from synapse.util.duration import Duration + + clock = NativeClock(server_name="test.server") + before = clock.time() + await clock.sleep(Duration(milliseconds=50)) + after = clock.time() + self.assertGreaterEqual(after - before, 0.04) + + async def test_call_later(self) -> None: + from synapse.util.clock import NativeClock + from synapse.util.duration import Duration + + clock = NativeClock(server_name="test.server") + called = asyncio.Event() + + def callback() -> None: + called.set() + + wrapper = clock.call_later(Duration(milliseconds=20), callback) + self.assertTrue(wrapper.active()) + + await asyncio.wait_for(called.wait(), timeout=1.0) + self.assertTrue(called.is_set()) + + async def test_call_later_cancel(self) -> None: + from synapse.util.clock import NativeClock + from synapse.util.duration import Duration + + clock = NativeClock(server_name="test.server") + called = False + + def callback() -> None: + nonlocal called + called = True + + wrapper = clock.call_later(Duration(milliseconds=50), callback) + self.assertTrue(wrapper.active()) + + wrapper.cancel() + self.assertFalse(wrapper.active()) + + await asyncio.sleep(0.1) + self.assertFalse(called) + + async def test_call_later_getTime(self) -> None: + from synapse.util.clock import NativeClock + from synapse.util.duration import Duration + + clock = NativeClock(server_name="test.server") + wrapper = clock.call_later(Duration(seconds=10), lambda: None) + # getTime should return a time in the future + self.assertGreater(wrapper.getTime(), 0) + wrapper.cancel() + + async def test_looping_call(self) -> None: + from synapse.util.clock import NativeClock + from synapse.util.duration import Duration + + clock = NativeClock(server_name="test.server") + call_count = 0 + + def callback() -> None: + nonlocal call_count + call_count += 1 + + call = clock.looping_call(callback, Duration(milliseconds=30)) + + # looping_call waits `duration` before first call + await asyncio.sleep(0.01) + self.assertEqual(call_count, 0) + + await asyncio.sleep(0.05) + self.assertGreaterEqual(call_count, 1) + + call.stop() + + old_count = call_count + await asyncio.sleep(0.05) + self.assertEqual(call_count, old_count) + + async def test_looping_call_now(self) -> None: + from synapse.util.clock import NativeClock + from synapse.util.duration import Duration + + clock = NativeClock(server_name="test.server") + call_count = 0 + + def callback() -> None: + nonlocal call_count + call_count += 1 + + call = clock.looping_call_now(callback, Duration(milliseconds=30)) + + # looping_call_now should call immediately + await asyncio.sleep(0.01) + self.assertGreaterEqual(call_count, 1) + + call.stop() + + async def test_looping_call_waits_for_completion(self) -> None: + """Test that the next iteration waits for the previous to complete.""" + from synapse.util.clock import NativeClock + from synapse.util.duration import Duration + + clock = NativeClock(server_name="test.server") + concurrent = 0 + max_concurrent = 0 + + async def slow_callback() -> None: + nonlocal concurrent, max_concurrent + concurrent += 1 + max_concurrent = max(max_concurrent, concurrent) + await asyncio.sleep(0.04) + concurrent -= 1 + + call = clock.looping_call_now(slow_callback, Duration(milliseconds=10)) + + await asyncio.sleep(0.15) + call.stop() + + # Should never have more than 1 concurrent execution + self.assertEqual(max_concurrent, 1) + + async def test_looping_call_survives_error(self) -> None: + """Test that an error in the callback doesn't stop the loop.""" + from synapse.util.clock import NativeClock + from synapse.util.duration import Duration + + clock = NativeClock(server_name="test.server") + call_count = 0 + + def flaky_callback() -> None: + nonlocal call_count + call_count += 1 + if call_count == 1: + raise ValueError("first call fails") + + call = clock.looping_call_now(flaky_callback, Duration(milliseconds=20)) + + await asyncio.sleep(0.08) + call.stop() + + # Should have been called more than once despite the error + self.assertGreaterEqual(call_count, 2) + + async def test_shutdown(self) -> None: + from synapse.util.clock import NativeClock + from synapse.util.duration import Duration + + clock = NativeClock(server_name="test.server") + looping_called = False + delayed_called = False + + def looping_cb() -> None: + nonlocal looping_called + looping_called = True + + def delayed_cb() -> None: + nonlocal delayed_called + delayed_called = True + + clock.looping_call(looping_cb, Duration(seconds=1)) + clock.call_later(Duration(seconds=1), delayed_cb) + + # Shutdown immediately before any callbacks can fire + clock.shutdown() + + await asyncio.sleep(0.05) + self.assertFalse(looping_called) + self.assertFalse(delayed_called) + + async def test_cancel_all_delayed_calls(self) -> None: + from synapse.util.clock import NativeClock + from synapse.util.duration import Duration + + clock = NativeClock(server_name="test.server") + called = False + + def callback() -> None: + nonlocal called + called = True + + clock.call_later(Duration(milliseconds=50), callback) + clock.call_later(Duration(milliseconds=50), callback) + + clock.cancel_all_delayed_calls() + + await asyncio.sleep(0.1) + self.assertFalse(called) + + async def test_call_when_running(self) -> None: + from synapse.util.clock import NativeClock + + clock = NativeClock(server_name="test.server") + called = asyncio.Event() + + def callback() -> None: + called.set() + + clock.call_when_running(callback) + + await asyncio.wait_for(called.wait(), timeout=1.0) + self.assertTrue(called.is_set()) + + async def test_add_system_event_trigger(self) -> None: + from synapse.util.clock import NativeClock + + clock = NativeClock(server_name="test.server") + + trigger_id = clock.add_system_event_trigger( + "before", "shutdown", lambda: None + ) + self.assertIsInstance(trigger_id, int) + self.assertEqual(len(clock._shutdown_callbacks), 1) + + async def test_shutdown_prevents_new_calls(self) -> None: + from synapse.util.clock import NativeClock + from synapse.util.duration import Duration + + clock = NativeClock(server_name="test.server") + clock.shutdown() + + with self.assertRaises(Exception): + clock.looping_call(lambda: None, Duration(seconds=1)) + + with self.assertRaises(Exception): + clock.call_later(Duration(seconds=1), lambda: None) + + if __name__ == "__main__": unittest.main()