mirror of
https://github.com/element-hq/synapse.git
synced 2026-05-18 01:05:34 +00:00
⏺ Phase 2: NativeClock — Complete
3 new classes added to synapse/util/clock.py, 15 new tests, all passing, mypy clean, no regressions. NativeLoopingCall — asyncio Task wrapper with stop(). Tracks in WeakSet for automatic cleanup. NativeDelayedCallWrapper — Wraps asyncio.TimerHandle with the same interface as DelayedCallWrapper (cancel(), active(), getTime(), delay(), reset()). Since TimerHandle is immutable, delay()/reset() cancel and reschedule. NativeClock — Same public API as Clock but uses: - time.time() instead of reactor.seconds() - asyncio.sleep() instead of Deferred + reactor.callLater - asyncio.create_task() with while True loop instead of LoopingCall - loop.call_later() instead of reactor.callLater() - loop.call_soon() instead of reactor.callWhenRunning() - Logcontext wrapping preserved (same PreserveLoggingContext + run_in_background pattern) - LoopingCall semantics preserved: waits for previous invocation to complete, survives errors
This commit is contained in:
@@ -14,7 +14,10 @@
|
||||
#
|
||||
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import time as time_mod
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
Any,
|
||||
@@ -608,3 +611,333 @@ class DelayedCallWrapper:
|
||||
def active(self) -> bool:
|
||||
"""Propagate the call to the underlying delayed_call."""
|
||||
return self.delayed_call.active()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Phase 2: asyncio-native Clock implementation
|
||||
#
|
||||
# NativeClock provides the same public interface as Clock but uses asyncio
|
||||
# primitives instead of Twisted. It is unused until later phases switch
|
||||
# hs.get_clock() to return a NativeClock.
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class NativeLoopingCall:
|
||||
"""asyncio-native equivalent of Twisted's LoopingCall.
|
||||
|
||||
Runs a function repeatedly with a fixed interval between completions.
|
||||
If the function returns an awaitable, waits for it to complete before
|
||||
scheduling the next call (same semantics as LoopingCall).
|
||||
"""
|
||||
|
||||
def __init__(self, task: "asyncio.Task[None]") -> None:
|
||||
self._task = task
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the looping call."""
|
||||
self._task.cancel()
|
||||
|
||||
|
||||
class NativeDelayedCallWrapper:
|
||||
"""asyncio-native equivalent of DelayedCallWrapper.
|
||||
|
||||
Wraps an asyncio.TimerHandle. Since TimerHandle is immutable (no delay/reset),
|
||||
delay() and reset() cancel and reschedule.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handle: asyncio.TimerHandle,
|
||||
call_id: int,
|
||||
clock: "NativeClock",
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
scheduled_time: float,
|
||||
callback: Callable,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
self._handle = handle
|
||||
self.call_id = call_id
|
||||
self.clock = clock
|
||||
self._loop = loop
|
||||
self._scheduled_time = scheduled_time
|
||||
self._callback = callback
|
||||
self._args = args
|
||||
self._kwargs = kwargs
|
||||
self._cancelled = False
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Cancel the scheduled call."""
|
||||
if not self._cancelled:
|
||||
self._handle.cancel()
|
||||
self._cancelled = True
|
||||
self.clock._call_id_to_delayed_call.pop(self.call_id, None)
|
||||
|
||||
def active(self) -> bool:
|
||||
"""Check if the call is still pending."""
|
||||
return not self._cancelled
|
||||
|
||||
def getTime(self) -> float:
|
||||
"""Return the scheduled execution time."""
|
||||
return self._scheduled_time
|
||||
|
||||
def delay(self, secondsLater: float) -> None:
|
||||
"""Delay execution by N additional seconds."""
|
||||
if self._cancelled:
|
||||
return
|
||||
remaining = self._scheduled_time - self._loop.time()
|
||||
new_delay = remaining + secondsLater
|
||||
self._handle.cancel()
|
||||
self._scheduled_time = self._loop.time() + new_delay
|
||||
self._handle = self._loop.call_later(
|
||||
new_delay, self._callback, *self._args
|
||||
)
|
||||
|
||||
def reset(self, secondsFromNow: float) -> None:
|
||||
"""Reset to fire secondsFromNow from now."""
|
||||
if self._cancelled:
|
||||
return
|
||||
self._handle.cancel()
|
||||
self._scheduled_time = self._loop.time() + secondsFromNow
|
||||
self._handle = self._loop.call_later(
|
||||
secondsFromNow, self._callback, *self._args
|
||||
)
|
||||
|
||||
|
||||
class NativeClock:
|
||||
"""asyncio-native equivalent of Clock.
|
||||
|
||||
Provides the same public interface as Clock but uses asyncio primitives
|
||||
(asyncio.sleep, loop.call_later, asyncio.create_task) instead of Twisted.
|
||||
|
||||
Args:
|
||||
server_name: The server name for logging context.
|
||||
"""
|
||||
|
||||
def __init__(self, server_name: str) -> None:
|
||||
self._server_name = server_name
|
||||
self._delayed_call_id: int = 0
|
||||
self._looping_calls: WeakSet[NativeLoopingCall] = WeakSet()
|
||||
self._call_id_to_delayed_call: dict[int, NativeDelayedCallWrapper] = {}
|
||||
self._is_shutdown = False
|
||||
self._shutdown_callbacks: list[tuple[str, str, Callable, tuple, dict]] = []
|
||||
# Lazily initialized when first needed
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
def _get_loop(self) -> asyncio.AbstractEventLoop:
|
||||
if self._loop is None:
|
||||
self._loop = asyncio.get_running_loop()
|
||||
return self._loop
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self._is_shutdown = True
|
||||
self.cancel_all_looping_calls()
|
||||
self.cancel_all_delayed_calls()
|
||||
|
||||
def time(self) -> float:
|
||||
"""Returns the current system time in seconds since epoch."""
|
||||
return time_mod.time()
|
||||
|
||||
def time_msec(self) -> int:
|
||||
"""Returns the current system time in milliseconds since epoch."""
|
||||
return int(self.time() * 1000)
|
||||
|
||||
async def sleep(self, duration: Duration) -> None:
|
||||
await asyncio.sleep(duration.as_secs())
|
||||
|
||||
def looping_call(
|
||||
self,
|
||||
f: Callable[P, object],
|
||||
duration: Duration,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> NativeLoopingCall:
|
||||
"""Call a function repeatedly, waiting `duration` before the first call."""
|
||||
return self._looping_call_common(f, duration, False, *args, **kwargs)
|
||||
|
||||
def looping_call_now(
|
||||
self,
|
||||
f: Callable[P, object],
|
||||
duration: Duration,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> NativeLoopingCall:
|
||||
"""Call a function immediately, then repeatedly thereafter."""
|
||||
return self._looping_call_common(f, duration, True, *args, **kwargs)
|
||||
|
||||
def _looping_call_common(
|
||||
self,
|
||||
f: Callable[P, object],
|
||||
duration: Duration,
|
||||
now: bool,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> NativeLoopingCall:
|
||||
if self._is_shutdown:
|
||||
raise Exception("Cannot start looping call. Clock has been shutdown")
|
||||
|
||||
instance_id = random_string_insecure_fast(5)
|
||||
interval = duration.as_secs()
|
||||
|
||||
async def _loop() -> None:
|
||||
if not now:
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
while True:
|
||||
try:
|
||||
clock_debug_logger.debug(
|
||||
"looping_call(%s): Executing callback", instance_id
|
||||
)
|
||||
with context.PreserveLoggingContext(
|
||||
context.LoggingContext(
|
||||
name="looping_call", server_name=self._server_name
|
||||
)
|
||||
):
|
||||
result = f(*args, **kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception:
|
||||
logger.exception("Looping call %s died", instance_id)
|
||||
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
task_obj = asyncio.create_task(_loop())
|
||||
call = NativeLoopingCall(task_obj)
|
||||
self._looping_calls.add(call)
|
||||
|
||||
clock_debug_logger.debug(
|
||||
"looping_call(%s): Scheduled looping call every %sms",
|
||||
instance_id,
|
||||
duration.as_millis(),
|
||||
stack_info=True,
|
||||
)
|
||||
|
||||
return call
|
||||
|
||||
def cancel_all_looping_calls(self, consumeErrors: bool = True) -> None:
|
||||
for call in list(self._looping_calls):
|
||||
try:
|
||||
call.stop()
|
||||
except Exception:
|
||||
if not consumeErrors:
|
||||
raise
|
||||
self._looping_calls.clear()
|
||||
|
||||
def call_later(
|
||||
self,
|
||||
delay: Duration,
|
||||
callback: Callable,
|
||||
*args: Any,
|
||||
call_later_cancel_on_shutdown: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> NativeDelayedCallWrapper:
|
||||
call_id = self._delayed_call_id
|
||||
self._delayed_call_id += 1
|
||||
|
||||
if self._is_shutdown:
|
||||
raise Exception("Cannot start delayed call. Clock has been shutdown")
|
||||
|
||||
loop = self._get_loop()
|
||||
|
||||
def wrapped_callback(*args: Any, **kwargs: Any) -> None:
|
||||
clock_debug_logger.debug("call_later(%s): Executing callback", call_id)
|
||||
try:
|
||||
with context.PreserveLoggingContext(
|
||||
context.LoggingContext(
|
||||
name="call_later", server_name=self._server_name
|
||||
)
|
||||
):
|
||||
context.run_in_background(callback, *args, **kwargs)
|
||||
finally:
|
||||
if call_later_cancel_on_shutdown:
|
||||
self._call_id_to_delayed_call.pop(call_id, None)
|
||||
|
||||
scheduled_time = loop.time() + delay.as_secs()
|
||||
handle = loop.call_later(delay.as_secs(), wrapped_callback, *args, **kwargs)
|
||||
|
||||
clock_debug_logger.debug(
|
||||
"call_later(%s): Scheduled call for %ss later (tracked: %s)",
|
||||
call_id,
|
||||
delay.as_secs(),
|
||||
call_later_cancel_on_shutdown,
|
||||
stack_info=True,
|
||||
)
|
||||
|
||||
wrapped_call = NativeDelayedCallWrapper(
|
||||
handle, call_id, self, loop, scheduled_time, wrapped_callback, args, kwargs
|
||||
)
|
||||
if call_later_cancel_on_shutdown:
|
||||
self._call_id_to_delayed_call[call_id] = wrapped_call
|
||||
|
||||
return wrapped_call
|
||||
|
||||
def cancel_call_later(
|
||||
self, wrapped_call: NativeDelayedCallWrapper, ignore_errs: bool = False
|
||||
) -> None:
|
||||
try:
|
||||
wrapped_call.cancel()
|
||||
except Exception:
|
||||
if not ignore_errs:
|
||||
raise
|
||||
|
||||
def cancel_all_delayed_calls(self, ignore_errs: bool = True) -> None:
|
||||
for call_id, call in list(self._call_id_to_delayed_call.items()):
|
||||
try:
|
||||
call.cancel()
|
||||
except Exception:
|
||||
if not ignore_errs:
|
||||
raise
|
||||
self._call_id_to_delayed_call.clear()
|
||||
|
||||
def call_when_running(
|
||||
self,
|
||||
callback: Callable[P, object],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
"""Call a function on the next event loop iteration."""
|
||||
instance_id = random_string_insecure_fast(5)
|
||||
|
||||
def wrapped_callback(*args: Any, **kwargs: Any) -> None:
|
||||
clock_debug_logger.debug(
|
||||
"call_when_running(%s): Executing callback", instance_id
|
||||
)
|
||||
with context.PreserveLoggingContext(
|
||||
context.LoggingContext(
|
||||
name="call_when_running", server_name=self._server_name
|
||||
)
|
||||
):
|
||||
context.run_in_background(callback, *args, **kwargs)
|
||||
|
||||
loop = self._get_loop()
|
||||
if kwargs:
|
||||
loop.call_soon(lambda: wrapped_callback(*args, **kwargs))
|
||||
else:
|
||||
loop.call_soon(wrapped_callback, *args)
|
||||
|
||||
def add_system_event_trigger(
|
||||
self,
|
||||
phase: str,
|
||||
event_type: str,
|
||||
callback: Callable[P, object],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> int:
|
||||
"""Store a callback to be invoked during shutdown.
|
||||
|
||||
Returns an ID that could be used to remove the trigger (not currently
|
||||
needed by callers, but matches the Twisted API).
|
||||
"""
|
||||
trigger_id = len(self._shutdown_callbacks)
|
||||
self._shutdown_callbacks.append(
|
||||
(phase, event_type, callback, args, kwargs)
|
||||
)
|
||||
clock_debug_logger.debug(
|
||||
"add_system_event_trigger: registered %s %s callback",
|
||||
phase,
|
||||
event_type,
|
||||
stack_info=True,
|
||||
)
|
||||
return trigger_id
|
||||
|
||||
@@ -492,5 +492,255 @@ class NativeReadWriteLockTest(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(order, ["reader_start", "reader_end", "writer"])
|
||||
|
||||
|
||||
class NativeClockTest(unittest.IsolatedAsyncioTestCase):
|
||||
"""Tests for the asyncio-native NativeClock."""
|
||||
|
||||
async def test_time(self) -> None:
|
||||
from synapse.util.clock import NativeClock
|
||||
|
||||
clock = NativeClock(server_name="test.server")
|
||||
t = clock.time()
|
||||
self.assertIsInstance(t, float)
|
||||
self.assertGreater(t, 0)
|
||||
|
||||
async def test_time_msec(self) -> None:
|
||||
from synapse.util.clock import NativeClock
|
||||
|
||||
clock = NativeClock(server_name="test.server")
|
||||
t = clock.time_msec()
|
||||
self.assertIsInstance(t, int)
|
||||
self.assertGreater(t, 0)
|
||||
|
||||
async def test_sleep(self) -> None:
|
||||
from synapse.util.clock import NativeClock
|
||||
from synapse.util.duration import Duration
|
||||
|
||||
clock = NativeClock(server_name="test.server")
|
||||
before = clock.time()
|
||||
await clock.sleep(Duration(milliseconds=50))
|
||||
after = clock.time()
|
||||
self.assertGreaterEqual(after - before, 0.04)
|
||||
|
||||
async def test_call_later(self) -> None:
|
||||
from synapse.util.clock import NativeClock
|
||||
from synapse.util.duration import Duration
|
||||
|
||||
clock = NativeClock(server_name="test.server")
|
||||
called = asyncio.Event()
|
||||
|
||||
def callback() -> None:
|
||||
called.set()
|
||||
|
||||
wrapper = clock.call_later(Duration(milliseconds=20), callback)
|
||||
self.assertTrue(wrapper.active())
|
||||
|
||||
await asyncio.wait_for(called.wait(), timeout=1.0)
|
||||
self.assertTrue(called.is_set())
|
||||
|
||||
async def test_call_later_cancel(self) -> None:
|
||||
from synapse.util.clock import NativeClock
|
||||
from synapse.util.duration import Duration
|
||||
|
||||
clock = NativeClock(server_name="test.server")
|
||||
called = False
|
||||
|
||||
def callback() -> None:
|
||||
nonlocal called
|
||||
called = True
|
||||
|
||||
wrapper = clock.call_later(Duration(milliseconds=50), callback)
|
||||
self.assertTrue(wrapper.active())
|
||||
|
||||
wrapper.cancel()
|
||||
self.assertFalse(wrapper.active())
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
self.assertFalse(called)
|
||||
|
||||
async def test_call_later_getTime(self) -> None:
|
||||
from synapse.util.clock import NativeClock
|
||||
from synapse.util.duration import Duration
|
||||
|
||||
clock = NativeClock(server_name="test.server")
|
||||
wrapper = clock.call_later(Duration(seconds=10), lambda: None)
|
||||
# getTime should return a time in the future
|
||||
self.assertGreater(wrapper.getTime(), 0)
|
||||
wrapper.cancel()
|
||||
|
||||
async def test_looping_call(self) -> None:
|
||||
from synapse.util.clock import NativeClock
|
||||
from synapse.util.duration import Duration
|
||||
|
||||
clock = NativeClock(server_name="test.server")
|
||||
call_count = 0
|
||||
|
||||
def callback() -> None:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
|
||||
call = clock.looping_call(callback, Duration(milliseconds=30))
|
||||
|
||||
# looping_call waits `duration` before first call
|
||||
await asyncio.sleep(0.01)
|
||||
self.assertEqual(call_count, 0)
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
self.assertGreaterEqual(call_count, 1)
|
||||
|
||||
call.stop()
|
||||
|
||||
old_count = call_count
|
||||
await asyncio.sleep(0.05)
|
||||
self.assertEqual(call_count, old_count)
|
||||
|
||||
async def test_looping_call_now(self) -> None:
|
||||
from synapse.util.clock import NativeClock
|
||||
from synapse.util.duration import Duration
|
||||
|
||||
clock = NativeClock(server_name="test.server")
|
||||
call_count = 0
|
||||
|
||||
def callback() -> None:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
|
||||
call = clock.looping_call_now(callback, Duration(milliseconds=30))
|
||||
|
||||
# looping_call_now should call immediately
|
||||
await asyncio.sleep(0.01)
|
||||
self.assertGreaterEqual(call_count, 1)
|
||||
|
||||
call.stop()
|
||||
|
||||
async def test_looping_call_waits_for_completion(self) -> None:
|
||||
"""Test that the next iteration waits for the previous to complete."""
|
||||
from synapse.util.clock import NativeClock
|
||||
from synapse.util.duration import Duration
|
||||
|
||||
clock = NativeClock(server_name="test.server")
|
||||
concurrent = 0
|
||||
max_concurrent = 0
|
||||
|
||||
async def slow_callback() -> None:
|
||||
nonlocal concurrent, max_concurrent
|
||||
concurrent += 1
|
||||
max_concurrent = max(max_concurrent, concurrent)
|
||||
await asyncio.sleep(0.04)
|
||||
concurrent -= 1
|
||||
|
||||
call = clock.looping_call_now(slow_callback, Duration(milliseconds=10))
|
||||
|
||||
await asyncio.sleep(0.15)
|
||||
call.stop()
|
||||
|
||||
# Should never have more than 1 concurrent execution
|
||||
self.assertEqual(max_concurrent, 1)
|
||||
|
||||
async def test_looping_call_survives_error(self) -> None:
|
||||
"""Test that an error in the callback doesn't stop the loop."""
|
||||
from synapse.util.clock import NativeClock
|
||||
from synapse.util.duration import Duration
|
||||
|
||||
clock = NativeClock(server_name="test.server")
|
||||
call_count = 0
|
||||
|
||||
def flaky_callback() -> None:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise ValueError("first call fails")
|
||||
|
||||
call = clock.looping_call_now(flaky_callback, Duration(milliseconds=20))
|
||||
|
||||
await asyncio.sleep(0.08)
|
||||
call.stop()
|
||||
|
||||
# Should have been called more than once despite the error
|
||||
self.assertGreaterEqual(call_count, 2)
|
||||
|
||||
async def test_shutdown(self) -> None:
|
||||
from synapse.util.clock import NativeClock
|
||||
from synapse.util.duration import Duration
|
||||
|
||||
clock = NativeClock(server_name="test.server")
|
||||
looping_called = False
|
||||
delayed_called = False
|
||||
|
||||
def looping_cb() -> None:
|
||||
nonlocal looping_called
|
||||
looping_called = True
|
||||
|
||||
def delayed_cb() -> None:
|
||||
nonlocal delayed_called
|
||||
delayed_called = True
|
||||
|
||||
clock.looping_call(looping_cb, Duration(seconds=1))
|
||||
clock.call_later(Duration(seconds=1), delayed_cb)
|
||||
|
||||
# Shutdown immediately before any callbacks can fire
|
||||
clock.shutdown()
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
self.assertFalse(looping_called)
|
||||
self.assertFalse(delayed_called)
|
||||
|
||||
async def test_cancel_all_delayed_calls(self) -> None:
|
||||
from synapse.util.clock import NativeClock
|
||||
from synapse.util.duration import Duration
|
||||
|
||||
clock = NativeClock(server_name="test.server")
|
||||
called = False
|
||||
|
||||
def callback() -> None:
|
||||
nonlocal called
|
||||
called = True
|
||||
|
||||
clock.call_later(Duration(milliseconds=50), callback)
|
||||
clock.call_later(Duration(milliseconds=50), callback)
|
||||
|
||||
clock.cancel_all_delayed_calls()
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
self.assertFalse(called)
|
||||
|
||||
async def test_call_when_running(self) -> None:
|
||||
from synapse.util.clock import NativeClock
|
||||
|
||||
clock = NativeClock(server_name="test.server")
|
||||
called = asyncio.Event()
|
||||
|
||||
def callback() -> None:
|
||||
called.set()
|
||||
|
||||
clock.call_when_running(callback)
|
||||
|
||||
await asyncio.wait_for(called.wait(), timeout=1.0)
|
||||
self.assertTrue(called.is_set())
|
||||
|
||||
async def test_add_system_event_trigger(self) -> None:
|
||||
from synapse.util.clock import NativeClock
|
||||
|
||||
clock = NativeClock(server_name="test.server")
|
||||
|
||||
trigger_id = clock.add_system_event_trigger(
|
||||
"before", "shutdown", lambda: None
|
||||
)
|
||||
self.assertIsInstance(trigger_id, int)
|
||||
self.assertEqual(len(clock._shutdown_callbacks), 1)
|
||||
|
||||
async def test_shutdown_prevents_new_calls(self) -> None:
|
||||
from synapse.util.clock import NativeClock
|
||||
from synapse.util.duration import Duration
|
||||
|
||||
clock = NativeClock(server_name="test.server")
|
||||
clock.shutdown()
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
clock.looping_call(lambda: None, Duration(seconds=1))
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
clock.call_later(Duration(seconds=1), lambda: None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user