⏺ Phase 2: NativeClock — Complete

3 new classes added to synapse/util/clock.py, 15 new tests, all passing, mypy clean, no regressions.

  NativeLoopingCall — asyncio Task wrapper with stop(). Tracks in WeakSet for automatic cleanup.

  NativeDelayedCallWrapper — Wraps asyncio.TimerHandle with the same interface as DelayedCallWrapper (cancel(), active(), getTime(), delay(), reset()). Since TimerHandle is immutable,
  delay()/reset() cancel and reschedule.

  NativeClock — Same public API as Clock but uses:
  - time.time() instead of reactor.seconds()
  - asyncio.sleep() instead of Deferred + reactor.callLater
  - asyncio.create_task() with while True loop instead of LoopingCall
  - loop.call_later() instead of reactor.callLater()
  - loop.call_soon() instead of reactor.callWhenRunning()
  - Logcontext wrapping preserved (same PreserveLoggingContext + run_in_background pattern)
  - LoopingCall semantics preserved: waits for previous invocation to complete, survives errors
This commit is contained in:
Matthew Hodgson
2026-03-21 14:35:29 +00:00
parent 24724a810e
commit a1267a1f37
2 changed files with 583 additions and 0 deletions
+333
View File
@@ -14,7 +14,10 @@
#
import asyncio
import inspect
import logging
import time as time_mod
from functools import wraps
from typing import (
Any,
@@ -608,3 +611,333 @@ class DelayedCallWrapper:
def active(self) -> bool:
"""Propagate the call to the underlying delayed_call."""
return self.delayed_call.active()
# ===========================================================================
# Phase 2: asyncio-native Clock implementation
#
# NativeClock provides the same public interface as Clock but uses asyncio
# primitives instead of Twisted. It is unused until later phases switch
# hs.get_clock() to return a NativeClock.
# ===========================================================================
class NativeLoopingCall:
"""asyncio-native equivalent of Twisted's LoopingCall.
Runs a function repeatedly with a fixed interval between completions.
If the function returns an awaitable, waits for it to complete before
scheduling the next call (same semantics as LoopingCall).
"""
def __init__(self, task: "asyncio.Task[None]") -> None:
self._task = task
def stop(self) -> None:
"""Stop the looping call."""
self._task.cancel()
class NativeDelayedCallWrapper:
"""asyncio-native equivalent of DelayedCallWrapper.
Wraps an asyncio.TimerHandle. Since TimerHandle is immutable (no delay/reset),
delay() and reset() cancel and reschedule.
"""
def __init__(
self,
handle: asyncio.TimerHandle,
call_id: int,
clock: "NativeClock",
loop: asyncio.AbstractEventLoop,
scheduled_time: float,
callback: Callable,
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> None:
self._handle = handle
self.call_id = call_id
self.clock = clock
self._loop = loop
self._scheduled_time = scheduled_time
self._callback = callback
self._args = args
self._kwargs = kwargs
self._cancelled = False
def cancel(self) -> None:
"""Cancel the scheduled call."""
if not self._cancelled:
self._handle.cancel()
self._cancelled = True
self.clock._call_id_to_delayed_call.pop(self.call_id, None)
def active(self) -> bool:
"""Check if the call is still pending."""
return not self._cancelled
def getTime(self) -> float:
"""Return the scheduled execution time."""
return self._scheduled_time
def delay(self, secondsLater: float) -> None:
"""Delay execution by N additional seconds."""
if self._cancelled:
return
remaining = self._scheduled_time - self._loop.time()
new_delay = remaining + secondsLater
self._handle.cancel()
self._scheduled_time = self._loop.time() + new_delay
self._handle = self._loop.call_later(
new_delay, self._callback, *self._args
)
def reset(self, secondsFromNow: float) -> None:
"""Reset to fire secondsFromNow from now."""
if self._cancelled:
return
self._handle.cancel()
self._scheduled_time = self._loop.time() + secondsFromNow
self._handle = self._loop.call_later(
secondsFromNow, self._callback, *self._args
)
class NativeClock:
"""asyncio-native equivalent of Clock.
Provides the same public interface as Clock but uses asyncio primitives
(asyncio.sleep, loop.call_later, asyncio.create_task) instead of Twisted.
Args:
server_name: The server name for logging context.
"""
def __init__(self, server_name: str) -> None:
self._server_name = server_name
self._delayed_call_id: int = 0
self._looping_calls: WeakSet[NativeLoopingCall] = WeakSet()
self._call_id_to_delayed_call: dict[int, NativeDelayedCallWrapper] = {}
self._is_shutdown = False
self._shutdown_callbacks: list[tuple[str, str, Callable, tuple, dict]] = []
# Lazily initialized when first needed
self._loop: asyncio.AbstractEventLoop | None = None
def _get_loop(self) -> asyncio.AbstractEventLoop:
if self._loop is None:
self._loop = asyncio.get_running_loop()
return self._loop
def shutdown(self) -> None:
self._is_shutdown = True
self.cancel_all_looping_calls()
self.cancel_all_delayed_calls()
def time(self) -> float:
"""Returns the current system time in seconds since epoch."""
return time_mod.time()
def time_msec(self) -> int:
"""Returns the current system time in milliseconds since epoch."""
return int(self.time() * 1000)
async def sleep(self, duration: Duration) -> None:
await asyncio.sleep(duration.as_secs())
def looping_call(
self,
f: Callable[P, object],
duration: Duration,
*args: P.args,
**kwargs: P.kwargs,
) -> NativeLoopingCall:
"""Call a function repeatedly, waiting `duration` before the first call."""
return self._looping_call_common(f, duration, False, *args, **kwargs)
def looping_call_now(
self,
f: Callable[P, object],
duration: Duration,
*args: P.args,
**kwargs: P.kwargs,
) -> NativeLoopingCall:
"""Call a function immediately, then repeatedly thereafter."""
return self._looping_call_common(f, duration, True, *args, **kwargs)
def _looping_call_common(
self,
f: Callable[P, object],
duration: Duration,
now: bool,
*args: P.args,
**kwargs: P.kwargs,
) -> NativeLoopingCall:
if self._is_shutdown:
raise Exception("Cannot start looping call. Clock has been shutdown")
instance_id = random_string_insecure_fast(5)
interval = duration.as_secs()
async def _loop() -> None:
if not now:
await asyncio.sleep(interval)
while True:
try:
clock_debug_logger.debug(
"looping_call(%s): Executing callback", instance_id
)
with context.PreserveLoggingContext(
context.LoggingContext(
name="looping_call", server_name=self._server_name
)
):
result = f(*args, **kwargs)
if inspect.isawaitable(result):
await result
except asyncio.CancelledError:
return
except Exception:
logger.exception("Looping call %s died", instance_id)
await asyncio.sleep(interval)
task_obj = asyncio.create_task(_loop())
call = NativeLoopingCall(task_obj)
self._looping_calls.add(call)
clock_debug_logger.debug(
"looping_call(%s): Scheduled looping call every %sms",
instance_id,
duration.as_millis(),
stack_info=True,
)
return call
def cancel_all_looping_calls(self, consumeErrors: bool = True) -> None:
for call in list(self._looping_calls):
try:
call.stop()
except Exception:
if not consumeErrors:
raise
self._looping_calls.clear()
def call_later(
self,
delay: Duration,
callback: Callable,
*args: Any,
call_later_cancel_on_shutdown: bool = True,
**kwargs: Any,
) -> NativeDelayedCallWrapper:
call_id = self._delayed_call_id
self._delayed_call_id += 1
if self._is_shutdown:
raise Exception("Cannot start delayed call. Clock has been shutdown")
loop = self._get_loop()
def wrapped_callback(*args: Any, **kwargs: Any) -> None:
clock_debug_logger.debug("call_later(%s): Executing callback", call_id)
try:
with context.PreserveLoggingContext(
context.LoggingContext(
name="call_later", server_name=self._server_name
)
):
context.run_in_background(callback, *args, **kwargs)
finally:
if call_later_cancel_on_shutdown:
self._call_id_to_delayed_call.pop(call_id, None)
scheduled_time = loop.time() + delay.as_secs()
handle = loop.call_later(delay.as_secs(), wrapped_callback, *args, **kwargs)
clock_debug_logger.debug(
"call_later(%s): Scheduled call for %ss later (tracked: %s)",
call_id,
delay.as_secs(),
call_later_cancel_on_shutdown,
stack_info=True,
)
wrapped_call = NativeDelayedCallWrapper(
handle, call_id, self, loop, scheduled_time, wrapped_callback, args, kwargs
)
if call_later_cancel_on_shutdown:
self._call_id_to_delayed_call[call_id] = wrapped_call
return wrapped_call
def cancel_call_later(
self, wrapped_call: NativeDelayedCallWrapper, ignore_errs: bool = False
) -> None:
try:
wrapped_call.cancel()
except Exception:
if not ignore_errs:
raise
def cancel_all_delayed_calls(self, ignore_errs: bool = True) -> None:
for call_id, call in list(self._call_id_to_delayed_call.items()):
try:
call.cancel()
except Exception:
if not ignore_errs:
raise
self._call_id_to_delayed_call.clear()
def call_when_running(
self,
callback: Callable[P, object],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""Call a function on the next event loop iteration."""
instance_id = random_string_insecure_fast(5)
def wrapped_callback(*args: Any, **kwargs: Any) -> None:
clock_debug_logger.debug(
"call_when_running(%s): Executing callback", instance_id
)
with context.PreserveLoggingContext(
context.LoggingContext(
name="call_when_running", server_name=self._server_name
)
):
context.run_in_background(callback, *args, **kwargs)
loop = self._get_loop()
if kwargs:
loop.call_soon(lambda: wrapped_callback(*args, **kwargs))
else:
loop.call_soon(wrapped_callback, *args)
def add_system_event_trigger(
self,
phase: str,
event_type: str,
callback: Callable[P, object],
*args: P.args,
**kwargs: P.kwargs,
) -> int:
"""Store a callback to be invoked during shutdown.
Returns an ID that could be used to remove the trigger (not currently
needed by callers, but matches the Twisted API).
"""
trigger_id = len(self._shutdown_callbacks)
self._shutdown_callbacks.append(
(phase, event_type, callback, args, kwargs)
)
clock_debug_logger.debug(
"add_system_event_trigger: registered %s %s callback",
phase,
event_type,
stack_info=True,
)
return trigger_id
+250
View File
@@ -492,5 +492,255 @@ class NativeReadWriteLockTest(unittest.IsolatedAsyncioTestCase):
self.assertEqual(order, ["reader_start", "reader_end", "writer"])
class NativeClockTest(unittest.IsolatedAsyncioTestCase):
"""Tests for the asyncio-native NativeClock."""
async def test_time(self) -> None:
from synapse.util.clock import NativeClock
clock = NativeClock(server_name="test.server")
t = clock.time()
self.assertIsInstance(t, float)
self.assertGreater(t, 0)
async def test_time_msec(self) -> None:
from synapse.util.clock import NativeClock
clock = NativeClock(server_name="test.server")
t = clock.time_msec()
self.assertIsInstance(t, int)
self.assertGreater(t, 0)
async def test_sleep(self) -> None:
from synapse.util.clock import NativeClock
from synapse.util.duration import Duration
clock = NativeClock(server_name="test.server")
before = clock.time()
await clock.sleep(Duration(milliseconds=50))
after = clock.time()
self.assertGreaterEqual(after - before, 0.04)
async def test_call_later(self) -> None:
from synapse.util.clock import NativeClock
from synapse.util.duration import Duration
clock = NativeClock(server_name="test.server")
called = asyncio.Event()
def callback() -> None:
called.set()
wrapper = clock.call_later(Duration(milliseconds=20), callback)
self.assertTrue(wrapper.active())
await asyncio.wait_for(called.wait(), timeout=1.0)
self.assertTrue(called.is_set())
async def test_call_later_cancel(self) -> None:
from synapse.util.clock import NativeClock
from synapse.util.duration import Duration
clock = NativeClock(server_name="test.server")
called = False
def callback() -> None:
nonlocal called
called = True
wrapper = clock.call_later(Duration(milliseconds=50), callback)
self.assertTrue(wrapper.active())
wrapper.cancel()
self.assertFalse(wrapper.active())
await asyncio.sleep(0.1)
self.assertFalse(called)
async def test_call_later_getTime(self) -> None:
from synapse.util.clock import NativeClock
from synapse.util.duration import Duration
clock = NativeClock(server_name="test.server")
wrapper = clock.call_later(Duration(seconds=10), lambda: None)
# getTime should return a time in the future
self.assertGreater(wrapper.getTime(), 0)
wrapper.cancel()
async def test_looping_call(self) -> None:
from synapse.util.clock import NativeClock
from synapse.util.duration import Duration
clock = NativeClock(server_name="test.server")
call_count = 0
def callback() -> None:
nonlocal call_count
call_count += 1
call = clock.looping_call(callback, Duration(milliseconds=30))
# looping_call waits `duration` before first call
await asyncio.sleep(0.01)
self.assertEqual(call_count, 0)
await asyncio.sleep(0.05)
self.assertGreaterEqual(call_count, 1)
call.stop()
old_count = call_count
await asyncio.sleep(0.05)
self.assertEqual(call_count, old_count)
async def test_looping_call_now(self) -> None:
from synapse.util.clock import NativeClock
from synapse.util.duration import Duration
clock = NativeClock(server_name="test.server")
call_count = 0
def callback() -> None:
nonlocal call_count
call_count += 1
call = clock.looping_call_now(callback, Duration(milliseconds=30))
# looping_call_now should call immediately
await asyncio.sleep(0.01)
self.assertGreaterEqual(call_count, 1)
call.stop()
async def test_looping_call_waits_for_completion(self) -> None:
"""Test that the next iteration waits for the previous to complete."""
from synapse.util.clock import NativeClock
from synapse.util.duration import Duration
clock = NativeClock(server_name="test.server")
concurrent = 0
max_concurrent = 0
async def slow_callback() -> None:
nonlocal concurrent, max_concurrent
concurrent += 1
max_concurrent = max(max_concurrent, concurrent)
await asyncio.sleep(0.04)
concurrent -= 1
call = clock.looping_call_now(slow_callback, Duration(milliseconds=10))
await asyncio.sleep(0.15)
call.stop()
# Should never have more than 1 concurrent execution
self.assertEqual(max_concurrent, 1)
async def test_looping_call_survives_error(self) -> None:
"""Test that an error in the callback doesn't stop the loop."""
from synapse.util.clock import NativeClock
from synapse.util.duration import Duration
clock = NativeClock(server_name="test.server")
call_count = 0
def flaky_callback() -> None:
nonlocal call_count
call_count += 1
if call_count == 1:
raise ValueError("first call fails")
call = clock.looping_call_now(flaky_callback, Duration(milliseconds=20))
await asyncio.sleep(0.08)
call.stop()
# Should have been called more than once despite the error
self.assertGreaterEqual(call_count, 2)
async def test_shutdown(self) -> None:
from synapse.util.clock import NativeClock
from synapse.util.duration import Duration
clock = NativeClock(server_name="test.server")
looping_called = False
delayed_called = False
def looping_cb() -> None:
nonlocal looping_called
looping_called = True
def delayed_cb() -> None:
nonlocal delayed_called
delayed_called = True
clock.looping_call(looping_cb, Duration(seconds=1))
clock.call_later(Duration(seconds=1), delayed_cb)
# Shutdown immediately before any callbacks can fire
clock.shutdown()
await asyncio.sleep(0.05)
self.assertFalse(looping_called)
self.assertFalse(delayed_called)
async def test_cancel_all_delayed_calls(self) -> None:
from synapse.util.clock import NativeClock
from synapse.util.duration import Duration
clock = NativeClock(server_name="test.server")
called = False
def callback() -> None:
nonlocal called
called = True
clock.call_later(Duration(milliseconds=50), callback)
clock.call_later(Duration(milliseconds=50), callback)
clock.cancel_all_delayed_calls()
await asyncio.sleep(0.1)
self.assertFalse(called)
async def test_call_when_running(self) -> None:
from synapse.util.clock import NativeClock
clock = NativeClock(server_name="test.server")
called = asyncio.Event()
def callback() -> None:
called.set()
clock.call_when_running(callback)
await asyncio.wait_for(called.wait(), timeout=1.0)
self.assertTrue(called.is_set())
async def test_add_system_event_trigger(self) -> None:
from synapse.util.clock import NativeClock
clock = NativeClock(server_name="test.server")
trigger_id = clock.add_system_event_trigger(
"before", "shutdown", lambda: None
)
self.assertIsInstance(trigger_id, int)
self.assertEqual(len(clock._shutdown_callbacks), 1)
async def test_shutdown_prevents_new_calls(self) -> None:
from synapse.util.clock import NativeClock
from synapse.util.duration import Duration
clock = NativeClock(server_name="test.server")
clock.shutdown()
with self.assertRaises(Exception):
clock.looping_call(lambda: None, Duration(seconds=1))
with self.assertRaises(Exception):
clock.call_later(Duration(seconds=1), lambda: None)
if __name__ == "__main__":
unittest.main()